[TensorFlow]生成对抗网络(GAN)介绍与实践

来源:互联网 发布:python 打开html文件 编辑:程序博客网 时间:2024/05/21 00:56

主旨

本文简要介绍了生成对抗网络(GAN)的原理,接下来通过tensorflow开发程序实现生成对抗网络(GAN),并且通过实现的GAN完成对等差数列的生成和识别。通过对设计思路和实现方案的介绍,本文可以辅助读者理解GAN的工作原理,并掌握实现方法。有了这样的基础,在面对工作中实际问题时可以将GAN纳入考虑,选择最合适的算法。

代码和运行环境

代码位置:
https://github.com/wangyaobupt/GAN

TensorFlow版本

>>> tf.version
‘1.1.0-rc2’

背景知识

Generative Adversarial Nets[1][https://arxiv.org/pdf/1406.2661v1.pdf]是Ian J. Goodfellow等在2014年提出的一种训练模型的方法,此方法通过两个网络(生成网络G和分类网络D)对抗训练,得到符合预期目标的生成模型和分类模型。

要理解GAN的原理,上述论文是最好的教材。但考虑到原文首先是英文撰写,其次包含不少数学推导,新手上手并不容易。因此笔者这里班门弄斧,基于论文简单转述GAN的设计思想要点

GAN的目标,给定一个真实样本(本文也称之为ground truth)集合,训练出两个模型,一个能够从噪声信号生成尽可能像ground truth的样本;另一个能够判断给定样本是否是ground truth。两个模型详细介绍如下

  • 生成模型:论文中称为generative model,本文称为G网络或G模型。G网络的输入是噪声信号(例如均匀分布的随机数),输出为形状与真实样本ground truth一致。G网络的训练目标是,尽可能输出与ground truth相似的样本。这里“相似”定义为:如果G网络生成的一个样本骗过了D网络,使得D网络误以为这就是真实样本,则就是相似的,G网络获得奖励;反之,获得惩罚。
  • 分类模型:论文中称为discriminative model,本文称为D网络或D模型。D网络是一个2分类器,输入为ground truth或者G网络生成的样本,输出为TRUE或FALSE:TRUE表示D网络认为当前输入样本是ground truth,FALSE表示D网络认为当前输入样本是G网络生成的“伪造”样本。D网络的训练目标是尽可能正确的区分开ground truth和G网络生成的“伪造”样本。

从上述讨论可以看出,G网络和D网络是两个目标完全相反的网络,G网络尽其所能“伪造”出像真实样本的数据,D网络尽可能区分真实与伪造数据。GAN中所谓“对抗”的概念,即来源于此。

GAN的训练过程就是G和D两个网络互相对抗的过程,对抗的结果是G网络被训练到能够生成以假乱真的样本,即G网络从噪声输入得到了尽可能与真实样本相似的输出,或者说G学会了从噪声生成ground truth的方法;D网络也可以区分ground truth与其他样本,即D学会了区分ground truth与其他数据的方法。

参考文献
1. Goodfellow I J, Pougetabadie J, Mirza M, et al. Generative adversarial nets[C]. neural information processing systems, 2014: 2672-2680.

神经网络设计和实现

问题构造

在开始设计神经网络之前,我们首先构造出预期GAN解决的问题。前述GAN论文中提出了一个从噪声学习正态分布的经典问题,读者如果在网络上搜索GAN的案例,除了图像识别,基本上只有这么一个问题和方案实现。

本文重新设计了一个与论文中不同的问题。问题描述如下

  • Ground Truth定义:[1,2,3,4,5,6,7,8,9,10]构成的等差数列,为了适当降低学习难度,此数列每个元素与噪声相加,噪声为0均值正态分布随机变量,标准差取0.1, 0.03, 0等不同数值
  • 输入噪声定义: [-1,1]之间均匀分布的随机变量。

网络结构设计

G网络:参考论文资料,我们选择多层全连接神经网络

D网络:由于要分辨的是等差数列,我们选择RNN作为D网络。

网络结构如下(下图是tensorboard生成的计算图):图中”G_net”表示G网络,”D_net”/”D_net_1”表示D网络,虽然图中D网络被分成了两份,但是其RNN参数是共享的,即图中正下方”rnn”这个单元。
这里写图片描述

代码实现

G网络定义

    # generative network    # use multi-layer percepton to generate time sequence from random noise    # input tensor must be in shape of (batch_size, self.seq_len)    def generator(self, inputTensor):        with tf.name_scope('G_net'):            gInputTensor = tf.identity(inputTensor, name='input')            # Multilayer percepton implementation            numNodesInEachLayer = 10            numLayers = 3             previous_output_tensor = gInputTensor            for layerIdx in range(numLayers):                activation,z = self.fullConnectedLayer(previous_output_tensor, numNodesInEachLayer, layerIdx)                previous_output_tensor = activation            g_logit = z            g_logit = tf.identity(g_logit, 'g_logit')            return g_logit

G网络损失函数
下面代码片段中self.d_logit_fake是D网络对G网络生成数据的判定结果。由于G网络的目标是尽可能骗过D网路,如果D网络对于G网络生成数据全部判为1(即TRUE),则损失最小,反之,损失最大。

g_loss_d = tf.reduce_mean(                tf.nn.sigmoid_cross_entropy_with_logits(                    logits=self.d_logit_fake,                    labels=tf.ones(shape=[self.batch_size_t,1])                    ),                name='g_loss_d'                )

D网络的定义
RNN+全连接输出层,无论是RNN还是全连接层都必须在对ground truth和G生成样本之间共享同一套参数

 def discriminator(self, inputTensor,reuseParam):        with tf.name_scope('D_net'):            num_units_in_LSTMCell = 10            # RNN definition            with tf.variable_scope('d_rnn'):                lstmCell = tf.contrib.rnn.BasicLSTMCell(num_units_in_LSTMCell,reuse=reuseParam)                init_state = lstmCell.zero_state(self.batch_size_t, dtype=tf.float32)                raw_output, final_state = tf.nn.dynamic_rnn(lstmCell, inputTensor, initial_state=init_state)            rnn_output_list = tf.unstack(tf.transpose(raw_output, [1, 0, 2]), name='outList')            rnn_output_tensor = rnn_output_list[-1];            # Full connected network            numberOfInputDims = inputTensor.shape[1].value            numOfNodesInLayer = 1            if not reuseParam:                self.d_w = tf.Variable(initial_value=tf.random_normal([numberOfInputDims, numOfNodesInLayer]),                        name=('dnet_w_1'))                self.d_b = tf.Variable(tf.zeros([1, numOfNodesInLayer]), name='dnet_b_1')            self.d_z = tf.matmul(rnn_output_tensor,self.d_w) + self.d_b            self.d_z = tf.identity(self.d_z, name='dnet_z_1')            d_sigmoid = tf.nn.sigmoid(self.d_z, name='dnet_a_1')            d_logit = self.d_z            d_logit = tf.identity(d_logit, 'd_net_logit')            return d_logit

D网络损失函数
D网络使用同一套参数分辨两种输入,一种是ground truth,另一种是G网络的输出。对于ground truth,训练目标为尽可能判为1,对于G网络的输出,训练目标为尽可能判为0,因此Loss函数定义如下

# For D-network, jduge ground truth to TRUE, jduge G-network output to FALSE,making loss low            d_loss_ground_truth = tf.reduce_mean(                tf.nn.sigmoid_cross_entropy_with_logits(                    logits=self.d_logit_gnd_truth,                    labels=tf.ones(shape=[self.batch_size_t,1])                    ),                name='d_loss_gnd'                )            d_loss_fake = tf.reduce_mean(                tf.nn.sigmoid_cross_entropy_with_logits(                    logits=self.d_logit_fake,                    labels=tf.zeros(shape=[self.batch_size_t,1])                    ),                name='d_loss_fake'                )            d_loss = d_loss_ground_truth + d_loss_fake

对抗训练
对抗训练中,G网络Loss值只用来调整G网络参数,D网络Loss值只用来调整D网络参数

        g_net_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G_net')        g_net_var_list = g_net_var_list +  tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='g_rnn')        self.train_g = tf.train.AdamOptimizer(self.lr_g).minimize(g_loss,var_list=g_net_var_list)        d_net_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='D_net')        d_net_var_list = d_net_var_list +  tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='d_rnn')        self.train_d = tf.train.AdamOptimizer(self.lr_d).minimize(d_loss,var_list=d_net_var_list)

训练效果

下图是训练过程中D网络对ground truth和G网络输出的分类正确率曲线
这里写图片描述
从图中可以看到3个阶段

  1. 训练开始后一秒左右:“D网络对ground truth的分类正确率”和“D网络对G网络输出的分类正确率”都快速上升到100%,即D网络经过训练可以完全正确的将真的判为真,假的判为假
  2. 训练后1-15s:D网络分类正确率保持全对
  3. 15s之后:“D网络对ground truth的分类正确率”和“D网络对G网络输出的分类正确率”出现震荡,表明在这个阶段G网络已经能够以假乱真,D网络将部分G网络输出判为真,同时也将部分ground truth判为假。

上述3个阶段就体现出对抗训练的特点,G网络和D网络互为对手,互相提高对方的训练难度,最终得到符合预期的模型。

接下来再从数据上给一个直观的认识

Ground truth: 在公差为1的等差数列上加入stddev=0.3, mean=0的正态分布噪声后,得到的一组Ground Truth数据如下

[ 1.1539436 ]
[ 2.08863655]
[ 2.78491645]
[ 3.93027817]
[ 4.75851967]
[ 5.88655699]
[ 7.10540526]
[ 7.43159023]
[ 9.19373617]
[ 10.08779359]

训练开始前G网络的数据

基本无规律,和输入噪声分布接近

[ 1.15080559]
[ 0.66351247]
[-0.39484465]
[-0.41690648]
[ 0.29061955]
[ 0.06131642]
[-2.46439648]
[-1.53692639]
[-0.30550677]
[-0.89200932]

迭代100次之后G网络的输出
出现等差数列的端倪

[ -0.53692651]
[ 0.86063552]
[ 2.47294378]
[ 5.24512053]
[ 7.7618413 ]
[ 9.57867622]
[ 9.15039253]
[ 9.86567402]
[ 10.62975025]
[ 10.24322414]

迭代500次之后G网络的输出
除了最后一个元素,前9个元素已经基本符合预期

[ 1.09549832]
[ 2.21490908]
[ 2.95311546]
[ 4.06684017]
[ 4.96308947]
[ 6.03393888]
[ 6.89026165]
[ 7.93375683]
[ 8.63552094]
[ 9.07077026]

迭代1500次之后G网络的输出
已经足以以假乱真

[ 0.07186054]
[ 1.08289695]
[ 2.55904818]
[ 4.07374573]
[ 5.14763832]
[ 6.07010031]
[ 6.79585028]
[ 8.17086124]
[ 8.81297684]
[ 10.38190079]

原创粉丝点击