1000字范文,内容丰富有趣,学习的好帮手!
1000字范文 > 【飞桨模型复现计划】SRCNN网络-超分辨率重建

【飞桨模型复现计划】SRCNN网络-超分辨率重建

时间:2023-09-26 11:31:55

相关推荐

【飞桨模型复现计划】SRCNN网络-超分辨率重建

项目简介

本项目是paperweekly paddlepaddle复现活动的第23篇论文《Single Image Super-Resolution Using Deep Learning》的复现代码。 论文主要介绍了图像超分辨率重建的DNN方法,SRCNN网络。本代码复现了基础模型、修改网络大小的结果和3通道RGB模型的结果,并展示经过超分辨率重建后的图片。

In[1]

# 查看当前挂载的数据集目录 !ls /home/aistudio/data/

data863 data904

In[2]

# 查看个人持久化工作区文件!rm -rf /home/aistudio/work/*!ls /home/aistudio/work/

解压数据集

In[3]

# 这个cell的代码只有第一次运行时需要执行import osos.mkdir('/home/aistudio/work/model') # 用于存储训练好的模型文件os.mkdir('/home/aistudio/work/model_3d') # 用于存储训练好的3d模型文件os.mkdir('/home/aistudio/work/model_fz') # 用于存储训练好的fz模型文件os.mkdir('/home/aistudio/work/dataset') # 用于存放所有的数据集os.mkdir('/home/aistudio/work/dataset/timofte') # 91图训练集# 解压训练集图片os.chdir('/home/aistudio/work/dataset/timofte')os.system('tar -xvf /home/aistudio/data/data863/input_images.tar')# 解压测试集图片os.chdir('/home/aistudio/work/dataset')os.system('tar -xvf /home/aistudio/data/data904/val_dataset.tar')# 回到工作路径os.chdir('/home/aistudio')

In[4]

# import 要用的库import paddle.fluid as fluidimport paddle.v2 as paddleimport numpy as npimport cv2 import osfrom matplotlib import pyplot as plt%matplotlib inline

baseline model

In[6]

class SRCNN(object): # 只对Y通道做超分辨率重构def __init__(self, lr, lr_f, batch_size, iter_num):self.lr = lr # 学习率self.lr_f = lr_f # 最后一层学习率self.batch_size = batch_size self.iter_num = iter_num # 总共训练多少次def net(self, X, Y): # 搭建模型 conv1 = fluid.layers.conv2d(X, 64, 9,act='relu', name='conv1' , param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001),name='conv1_w'),bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.),name='conv1_b'))conv2 = fluid.layers.conv2d(conv1, 32, 1, act='relu', name='conv2' , param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001),name='conv2_w'),bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.),name='conv2_b'))pred = fluid.layers.conv2d(conv2, 1, 5, name='pred', param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001),name='pred_w'),bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.),name='pred_b')) loss = fluid.layers.reduce_mean(fluid.layers.square(pred - Y)) return pred, lossdef train(self):# 模型训练X_train = fluid.layers.data(shape=[1, 33, 33], dtype='float32', name='image')Y_train = fluid.layers.data(shape=[1, 21, 21], dtype='float32', name='gdt')y_predict, y_loss = (X_train, Y_train)Optimizer = fluid.optimizer.AdamOptimizer(learning_rate=self.lr)Optimizer_f = fluid.optimizer.AdamOptimizer(learning_rate=self.lr_f)Optimizer.minimize(y_loss, parameter_list=['conv1_w','conv1_b', 'conv2_w', 'conv2_b'])Optimizer_f.minimize(y_loss, parameter_list=['pred_w', 'pred_b'])# 读取训练集数据train_reader = paddle.batch(self.read_data('work/dataset/timofte'), batch_size=self.batch_size) # 定义执行器place = fluid.CPUPlace()exe = fluid.Executor(place)exe.run(fluid.default_startup_program())def train_loop(main_program):feeder = fluid.DataFeeder(place=place, feed_list=[X_train, Y_train])exe.run(fluid.default_startup_program())backprops_cnt = 0 # 论文中作图的横坐标self.backprops_cnts = [] # 绑定为类的一个属性,用于画图 self.psnr = []# psnr的值for epoch in range(self.iter_num):for batch_id, data in enumerate(train_reader()): loss = exe.run(fluid.framework.default_main_program(),feed=feeder.feed(data),fetch_list=[y_loss])if batch_id == 0: # 每个epoch算一下psnr,画图用的## 算psnr要在测试集上面fluid.io.save_inference_model('work/model/', ['image'], [y_predict], exe)val_loss, val_psnr = self.validation()self.backprops_cnts.append(backprops_cnt * self.batch_size)self.psnr.append(val_psnr)print("%i\tEpoch: %d \tCur Cost : %f\t Val Cost: %f\t PSNR :%f" % (backprops_cnt, epoch, np.array(loss[0])[0], val_loss, val_psnr))backprops_cnt += 1 fluid.io.save_inference_model('work/model/', ['image'], [y_predict], exe)train_loop(fluid.default_main_program())def validation(self):place = fluid.CPUPlace()exe = fluid.Executor(place)inference_scope = fluid.core.Scope() test_set = 'work/dataset/set5/'scale_factor = 3for img_name in os.listdir(test_set): img_val = cv2.imread(os.path.join(test_set, img_name))yuv = cv2.cvtColor(img_val, cv2.COLOR_BGR2YCrCb)img_y, img_u, img_v = cv2.split(yuv)img_h, img_w = img_y.shapeimg_blur = cv2.GaussianBlur(img_y, (5, 5), 0)img_subsample = cv2.resize(img_blur, (img_w/scale_factor, img_h/scale_factor))img_input = cv2.resize(img_blur, (img_w, img_h), interpolation=cv2.INTER_CUBIC)img_input = np.reshape(img_input, [1,1,img_h, img_w]).astype("float32") # h,w losses = []with fluid.scope_guard(inference_scope):[inference_program, feed_target_names, fetch_targets] = (fluid.io.load_inference_model('work/model/', exe)) results = exe.run(inference_program,feed={feed_target_names[0]: img_input},fetch_list=fetch_targets)[0]loss = np.mean(np.square(results[0,0]-img_y[6:-6, 6:-6]))losses.append(loss) avg_loss = np.sum(np.array(losses))/len(losses)psnr = 10 * np.log10(255*255/avg_loss)return avg_loss,psnrdef generate_reconstruct_img(self, img_name): place = fluid.CPUPlace()exe = fluid.Executor(place)inference_scope = fluid.core.Scope() img_test = cv2.imread('work/dataset/set5/%s' % img_name)yuv_test = cv2.cvtColor(img_test, cv2.COLOR_BGR2YCrCb) img_h, img_w, img_c = img_test.shapeprint "=====原始图片========="b,g,r = cv2.split(img_test) # AI studio 不支持cv2.imshow,所以用plt.imshow输出,两者rgb顺序不一样img_test = cv2.merge([r,g,b])plt.imshow(img_test)plt.show() # 图像模糊+cubic插值img_blur = cv2.GaussianBlur(yuv_test.copy(), (5, 5), 0)img_subsample = cv2.resize(img_blur, (img_w/3, img_h/3)) #这里注意cv2.resize里面的shape是w,h的顺序 img_cubic = cv2.resize(img_blur, (img_w, img_h), interpolation=cv2.INTER_CUBIC)img_y, img_u, img_v = cv2.split(img_cubic) img_input = np.reshape(img_y, [1,1,img_h, img_w]).astype("float32") # 把y通道作为输入with fluid.scope_guard(inference_scope):[inference_program, feed_target_names, fetch_targets] = (fluid.io.load_inference_model('work/model/', exe)) results = exe.run(inference_program,feed={feed_target_names[0]: img_input},fetch_list=fetch_targets)[0]result_img = np.array(results) result_img[result_img < 0] = 0result_img[result_img >255] = 255gap_y = (img_y.shape[0]-result_img.shape[2])/2gap_x = (img_y.shape[1]-result_img.shape[3])/2print "=====Y通道输入========="plt.imshow(img_y, cmap='gray')plt.show()img_y[gap_y: gap_y + result_img.shape[2],gap_x: gap_x + result_img.shape[3]]=result_imgimg_test_r = cv2.merge([img_y, img_u, img_v])img_test_r = cv2.cvtColor(img_test_r, cv2.COLOR_YCrCb2BGR)print "=====Y通道输出========="plt.imshow(img_y, cmap='gray') plt.show()print "=====彩图结果========="b,g,r = cv2.split(img_test_r) img_test_show = cv2.merge([r,g,b])plt.imshow(img_test_show)plt.show()def read_data(self, data_path):def data_reader():for image in os.listdir(data_path):if image.endswith('.bmp'):img = cv2.imread(os.path.join(data_path, image))yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)img_y, img_u, img_v = cv2.split(yuv)# 下面是切图的步骤j = 0count = 0while j+33 < len(img_y):i = 0while i+33 < len(img_y[0]):img_patch = img_y[j:j+33, i:i+33]img_gth = img_patch[6:27, 6:27].copy()img_blur = cv2.GaussianBlur(img_patch, (5, 5), 0)img_sumsample = cv2.resize(img_blur, (11, 11))img_input = cv2.resize(img_blur, (33, 33), interpolation=cv2.INTER_CUBIC)yield img_input, img_gthi+=14j+= 14return data_reader

In[7]

model = SRCNN(0.0001, 0.00001, 100, 150)model.train()

0Epoch: 0 Cur Cost : 18164.226562 Val Cost: 21233.191406 PSNR :4.860651219Epoch: 1 Cur Cost : 284.385284 Val Cost: 417.904175 PSNR :21.97438Epoch: 2 Cur Cost : 281.428558 Val Cost: 413.779419 PSNR :21.963115657Epoch: 3 Cur Cost : 277.070312 Val Cost: 406.204071 PSNR :22.043361876Epoch: 4 Cur Cost : 253.116455 Val Cost: 372.079315 PSNR :22.4244481095Epoch: 5 Cur Cost : 207.459930 Val Cost: 303.296204 PSNR :23.3121341314Epoch: 6 Cur Cost : 154.225296 Val Cost: 221.624435 PSNR :24.6746271533Epoch: 7 Cur Cost : 117.966713 Val Cost: 165.380890 PSNR :25.9459501752Epoch: 8 Cur Cost : 114.436462 Val Cost: 158.598419 PSNR :26.1278151971Epoch: 9 Cur Cost : 112.926453 Val Cost: 155.753204 PSNR :26.2064342190Epoch: 10 Cur Cost : 111.557121 Val Cost: 153.391846 PSNR :26.2727812409Epoch: 11 Cur Cost : 110.163071 Val Cost: 151.123138 PSNR :26.3374942628Epoch: 12 Cur Cost : 108.74 Val Cost: 148.844711 PSNR :26.4034702847Epoch: 13 Cur Cost : 107.220848 Val Cost: 146.517197 PSNR :26.4719183066Epoch: 14 Cur Cost : 105.661781 Val Cost: 144.118179 PSNR :26.5436163285Epoch: 15 Cur Cost : 104.040611 Val Cost: 141.630585 PSNR :26.6192333504Epoch: 16 Cur Cost : 102.351303 Val Cost: 139.039047 PSNR :26.6994363723Epoch: 17 Cur Cost : 100.585503 Val Cost: 136.326965 PSNR :26.7849863942Epoch: 18 Cur Cost : 98.731277 Val Cost: 133.479095 PSNR :26.8766714161Epoch: 19 Cur Cost : 96.780418 Val Cost: 130.485748 PSNR :26.9751734380Epoch: 20 Cur Cost : 94.736946 Val Cost: 127.345490 PSNR :27.0809684599Epoch: 21 Cur Cost : 92.607079 Val Cost: 124.054619 PSNR :27.1946744818Epoch: 22 Cur Cost : 90.392410 Val Cost: 120.612396 PSNR :27.3168845037Epoch: 23 Cur Cost : 88.097008 Val Cost: 117.033394 PSNR :27.4477065256Epoch: 24 Cur Cost : 85.727638 Val Cost: 113.339539 PSNR :27.5869895475Epoch: 25 Cur Cost : 83.291718 Val Cost: 109.558945 PSNR :27.7343255694Epoch: 26 Cur Cost : 80.803673 Val Cost: 105.716370 PSNR :27.8893815913Epoch: 27 Cur Cost : 78.289711 Val Cost: 101.814842 PSNR :28.0526936132Epoch: 28 Cur Cost : 75.769890 Val Cost: 97.854851 PSNR :28.2249806351Epoch: 29 Cur Cost : 73.263191 Val Cost: 93.865204 PSNR :28.4057576570Epoch: 30 Cur Cost : 70.80 Val Cost: 89.902550 PSNR :28.5930846789Epoch: 31 Cur Cost : 68.425529 Val Cost: 86.056076 PSNR :28.7829887008Epoch: 32 Cur Cost : 66.165245 Val Cost: 82.378616 PSNR :28.9726597227Epoch: 33 Cur Cost : 64.029266 Val Cost: 78.877708 PSNR :29.1612617446Epoch: 34 Cur Cost : 61.938168 Val Cost: 75.441162 PSNR :29.3547207665Epoch: 35 Cur Cost : 59.845867 Val Cost: 71.962860 PSNR :29.5597197884Epoch: 36 Cur Cost : 57.924450 Val Cost: 68.783943 PSNR :29.7559338103Epoch: 37 Cur Cost : 56.233124 Val Cost: 65.943611 PSNR :29.9390768322Epoch: 38 Cur Cost : 54.610748 Val Cost: 63.266205 PSNR :30.1190868541Epoch: 39 Cur Cost : 53.087059 Val Cost: 60.819660 PSNR :30.2903648760Epoch: 40 Cur Cost : 51.752823 Val Cost: 58.767235 PSNR :30.4394518979Epoch: 41 Cur Cost : 50.548962 Val Cost: 56.994938 PSNR :30.5724419198Epoch: 42 Cur Cost : 49.409630 Val Cost: 55.392437 PSNR :30.6962999417Epoch: 43 Cur Cost : 48.292534 Val Cost: 53.872486 PSNR :30.8171339636Epoch: 44 Cur Cost : 47.169697 Val Cost: 52.356304 PSNR :30.9411149855Epoch: 45 Cur Cost : 46.026646 Val Cost: 50.804356 PSNR :31.07179410074Epoch: 46 Cur Cost : 44.855438 Val Cost: 49.224430 PSNR :31.20899710293Epoch: 47 Cur Cost : 43.651192 Val Cost: 47.612514 PSNR :31.35359210512Epoch: 48 Cur Cost : 42.413403 Val Cost: 45.968498 PSNR :31.50620010731Epoch: 49 Cur Cost : 41.161961 Val Cost: 44.319637 PSNR :31.66484210950Epoch: 50 Cur Cost : 39.941639 Val Cost: 42.723743 PSNR :31.82411111169Epoch: 51 Cur Cost : 38.763580 Val Cost: 41.210186 PSNR :31.98075811388Epoch: 52 Cur Cost : 37.628029 Val Cost: 39.770855 PSNR :32.13515411607Epoch: 53 Cur Cost : 36.646980 Val Cost: 38.539745 PSNR :32.27171511826Epoch: 54 Cur Cost : 35.795269 Val Cost: 37.488255 PSNR :32.39185112045Epoch: 55 Cur Cost : 35.093742 Val Cost: 36.636829 PSNR :32.49162512264Epoch: 56 Cur Cost : 34.488441 Val Cost: 35.905704 PSNR :32.57916912483Epoch: 57 Cur Cost : 33.963333 Val Cost: 35.274933 PSNR :32.65614212702Epoch: 58 Cur Cost : 33.505180 Val Cost: 34.726208 PSNR :32.72423012921Epoch: 59 Cur Cost : 33.103012 Val Cost: 34.243057 PSNR :32.78507813140Epoch: 60 Cur Cost : 32.750072 Val Cost: 33.818142 PSNR :32.83930613359Epoch: 61 Cur Cost : 32.436031 Val Cost: 33.438747 PSNR :32.88830413578Epoch: 62 Cur Cost : 32.150639 Val Cost: 33.090755 PSNR :32.93373713797Epoch: 63 Cur Cost : 31.886650 Val Cost: 32.765667 PSNR :32.97661314016Epoch: 64 Cur Cost : 31.639585 Val Cost: 32.457993 PSNR :33.01758714235Epoch: 65 Cur Cost : 31.405382 Val Cost: 32.163200 PSNR :33.05721114454Epoch: 66 Cur Cost : 31.181067 Val Cost: 31.877956 PSNR :33.09589914673Epoch: 67 Cur Cost : 30.963821 Val Cost: 31.595957 PSNR :33.13448814892Epoch: 68 Cur Cost : 30.745806 Val Cost: 31.313644 PSNR :33.17346715111Epoch: 69 Cur Cost : 30.530462 Val Cost: 31.029354 PSNR :33.21307615330Epoch: 70 Cur Cost : 30.316664 Val Cost: 30.746445 PSNR :33.25285515549Epoch: 71 Cur Cost : 30.103735 Val Cost: 30.461548 PSNR :33.29328415768Epoch: 72 Cur Cost : 29.889498 Val Cost: 30.172472 PSNR :33.33469515987Epoch: 73 Cur Cost : 29.668657 Val Cost: 29.870338 PSNR :33.37840216206Epoch: 74 Cur Cost : 29.430288 Val Cost: 29.451651 PSNR :33.43970716425Epoch: 75 Cur Cost : 29.207491 Val Cost: 29.195040 PSNR :33.47771316644Epoch: 76 Cur Cost : 29.008448 Val Cost: 28.942610 PSNR :33.51542716863Epoch: 77 Cur Cost : 28.805315 Val Cost: 28.677612 PSNR :33.55537417082Epoch: 78 Cur Cost : 28.597673 Val Cost: 28.398453 PSNR :33.59785717301Epoch: 79 Cur Cost : 28.384197 Val Cost: 28.109222 PSNR :33.64231517520Epoch: 80 Cur Cost : 28.165312 Val Cost: 27.810051 PSNR :33.68878617739Epoch: 81 Cur Cost : 27.941040 Val Cost: 27.501369 PSNR :33.73726017958Epoch: 82 Cur Cost : 27.711397 Val Cost: 27.183748 PSNR :33.78771018177Epoch: 83 Cur Cost : 27.476480 Val Cost: 26.856071 PSNR :33.84037918396Epoch: 84 Cur Cost : 27.235754 Val Cost: 26.519581 PSNR :33.89513718615Epoch: 85 Cur Cost : 26.991680 Val Cost: 26.172489 PSNR :33.95235318834Epoch: 86 Cur Cost : 26.743433 Val Cost: 25.811661 PSNR :34.01264419053Epoch: 87 Cur Cost : 26.480677 Val Cost: 25.444807 PSNR :34.07481219272Epoch: 88 Cur Cost : 26.224283 Val Cost: 25.074034 PSNR :34.13856219491Epoch: 89 Cur Cost : 25.963892 Val Cost: 24.700274 PSNR :34.20378619710Epoch: 90 Cur Cost : 25.710884 Val Cost: 24.330662 PSNR :34.26926419929Epoch: 91 Cur Cost : 25.457993 Val Cost: 23.969437 PSNR :34.3342258Epoch: 92 Cur Cost : 25.210745 Val Cost: 23.615768 PSNR :34.39878320367Epoch: 93 Cur Cost : 24.963886 Val Cost: 23.267546 PSNR :34.46329820586Epoch: 94 Cur Cost : 24.722895 Val Cost: 22.925728 PSNR :34.52757220805Epoch: 95 Cur Cost : 24.486914 Val Cost: 22.588749 PSNR :34.59188221024Epoch: 96 Cur Cost : 24.253819 Val Cost: 22.257490 PSNR :34.65604221243Epoch: 97 Cur Cost : 24.023243 Val Cost: 21.936876 PSNR :34.71905621462Epoch: 98 Cur Cost : 23.806646 Val Cost: 21.630701 PSNR :34.78009821681Epoch: 99 Cur Cost : 23.600349 Val Cost: 21.338070 PSNR :34.83925221900Epoch: 100 Cur Cost : 23.404703 Val Cost: 21.059601 PSNR :34.89630222119Epoch: 101 Cur Cost : 23.220484 Val Cost: 20.796125 PSNR :34.95097922338Epoch: 102 Cur Cost : 23.047068 Val Cost: 20.546223 PSNR :35.00348422557Epoch: 103 Cur Cost : 22.884609 Val Cost: 20.312155 PSNR :35.05324422776Epoch: 104 Cur Cost : 22.732927 Val Cost: 20.091757 PSNR :35.10062422995Epoch: 105 Cur Cost : 22.590960 Val Cost: 19.883532 PSNR :35.14586823214Epoch: 106 Cur Cost : 22.457445 Val Cost: 19.684614 PSNR :35.18953523433Epoch: 107 Cur Cost : 22.331661 Val Cost: 19.496771 PSNR :35.231177

23652Epoch: 108 Cur Cost : 22.212849 Val Cost: 19.317987 PSNR :35.27118523871Epoch: 109 Cur Cost : 22.100451 Val Cost: 19.148111 PSNR :35.30954424090Epoch: 110 Cur Cost : 21.994497 Val Cost: 18.987167 PSNR :35.3464309Epoch: 111 Cur Cost : 21.894119 Val Cost: 18.834215 PSNR :35.38132824528Epoch: 112 Cur Cost : 21.798851 Val Cost: 18.687769 PSNR :35.41522924747Epoch: 113 Cur Cost : 21.708071 Val Cost: 18.547321 PSNR :35.44799224966Epoch: 114 Cur Cost : 21.621136 Val Cost: 18.413755 PSNR :35.47938025185Epoch: 115 Cur Cost : 21.537928 Val Cost: 18.285975 PSNR :35.50962325404Epoch: 116 Cur Cost : 21.458000 Val Cost: 18.164145 PSNR :35.53865425623Epoch: 117 Cur Cost : 21.381271 Val Cost: 18.046951 PSNR :35.56676525842Epoch: 118 Cur Cost : 21.307785 Val Cost: 17.934608 PSNR :35.59388526061Epoch: 119 Cur Cost : 21.236940 Val Cost: 17.827406 PSNR :35.61992226280Epoch: 120 Cur Cost : 21.168789 Val Cost: 17.724836 PSNR :35.64498126499Epoch: 121 Cur Cost : 21.103100 Val Cost: 17.627266 PSNR :35.66895426718Epoch: 122 Cur Cost : 21.039707 Val Cost: 17.533913 PSNR :35.6926937Epoch: 123 Cur Cost : 20.978670 Val Cost: 17.444506 PSNR :35.71421727156Epoch: 124 Cur Cost : 20.919706 Val Cost: 17.358051 PSNR :35.73579427375Epoch: 125 Cur Cost : 20.862633 Val Cost: 17.275480 PSNR :35.75650227594Epoch: 126 Cur Cost : 20.807596 Val Cost: 17.195841 PSNR :35.77656927813Epoch: 127 Cur Cost : 20.754372 Val Cost: 17.119400 PSNR :35.79591828032Epoch: 128 Cur Cost : 20.702871 Val Cost: 17.045691 PSNR :35.81465828251Epoch: 129 Cur Cost : 20.652927 Val Cost: 16.974653 PSNR :35.83279428470Epoch: 130 Cur Cost : 20.604521 Val Cost: 16.906082 PSNR :35.85037428689Epoch: 131 Cur Cost : 20.557560 Val Cost: 16.840302 PSNR :35.86730528908Epoch: 132 Cur Cost : 20.511902 Val Cost: 16.776218 PSNR :35.88386329127Epoch: 133 Cur Cost : 20.467619 Val Cost: 16.714396 PSNR :35.89989729346Epoch: 134 Cur Cost : 20.424440 Val Cost: 16.654495 PSNR :35.91548929565Epoch: 135 Cur Cost : 20.382488 Val Cost: 16.596294 PSNR :35.93069229784Epoch: 136 Cur Cost : 20.341431 Val Cost: 16.539652 PSNR :35.94554030003Epoch: 137 Cur Cost : 20.301138 Val Cost: 16.484528 PSNR :35.96003930222Epoch: 138 Cur Cost : 20.261866 Val Cost: 16.430748 PSNR :35.97423030441Epoch: 139 Cur Cost : 20.223341 Val Cost: 16.378759 PSNR :35.98799430660Epoch: 140 Cur Cost : 20.185614 Val Cost: 16.328671 PSNR :36.00129530879Epoch: 141 Cur Cost : 20.148609 Val Cost: 16.278326 PSNR :36.01470631098Epoch: 142 Cur Cost : 20.112509 Val Cost: 16.229719 PSNR :36.02769431317Epoch: 143 Cur Cost : 20.076233 Val Cost: 16.180443 PSNR :36.04090031536Epoch: 144 Cur Cost : 20.039700 Val Cost: 16.131332 PSNR :36.05410131755Epoch: 145 Cur Cost : 20.001612 Val Cost: 16.085133 PSNR :36.06655731974Epoch: 146 Cur Cost : 19.969093 Val Cost: 16.049982 PSNR :36.07605832193Epoch: 147 Cur Cost : 19.926191 Val Cost: 15.993620 PSNR :36.09133632412Epoch: 148 Cur Cost : 19.890646 Val Cost: 15.950057 PSNR :36.10318132631Epoch: 149 Cur Cost : 19.856085 Val Cost: 15.906347 PSNR :36.115099

In[8]

# 绘制训练曲线plt.plot(model.backprops_cnts,model.psnr)

[<matplotlib.lines.Line2D at 0x7f92346cf5d0>]

In[9]

model.generate_reconstruct_img('butterfly_GT.bmp')

=====原始图片=========

=====Y通道输入=========

=====Y通道输出=========

=====彩图结果=========

In[10]

model.generate_reconstruct_img('baby_GT.bmp')

=====原始图片=========

=====Y通道输入=========

=====Y通道输出=========

=====彩图结果=========

In[11]

class SRCNN_fz(object): # 只对Y通道做超分辨率重构def __init__(self, lr, lr_f, batch_size, iter_num):self.lr = lr # 学习率self.lr_f = lr_fself.batch_size = batch_size self.iter_num = iter_num # 总共训练多少次def net(self, X, Y): # 搭建模型 conv1 = fluid.layers.conv2d(X, 64, 9,act='relu', name='conv1' , param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001),name='conv1_w'),bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.),name='conv1_b'))conv2 = fluid.layers.conv2d(conv1, 32, 3, act='relu', name='conv2' , param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001),name='conv2_w'),bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.),name='conv2_b'))pred = fluid.layers.conv2d(conv2, 1, 5, name='pred', param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001),name='pred_w'),bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.),name='pred_b')) loss = fluid.layers.reduce_mean(fluid.layers.square(pred - Y)) return pred, lossdef train(self):# 模型训练X_train = fluid.layers.data(shape=[1, 33, 33], dtype='float32', name='image')Y_train = fluid.layers.data(shape=[1, 19, 19], dtype='float32', name='gdt')y_predict, y_loss = (X_train, Y_train)Optimizer = fluid.optimizer.AdamOptimizer(learning_rate=self.lr)Optimizer_f = fluid.optimizer.AdamOptimizer(learning_rate=self.lr_f)Optimizer.minimize(y_loss, parameter_list=['conv1_w','conv1_b', 'conv2_w', 'conv2_b'])Optimizer_f.minimize(y_loss, parameter_list=['pred_w', 'pred_b'])# 读取训练集数据train_reader = paddle.batch(self.read_data('work/dataset/timofte'), batch_size=self.batch_size) # 定义执行器place = fluid.CPUPlace()exe = fluid.Executor(place)exe.run(fluid.default_startup_program())def train_loop(main_program):feeder = fluid.DataFeeder(place=place, feed_list=[X_train, Y_train])exe.run(fluid.default_startup_program())backprops_cnt = 0 # 论文中作图的横坐标self.backprops_cnts = [] # 绑定为类的一个属性,用于画图 self.psnr = []# psnr的值for epoch in range(self.iter_num):for batch_id, data in enumerate(train_reader()): loss = exe.run(fluid.framework.default_main_program(),feed=feeder.feed(data),fetch_list=[y_loss]) if batch_id == 0: # 每个epoch算一下psnr,画图用的## 算psnr要在测试集上面fluid.io.save_inference_model('work/model_fz/', ['image'], [y_predict], exe)val_loss, val_psnr = self.validation()self.backprops_cnts.append(backprops_cnt * self.batch_size)self.psnr.append(val_psnr)print("%i\tEpoch: %d \tCur Cost : %f\t Val Cost: %f\t PSNR :%f" % (backprops_cnt, epoch, np.array(loss[0])[0], val_loss, val_psnr))backprops_cnt += 1 fluid.io.save_inference_model('work/model_fz/', ['image'], [y_predict], exe)train_loop(fluid.default_main_program())def validation(self):place = fluid.CPUPlace()exe = fluid.Executor(place)inference_scope = fluid.core.Scope() test_set = 'work/dataset/set5/'scale_factor = 3for img_name in os.listdir(test_set): img_val = cv2.imread(os.path.join(test_set, img_name))yuv = cv2.cvtColor(img_val, cv2.COLOR_BGR2YCrCb)img_y, img_u, img_v = cv2.split(yuv)img_h, img_w = img_y.shapeimg_blur = cv2.GaussianBlur(img_y, (5, 5), 0)img_subsample = cv2.resize(img_blur, (img_w/scale_factor, img_h/scale_factor))img_input = cv2.resize(img_blur, (img_w, img_h), interpolation=cv2.INTER_CUBIC)img_input = np.reshape(img_input, [1,1,img_h, img_w]).astype("float32") # h,w losses = []with fluid.scope_guard(inference_scope):[inference_program, feed_target_names, fetch_targets] = (fluid.io.load_inference_model('work/model_fz/', exe)) results = exe.run(inference_program,feed={feed_target_names[0]: img_input},fetch_list=fetch_targets)[0]loss = np.mean(np.square(results[0,0]-img_y[7:-7, 7:-7]))losses.append(loss) avg_loss = np.sum(np.array(losses))/len(losses)psnr = 10 * np.log10(255*255/avg_loss)return avg_loss,psnrdef generate_reconstruct_img(self, img_name): place = fluid.CPUPlace()exe = fluid.Executor(place)inference_scope = fluid.core.Scope() img_test = cv2.imread('work/dataset/set5/%s' % img_name)yuv_test = cv2.cvtColor(img_test, cv2.COLOR_BGR2YCrCb) img_h, img_w, img_c = img_test.shapeprint "=====原始图片========="b,g,r = cv2.split(img_test) # AI studio 不支持cv2.imshow,所以用plt.imshow输出,两者rgb顺序不一样img_test = cv2.merge([r,g,b])plt.imshow(img_test)plt.show() # 图像模糊+cubic插值img_blur = cv2.GaussianBlur(yuv_test.copy(), (5, 5), 0)img_subsample = cv2.resize(img_blur, (img_w/3, img_h/3)) #这里注意cv2.resize里面的shape是w,h的顺序 img_cubic = cv2.resize(img_blur, (img_w, img_h), interpolation=cv2.INTER_CUBIC)img_y, img_u, img_v = cv2.split(img_cubic) img_input = np.reshape(img_y, [1,1,img_h, img_w]).astype("float32") # 把y通道作为输入with fluid.scope_guard(inference_scope):[inference_program, feed_target_names, fetch_targets] = (fluid.io.load_inference_model('work/model_fz/', exe)) results = exe.run(inference_program,feed={feed_target_names[0]: img_input},fetch_list=fetch_targets)[0]result_img = np.array(results) result_img[result_img < 0] = 0result_img[result_img >255] = 255gap_y = (img_y.shape[0]-result_img.shape[2])/2gap_x = (img_y.shape[1]-result_img.shape[3])/2print "=====Y通道输入========="plt.imshow(img_y, cmap='gray')plt.show()img_y[gap_y: gap_y + result_img.shape[2],gap_x: gap_x + result_img.shape[3]]=result_imgimg_test_r = cv2.merge([img_y, img_u, img_v])img_test_r = cv2.cvtColor(img_test_r, cv2.COLOR_YCrCb2BGR)print "=====Y通道输出========="plt.imshow(img_y, cmap='gray') plt.show()print "=====彩图结果========="b,g,r = cv2.split(img_test_r) img_test_show = cv2.merge([r,g,b])plt.imshow(img_test_show)plt.show()def read_data(self, data_path):def data_reader():for image in os.listdir(data_path):if image.endswith('.bmp'):img = cv2.imread(os.path.join(data_path, image))yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)img_y, img_u, img_v = cv2.split(yuv)# 下面是切图的步骤j = 0count = 0while j+33 < len(img_y):i = 0while i+33 < len(img_y[0]):img_patch = img_y[j:j+33, i:i+33]img_gth = img_patch[7:-7, 7:-7].copy()img_blur = cv2.GaussianBlur(img_patch, (5, 5), 0)img_sumsample = cv2.resize(img_blur, (11, 11))img_input = cv2.resize(img_blur, (33, 33), interpolation=cv2.INTER_CUBIC)yield img_input, img_gthi+=14j+= 14return data_reader

改变filter_size的训练结果

需要重启一下kernel清空定义的模型

In[12]

model = SRCNN_fz(0.0001, 0.00001, 100, 150)model.train()

0Epoch: 0 Cur Cost : 18187.878906 Val Cost: 21415.037109 PSNR :4.823615219Epoch: 1 Cur Cost : 293.490234 Val Cost: 423.864105 PSNR :21.858537438Epoch: 2 Cur Cost : 286.713898 Val Cost: 413.923096 PSNR :21.961607657Epoch: 3 Cur Cost : 277.595795 Val Cost: 400.915894 PSNR :22.100271876Epoch: 4 Cur Cost : 266.524780 Val Cost: 385.093658 PSNR :22.2751401095Epoch: 5 Cur Cost : 253.052292 Val Cost: 365.947205 PSNR :22.4966191314Epoch: 6 Cur Cost : 236.651108 Val Cost: 342.572601 PSNR :22.7832771533Epoch: 7 Cur Cost : 216.273224 Val Cost: 312.513641 PSNR :23.1821141752Epoch: 8 Cur Cost : 189.500977 Val Cost: 272.160767 PSNR :23.7825481971Epoch: 9 Cur Cost : 157.191208 Val Cost: 224.049789 PSNR :24.6273582190Epoch: 10 Cur Cost : 129.589066 Val Cost: 182.939468 PSNR :25.5077292409Epoch: 11 Cur Cost : 113.882393 Val Cost: 155.549286 PSNR :26.2121232628Epoch: 12 Cur Cost : 109.084671 Val Cost: 146.027054 PSNR :26.4864702847Epoch: 13 Cur Cost : 104.128502 Val Cost: 137.899292 PSNR :26.7351833066Epoch: 14 Cur Cost : 97.342422 Val Cost: 127.754402 PSNR :27.0670453285Epoch: 15 Cur Cost : 91.355812 Val Cost: 118.736206 PSNR :27.3849723504Epoch: 16 Cur Cost : 85.422371 Val Cost: 109.795456 PSNR :27.7249603723Epoch: 17 Cur Cost : 79.140945 Val Cost: 100.095528 PSNR :28.1266573942Epoch: 18 Cur Cost : 73.169403 Val Cost: 90.881409 PSNR :28.5460534161Epoch: 19 Cur Cost : 67.917305 Val Cost: 82.804932 PSNR :28.9502424380Epoch: 20 Cur Cost : 63.625111 Val Cost: 76.324135 PSNR :29.3041854599Epoch: 21 Cur Cost : 59.751232 Val Cost: 70.921104 PSNR :29.6230494818Epoch: 22 Cur Cost : 55.927074 Val Cost: 65.637527 PSNR :29.9592815037Epoch: 23 Cur Cost : 52.361134 Val Cost: 60.628269 PSNR :30.3040525256Epoch: 24 Cur Cost : 49.219032 Val Cost: 56.010677 PSNR :30.6480955475Epoch: 25 Cur Cost : 46.496906 Val Cost: 51.900589 PSNR :30.9790815694Epoch: 26 Cur Cost : 44.209206 Val Cost: 48.567101 PSNR :31.2673825913Epoch: 27 Cur Cost : 42.444748 Val Cost: 46.098759 PSNR :31.4939116132Epoch: 28 Cur Cost : 40.951580 Val Cost: 44.224030 PSNR :31.6742206351Epoch: 29 Cur Cost : 39.668636 Val Cost: 42.721249 PSNR :31.8243646570Epoch: 30 Cur Cost : 38.468319 Val Cost: 41.372662 PSNR :31.9636696789Epoch: 31 Cur Cost : 37.305447 Val Cost: 40.067513 PSNR :32.1028807008Epoch: 32 Cur Cost : 36.248657 Val Cost: 38.869499 PSNR :32.2347147227Epoch: 33 Cur Cost : 35.280151 Val Cost: 37.761169 PSNR :32.3603497446Epoch: 34 Cur Cost : 34.358017 Val Cost: 36.653786 PSNR :32.4896157665Epoch: 35 Cur Cost : 33.456329 Val Cost: 35.482918 PSNR :32.6306107884Epoch: 36 Cur Cost : 32.594936 Val Cost: 34.364208 PSNR :32.7697408103Epoch: 37 Cur Cost : 31.807510 Val Cost: 33.297546 PSNR :32.9066818322Epoch: 38 Cur Cost : 31.098452 Val Cost: 32.322258 PSNR :33.0357878541Epoch: 39 Cur Cost : 30.458292 Val Cost: 31.449808 PSNR :33.1546248760Epoch: 40 Cur Cost : 29.896788 Val Cost: 30.646156 PSNR :33.2670438979Epoch: 41 Cur Cost : 29.400312 Val Cost: 29.952234 PSNR :33.3665119198Epoch: 42 Cur Cost : 28.963032 Val Cost: 29.316887 PSNR :33.4596259417Epoch: 43 Cur Cost : 28.554939 Val Cost: 28.730633 PSNR :33.5473529636Epoch: 44 Cur Cost : 28.187880 Val Cost: 28.178003 PSNR :33.6317019855Epoch: 45 Cur Cost : 27.815355 Val Cost: 27.621120 PSNR :33.71839110074Epoch: 46 Cur Cost : 27.450251 Val Cost: 27.054892 PSNR :33.80834610293Epoch: 47 Cur Cost : 27.095476 Val Cost: 26.506290 PSNR :33.89731410512Epoch: 48 Cur Cost : 26.726200 Val Cost: 25.993595 PSNR :33.98214010731Epoch: 49 Cur Cost : 26.418917 Val Cost: 25.480820 PSNR :34.06867010950Epoch: 50 Cur Cost : 26.102079 Val Cost: 25.046286 PSNR :34.14337011169Epoch: 51 Cur Cost : 25.814594 Val Cost: 24.607016 PSNR :34.2411388Epoch: 52 Cur Cost : 25.541672 Val Cost: 24.180517 PSNR :34.29614811607Epoch: 53 Cur Cost : 25.276190 Val Cost: 23.795362 PSNR :34.36588011826Epoch: 54 Cur Cost : 25.032814 Val Cost: 23.423958 PSNR :34.4342045Epoch: 55 Cur Cost : 24.795341 Val Cost: 23.076561 PSNR :34.49909312264Epoch: 56 Cur Cost : 24.567106 Val Cost: 22.729460 PSNR :34.56491212483Epoch: 57 Cur Cost : 24.350388 Val Cost: 22.400122 PSNR :34.62830012702Epoch: 58 Cur Cost : 24.139818 Val Cost: 22.082649 PSNR :34.69029212921Epoch: 59 Cur Cost : 23.937288 Val Cost: 21.775942 PSNR :34.75103413140Epoch: 60 Cur Cost : 23.740856 Val Cost: 21.478758 PSNR :34.81071213359Epoch: 61 Cur Cost : 23.550505 Val Cost: 21.190653 PSNR :34.86936013578Epoch: 62 Cur Cost : 23.365761 Val Cost: 20.912657 PSNR :34.92671213797Epoch: 63 Cur Cost : 23.186628 Val Cost: 20.643282 PSNR :34.98301614016Epoch: 64 Cur Cost : 23.012552 Val Cost: 20.381088 PSNR :35.03853014235Epoch: 65 Cur Cost : 22.843679 Val Cost: 20.126986 PSNR :35.09301614454Epoch: 66 Cur Cost : 22.680412 Val Cost: 19.880857 PSNR :35.14645214673Epoch: 67 Cur Cost : 22.522045 Val Cost: 19.643797 PSNR :35.19854914892Epoch: 68 Cur Cost : 22.368488 Val Cost: 19.414352 PSNR :35.24957515111Epoch: 69 Cur Cost : 22.227560 Val Cost: 19.191557 PSNR :35.29970215330Epoch: 70 Cur Cost : 22.084782 Val Cost: 18.978037 PSNR :35.34829115549Epoch: 71 Cur Cost : 21.936750 Val Cost: 18.766928 PSNR :35.39687215768Epoch: 72 Cur Cost : 21.809185 Val Cost: 18.569225 PSNR :35.44286615987Epoch: 73 Cur Cost : 21.701038 Val Cost: 18.389750 PSNR :35.48504516206Epoch: 74 Cur Cost : 21.573944 Val Cost: 18.225378 PSNR :35.52403816425Epoch: 75 Cur Cost : 21.466406 Val Cost: 18.062704 PSNR :35.56297616644Epoch: 76 Cur Cost : 21.362574 Val Cost: 17.919756 PSNR :35.59748316863Epoch: 77 Cur Cost : 21.265091 Val Cost: 17.759726 PSNR :35.63644117082Epoch: 78 Cur Cost : 21.167427 Val Cost: 17.617479 PSNR :35.67136617301Epoch: 79 Cur Cost : 21.079617 Val Cost: 17.467531 PSNR :35.70848817520Epoch: 80 Cur Cost : 20.975424 Val Cost: 17.328568 PSNR :35.74317717739Epoch: 81 Cur Cost : 20.885769 Val Cost: 17.211353 PSNR :35.77265317958Epoch: 82 Cur Cost : 20.804918 Val Cost: 17.072248 PSNR :35.80789618177Epoch: 83 Cur Cost : 20.729193 Val Cost: 16.936848 PSNR :35.84247818396Epoch: 84 Cur Cost : 20.651018 Val Cost: 16.827904 PSNR :35.87050318615Epoch: 85 Cur Cost : 20.577908 Val Cost: 16.708214 PSNR :35.90150318834Epoch: 86 Cur Cost : 20.506536 Val Cost: 16.616051 PSNR :35.92552619053Epoch: 87 Cur Cost : 20.444763 Val Cost: 16.530025 PSNR :35.94806819272Epoch: 88 Cur Cost : 20.384291 Val Cost: 16.441322 PSNR :35.97143619491Epoch: 89 Cur Cost : 20.331388 Val Cost: 16.359850 PSNR :35.99301019710Epoch: 90 Cur Cost : 20.274405 Val Cost: 16.283634 PSNR :36.01329019929Epoch: 91 Cur Cost : 20.223141 Val Cost: 16.14 PSNR :36.0350338Epoch: 92 Cur Cost : 20.172297 Val Cost: 16.124725 PSNR :36.05588020367Epoch: 93 Cur Cost : 20.120977 Val Cost: 16.051615 PSNR :36.07561620586Epoch: 94 Cur Cost : 20.073071 Val Cost: 15.968109 PSNR :36.09826920805Epoch: 95 Cur Cost : 20.027214 Val Cost: 15.915732 PSNR :36.11253721024Epoch: 96 Cur Cost : 19.984432 Val Cost: 15.839420 PSNR :36.13341121243Epoch: 97 Cur Cost : 19.940655 Val Cost: 15.775101 PSNR :36.15108221462Epoch: 98 Cur Cost : 19.899282 Val Cost: 15.708839 PSNR :36.16936321681Epoch: 99 Cur Cost : 19.859011 Val Cost: 15.640730 PSNR :36.18823321900Epoch: 100 Cur Cost : 19.821726 Val Cost: 15.585388 PSNR :36.20362722119Epoch: 101 Cur Cost : 19.786446 Val Cost: 15.522793 PSNR :36.22110522338Epoch: 102 Cur Cost : 19.748981 Val Cost: 15.472646 PSNR :36.23515822557Epoch: 103 Cur Cost : 19.715595 Val Cost: 15.406190 PSNR :36.25385122776Epoch: 104 Cur Cost : 19.679775 Val Cost: 15.351366 PSNR :36.26933322995Epoch: 105 Cur Cost : 19.644049 Val Cost: 15.300534 PSNR :36.28373823214Epoch: 106 Cur Cost : 19.611214 Val Cost: 15.245579 PSNR :36.29936423433Epoch: 107 Cur Cost : 19.583275 Val Cost: 15.178621 PSNR :36.31848023652Epoch: 108 Cur Cost : 19.554924 Val Cost: 15.142241 PSNR :36.328902

23871Epoch: 109 Cur Cost : 19.527620 Val Cost: 15.100776 PSNR :36.34081124090Epoch: 110 Cur Cost : 19.499495 Val Cost: 15.060397 PSNR :36.35243924309Epoch: 111 Cur Cost : 19.472237 Val Cost: 15.014300 PSNR :36.36575324528Epoch: 112 Cur Cost : 19.443647 Val Cost: 14.971251 PSNR :36.37822324747Epoch: 113 Cur Cost : 19.417927 Val Cost: 14.922546 PSNR :36.39237424966Epoch: 114 Cur Cost : 19.388918 Val Cost: 14.871901 PSNR :36.40713925185Epoch: 115 Cur Cost : 19.362228 Val Cost: 14.819825 PSNR :36.42237325404Epoch: 116 Cur Cost : 19.338102 Val Cost: 14.777211 PSNR :36.43487925623Epoch: 117 Cur Cost : 19.306536 Val Cost: 14.740740 PSNR :36.44561125842Epoch: 118 Cur Cost : 19.279057 Val Cost: 14.691047 PSNR :36.46027626061Epoch: 119 Cur Cost : 19.254765 Val Cost: 14.637246 PSNR :36.47621026280Epoch: 120 Cur Cost : 19.223206 Val Cost: 14.604916 PSNR :36.48581326499Epoch: 121 Cur Cost : 19.197382 Val Cost: 14.570190 PSNR :36.49615126718Epoch: 122 Cur Cost : 19.166422 Val Cost: 14.527482 PSNR :36.50890026937Epoch: 123 Cur Cost : 19.141092 Val Cost: 14.476220 PSNR :36.52425227156Epoch: 124 Cur Cost : 19.113115 Val Cost: 14.421227 PSNR :36.54078227375Epoch: 125 Cur Cost : 19.084183 Val Cost: 14.397281 PSNR :36.54799927594Epoch: 126 Cur Cost : 19.063801 Val Cost: 14.360329 PSNR :36.55916027813Epoch: 127 Cur Cost : 19.037209 Val Cost: 14.328735 PSNR :36.56872528032Epoch: 128 Cur Cost : 19.020245 Val Cost: 14.305547 PSNR :36.57575928251Epoch: 129 Cur Cost : 18.989275 Val Cost: 14.252780 PSNR :36.59180828470Epoch: 130 Cur Cost : 18.959911 Val Cost: 14.243009 PSNR :36.59478628689Epoch: 131 Cur Cost : 18.954180 Val Cost: 14.218270 PSNR :36.60233628908Epoch: 132 Cur Cost : 18.909533 Val Cost: 14.155800 PSNR :36.62145929127Epoch: 133 Cur Cost : 18.883560 Val Cost: 14.137563 PSNR :36.62705829346Epoch: 134 Cur Cost : 18.869696 Val Cost: 14.116048 PSNR :36.63367229565Epoch: 135 Cur Cost : 18.832462 Val Cost: 14.054341 PSNR :36.65269929784Epoch: 136 Cur Cost : 18.814505 Val Cost: 14.029723 PSNR :36.66031330003Epoch: 137 Cur Cost : 18.805397 Val Cost: 14.000203 PSNR :36.66946030222Epoch: 138 Cur Cost : 18.781073 Val Cost: 13.994488 PSNR :36.67123430441Epoch: 139 Cur Cost : 18.746599 Val Cost: 13.965572 PSNR :36.68021630660Epoch: 140 Cur Cost : 18.732073 Val Cost: 13.914907 PSNR :36.69600030879Epoch: 141 Cur Cost : 18.704937 Val Cost: 13.902081 PSNR :36.70000531098Epoch: 142 Cur Cost : 18.685591 Val Cost: 13.889931 PSNR :36.70380331317Epoch: 143 Cur Cost : 18.670761 Val Cost: 13.844592 PSNR :36.71800231536Epoch: 144 Cur Cost : 18.646767 Val Cost: 13.832062 PSNR :36.72193431755Epoch: 145 Cur Cost : 18.633806 Val Cost: 13.826068 PSNR :36.72381731974Epoch: 146 Cur Cost : 18.607412 Val Cost: 13.794584 PSNR :36.73371732193Epoch: 147 Cur Cost : 18.596941 Val Cost: 13.769987 PSNR :36.74146832412Epoch: 148 Cur Cost : 18.576069 Val Cost: 13.756820 PSNR :36.74562332631Epoch: 149 Cur Cost : 18.549088 Val Cost: 13.727606 PSNR :36.754856

In[ 13 ]

plt.plot(model.backprops_cnts,model.psnr)

In[ 14 ]

model.generate_reconstruct_img('butterfly_GT.bmp')

In[ 15 ]

model.generate_reconstruct_img('baby_GT.bmp')

3通道彩色模型的训练结果

需要重启kernel

In[ 16 ]

# 先读取单通道训练结果,作为pre_train结果def get_w(img_name): place = fluid.CPUPlace()exe = fluid.Executor(place)inference_scope = fluid.core.Scope() img_test = cv2.imread('work/dataset/set5/%s' % img_name)yuv_test = cv2.cvtColor(img_test, cv2.COLOR_BGR2YCrCb) img_h, img_w, img_c = img_test.shape# 图像模糊+cubic插值img_blur = cv2.GaussianBlur(yuv_test.copy(), (5, 5), 0)img_subsample = cv2.resize(img_blur, (img_w/3, img_h/3)) #这里注意cv2.resize里面的shape是w,h的顺序 img_cubic = cv2.resize(img_blur, (img_w, img_h), interpolation=cv2.INTER_CUBIC)img_y, img_u, img_v = cv2.split(img_cubic) img_input = np.reshape(img_y, [1,1,img_h, img_w]).astype("float32") # 把y通道作为输入with fluid.scope_guard(inference_scope):[inference_program, feed_target_names, fetch_targets] = (fluid.io.load_inference_model('work/model/', exe))results = exe.run(inference_program,feed={feed_target_names[0]: img_input},fetch_list=fetch_targets)[0]with fluid.program_guard(inference_program):conv1_w_v = fluid.fetch_var('conv1_w')conv1_b_v = fluid.fetch_var('conv1_b')conv2_w_v = fluid.fetch_var('conv2_w')conv2_b_v = fluid.fetch_var('conv2_b')pred_w_v = fluid.fetch_var('pred_w')pred_b_v = fluid.fetch_var('pred_b')return conv1_w_v, conv1_b_v, conv2_w_v, conv2_b_v, pred_w_v, pred_b_vconv1_w_v, conv1_b_v, conv2_w_v, conv2_b_v, pred_w_v, pred_b_v = get_w('butterfly_GT.bmp')conv1_w_v.dump('work/model/conv1_w_v')conv1_b_v.dump('work/model/conv1_b_v')conv2_w_v.dump('work/model/conv2_w_v')conv2_b_v.dump('work/model/conv2_b_v')pred_w_v.dump('work/model/pred_w_v')pred_b_v.dump('work/model/pred_b_v')

In[ 17 ]

class SRCNN_3dim(object): def __init__(self, lr, lr_f, batch_size, iter_num):self.lr = lr # 学习率self.lr_f = lr_f # 最后一层学习率self.batch_size = batch_size self.iter_num = iter_num # 总共训练多少次def net(self, X, Y): # 搭建模型 conv1 = fluid.layers.conv2d(X, 64, 9,act='relu', name='conv1' , param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001),name='conv1_w'),bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.),name='conv1_b'))conv2 = fluid.layers.conv2d(conv1, 32, 1, act='relu', name='conv2' , param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001),name='conv2_w'),bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.),name='conv2_b'))pred = fluid.layers.conv2d(conv2, 3, 5, name='pred', param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001),name='pred_w'),bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.),name='pred_b')) loss = fluid.layers.reduce_mean(fluid.layers.square(pred - Y)) return pred, lossdef train(self):# 模型训练X_train = fluid.layers.data(shape=[3, 33, 33], dtype='float32', name='image')Y_train = fluid.layers.data(shape=[3, 21, 21], dtype='float32', name='gdt')y_predict, y_loss = (X_train, Y_train)Optimizer = fluid.optimizer.AdamOptimizer(learning_rate=self.lr)Optimizer_f = fluid.optimizer.AdamOptimizer(learning_rate=self.lr_f)Optimizer.minimize(y_loss, parameter_list=['conv1_w','conv1_b', 'conv2_w', 'conv2_b'])Optimizer_f.minimize(y_loss, parameter_list=['pred_w', 'pred_b'])# 读取训练集数据train_reader = paddle.batch(self.read_data('work/dataset/timofte'), batch_size=self.batch_size) # 定义执行器place = fluid.CPUPlace()exe = fluid.Executor(place)exe.run(fluid.default_startup_program())def train_loop(main_program):feeder = fluid.DataFeeder(place=place, feed_list=[X_train, Y_train])exe.run(fluid.default_startup_program())# 用预训练的结果赋值conv1_w = fluid.global_scope().find_var('conv1_w').get_tensor()conv1_b = fluid.global_scope().find_var('conv1_b').get_tensor()conv2_w = fluid.global_scope().find_var('conv2_w').get_tensor()conv2_b = fluid.global_scope().find_var('conv2_b').get_tensor()pred_w = fluid.global_scope().find_var('pred_w').get_tensor()pred_b = fluid.global_scope().find_var('pred_b').get_tensor()conv1_w.set(np.tile(conv1_w_v, (1,3,1,1)),place) conv1_b.set(conv1_b_v,place) conv2_w.set(conv2_w_v,place) conv2_b.set(conv2_b_v,place) pred_w.set(np.tile(pred_w_v,(3,1,1,1)),place)pred_b.set(np.tile(pred_b_v,(3)),place)backprops_cnt = 0 # 论文中作图的横坐标self.backprops_cnts = [] # 绑定为类的一个属性,用于画图 self.psnr = []# psnr的值for epoch in range(self.iter_num):for batch_id, data in enumerate(train_reader()): loss = exe.run(fluid.framework.default_main_program(),feed=feeder.feed(data),fetch_list=[y_loss])if batch_id == 0: # 每个epoch算一下psnr,画图用的## 算psnr要在测试集上面fluid.io.save_inference_model('work/model_3d/', ['image'], [y_predict], exe)val_loss, val_psnr = self.validation()self.backprops_cnts.append(backprops_cnt * self.batch_size)self.psnr.append(val_psnr)print("%i\tEpoch: %d \tCur Cost : %f\t Val Cost: %f\t PSNR :%f" % (backprops_cnt, epoch, np.array(loss[0])[0], val_loss, val_psnr))backprops_cnt += 1 fluid.io.save_inference_model('work/model_3d/', ['image'], [y_predict], exe)train_loop(fluid.default_main_program())def validation(self):place = fluid.CPUPlace()exe = fluid.Executor(place)inference_scope = fluid.core.Scope() test_set = 'work/dataset/set5/'scale_factor = 3for img_name in os.listdir(test_set): img_val = cv2.imread(os.path.join(test_set, img_name))img_h, img_w, _ = img_val.shapeimg_blur = cv2.GaussianBlur(img_val, (5, 5), 0)img_subsample = cv2.resize(img_blur, (img_w/scale_factor, img_h/scale_factor))img_input = cv2.resize(img_blur, (img_w, img_h), interpolation=cv2.INTER_CUBIC)img_input = np.swapaxes(img_input, 1, 2) # HWC->CHWimg_input = np.swapaxes(img_input, 0, 1)img_input = np.reshape(img_input, [1,3, img_h, img_w]).astype("float32") # h,w losses = []with fluid.scope_guard(inference_scope):[inference_program, feed_target_names, fetch_targets] = (fluid.io.load_inference_model('work/model_3d/', exe)) results = exe.run(inference_program,feed={feed_target_names[0]: img_input},fetch_list=fetch_targets)[0]img_val = np.swapaxes(img_val, 1, 2) # HWC->CHWimg_val = np.swapaxes(img_val, 0, 1)loss = np.mean(np.square(results[0]-img_val[:,6:-6, 6:-6]))losses.append(loss) avg_loss = np.sum(np.array(losses))/len(losses)psnr = 10 * np.log10(255*255/avg_loss)return avg_loss,psnrdef generate_reconstruct_img(self, img_name): place = fluid.CPUPlace()exe = fluid.Executor(place)inference_scope = fluid.core.Scope() img_test = cv2.imread('work/dataset/set5/%s' % img_name)print "=====原始图片========="b,g,r = cv2.split(img_test) # AI studio 不支持cv2.imshow,所以用plt.imshow输出,两者rgb顺序不一样img_show = cv2.merge([r,g,b])plt.imshow(img_show)plt.show() # 图像模糊+cubic插值img_h, img_w, _ = img_test.shapeimg_blur = cv2.GaussianBlur(img_test, (5, 5), 0)img_subsample = cv2.resize(img_blur, (img_w/3, img_h/3)) #这里注意cv2.resize里面的shape是w,h的顺序 img_cubic = cv2.resize(img_blur, (img_w, img_h), interpolation=cv2.INTER_CUBIC) print "=====输入图片========="b,g,r = cv2.split(img_cubic) # AI studio 不支持cv2.imshow,所以用plt.imshow输出,两者rgb顺序不一样img_show = cv2.merge([r,g,b])plt.imshow(img_show)plt.show() img_cubic = np.swapaxes(img_cubic, 1, 2) # HWC->CHWimg_cubic = np.swapaxes(img_cubic, 0, 1)img_input = np.reshape(img_cubic, [1,3,img_h, img_w]).astype("float32") # 把RGB3通道作为输入with fluid.scope_guard(inference_scope):[inference_program, feed_target_names, fetch_targets] = (fluid.io.load_inference_model('work/model_3d/', exe)) results = exe.run(inference_program,feed={feed_target_names[0]: img_input},fetch_list=fetch_targets)[0] result_img = np.array(results) result_img[result_img < 0] = 0result_img[result_img >255] = 255gap_y = (img_test.shape[0]-result_img.shape[2])/2gap_x = (img_test.shape[1]-result_img.shape[3])/2 result = np.swapaxes(result_img[0].copy(), 0,1) # CHW_>HWCresult = np.swapaxes(result, 1,2)img_test[gap_y: gap_y + result_img.shape[2],gap_x: gap_x + result_img.shape[3]]=resultprint "=====彩图结果========="b,g,r = cv2.split(img_test) img_test_show = cv2.merge([r,g,b])plt.imshow(img_test_show)plt.show()def read_data(self, data_path):def data_reader():for image in os.listdir(data_path):if image.endswith('.bmp'):img = cv2.imread(os.path.join(data_path, image))# 下面是切图的步骤j = 0count = 0while j+33 < len(img):i = 0while i+33 < len(img[0]):img_patch = img[j:j+33, i:i+33, :]img_gth = img_patch[6:27, 6:27].copy()img_blur = cv2.GaussianBlur(img_patch, (5, 5), 0)img_sumsample = cv2.resize(img_blur, (11, 11))img_input = cv2.resize(img_blur, (33, 33), interpolation=cv2.INTER_CUBIC)img_input = np.swapaxes(img_input, 1, 2) # HWC->CHWimg_input = np.swapaxes(img_input, 0, 1) img_gth = np.swapaxes(img_gth, 1, 2) # HWC->CHWimg_gth = np.swapaxes(img_gth, 0, 1) yield img_input, img_gthi+=14j+= 14return data_reader

In[ 18 ]

conv1_w_v=np.load('work/model/conv1_w_v')conv1_b_v=np.load('work/model/conv1_b_v')conv2_w_v=np.load('work/model/conv2_w_v')conv2_b_v=np.load('work/model/conv2_b_v')pred_w_v=np.load('work/model/pred_w_v')pred_b_v=np.load('work/model/pred_b_v')

In[ 19 ]

model = SRCNN_3dim(0.0001, 0.00001, 100, 200)model.train()

In[ 20 ]

plt.plot(model.backprops_cnts,model.psnr)

In[ 21 ]

model.generate_reconstruct_img('butterfly_GT.bmp')

In[ 22 ]

model.generate_reconstruct_img('baby_GT.bmp')

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。