Dual GAN

来源:互联网 发布:淘宝蛋糕店排名 编辑:程序博客网 时间:2024/06/11 12:58

此篇博客仅为自己阅读Dual GAN的一些笔记。
原论文链接:https://arxiv.org/abs/1704.02510

在看Dual GAN 之前,听了学长学姐汇报了cycleGAN和discoGAN,感觉大体和DualGAN是相似的,于是我一直在思考这篇文章与另外两篇文章的区别,最终发现了这三篇文章虽然思想一致,但是由于提出的目的不同,在网络结构的设计上有了不同,整体来说Dual GAN 和cycleGAN还是比较相似的【1】。

为什么提出该论文?

现在的模型大部分都是需要成对的带标签的图像作为训练集,人为的做标记是很费事不科学的,并且还有时候我们很难找到一对一的图片,比如说照片及其对应的素描画(题外话:现在好像有不错的算法,生成一张图片对应的素描画,但要生成这样对应的训练集也是很麻烦就是了),我们往往是有一堆照片和一堆素描画,这篇论文想要做到的是,随机从两个数据集中各拿一张图片,可以将一张图片生成带有另一张图片风格的新图片。

文章的思想来源

这篇文章的灵感来源是Xia et al 提出的一篇做机器翻译的文章NIP Dual【2】。这篇文章的一个例子很好的解释了对偶的思想。首先假设我们有两个人A和B,A会英文不会法语,B会法语不会英文。A要和B进行交流,A写了一段话,通过翻译器GA翻译成法语,但A看不懂法语,只能将翻译的结果直接发给B;B收到这段话之后,用自己的理解整理了一下,通过翻译器GB翻译成英文直接发给A;A收到之后,要检查B是否真正理解自己说的话。如此往复几次,A和B都能确认对方理解了自己。在这个例子中,翻译器A和B分别是两个生成器,A和B分别担任了判别器的角色。

理论部分

接下来上原文中的图
这里写图片描述
再来看这个图是不是很好理解,好了,这里就不做重复解释了(其实是我懒- -)
NIP方法是依赖与训练来保证网络的正确性的,嗯,看到这里是不是感觉到不对劲了,对于语言,我们有很多对应的数据可以预训练,那对于图片呢,显然没有,原因见上,额,这下仿佛陷入了一个死循环了。别担心,除了NIP我们还有很多方法啊,比如说WGAN。

嗯,接下来让我们跳转到WGAN频道,本来想直接啃了这篇文章,发现太难了,网络结构和损失函数都参考WGAN,这让我怎么干啃T-T

WGAN

什么是WGAN?看了这篇文章,才算明白,为什么GAN提出了那么久没火起来,这两年突然间就火了呢,在这里我就姑且称它为GAN的再生之父吧。接下来的内容参考令人拍案叫绝的Wasserstein
GAN。请原谅一个数学不怎么好的人,一些数学证明这里就不复述了,需要的话请直接跳转原文,这里仅说明一些结论。
原始的GAN到底有什么问题呢?这时我们需要返回GAN的开山论文中《Generative Adversarial Nets》。
1、在《Generative Adversarial Nets》中,4.1中花了很大的篇幅在证明当判别器最优的时候,P_g=P_data是全局最优解。
这里写图片描述
这是文章该部分最终的出的结论。从这个式子我们可以得出什么结论呢?在最优判别器的条件下,原始GAN的生成器loss等价变换为最小化真实分布P_data与生成分布P_g之间的JS散度。我们越训练判别器,它就越接近最优,最小化生成器的loss也就会越近似于最小化P_r和P_g之间的JS散度。
这公式仿佛没什么问题,是的,问题不在公式上,在于一开始采用的度量方式就出错了。这说明了选择比努力更重要啊。。。。
为什么错呢?这里引进了一个概念,当P_r与P_g的支撑集是高维空间中的低维流形时,P_r与P_g重叠部分为0的概率为1。这个时候
这里写图片描述
值就固定为log2了。什么!!!我们的GAN一般是从低像素(低维)生成到高像素(高维中),这不就刚好是这种情况了,那我们还训练什么C(G)都不动了。。。。
这就是为什么我们在训练GAN的时候不能把判别器训练得太好了。。。。。。。这是个大问题。
2、我们知道生成器的loss被优化成下面这个式子这里写图片描述这个式子又可以转换成JS和KL表示,如下:这里写图片描述
至于怎么推导的,看原论文去T-T
从上面这个式子,我们可以看到我们要最小化生成器loss既要减小KL又要同时增大JS,听起来就很矛盾。这就是为什么会梯度不稳定的原因啦。还有一个,KL(P_g||p_data)和KL(P_data||P_g)这两个式子不是一个东西,第一种错误说的是生成器没生成正确的样本,惩罚比较小,而第二种错误指的是生成器生成不了正确的样本,惩罚很大。第一种错误对应的是缺乏多样性,第二种错误对应的是缺乏准确性。这一放一打之下,生成器宁可多生成一些重复但是很“安全”的样本,也不愿意去生成多样性的样本,因为那样一不小心就会产生第二种错误,得不偿失。这种现象就是大家常说的collapse mode。
知道了GAN的毛病,及产生这个毛病的问题之后,作者当然是要提出解决方案。既然是KL和JS度量出了问题,那我们就不用它。

作者提出了一种新的度量方式:Wasserstein距离
这里写图片描述
r服从Pr和Pg的联合分布,取出所有的(x,y)~r,计算||x-y||(x和y)的距离,这个式子就是要求这些距离的期望值。虽然这个式子很好理解,但实际上很难进行操作。于是作者又将这个式子变换成了如下:
这里写图片描述
怎么证明这个式子我们先不管它,我们来看它要怎么应用。 这里的||f||L<=K指的是什么?Lipschitz限制。用公式怎么表达呢?这里写图片描述
于是我们可以将上式W转换成
这里写图片描述
fw表示一个带参数w的神经网络。为了满足||f||L<=K,作者将参数w通过clipping操作,限制在一个范围内。在WGAN中的判别器fw做的是近似拟合Wasserstein距离,属于回归任务,所以要把最后一层的sigmoid拿掉。
文章用这里写图片描述L取最大值时来近似Wasserstein距离。

我们可以看出公式14中第一项和生成器没有关系。于是我们就要甩出我们的重头戏损失函数啦,
这里写图片描述
公式1表示判别器希望尽可能拉高真样本的分数,拉低假样本的分数,公式2表示生成器希望尽可能拉高假样本的分数。

接下来放出我们的大招:WGAN算法的流程
这里写图片描述
这里贴出了我们原本GAN的流程图仅供对比:
这里写图片描述
看出差别了嘛?
1、生成器和判别器的loss不一样
2、每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
3、不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行

================================================

呃呃呃,理想很美满,现实很骨感,在WGAN中,一系列实验都呈现出很好的效果,然,很多人在实验的时候发现了WGAN实际上没那么完美,反而存在着训练困难、收敛速度慢等问题。还记得嘛,在上文我们为了保证||f||L<=K,对网络f的参数w进行了clipping操作,直接将w的值限制在一个范围内,这就是为什么WGAN出现问题的原因了。下图是w属于[-0.01,0.01]范围的判别器数值分布图。
这里写图片描述
可以发现我们的分布都集中在w最大值和最小值上了,这个网络的映射太简单,判别器没能充分利用自身的模型能力,经过它回传给生成器的梯度也会跟着变差。
还有一个问题,我们对w进行clip,会出现梯度消失或者梯度爆炸。原因是如果我们w的值设得大一点,那么经过多层网络,会指数增长,出现梯度爆炸;w的值设得小一点嘛,又会出现指数衰减,怎么把握w的度呢?这就要看你的调参技术和人品了。
这个时候,我们GAN的再生之父就有紧接着提出了gradient penalty。
这里写图片描述
从上图我们可以明显的看出gradient penalty的梯度变化比较平缓。比起前文中的weight clip更适合用来限制网络f的参数w的范围。

上文我们说到Lipschitz限制是要求判别器的梯度不超过K,WGAN的后续工作中,作者指出可以用一个额外的loss项来满足这个限制(怎么感觉思想有点像来自之前给式子加拉格朗日限制的)。
训练过程中判别器希望尽可能拉大真假样本的分数差距,真的数据打分越高,假的数据打分越低。对应的就是希望梯度越大越好。所以判别器在充分训练之后,其梯度norm其实就会是在K附近。所以我们根据这个思想得出额外的loss项梯度norm离K越近越好。因此作者得出额外loss的表达式如下:
这里写图片描述
总结一下,到目前为止,我们判别器的loss变成下面这个形式:
这里写图片描述
作者对如何求解这个判别器的loss,给出了一个方案,我们在具体训练网络的过程中,都是通过采样的方式进行计算的(具体请回顾上面的伪代码)。我们可以看到上面这个式子原本是需要对在整个样本上进行采样的。但实际上我们没有必要对整个空间进行采样,只要重点抓住生成样本集中区域、真实样本集中区域以及夹在它们中间的区域就行了。
这里写图片描述
也就是说我们只需要拿到一对真假数据,然后对它们进行随机插值得到一个中间数,将这个中间数带入第三项即可。

好啦,讲到这里我们的题外话WGAN算是告一段落了,接下来该回到我们的主题Dual GAN中来了。

我们可以发现GAN其实大同小异,主要是网络的设计和损失函数的定义。接下来我们先来看看Dual GAN的损失函数是怎么设计的。
这里写图片描述
这里写图片描述
是不是觉得这些损失函数有点似曾相识呢?没错,就是我们WGAN中的损失函数,这里为了便于理解,再将WGAN的损失函数再贴出来一下
这里写图片描述
我们来单独对比一下判别器,Da(v)对应的就是WGAN中的Ex_pr[D(x)],Da(Ga(u,z))对应的是Ex_pg[D(x)]。再来看看生成器,第一项和第二项分别代表前向和反向的内容损失,忘记的话,翻到上面去看看结构图,第三项和第四项分别是两个生成器的损失。

好啦,我们的损失函数有了,接下来就来介绍一下我们的网络结构。在Dual中,生成器采用的是U-net网络。判别器使用的是patchGAN。
这里写图片描述

上Dual的伪代码:
这里写图片描述
为了方便对比,再上原始WGAN的伪代码(别嫌弃我啰嗦)
这里写图片描述
总结一下,其实Dual GAN用的是没有改进的WGAN的流程,然后加入对偶的思想,生成网络采用的是U-net,判别网络使用patchGAN的思想。这里如果要进行实现,个人建议采用改进后的WGAN-GP方法,效果可能会更好。理解了WGAN之后再来看Dual GAN是不是就很简单啊~~
附上DualGAN源代码链接https://github.com/duxingren14/DualGAN
WGAN源码链接https://github.com/igul222/improved_wgan_training

实验部分

这篇文章主要做了三个实验:照片和素描画的转换,带标签的图像翻译,图片风格化。使用了4个库(PHOTO-SKETCH,DAY-NIGHT, LABEL-FACADES , and AERIALMAPS)来做实验。这些库分别包括两个域对应的图片,所以我们把它用作GT,但是这些数据集都没办发保证特征在像素级别上有高的精确度。
这里写图片描述
上图是从label到facade的转换,我们可以看出来DualGAN对图像的结构复原的很好,其他的算法,对于图片中结构没有对其的部分复原效果比较差。
这里写图片描述
这是照片生成简笔画的结果,哪个效果好就不用我多说了吧~
这里写图片描述
国画和油画的转换~
这里写图片描述
物体材质转换~
这里写图片描述
地图到航拍照片

从视觉上,可以很明显的看出我们的DualGAN的效果杠杠的,至于对数据结果有兴趣的,就请移步原论文吧~

讲到这里我们DualGAN论文主要内容就讲完啦~此篇仅为个人笔记,如果有哪里错啦,还望指出~~

[1]乾貨 | 孿生三兄弟 CycleGAN, DiscoGAN, DualGAN 還有哪些散落天涯的遠親
[2]Y. Xia, D. He, T. Qin, L. Wang, N. Yu, T.-Y. Liu,
and W.-Y. Ma. Dual learning for machine translation.
arXiv preprint arXiv:1611.00179, 2016. 1, 2, 3, 4