1000字范文,内容丰富有趣,学习的好帮手!
1000字范文 > 超分辨网络SRCNN的Pytorch实现

超分辨网络SRCNN的Pytorch实现

时间:2021-09-17 08:18:56

相关推荐

超分辨网络SRCNN的Pytorch实现

文章目录

整体框架网络模型数据集训练

整体框架

SR,即super resolution,即超分辨率。CNN相对来说比较著名,就是卷积神经网络了。从名字可以看出,SRCNN是首个应用于超分辨领域的卷积神经网络,事实上也的确如此。

所谓超分辨率,就是把低分辨率(LR, Low Resolution)图片放大为高分辨率(HR, High Resolution)的过程。由于是开山之作,SRCNN相对比较简单,总共分三步

输入LR图像XXX,经双三次(bicubic)插值,被放大成目标尺寸,得到YYY通过三层卷积网络拟合非线性映射输出HR图像结果F(Y)F(Y)F(Y)

训练的目标损失是最小化SR图像F(Y;θ)F(Y;\theta)F(Y;θ)和原高分辨率图像XXX像素差的均方误差

L(θ)=1n∑i=1n∥F(Yi;θ)−Xi∥2L(\theta)=\frac{1}{n}\sum^n_{i=1}\Vert F(Y_i;\theta)-X_i\Vert^2 L(θ)=n1​i=1∑n​∥F(Yi​;θ)−Xi​∥2

其中,nnn为训练样本数量,参数更新公式为

Δi+1=0.9Δi+η∂L∂Wil,Wi+1l=Wil+Δi+1\Delta_{i+1}=0.9\Delta_i+\eta\frac{\partial L}{\partial W^l_i},\quad W^l_{i+1}=W^l_i+\Delta_{i+1} Δi+1​=0.9Δi​+η∂Wil​∂L​,Wi+1l​=Wil​+Δi+1​

网络模型

其网络结构如下

诚如前文所述,网络分为三个卷积层

维度是1×9×9×641\times9\times9\times641×9×9×64,表示输入图像通道数为1,进行卷积运算的核尺寸为9×99\times99×9,输出深度为64。维度是64×5×5×3264\times5\times5\times3264×5×5×32,64即上一层输出,32为下一层输出。维度是32×5×5×132\times5\times5\times132×5×5×1。其输出为单通道图像,与输入相同。

所以这个模型实现起来毫无难度

# models.pyclass SRCNN(nn.Module):def __init__(self, nChannel=1):super(SRCNN,self).__init__()self.conv1 = nn.Conv2d(nChannel, 64,kernel_size=9, padding=9//2)self.conv2 = nn.Conv2d(64, 32,kernel_size=5, padding=5//2)self.conv3 = nn.Conv2d(32, nChannel, kernel_size=5, padding=5//2)self.relu = nn.ReLU(inplace=True)def forward(self,x):x = self.relu(self.conv1(x))x = self.relu(self.conv2(x))x = self.conv3(x)return x

数据集

训练数据集可手动生成,设放大倍数为scale,考虑到原始数据未必会被scale整除,所以要重新规划一下图像尺寸,所以训练数据集的生成分为三步:

将原始图像通过双三次插值重设尺寸,使之可被scale整除,作为高分辨图像数据HR将HR通过双三次插值压缩scale倍,为低分辨图像的原始数据将低分辨图像通过双三次插值放大scale倍,与HR图像维度相等,作为低分辨图像数据LR

最后,可通过h5py将训练数据分块并打包,其生成代码为

import h5pyimport PIL.Image as pImgdef rgb2gray(img):return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.# imgPath为图像路径;h5Path为存储路径;scale为放大倍数# pSize为patch尺寸; pStride为步长def setTrianData(imgPath, h5Path, scale=3, pSize=33, pStride=14):h5_file = h5py.File(h5Path, 'w')lrPatches, hrPatches = [], [] #用于存储低分辨率和高分辨率的patchfor p in sorted(glob.glob(f'{imgPath}/*')):hr = pImg.open(p).convert('RGB')lrWidth, lrHeight = hr.width // scale, hr.height // scale# width, height为可被scale整除的训练数据尺寸width, height = lrWidth*scale, lrHeight*scalehr = hr.resize((width, height), resample=pImg.BICUBIC)lr = hr.resize((lrWidth, lrHeight), resample=pImg.BICUBIC)lr = lr.resize((width, height), resample=pImg.BICUBIC)hr = np.array(hr).astype(np.float32)lr = np.array(lr).astype(np.float32)hr = rgb2gray(hr)lr = rgb2gray(lr)# 将数据分割for i in range(0, height - pSize + 1, pStride):for j in range(0, width - pSize + 1, pStride):lrPatches.append(lr[i:i + pSize, j:j + pSize])hrPatches.append(hr[i:i + pSize, j:j + pSize])h5_file.create_dataset('lr', data=np.array(lrPatches))h5_file.create_dataset('hr', data=np.array(hrPatches))h5_file.close()

以比较常见的T91数据集为例,通过上面的方法,可以得到一个181M的h5文件。

对于预测数据,也做同样处理。

在做好训练数据之后,需要为这些数据创建一个读取类,以便torch中的DataLoader调用,而DataLoader中的内容则是Dataset,所以新建的读取类需要继承Dataset,并实现其__getitem____len__这两个成员方法。

这两个方法只是看上去吓人,但对Python稍有一点深入了解,就会知道__getitem__是字典索引的方法,而__len__则设定了len函数的返回值。

import h5pyimport numpy as npfrom torch.utils.data import Datasetclass DataSet(Dataset):def __init__(self, h5_file):super(Dataset, self).__init__()self.h5_file = h5_filedef __getitem__(self, idx):with h5py.File(self.h5_file, 'r') as f:return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)def __len__(self):with h5py.File(self.h5_file, 'r') as f:return len(f['lr'])

训练

首先,训练需要一点准备工作,比如数据集准备好,相关的文件夹需要建好,建好模型之后,需要采用什么样的优化方式。训练设备是用cpu还是cuda,然后将数据集和模型装载到设备上。

数据准备

import osimport copyimport torchfrom torch import nnimport torch.optim as optimimport torch.backends.cudnn as cudnnfrom torch.utils.data.dataloader import DataLoaderfrom models import SRCNNtrainFile = "91-image.h5"evalFile = "Set5.h5"cudnn.benchmark = True# 设置训练设备 是CPU还是cudadevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')# 装载训练数据trainData = Dataset(trainFile)trainLoader = DataLoader(dataset=trainData,bSize=bSize,shuffle=True,# 表示打乱样本num_workers=nWorker, # 线程数pin_memory=True, # 方便载入CUDAdrop_last=True)# 装载预测数据evalDatas = Dataset(evalFile)evalLoader = DataLoader(dataset=evalDatas, bSize=1)

模型准备

# 模型和设备lr = 1e-4 #学习率torch.manual_seed(seed)#设置随机数种子model = SRCNN().to(device) #将模型载入设备criterion = nn.MSELoss() #设置损失函数optimizer = optim.Adam([{'params': model.conv1.parameters()},{'params': model.conv2.parameters()},{'params': model.conv3.parameters(), 'lr': lr * 0.1}], lr=lr)

训练

outPath = "outputs"scale = 3bSize = 16nEpoch = 400nWorker = 8#线程数seed = 42 #随机数种子def initPSNR():return {'avg':0, 'sum':0, 'count':0}def updatePSNR(psnr, val, n=1):s = psnr['sum'] + val*nc = psnr['count'] + nreturn {'avg':s/c, 'sum':s, 'count':c}bestWeights = copy.deepcopy(model.state_dict()) #最佳模型bestEpoch = 0 #最佳训练结果bestPSNR = 0.0 #最佳psnr# 训练主循环for epoch in range(nEpoch):model.train()epochLosses = initPSNR()for data in trainLoader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)preds = model(inputs)loss = criterion(preds, labels)epochLosses = updatePSNR(epochLosses,loss.item(), len(inputs))optimizer.zero_grad() #清空梯度loss.backward() #反向传播optimizer.step() #根据梯度更新网络参数print(f'{epochLosses['avg']:.6f}')torch.save(model.state_dict(), os.path.join(outPath, f'epoch_{epoch}.pth'))model.eval() #取消dropoutpsnr = AverageMeter()for data in evalLoader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)# 令reqires_grad自动设为False,关闭自动求导# clamp将inputs归一化为0到1区间with torch.no_grad():preds = model(inputs).clamp(0.0, 1.0)tmp_psnr = 10. * torch.log10(1. / torch.mean((preds - labels) ** 2))psnr = updatePSNR(psnr, tmp_psnr, len(inputs))print(f'eval psnr: {psnr.avg:.2f}')if psnr['avg'] > bestPSNR:bestEpoch = epochbestPSNR = psnr['avg']bestWeights = copy.deepcopy(model.state_dict())print(f'best epoch: {bestEpoch}, psnr: {bestPSNR:.2f}')torch.save(bestWeights, os.path.join(outPath, 'best.pth'))

最终的结果为

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。