CNTK API文档翻译(21)——深度卷积GAN处理MSIST数据基础
来源:互联网 发布:荣天视 淘宝 编辑:程序博客网 时间:2024/06/01 17:38
完成本期教程需要完成本系列的第四篇教程。
介绍
生成模型在深度学习的半监督或者非监督学习领域引起了广泛的专注,这些领域传统上都是使用判别模型的。
概览
在上一个教程中我们介绍了Goodfellow等人在NIPS2014上提出来的原生GAN网络。这个开创新的网络现在已经被很好的扩展,并发表了很多技术。其中深度卷积生成对抗网络(Deep Convolutional Generative Adversarial Network,DCGAN)是一项社区的主流推荐项目。
在本教程中,我们实现一个已经被前人验证过的DCGAN结构,来提高GAN训练的稳定性:
- 我们在判别器中使用strided(不知道怎么翻译)卷积,在生成器中使用fractional-strided卷积。
- 我们在生成器和判别器中使用Batch Normalization
- 我们移除全连接隐藏层,替换成深度结构
- 我们在生成器中的每层使用ReLU激活函数,只有最后的输出层使用Tanh。
- 我们在判别器中的每层都使用LeakyReLU激活函数。
import matplotlib as mplimport matplotlib.pyplot as pltimport numpy as npimport osimport cntk as Cimport cntk.tests.test_utils# (only needed for our build system)cntk.tests.test_utils.set_device_from_pytest_env()# fix a random seed for CNTK componentsC.cntk_py.set_fixed_random_seed(1)
我们设定了两种运行模式:
- 快速模式:isFast变量设置成True。这是我们的默认模式,在这个模式下我们会训练更少的次数,也会使用更少的数据,这个模式保证功能的正确性,但训练的结果还远远达不到可用的要求。
- 慢速模式:我们建议学习者在学习的时候试试将isFast变量设置成False,这会让学习者更加了解本教程的内容。
-
注意如果isFast被设为False,在有GPU的机器上代码将运行几个小时。你可以试试通过吧num_minibatches设置成一个较小的数字比如20000,减少循环次数,不过带来的代价就是生成图像质量的降低。
isFast = True
数据读取(见上期,本期略)
模型创建
(原文章关于GAN的介绍见上期,本期略)
模型构成
我们为我们的模型构建计算图,一个给生成器一个给判别器。首先我们我们创建一些模型结构参数。
# architectural parametersimg_h, img_w = 28, 28kernel_h, kernel_w = 5, 5 stride_h, stride_w = 2, 2# Input / Output parameter of Generator and Discriminatorg_input_dim = 100g_output_dim = d_input_dim = img_h * img_w# We expect the kernel shapes to be square in this tutorial and# the strides to be of the same length along each data dimensionif kernel_h == kernel_w: gkernel = dkernel = kernel_helse: raise ValueError('This tutorial needs square shaped kernel') if stride_h == stride_w: gstride = dstride = stride_helse: raise ValueError('This tutorial needs same stride in all dims')# Helper functionsdef bn_with_relu(x, activation=C.relu): h = C.layers.BatchNormalization(map_rank=1)(x) return C.relu(h)# We use param-relu function to use a leak=0.2 since CNTK implementation # of Leaky ReLU is fixed to 0.01def bn_with_leaky_relu(x, leak=0.2): h = C.layers.BatchNormalization(map_rank=1)(x) r = C.param_relu(C.constant((np.ones(h.shape)*leak).astype(np.float32)), h) return r
生成器
生成器输入100维随机向量($z$
)输出一个784维的向量,对应28×28合成图像($x^*$
)的扁平状态。在本教程中,除了最后一层之外,我们使用ReLU激活函数的fractionally strided卷积,我们最后一层使用Tanh激活函数以保证生成器的输出结果在[-1,1]之间。使用ReLU和Tanh激活函数是使用fractionally strided卷积的关键。
def convolutional_generator(z): with C.layers.default_options(init=C.normal(scale=0.02)): print('Generator input shape: ', z.shape) s_h2, s_w2 = img_h//2, img_w//2 #Input shape (14,14) s_h4, s_w4 = img_h//4, img_w//4 # Input shape (7,7) gfc_dim = 1024 gf_dim = 64 h0 = C.layers.Dense(gfc_dim, activation=None)(z) h0 = bn_with_relu(h0) print('h0 shape', h0.shape) h1 = C.layers.Dense([gf_dim * 2, s_h4, s_w4], activation=None)(h0) h1 = bn_with_relu(h1) print('h1 shape', h1.shape) h2 = C.layers.ConvolutionTranspose2D(gkernel, num_filters=gf_dim*2, strides=gstride, pad=True, output_shape=(s_h2, s_w2), activation=None)(h1) h2 = bn_with_relu(h2) print('h2 shape', h2.shape) h3 = C.layers.ConvolutionTranspose2D(gkernel, num_filters=1, strides=gstride, pad=True, output_shape=(img_h, img_w), activation=C.sigmoid)(h2) print('h3 shape :', h3.shape) return C.reshape(h3, img_h * img_w)
判别器
判别器输入从生成器中输出的或者来自真实MNIST图像的784维向量($x^*$
),输出输入图像是真实MNIST图像的概率。除了最后一层,我们的网络用使用ReLU激活函数strided卷积,最后一层我们使用sigmoid激活函数保证判别器的输出值在[0,1]之间。
def convolutional_discriminator(x): with C.layers.default_options(init=C.normal(scale=0.02)): dfc_dim = 1024 df_dim = 64 print('Discriminator convolution input shape', x.shape) x = C.reshape(x, (1, img_h, img_w)) h0 = C.layers.Convolution2D(dkernel, 1, strides=dstride)(x) h0 = bn_with_leaky_relu(h0, leak=0.2) print('h0 shape :', h0.shape) h1 = C.layers.Convolution2D(dkernel, df_dim, strides=dstride)(h0) h1 = bn_with_leaky_relu(h1, leak=0.2) print('h1 shape :', h1.shape) h2 = C.layers.Dense(dfc_dim, activation=None)(h1) h2 = bn_with_leaky_relu(h2, leak=0.2) print('h2 shape :', h2.shape) h3 = C.layers.Dense(1, activation=C.sigmoid)(h2) print('h3 shape :', h3.shape) return h3
我们使用的取样包数大小是128,固定学习速率0.0002。如果使用快速模式我们只训练500轮以证明其功能正确性。
注意:在慢速模式下,训练结果看起来将会好很多,不过这需要数十分钟,具体根据您的硬件条件决定。一般来说,取样包训练的越多,生成的图像越逼真。
# training configminibatch_size = 128num_minibatches = 5000 if isFast else 10000lr = 0.0002momentum = 0.5 #equivalent to beta1
构建计算图
计算图的剩下部分主要用于协调训练算法和参数更新,这由于以下原因对GAN十分困难。
- 第一,判别器必须既用于真实MNIST图像,也用于生成器函数生成的模拟图像。一种在计算图上记录上诉状态的方法是创建一个判别器函数输出的克隆副本,但是用不同的输入。在副本函数中设置method=share确保不同方式使用的判别器使用一样的参数。
- 第二,我们需要对生成器和判别器使用不同的成本函数来更新模型参数。我们可以通过parameters属性获取计算图中函数对象的参数。然而,当更新模型参数时,更新只发生在两个子网络中的一个,另一个没有改变。换句话说,当更新生成器的参数时,我们只更新了G函数的参数,没有更新D函数的参数。
训练模型
训练GAN的代码与2014年神经信息处理系统大会(NIPS)上的一篇论文(链接:https://arxiv.org/pdf/1406.2661v1.pdf)提出的算法非常接近。在实现是,我们训练D来最大化给训练样本和G中生产的样本贴正确标签的概率。换句话说,D和G在玩一个双人针对函数
这个游戏的最优点,生成器将生成非常逼真的数据,判别器预测合成图片的概率将会变成0.5。上面提到的论文中提到的算法会在下面的代码中实现。
def build_graph(noise_shape, image_shape, generator, discriminator): input_dynamic_axes = [C.Axis.default_batch_axis()] Z = C.input_variable(noise_shape, dynamic_axes=input_dynamic_axes) X_real = C.input_variable(image_shape, dynamic_axes=input_dynamic_axes) X_real_scaled = X_real / 255.0 # Create the model function for the generator and discriminator models X_fake = generator(Z) D_real = discriminator(X_real_scaled) D_fake = D_real.clone( method = 'share', substitutions = {X_real_scaled.output: X_fake.output} ) # Create loss functions and configure optimazation algorithms G_loss = 1.0 - C.log(D_fake) D_loss = -(C.log(D_real) + C.log(1.0 - D_fake)) G_learner = C.adam( parameters = X_fake.parameters, lr = C.learning_rate_schedule(lr, C.UnitType.sample), momentum = C.momentum_schedule(momentum) ) D_learner = C.adam( parameters = D_real.parameters, lr = C.learning_rate_schedule(lr, C.UnitType.sample), momentum = C.momentum_schedule(momentum) ) # Instantiate the trainers G_trainer = C.Trainer(X_fake, (G_loss, None), G_learner) D_trainer = C.Trainer(D_real, (D_loss, None), D_learner) return X_real, X_fake, Z, G_trainer, D_trainer
随着定义值函数,我们开始对GAN模型进行间接训练。训练这个模型根据硬件状况将会话费很长时间特别是如果你把isFast设为False。
def train(reader_train, generator, discriminator): X_real, X_fake, Z, G_trainer, D_trainer = \ build_graph(g_input_dim, d_input_dim, generator, discriminator) # print out loss for each model for upto 25 times print_frequency_mbsize = num_minibatches // 25 print("First row is Generator loss, second row is Discriminator loss") pp_G = C.logging.ProgressPrinter(print_frequency_mbsize) pp_D = C.logging.ProgressPrinter(print_frequency_mbsize) k = 2 input_map = {X_real: reader_train.streams.features} for train_step in range(num_minibatches): # train the discriminator model for k steps for gen_train_step in range(k): Z_data = noise_sample(minibatch_size) X_data = reader_train.next_minibatch(minibatch_size, input_map) if X_data[X_real].num_samples == Z_data.shape[0]: batch_inputs = {X_real: X_data[X_real].data, Z: Z_data} D_trainer.train_minibatch(batch_inputs) # train the generator model for a single step Z_data = noise_sample(minibatch_size) batch_inputs = {Z: Z_data} G_trainer.train_minibatch(batch_inputs) G_trainer.train_minibatch(batch_inputs) pp_G.update_with_trainer(G_trainer) pp_D.update_with_trainer(D_trainer) G_trainer_loss = G_trainer.previous_minibatch_loss_average return Z, X_fake, G_trainer_lossreader_train = create_reader(train_file, True, d_input_dim, label_dim=10)# G_input, G_output, G_trainer_loss = train(reader_train, dense_generator, dense_discriminator)G_input, G_output, G_trainer_loss = train(reader_train, convolutional_generator, convolutional_discriminator)
生成合成图片(见上期,本期略)
欢迎扫码关注我的微信公众号获取最新文章
- CNTK API文档翻译(21)——深度卷积GAN处理MSIST数据基础
- CNTK API文档翻译(20)——GAN处理MSIST数据基础
- CNTK API文档翻译(7)——对MNIST数据使用卷积神经网络
- CNTK API文档翻译(16)——增强学习基础
- CNTK API文档翻译(17)——多对多神经网络处理文本数据(1)
- CNTK API文档翻译(18)——多对多神经网络处理文本数据(2)
- CNTK API文档翻译(17)——多对多神经网络处理文本数据(1)
- CNTK API文档翻译(4)——MNIST数据加载
- CNTK API文档翻译(12)——CNTK进阶
- CNTK API文档翻译(24)——使用深度迁移学习进行图像识别
- CNTK API文档翻译(5)——对MNIST数据使用逻辑回归
- CNTK API文档翻译(6)——对MNIST数据使用多层感知机
- CNTK API文档翻译(9)——使用自编码器压缩MNIST数据
- CNTK API文档翻译(10)——使用LSTM预测时间序列数据
- CNTK API文档翻译(13)——CIFAR-10数据准备
- CNTK API文档翻译(1)——使用数列
- CNTK API文档翻译(2)——逻辑回归
- CNTK API文档翻译(3)——前馈神经网络
- Hibernate--组件
- python爬虫实践----爬取京东图片
- 查分约束系统板子
- 自动人脸识别基本原理 --基于静态图像的识别算法(一)特征脸补充知识 PCA
- ext6.2 store如何更新数据刷新
- CNTK API文档翻译(21)——深度卷积GAN处理MSIST数据基础
- java switch语句的一个例子
- shell 脚本 read 提供默认值
- Error assembling WAR: webxml attribute is required (or pre-existing WEB-INF/web.xml if executing in
- 基于EasyPR的车牌识别android实现
- 使用Eclipse的Working Set管理项目
- 事物默认自动提交AUTOCOMMIT
- python协程
- 人脸识别必读的N篇文章