1000字范文,内容丰富有趣,学习的好帮手!
1000字范文 > [深度学习]生成对抗网络的实践例子

[深度学习]生成对抗网络的实践例子

时间:2023-12-12 08:13:58

相关推荐

[深度学习]生成对抗网络的实践例子

系列文章目录

深度学习GAN(一)之简单介绍

深度学习GAN(二)之DCGAN基于CIFAR10数据集的例子

深度学习GAN(三)之DCGAN基于手写体Mnist数据集的例子

深度学习GAN(四)之cGAN (Conditional GAN)的例子

深度学习GAN(五)之PIX2PIX GAN的例子

深度学习GAN(六)之CycleGAN的例子


Pix2Pix GAN的例子

系列文章目录1. Pix2Pix介绍2. 下载卫星地图数据集3. 数据预处理(Data Reprocessing)4. 定义判别器5. 定义生成器6. 定义GAN模型7. 加载真实图片以及生成假的图片8. 用生成器每个几个Epoch生成一些假的图片。看看效果10. 训练过程11. 训练后效果12.完整的代码

1. Pix2Pix介绍

Pix2Pix是一个对抗神经网络(GAN)模型,设计一般用于图像到图像转换。

该方法由Phillip Isola等提出。在其题为“使用条件对抗网络的图像到图像翻译”的论文中,该论文于在CVPR上发表。

GAN架构由用于输出新的合理合成图像的生成器模型和将图像分类为真实(来自数据集)或伪图像(生成)的鉴别器模型组成。鉴别器模型直接更新,而生成器模型通过鉴别器模型更新。这样,在对抗过程中同时训练两个模型,其中生成器试图更好地欺骗鉴别器,而鉴别器试图更好地识别伪造图像。

Pix2Pix模型是一种条件GAN或cGAN,其中输出图像的生成取决于输入(在这种情况下为源图像)。鉴别器既提供源图像又提供目标图像,并且必须确定目标是否是源图像的合理变换。

通过对抗损失训练生成器,这鼓励了生成器在目标域中生成合理的图像。还通过在生成的图像和预期的输出图像之间测量的L1损耗来更新生成器。这种额外的损失鼓励生成器模型创建源图像的合理翻译。

Pix2Pix GAN已在一系列图像到图像转换任务中得到了证明,例如将地图转换为卫星照片,将黑白照片转换为颜色,将产品草图转换为产品照片。

现在我们已经熟悉了Pix2Pix GAN,下面我们准备一个可用于图像到图像转换的数据集。

2. 下载卫星地图数据集

这个数据集由纽约的卫星图像及其相应的Google地图组成。 图像的转换问题涉及将卫星照片转换为Google地图格式,或者将Google地图图像转换为卫星照片。

数据集在pix2pix网站上提供,可以作为255 MB的zip文件下载。

Download Maps Dataset (maps.tar.gz)

下载后解压后目录结构如下:

进入任意一个目录,打开其中一个图片,

3. 数据预处理(Data Reprocessing)

为了让图片在训练的时候加载的快一点,我们把下载的所有的图片都用Numpy保存在maps_256.npz.

from os import listdirfrom numpy import asarrayfrom numpy import vstackfrom keras.preprocessing.image import img_to_arrayfrom keras.preprocessing.image import load_imgfrom numpy import savez_compressed# load all images in a directory into memorydef load_images(path, size=(256,512)):src_list, tar_list = list(), list()# enumerate filenames in directory, assume all are imagesfor filename in listdir(path):# load and resize the imagepixels = load_img(path + filename, target_size=size)# convert to numpy arraypixels = img_to_array(pixels)# split into satellite and mapsat_img, map_img = pixels[:, :256], pixels[:, 256:]src_list.append(sat_img)tar_list.append(map_img)return [asarray(src_list), asarray(tar_list)]# dataset pathpath = 'D:/ML/datasets/maps/train/'# load dataset[src_images, tar_images] = load_images(path)print('Loaded: ', src_images.shape, tar_images.shape)# save as compressed numpy arrayfilename = 'maps_256.npz'savez_compressed(filename, src_images, tar_images)print('Saved dataset: ', filename)

结果是

Loaded: (1096, 256, 256, 3) (1096, 256, 256, 3)Saved dataset: maps_256.npz

然后运行下面代码验证一下是否正确的可以显示图片。

# load the prepared datasetfrom numpy import loadfrom matplotlib import pyplot# load the datasetdata = load('maps_256.npz')src_images, tar_images = data['arr_0'], data['arr_1']print('Loaded: ', src_images.shape, tar_images.shape)# plot source imagesn_samples = 3for i in range(n_samples):pyplot.subplot(2, n_samples, 1 + i)pyplot.axis('off')pyplot.imshow(src_images[i].astype('uint8'))# plot target imagefor i in range(n_samples):pyplot.subplot(2, n_samples, 1 + n_samples + i)pyplot.axis('off')pyplot.imshow(tar_images[i].astype('uint8'))pyplot.show()

4. 定义判别器

这个判别器是基于PatchGAN discriminator model实现的。

注意这里的输入是两个图片,in_src_image是卫星图像, in_target_image是谷歌地图。

同过Concatenate方法,合并为6个通道,每天图片是的3个通道(RGB).

激活函数用LeakyReLU, 除了第一层与最后一层,其它都用BatchNormalization.

输出层输出是(16,16,1)

# define the discriminator modeldef define_discriminator(image_shape):# weight initializationinit = RandomNormal(stddev=0.02)# source image inputin_src_image = Input(shape=image_shape)# target image inputin_target_image = Input(shape=image_shape)# concatenate images channel-wisemerged = Concatenate()([in_src_image, in_target_image])# C64d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)d = LeakyReLU(alpha=0.2)(d)# C128d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# C256d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# C512d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# second last output layerd = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# patch outputd = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)patch_out = Activation('sigmoid')(d)# define modelmodel = Model([in_src_image, in_target_image], patch_out)# compile modelopt = Adam(lr=0.0002, beta_1=0.5)pile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])return modelif __name__ == '__main__':d_model = define_discriminator((256,256,3))print(d_model.summary())

它的结构是

5. 定义生成器

生成器是使用U-Net架构的encoder-decoder模型。 该模型获取源图像(例如卫星照片)并生成目标图像(例如Google地图图像)。 它首先通过对输入图像进行下采样或编码到瓶颈层(bottleneck layer),然后对瓶颈(bottleneck layer)表示进行上采样或解码到输出图像的大小来做到这一点。 U-Net体系结构意味着在编码层和相应的解码层之间添加跳过连接(skip-connections),从而形成U形。

下图清楚地显示了跳过连接(skip-connections),显示了编码器的第一层如何连接到解码器的最后一层,依此类推。

生成器的encoder和decoder由convolutional, batch normalization, dropout, and activation layers组成。 这种标准化意味着我们可以开发辅助函数来创建每个图层块,并反复调用它以建立模型的encoder和decoder部分。

下面的define_generator()函数实现了U-Net编码器-解码器生成器模型。 它使用define_encoder_block()帮助函数创建用于编码器的层块,并使用coder_block()函数创建用于解码器的层块。 tanh激活函数在输出层中使用,这意味着生成的图像中的像素值将在[-1,1]范围内。

输入是一个文星图片,经过Encoder-Decoder这个网络结构,最后生成一个谷歌地图

(256,256,3) ->Encoder-> (1,1,512) -> Decoder -> (256,256,3)

# define an encoder blockdef define_encoder_block(layer_in, n_filters, batchnorm=True):# weight initializationinit = RandomNormal(stddev=0.02)# add downsampling layerg = Conv2D(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)# conditionally add batch normalizationif batchnorm:g = BatchNormalization()(g, training=True)# leaky relu activationg = LeakyReLU(alpha=0.2)(g)return g# define a decoder blockdef decoder_block(layer_in, skip_in, n_filters, dropout=True):# weight initializationinit = RandomNormal(stddev=0.02)# add upsampling layerg = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)# add batch normalizationg = BatchNormalization()(g, training=True)# conditionally add dropoutif dropout:g = Dropout(0.5)(g, training=True)# merge with skip connectiong = Concatenate()([g, skip_in])# relu activationg = Activation('relu')(g)return g# define the standalone generator modeldef define_generator(image_shape=(256,256,3)):# weight initializationinit = RandomNormal(stddev=0.02)# image inputin_image = Input(shape=image_shape)# encoder modele1 = define_encoder_block(in_image, 64, batchnorm=False)e2 = define_encoder_block(e1, 128)e3 = define_encoder_block(e2, 256)e4 = define_encoder_block(e3, 512)e5 = define_encoder_block(e4, 512)e6 = define_encoder_block(e5, 512)e7 = define_encoder_block(e6, 512)# bottleneck, no batch norm and relub = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)b = Activation('relu')(b)# decoder modeld1 = decoder_block(b, e7, 512)d2 = decoder_block(d1, e6, 512)d3 = decoder_block(d2, e5, 512)d4 = decoder_block(d3, e4, 512, dropout=False)d5 = decoder_block(d4, e3, 256, dropout=False)d6 = decoder_block(d5, e2, 128, dropout=False)d7 = decoder_block(d6, e1, 64, dropout=False)# outputg = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)out_image = Activation('tanh')(g)# define modelmodel = Model(in_image, out_image)return modelif __name__ == '__main__':g_model = define_generator((256,256,3))print(g_model.summary())

它的结构是

6. 定义GAN模型

GAN的模型主要是训练生成器,所以判别器不训练(d_model.trainable = False)。

输入层是卫星图片(256,256,3),

输出层是 dis_out=(16,16,1)

gen_out = (256,256,3)

# define the combined generator and discriminator model, for updating the generatordef define_gan(g_model, d_model, image_shape):# make weights in the discriminator not trainabled_model.trainable = False# define the source imagein_src = Input(shape=image_shape)# connect the source image to the generator inputgen_out = g_model(in_src)# connect the source input and generator output to the discriminator inputdis_out = d_model([in_src, gen_out])# src image as input, generated image and classification outputmodel = Model(in_src, [dis_out, gen_out])# compile modelopt = Adam(lr=0.0002, beta_1=0.5)pile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])return modelif __name__ == '__main__':d_model = define_discriminator((256,256,3))g_model = define_generator((256,256,3))gan_model = define_gan(g_model, d_model, (256,256,3))print(g_model.summary())

Model: "model_1"__________________________________________________________________________________________________Layer (type)Output Shape Param #Connected to ==================================================================================================input_3 (InputLayer) [(None, 256, 256, 3) 0 __________________________________________________________________________________________________conv2d_6 (Conv2D)(None, 128, 128, 64) 3136 input_3[0][0]__________________________________________________________________________________________________leaky_re_lu_5 (LeakyReLU) (None, 128, 128, 64) 0 conv2d_6[0][0] __________________________________________________________________________________________________conv2d_7 (Conv2D)(None, 64, 64, 128) 131200leaky_re_lu_5[0][0] __________________________________________________________________________________________________batch_normalization_4 (BatchNor (None, 64, 64, 128) 512 conv2d_7[0][0] __________________________________________________________________________________________________leaky_re_lu_6 (LeakyReLU) (None, 64, 64, 128) 0 batch_normalization_4[0][0]__________________________________________________________________________________________________conv2d_8 (Conv2D)(None, 32, 32, 256) 524544leaky_re_lu_6[0][0] __________________________________________________________________________________________________batch_normalization_5 (BatchNor (None, 32, 32, 256) 1024 conv2d_8[0][0] __________________________________________________________________________________________________leaky_re_lu_7 (LeakyReLU) (None, 32, 32, 256) 0 batch_normalization_5[0][0]__________________________________________________________________________________________________conv2d_9 (Conv2D)(None, 16, 16, 512) 2097664leaky_re_lu_7[0][0] __________________________________________________________________________________________________batch_normalization_6 (BatchNor (None, 16, 16, 512) 2048 conv2d_9[0][0] __________________________________________________________________________________________________leaky_re_lu_8 (LeakyReLU) (None, 16, 16, 512) 0 batch_normalization_6[0][0]__________________________________________________________________________________________________conv2d_10 (Conv2D) (None, 8, 8, 512) 4194816leaky_re_lu_8[0][0] __________________________________________________________________________________________________batch_normalization_7 (BatchNor (None, 8, 8, 512) 2048 conv2d_10[0][0] __________________________________________________________________________________________________leaky_re_lu_9 (LeakyReLU) (None, 8, 8, 512) 0 batch_normalization_7[0][0]__________________________________________________________________________________________________conv2d_11 (Conv2D) (None, 4, 4, 512) 4194816leaky_re_lu_9[0][0] __________________________________________________________________________________________________batch_normalization_8 (BatchNor (None, 4, 4, 512) 2048 conv2d_11[0][0] __________________________________________________________________________________________________leaky_re_lu_10 (LeakyReLU)(None, 4, 4, 512) 0 batch_normalization_8[0][0]__________________________________________________________________________________________________conv2d_12 (Conv2D) (None, 2, 2, 512) 4194816leaky_re_lu_10[0][0] __________________________________________________________________________________________________batch_normalization_9 (BatchNor (None, 2, 2, 512) 2048 conv2d_12[0][0] __________________________________________________________________________________________________leaky_re_lu_11 (LeakyReLU)(None, 2, 2, 512) 0 batch_normalization_9[0][0]__________________________________________________________________________________________________conv2d_13 (Conv2D) (None, 1, 1, 512) 4194816leaky_re_lu_11[0][0] __________________________________________________________________________________________________activation_1 (Activation) (None, 1, 1, 512) 0 conv2d_13[0][0] __________________________________________________________________________________________________conv2d_transpose (Conv2DTranspo (None, 2, 2, 512) 4194816activation_1[0][0]__________________________________________________________________________________________________batch_normalization_10 (BatchNo (None, 2, 2, 512) 2048 conv2d_transpose[0][0] __________________________________________________________________________________________________dropout (Dropout)(None, 2, 2, 512) 0 batch_normalization_10[0][0]__________________________________________________________________________________________________concatenate_1 (Concatenate)(None, 2, 2, 1024) 0 dropout[0][0]leaky_re_lu_11[0][0] __________________________________________________________________________________________________activation_2 (Activation) (None, 2, 2, 1024) 0 concatenate_1[0][0] __________________________________________________________________________________________________conv2d_transpose_1 (Conv2DTrans (None, 4, 4, 512) 8389120activation_2[0][0]__________________________________________________________________________________________________batch_normalization_11 (BatchNo (None, 4, 4, 512) 2048 conv2d_transpose_1[0][0] __________________________________________________________________________________________________dropout_1 (Dropout) (None, 4, 4, 512) 0 batch_normalization_11[0][0]__________________________________________________________________________________________________concatenate_2 (Concatenate)(None, 4, 4, 1024) 0 dropout_1[0][0] leaky_re_lu_10[0][0] __________________________________________________________________________________________________activation_3 (Activation) (None, 4, 4, 1024) 0 concatenate_2[0][0] __________________________________________________________________________________________________conv2d_transpose_2 (Conv2DTrans (None, 8, 8, 512) 8389120activation_3[0][0]__________________________________________________________________________________________________batch_normalization_12 (BatchNo (None, 8, 8, 512) 2048 conv2d_transpose_2[0][0] __________________________________________________________________________________________________dropout_2 (Dropout) (None, 8, 8, 512) 0 batch_normalization_12[0][0]__________________________________________________________________________________________________concatenate_3 (Concatenate)(None, 8, 8, 1024) 0 dropout_2[0][0] leaky_re_lu_9[0][0] __________________________________________________________________________________________________activation_4 (Activation) (None, 8, 8, 1024) 0 concatenate_3[0][0] __________________________________________________________________________________________________conv2d_transpose_3 (Conv2DTrans (None, 16, 16, 512) 8389120activation_4[0][0]__________________________________________________________________________________________________batch_normalization_13 (BatchNo (None, 16, 16, 512) 2048 conv2d_transpose_3[0][0] __________________________________________________________________________________________________concatenate_4 (Concatenate)(None, 16, 16, 1024) 0 batch_normalization_13[0][0]leaky_re_lu_8[0][0] __________________________________________________________________________________________________activation_5 (Activation) (None, 16, 16, 1024) 0 concatenate_4[0][0] __________________________________________________________________________________________________conv2d_transpose_4 (Conv2DTrans (None, 32, 32, 256) 4194560activation_5[0][0]__________________________________________________________________________________________________batch_normalization_14 (BatchNo (None, 32, 32, 256) 1024 conv2d_transpose_4[0][0] __________________________________________________________________________________________________concatenate_5 (Concatenate)(None, 32, 32, 512) 0 batch_normalization_14[0][0]leaky_re_lu_7[0][0] __________________________________________________________________________________________________activation_6 (Activation) (None, 32, 32, 512) 0 concatenate_5[0][0] __________________________________________________________________________________________________conv2d_transpose_5 (Conv2DTrans (None, 64, 64, 128) 1048704activation_6[0][0]__________________________________________________________________________________________________batch_normalization_15 (BatchNo (None, 64, 64, 128) 512 conv2d_transpose_5[0][0] __________________________________________________________________________________________________concatenate_6 (Concatenate)(None, 64, 64, 256) 0 batch_normalization_15[0][0]leaky_re_lu_6[0][0] __________________________________________________________________________________________________activation_7 (Activation) (None, 64, 64, 256) 0 concatenate_6[0][0] __________________________________________________________________________________________________conv2d_transpose_6 (Conv2DTrans (None, 128, 128, 64) 262208activation_7[0][0]__________________________________________________________________________________________________batch_normalization_16 (BatchNo (None, 128, 128, 64) 256 conv2d_transpose_6[0][0] __________________________________________________________________________________________________concatenate_7 (Concatenate)(None, 128, 128, 128 0 batch_normalization_16[0][0]leaky_re_lu_5[0][0] __________________________________________________________________________________________________activation_8 (Activation) (None, 128, 128, 128 0 concatenate_7[0][0] __________________________________________________________________________________________________conv2d_transpose_7 (Conv2DTrans (None, 256, 256, 3) 6147 activation_8[0][0]__________________________________________________________________________________________________activation_9 (Activation) (None, 256, 256, 3) 0 conv2d_transpose_7[0][0] ==================================================================================================Total params: 54,429,315Trainable params: 54,419,459Non-trainable params: 9,856

7. 加载真实图片以及生成假的图片

load_real_samples方法是加载真实图片。

generate_real_samples 方法是生成真实图片。每个数组标签都是1, shape是(16,16,1)

generate_fake_samples方法是生成假的图片。每个数组标签都是0,shape是(16,16,1)

标签这里不一样,一般是数字,但是这里是shape为(16,16,1)三维数组。

# load and prepare training imagesdef load_real_samples(filename):# load compressed arraysdata = load(filename)# unpack arraysX1, X2 = data['arr_0'], data['arr_1']# scale from [0,255] to [-1,1]X1 = (X1 - 127.5) / 127.5X2 = (X2 - 127.5) / 127.5return [X1, X2]# select a batch of random samples, returns images and targetdef generate_real_samples(dataset, n_samples, patch_shape):# unpack datasettrainA, trainB = dataset# choose random instancesix = randint(0, trainA.shape[0], n_samples)# retrieve selected imagesX1, X2 = trainA[ix], trainB[ix]# generate 'real' class labels (1)y = ones((n_samples, patch_shape, patch_shape, 1))return [X1, X2], y# generate a batch of images, returns images and targetsdef generate_fake_samples(g_model, samples, patch_shape):# generate fake instanceX = g_model.predict(samples)# create 'fake' class labels (0)y = zeros((len(X), patch_shape, patch_shape, 1))return X, y

8. 用生成器每个几个Epoch生成一些假的图片。看看效果

# generate samples and save as a plot and save the modeldef summarize_performance(step, g_model, dataset, n_samples=3):# select a sample of input images[X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)# generate a batch of fake samplesX_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)# scale all pixels from [-1,1] to [0,1]X_realA = (X_realA + 1) / 2.0X_realB = (X_realB + 1) / 2.0X_fakeB = (X_fakeB + 1) / 2.0# plot real source imagesfor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + i)pyplot.axis('off')pyplot.imshow(X_realA[i])# plot generated target imagefor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + n_samples + i)pyplot.axis('off')pyplot.imshow(X_fakeB[i])# plot real target imagefor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + n_samples*2 + i)pyplot.axis('off')pyplot.imshow(X_realB[i])# save plot to filefilename1 = 'pix2pix_plot_%06d.png' % (step+1)pyplot.savefig(filename1)pyplot.close()# save the generator modelfilename2 = 'pix2pix_model_%06d.h5' % (step+1)g_model.save(filename2)print('>Saved: %s and %s' % (filename1, filename2))

10. 训练过程

# train pix2pix modelsdef train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=1):# determine the output square shape of the discriminatorn_patch = d_model.output_shape[1]# unpack datasettrainA, trainB = dataset# calculate the number of batches per training epochbat_per_epo = int(len(trainA) / n_batch)# calculate the number of training iterationsn_steps = bat_per_epo * n_epochs# manually enumerate epochsfor i in range(n_steps):# select a batch of real samples[X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)# generate a batch of fake samplesX_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)# update discriminator for real samplesd_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)# update discriminator for generated samplesd_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)# update the generatorg_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])# summarize performanceprint('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, d_loss1, d_loss2, g_loss))# summarize model performanceif (i+1) % (bat_per_epo * 10) == 0:summarize_performance(i, g_model, dataset)

11. 训练后效果

在前10个时间段之后,尽管街道的线条并不完全笔直且图像中有些模糊,但仍会生成看起来合理的地图图像。 但是,大型结构在正确的位置带有大多数正确的颜色。

经过约50个训练时期后生成的图像开始看起来非常逼真,至少意味着,并且在其余训练过程中质量似乎仍然保持良好。

请注意下面第一个生成的图像示例(右列,中间行),该示例包含比真实Google地图图像更有用的细节。

12.完整的代码

# example of pix2pix gan for satellite to map image-to-image translationimport tensorflow as tfimport tensorflow.keras as kerasimport numpy as npimport matplotlib.pyplot as pltfrom numpy import loadfrom numpy import zerosfrom numpy import onesfrom numpy.random import randintfrom tensorflow.keras.optimizers import Adamfrom tensorflow.keras.initializers import RandomNormalfrom tensorflow.keras.models import Modelfrom tensorflow.keras.layers import Inputfrom tensorflow.keras.layers import Conv2Dfrom tensorflow.keras.layers import Conv2DTransposefrom tensorflow.keras.layers import Activationfrom tensorflow.keras.layers import Concatenatefrom tensorflow.keras.layers import Dropoutfrom tensorflow.keras.layers import BatchNormalizationfrom tensorflow.keras.layers import LeakyReLUfrom matplotlib import pyplot# define the discriminator modeldef define_discriminator(image_shape):# weight initializationinit = RandomNormal(stddev=0.02)# source image inputin_src_image = Input(shape=image_shape)# target image inputin_target_image = Input(shape=image_shape)# concatenate images channel-wisemerged = Concatenate()([in_src_image, in_target_image])# C64d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)d = LeakyReLU(alpha=0.2)(d)# C128d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# C256d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# C512d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# second last output layerd = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# patch outputd = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)patch_out = Activation('sigmoid')(d)# define modelmodel = Model([in_src_image, in_target_image], patch_out)# compile modelopt = Adam(lr=0.0002, beta_1=0.5)pile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])return model# define an encoder blockdef define_encoder_block(layer_in, n_filters, batchnorm=True):# weight initializationinit = RandomNormal(stddev=0.02)# add downsampling layerg = Conv2D(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)# conditionally add batch normalizationif batchnorm:g = BatchNormalization()(g, training=True)# leaky relu activationg = LeakyReLU(alpha=0.2)(g)return g# define a decoder blockdef decoder_block(layer_in, skip_in, n_filters, dropout=True):# weight initializationinit = RandomNormal(stddev=0.02)# add upsampling layerg = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)# add batch normalizationg = BatchNormalization()(g, training=True)# conditionally add dropoutif dropout:g = Dropout(0.5)(g, training=True)# merge with skip connectiong = Concatenate()([g, skip_in])# relu activationg = Activation('relu')(g)return g# define the standalone generator modeldef define_generator(image_shape=(256,256,3)):# weight initializationinit = RandomNormal(stddev=0.02)# image inputin_image = Input(shape=image_shape)# encoder modele1 = define_encoder_block(in_image, 64, batchnorm=False)e2 = define_encoder_block(e1, 128)e3 = define_encoder_block(e2, 256)e4 = define_encoder_block(e3, 512)e5 = define_encoder_block(e4, 512)e6 = define_encoder_block(e5, 512)e7 = define_encoder_block(e6, 512)# bottleneck, no batch norm and relub = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)b = Activation('relu')(b)# decoder modeld1 = decoder_block(b, e7, 512)d2 = decoder_block(d1, e6, 512)d3 = decoder_block(d2, e5, 512)d4 = decoder_block(d3, e4, 512, dropout=False)d5 = decoder_block(d4, e3, 256, dropout=False)d6 = decoder_block(d5, e2, 128, dropout=False)d7 = decoder_block(d6, e1, 64, dropout=False)# outputg = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)out_image = Activation('tanh')(g)# define modelmodel = Model(in_image, out_image)return model# define the combined generator and discriminator model, for updating the generatordef define_gan(g_model, d_model, image_shape):# make weights in the discriminator not trainabled_model.trainable = False# define the source imagein_src = Input(shape=image_shape)# connect the source image to the generator inputgen_out = g_model(in_src)# connect the source input and generator output to the discriminator inputdis_out = d_model([in_src, gen_out])print(dis_out)# src image as input, generated image and classification outputmodel = Model(in_src, [dis_out, gen_out])# compile modelopt = Adam(lr=0.0002, beta_1=0.5)pile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])return model# load and prepare training imagesdef load_real_samples(filename):# load compressed arraysdata = load(filename)# unpack arraysX1, X2 = data['arr_0'], data['arr_1']# scale from [0,255] to [-1,1]X1 = (X1 - 127.5) / 127.5X2 = (X2 - 127.5) / 127.5return [X1, X2]# select a batch of random samples, returns images and targetdef generate_real_samples(dataset, n_samples, patch_shape):# unpack datasettrainA, trainB = dataset# choose random instancesix = randint(0, trainA.shape[0], n_samples)# retrieve selected imagesX1, X2 = trainA[ix], trainB[ix]# generate 'real' class labels (1)y = ones((n_samples, patch_shape, patch_shape, 1))return [X1, X2], y# generate a batch of images, returns images and targetsdef generate_fake_samples(g_model, samples, patch_shape):# generate fake instanceX = g_model.predict(samples)# create 'fake' class labels (0)y = zeros((len(X), patch_shape, patch_shape, 1))return X, y# generate samples and save as a plot and save the modeldef summarize_performance(step, g_model, dataset, n_samples=3):# select a sample of input images[X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)# generate a batch of fake samplesX_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)# scale all pixels from [-1,1] to [0,1]X_realA = (X_realA + 1) / 2.0X_realB = (X_realB + 1) / 2.0X_fakeB = (X_fakeB + 1) / 2.0# plot real source imagesfor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + i)pyplot.axis('off')pyplot.imshow(X_realA[i])# plot generated target imagefor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + n_samples + i)pyplot.axis('off')pyplot.imshow(X_fakeB[i])# plot real target imagefor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + n_samples*2 + i)pyplot.axis('off')pyplot.imshow(X_realB[i])# save plot to filefilename1 = 'pix2pix_plot_%06d.png' % (step+1)pyplot.savefig(filename1)pyplot.close()# save the generator modelfilename2 = 'pix2pix_model_%06d.h5' % (step+1)g_model.save(filename2)print('>Saved: %s and %s' % (filename1, filename2))# train pix2pix modelsdef train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=1):# determine the output square shape of the discriminatorn_patch = d_model.output_shape[1]# unpack datasettrainA, trainB = dataset# calculate the number of batches per training epochbat_per_epo = int(len(trainA) / n_batch)# calculate the number of training iterationsn_steps = bat_per_epo * n_epochs# manually enumerate epochsfor i in range(n_steps):# select a batch of real samples[X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)# generate a batch of fake samplesX_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)# update discriminator for real samplesd_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)# update discriminator for generated samplesd_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)# update the generatorg_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])# summarize performanceprint('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, d_loss1, d_loss2, g_loss))# summarize model performanceif (i+1) % (bat_per_epo * 10) == 0:summarize_performance(i, g_model, dataset)def start_train():# load image datadataset = load_real_samples('maps_256.npz')print('Loaded', dataset[0].shape, dataset[1].shape)# define input shape based on the loaded datasetimage_shape = dataset[0].shape[1:]# define the modelsd_model = define_discriminator(image_shape)print(image_shape)print(d_model.summary())g_model = define_generator(image_shape)# define the composite modelgan_model = define_gan(g_model, d_model, image_shape)# train modeltrain(d_model, g_model, gan_model, dataset)if __name__ == '__main__':#d_model = define_discriminator((256,256,3))#print(d_model.summary())#g_model = define_generator((256,256,3))#print(g_model.summary())#gan_model = define_gan(g_model, d_model, (256,256,3))#print(g_model.summary())start_train()

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。