1000字范文,内容丰富有趣,学习的好帮手!
1000字范文 > 【超分辨率】【深度学习】SRCNN pytorch代码(附详细注释和数据集)

【超分辨率】【深度学习】SRCNN pytorch代码(附详细注释和数据集)

时间:2018-11-12 15:04:16

相关推荐

【超分辨率】【深度学习】SRCNN pytorch代码(附详细注释和数据集)

超分辨率

前言1 数据集预处理2 prepare.py 主要看注释(方便理解)3 train.py 主要看注释4 test.py5 结果对比

前言

主要改进:

断点恢复,可以恢复训练。注释掉原test.py的38行才是真正的超分辨率。

即image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC)

其中//代表整除的意思。model.py存在两个与原论文有出入,请仔细思考,如果想不出来可以联系我,但自己思考更有成就感。

关于第二点的注释可以知道,这份代码更注重于研究图像生成,改善的是图像细节而非分辨率。

这里主要是对代码进行讲解,对SRCNN不了解的同学可以先去参考其他博文。

原论文链接:Image quality assessment for determining efficacy and limitations of Super-Resolution Convolutional Neural Network (SRCNN)

有问题,不知道如何跑代码的同学联系:809267697@

代码转自:/yjn870/SRCNN-pytorch

对于新学深度学习代码的同学来说,推荐先阅读这一篇文章:

一个完整的Pytorch深度学习项目代码,项目结构是怎样的?

下面是这篇代码的步骤。

1 数据集预处理

首先准备好数据集,这里以img-91作为训练集,Set5作为测试集。

数据集:

/s/1Mmgh5xMsnYyDUpG6xbb9iw?pwd=bkac

运行prepare.py 将两个数据集转为h5格式。(测试集要在命令加上 --eval)

之后运行train.py

2 prepare.py 主要看注释(方便理解)

import argparseimport globimport h5pyimport numpy as npimport PIL.Image as pil_imagefrom utils import convert_rgb_to_ydef train(args):h5_file = h5py.File(args.output_path, 'w')lr_patches = []hr_patches = []for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):#将照片转换为RGB通道hr = pil_image.open(image_path).convert('RGB')#取放大倍数的倍数hr_width = (hr.width // args.scale) * args.scalehr_height = (hr.height // args.scale) * args.scale#图像大小调整hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)#低分辨率图像缩小lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)#低分辨率图像放大lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)#转换为浮点并取ycrcb中的y通道hr = np.array(hr).astype(np.float32)lr = np.array(lr).astype(np.float32)hr = convert_rgb_to_y(hr)lr = convert_rgb_to_y(lr)for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride):for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride):lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size])hr_patches.append(hr[i:i + args.patch_size, j:j + args.patch_size])lr_patches = np.array(lr_patches)hr_patches = np.array(hr_patches)#创建数据集h5_file.create_dataset('lr', data=lr_patches)h5_file.create_dataset('hr', data=hr_patches)h5_file.close()#下同def eval(args):h5_file = h5py.File(args.output_path, 'w')lr_group = h5_file.create_group('lr')hr_group = h5_file.create_group('hr')for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))):hr = pil_image.open(image_path).convert('RGB')hr_width = (hr.width // args.scale) * args.scalehr_height = (hr.height // args.scale) * args.scalehr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)hr = np.array(hr).astype(np.float32)lr = np.array(lr).astype(np.float32)hr = convert_rgb_to_y(hr)lr = convert_rgb_to_y(lr)lr_group.create_dataset(str(i), data=lr)hr_group.create_dataset(str(i), data=hr)h5_file.close()if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--images-dir', type=str, required=True)parser.add_argument('--output-path', type=str, required=True)parser.add_argument('--patch-size', type=int, default=32)parser.add_argument('--stride', type=int, default=14)parser.add_argument('--scale', type=int, default=4)parser.add_argument('--eval', action='store_true')args = parser.parse_args()if not args.eval:train(args)else:eval(args)

3 train.py 主要看注释

之后运行,看不懂注释可以先去其他博文了解SRCNN的网络结构和训练过程。

import argparseimport osimport copyimport numpy as npfrom torch import Tensorimport torchfrom torch import nnimport torch.optim as optim##gpu加速库import torch.backends.cudnn as cudnnfrom torch.utils.data.dataloader import DataLoader#进度条from tqdm import tqdmfrom model import SRCNNfrom datasets import TrainDataset, EvalDatasetfrom utils import AverageMeter, calc_psnr##需要修改的参数#epoch.pth#losslog#psnrlog#best.pthif __name__ == '__main__':#初始参数设定parser = argparse.ArgumentParser()parser.add_argument('--train-file', type=str, required=True)parser.add_argument('--eval-file', type=str, required=True)parser.add_argument('--outputs-dir', type=str, required=True)parser.add_argument('--scale', type=int, default=3)parser.add_argument('--lr', type=float, default=1e-4)parser.add_argument('--batch-size', type=int, default=16)parser.add_argument('--num-workers', type=int, default=0)parser.add_argument('--num-epochs', type=int, default=400)parser.add_argument('--seed', type=int, default=123)args = parser.parse_args()#输出放入固定文件夹里args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))if not os.path.exists(args.outputs_dir):os.makedirs(args.outputs_dir)#benckmark模式,加速计算,但寻找最优配置,计算的前馈结果会有差异cudnn.benchmark = True#gpu模式device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')#每次程序运行生成的随机数固定torch.manual_seed(args.seed)#运算模式model = SRCNN().to(device)#恢复训练#model.load_state_dict(torch.load('outputs/x3/epoch_173.pth'))#代价函数MSEcriterion = nn.MSELoss()#优化函数Adam,lr代表学习率optimizer = optim.Adam([{'params': model.conv1.parameters()},{'params': model.conv2.parameters()},{'params': model.conv3.parameters(), 'lr': args.lr*0.1}], lr=args.lr)#预处理训练集train_dataset = TrainDataset(args.train_file)train_dataloader = DataLoader(#数据dataset=train_dataset,#分块batch_size=args.batch_size,#数据集数据洗牌,打乱后取batchshuffle=True,#工作进程,像是虚拟存储器中的页表机制num_workers=args.num_workers,#锁页内存,不换出内存pin_memory=True,#不取余,丢弃不足batchSize的图像drop_last=True)#预处理验证集eval_dataset = EvalDataset(args.eval_file)eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)#拷贝权重best_weights = copy.deepcopy(model.state_dict())best_epoch = 0best_psnr = 0.0#画图用lossLog=[]psnrLog=[]#恢复训练#for epoch in range(args.num_epochs):for epoch in range(1, args.num_epochs + 1):#for epoch in range(174, 400):#模型训练入口model.train()#变量更新,计算epoch平均损失epoch_losses = AverageMeter()#进度条,就是不要不足batchsize的部分with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:#t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs - 1))t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs))#每个batch计算一次for data in train_dataloader:#对应datastes.py中的__getItem__,分别为lr,hr图像inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)#训练preds = model(inputs)#获得损失loss = criterion(preds, labels)#显示损失值与长度epoch_losses.update(loss.item(), len(inputs))#梯度清零optimizer.zero_grad()#反向传播loss.backward()#更新参数optimizer.step()#进度条更新t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))t.update(len(inputs))#记录lossLog 方面画图lossLog.append(np.array(epoch_losses.avg))#可以在前面加上路径np.savetxt("lossLog.txt", lossLog)#保存模型torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))#是否更新当前最好参数model.eval()epoch_psnr = AverageMeter()for data in eval_dataloader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)#验证不用求导with torch.no_grad():preds = model(inputs).clamp(0.0, 1.0)epoch_psnr.update(calc_psnr(preds, labels), len(inputs))print('eval psnr: {:.2f}'.format(epoch_psnr.avg))#记录psnrpsnrLog.append(Tensor.cpu(epoch_psnr.avg)) np.savetxt('psnrLog.txt', psnrLog)if epoch_psnr.avg > best_psnr:best_epoch = epochbest_psnr = epoch_psnr.avgbest_weights = copy.deepcopy(model.state_dict())print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

4 test.py

之后运行test.py就可以了,其中跟train.py差不多就不注释了。

test.py是放入图片、权重和倍数就行,会生成两张图片。

5 结果对比

(a)是原图 (b)是bicubic (c)是SRCNN

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。