生成对抗网络

来源:互联网 发布:宁武子 邦有道则知翻译 编辑:程序博客网 时间:2024/05/14 07:01

生成对抗网络

判别器使用多层感知机判断样本来自于生成器还是真实数据
生成网络用随机的噪声经过多层感知机生成样本以迷惑判别器,最大化判别器犯错的可能性

网络的损失函数是:
这里写图片描述
D(x)表示 x 属于真实数据而不是生成数据的概率
G(z)表示从噪声 z 生成数据
这是一个极小极大问题,固定G,优化D来最大化V,固定D,优化G来最小化V,两者形成对抗。

这里写图片描述
如图是训练过程,假设我们已经有一个接近真实分布的生成分布,如图(a),图中黑线表示真实分布,绿线表示生成分布,蓝线表示判别器,判别器的值从0-1波动。
图(a),固定生成器,优化判别器,来最大化V,假设运算足够,判别器可以达到一个最优状态D*(x)=这里写图片描述表示判别器判断一个x属于真实数据的概率依靠的是真实数据在真实数据和生成数据之间所占的比重。图(b),D达到一个最优值,可以看出在前面只有真实数据的时候,D的值为1,后面之后生成数据的时候,D为0,中间生成数据和分布数据有重叠区域,即对于同一个x,既有可能来自真实数据,也有可能来自生成数据,最优 D 按照概率来判断。图(c),固定判别器,优化生成器,将生成分布逼近于真实分布。图(d),固定生成器,优化判别器。如果生成器和训练器的能力都足够强大,训练足够充分,最后真实数据和生成数据的分布重合,判别器无法判断,D(x)=1/2.

伪代码
这里写图片描述

这里写图片描述
这里写图片描述

https://zhuanlan.zhihu.com/p/28853704
https://zhuanlan.zhihu.com/p/25071913

#####mnist训练######
这里写图片描述
这里写图片描述

fake_logit = critic(train)true_logit = critic(real_data)c_loss = tf.reduce_mean(fake_logit - true_logit)g_loss = tf.reduce_mean(-fake_logit)
原创粉丝点击