文章目录
整体框架网络模型数据集训练整体框架
SR,即super resolution,即超分辨率。CNN相对来说比较著名,就是卷积神经网络了。从名字可以看出,SRCNN是首个应用于超分辨领域的卷积神经网络,事实上也的确如此。
所谓超分辨率,就是把低分辨率(LR, Low Resolution)图片放大为高分辨率(HR, High Resolution)的过程。由于是开山之作,SRCNN相对比较简单,总共分三步
输入LR图像XXX,经双三次(bicubic)插值,被放大成目标尺寸,得到YYY通过三层卷积网络拟合非线性映射输出HR图像结果F(Y)F(Y)F(Y)
训练的目标损失是最小化SR图像F(Y;θ)F(Y;\theta)F(Y;θ)和原高分辨率图像XXX像素差的均方误差
L(θ)=1n∑i=1n∥F(Yi;θ)−Xi∥2L(\theta)=\frac{1}{n}\sum^n_{i=1}\Vert F(Y_i;\theta)-X_i\Vert^2 L(θ)=n1i=1∑n∥F(Yi;θ)−Xi∥2
其中,nnn为训练样本数量,参数更新公式为
Δi+1=0.9Δi+η∂L∂Wil,Wi+1l=Wil+Δi+1\Delta_{i+1}=0.9\Delta_i+\eta\frac{\partial L}{\partial W^l_i},\quad W^l_{i+1}=W^l_i+\Delta_{i+1} Δi+1=0.9Δi+η∂Wil∂L,Wi+1l=Wil+Δi+1
网络模型
其网络结构如下
诚如前文所述,网络分为三个卷积层
维度是1×9×9×641\times9\times9\times641×9×9×64,表示输入图像通道数为1,进行卷积运算的核尺寸为9×99\times99×9,输出深度为64。维度是64×5×5×3264\times5\times5\times3264×5×5×32,64即上一层输出,32为下一层输出。维度是32×5×5×132\times5\times5\times132×5×5×1。其输出为单通道图像,与输入相同。
所以这个模型实现起来毫无难度
# models.pyclass SRCNN(nn.Module):def __init__(self, nChannel=1):super(SRCNN,self).__init__()self.conv1 = nn.Conv2d(nChannel, 64,kernel_size=9, padding=9//2)self.conv2 = nn.Conv2d(64, 32,kernel_size=5, padding=5//2)self.conv3 = nn.Conv2d(32, nChannel, kernel_size=5, padding=5//2)self.relu = nn.ReLU(inplace=True)def forward(self,x):x = self.relu(self.conv1(x))x = self.relu(self.conv2(x))x = self.conv3(x)return x
数据集
训练数据集可手动生成,设放大倍数为scale
,考虑到原始数据未必会被scale
整除,所以要重新规划一下图像尺寸,所以训练数据集的生成分为三步:
将原始图像通过双三次插值重设尺寸,使之可被scale
整除,作为高分辨图像数据HR将HR通过双三次插值压缩scale倍,为低分辨图像的原始数据将低分辨图像通过双三次插值放大scale倍,与HR图像维度相等,作为低分辨图像数据LR
最后,可通过h5py
将训练数据分块并打包,其生成代码为
import h5pyimport PIL.Image as pImgdef rgb2gray(img):return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.# imgPath为图像路径;h5Path为存储路径;scale为放大倍数# pSize为patch尺寸; pStride为步长def setTrianData(imgPath, h5Path, scale=3, pSize=33, pStride=14):h5_file = h5py.File(h5Path, 'w')lrPatches, hrPatches = [], [] #用于存储低分辨率和高分辨率的patchfor p in sorted(glob.glob(f'{imgPath}/*')):hr = pImg.open(p).convert('RGB')lrWidth, lrHeight = hr.width // scale, hr.height // scale# width, height为可被scale整除的训练数据尺寸width, height = lrWidth*scale, lrHeight*scalehr = hr.resize((width, height), resample=pImg.BICUBIC)lr = hr.resize((lrWidth, lrHeight), resample=pImg.BICUBIC)lr = lr.resize((width, height), resample=pImg.BICUBIC)hr = np.array(hr).astype(np.float32)lr = np.array(lr).astype(np.float32)hr = rgb2gray(hr)lr = rgb2gray(lr)# 将数据分割for i in range(0, height - pSize + 1, pStride):for j in range(0, width - pSize + 1, pStride):lrPatches.append(lr[i:i + pSize, j:j + pSize])hrPatches.append(hr[i:i + pSize, j:j + pSize])h5_file.create_dataset('lr', data=np.array(lrPatches))h5_file.create_dataset('hr', data=np.array(hrPatches))h5_file.close()
以比较常见的T91数据集为例,通过上面的方法,可以得到一个181M的h5文件。
对于预测数据,也做同样处理。
在做好训练数据之后,需要为这些数据创建一个读取类,以便torch
中的DataLoader
调用,而DataLoader
中的内容则是Dataset
,所以新建的读取类需要继承Dataset
,并实现其__getitem__
和__len__
这两个成员方法。
这两个方法只是看上去吓人,但对Python稍有一点深入了解,就会知道__getitem__
是字典索引的方法,而__len__
则设定了len
函数的返回值。
import h5pyimport numpy as npfrom torch.utils.data import Datasetclass DataSet(Dataset):def __init__(self, h5_file):super(Dataset, self).__init__()self.h5_file = h5_filedef __getitem__(self, idx):with h5py.File(self.h5_file, 'r') as f:return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)def __len__(self):with h5py.File(self.h5_file, 'r') as f:return len(f['lr'])
训练
首先,训练需要一点准备工作,比如数据集准备好,相关的文件夹需要建好,建好模型之后,需要采用什么样的优化方式。训练设备是用cpu
还是cuda
,然后将数据集和模型装载到设备上。
数据准备
import osimport copyimport torchfrom torch import nnimport torch.optim as optimimport torch.backends.cudnn as cudnnfrom torch.utils.data.dataloader import DataLoaderfrom models import SRCNNtrainFile = "91-image.h5"evalFile = "Set5.h5"cudnn.benchmark = True# 设置训练设备 是CPU还是cudadevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')# 装载训练数据trainData = Dataset(trainFile)trainLoader = DataLoader(dataset=trainData,bSize=bSize,shuffle=True,# 表示打乱样本num_workers=nWorker, # 线程数pin_memory=True, # 方便载入CUDAdrop_last=True)# 装载预测数据evalDatas = Dataset(evalFile)evalLoader = DataLoader(dataset=evalDatas, bSize=1)
模型准备
# 模型和设备lr = 1e-4 #学习率torch.manual_seed(seed)#设置随机数种子model = SRCNN().to(device) #将模型载入设备criterion = nn.MSELoss() #设置损失函数optimizer = optim.Adam([{'params': model.conv1.parameters()},{'params': model.conv2.parameters()},{'params': model.conv3.parameters(), 'lr': lr * 0.1}], lr=lr)
训练
outPath = "outputs"scale = 3bSize = 16nEpoch = 400nWorker = 8#线程数seed = 42 #随机数种子def initPSNR():return {'avg':0, 'sum':0, 'count':0}def updatePSNR(psnr, val, n=1):s = psnr['sum'] + val*nc = psnr['count'] + nreturn {'avg':s/c, 'sum':s, 'count':c}bestWeights = copy.deepcopy(model.state_dict()) #最佳模型bestEpoch = 0 #最佳训练结果bestPSNR = 0.0 #最佳psnr# 训练主循环for epoch in range(nEpoch):model.train()epochLosses = initPSNR()for data in trainLoader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)preds = model(inputs)loss = criterion(preds, labels)epochLosses = updatePSNR(epochLosses,loss.item(), len(inputs))optimizer.zero_grad() #清空梯度loss.backward() #反向传播optimizer.step() #根据梯度更新网络参数print(f'{epochLosses['avg']:.6f}')torch.save(model.state_dict(), os.path.join(outPath, f'epoch_{epoch}.pth'))model.eval() #取消dropoutpsnr = AverageMeter()for data in evalLoader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)# 令reqires_grad自动设为False,关闭自动求导# clamp将inputs归一化为0到1区间with torch.no_grad():preds = model(inputs).clamp(0.0, 1.0)tmp_psnr = 10. * torch.log10(1. / torch.mean((preds - labels) ** 2))psnr = updatePSNR(psnr, tmp_psnr, len(inputs))print(f'eval psnr: {psnr.avg:.2f}')if psnr['avg'] > bestPSNR:bestEpoch = epochbestPSNR = psnr['avg']bestWeights = copy.deepcopy(model.state_dict())print(f'best epoch: {bestEpoch}, psnr: {bestPSNR:.2f}')torch.save(bestWeights, os.path.join(outPath, 'best.pth'))
最终的结果为