GAN模型总结

来源:互联网 发布:淘宝买家福利晒图 编辑:程序博客网 时间:2024/06/06 03:40

GAN模型

一、什么是GAN模型

生成对抗网络(Generative Adversarial Network)由一个生成网络与一个判别网络组成。生成网络从潜在空间(latent space)中随机采样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别网络的输入则为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能分辨出来。而生成网络则要尽可能地欺骗判别网络。两个网络相互对抗、不断调整参数,最终目的是使判别网络无法判断生成网络的输出结果是否真实。 ——维基百科

二、GAN模型的具体内容

(一)、模型结构

GAN模型最初由Ian J. Goodfellow教授及其团队于2014年在Generative Adversarial Nets中提出。GAN作为一种非监督学习方式,是由生成器G、分类器D这两套独立的神经网络所组成。生成器G用于生成假样本,分类器D用于分辨生成器G所生成样本是真实数据还是虚假数据。每一次判断的结果都会作为反向传播的输入到G、D之中,如果D判断正确,那就需要调整G的参数从而使得生成的假数据更为逼真;如果D判断错误,则需调节D的参数,避免下次类似判断出错。训练会一直持续到两者进入到一个均衡和谐的状态。。GAN模型的最终目的是得到一个质量较高的自动生成器和一个判断能力较强的分类器。
GAN模型

(二)、对抗网络

举个简单的例子,对抗网络就类似一场博弈,从最初始的数据库中映射出一个由多层感知器表示的可微函数G(z;θg),此后再引入一个与之对抗的多层感知器D(z;θg),两者之间展开一场“二人博弈”,分别达到最优化的目的。
假设D,G都是小的多层感知机,每层总共有稀薄的数个隐含单元。G的输入是一个噪音分布z∼uniform(0,1)中的单个样例。我们想使用G来将点z1,z2,…zM映射为x1,x2,…xM,这样映射的点xi=G(zi)在pdata(X)密集的地方会密集聚集。因此,在G中输入z将生成伪数据x′。
同时,判别器D,以x为输入,然后输出该输入属于pdata的可能性。令D1和D2为D的副本(它们共享参数,那么D1(x)=D2(x))。D1的输入是从合法的数据分布x∼pdata中得到的单个样例,所以当优化判别器时我们想使D1(x)最大化。D2以x′(G生成的伪数据)为输入,所以当优化D时,我们想使D2(x)最小化。
最终,优化函数为MAXD1MIND2{log(D1(x))+log(1−D2(G(z)))}

对抗网络训练过程
上图为具体的训练过程,蓝色虚线代表着分类器D的结果,它分辨的是由生成器G(绿色实现)生成的数据px(黑色虚线),最终,G与D会达到一个平衡阶段(图中(d)的阶段)、

三、缺点与改进

直接把GAN应用到NLP领域,有两方面的问题:1. 相较于图像生成,NLP领域更多的是生成离散的结果,因此,设计用来做图像识别的GAN无法用D分类器生成的离散结果来很好的训练生成器G。2.GAN只可以对已经生成的完整序列进行分辨,而对一部分生成的序列,如何判断它现在生成的一部分的质量和之后生成整个序列的质量也是一个问题。1
关于这些问题,已经有了很多的改进方式,如:SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient。这篇论文利用了强化学习的方式来解决以上问题。
PGM如图,针对第一个问题,首先是将D的输出作为Reward,然后用Policy Gradient Method来训练G。针对第二个问题,通过蒙特卡罗搜索,针对部分生成的序列,用一个Roll-Out Policy(一种LSTM)来取样完整的序列,再交给D进行分辨,最后对得到的Reward求平均值。

Written by lreaderl