DCGAN代码分析
来源:互联网 发布:行楷字帖谁的好 知乎 编辑:程序博客网 时间:2024/06/06 00:34
生成网络
生成网络输入为噪声向量z,和输出类别向量y,输出为生成图像.
以手写字体图像为例,z为100维的向量,由于手写字体分为10类,因此输出类别向量y为10维的向量,对应的类别的索引值为1,其他值为0.
代码如下:
def generator(self, z, y=None): with tf.variable_scope("generator") as scope: if not self.y_dim: s_h, s_w = self.output_height, self.output_width s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2) s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2) s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2) s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2) # project `z` and reshape self.z_, self.h0_w, self.h0_b = linear( z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin', with_w=True) self.h0 = tf.reshape( self.z_, [-1, s_h16, s_w16, self.gf_dim * 8]) h0 = tf.nn.relu(self.g_bn0(self.h0)) self.h1, self.h1_w, self.h1_b = deconv2d( h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1', with_w=True) h1 = tf.nn.relu(self.g_bn1(self.h1)) h2, self.h2_w, self.h2_b = deconv2d( h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2', with_w=True) h2 = tf.nn.relu(self.g_bn2(h2)) h3, self.h3_w, self.h3_b = deconv2d( h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3', with_w=True) h3 = tf.nn.relu(self.g_bn3(h3)) h4, self.h4_w, self.h4_b = deconv2d( h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4', with_w=True) return tf.nn.tanh(h4) else: s_h, s_w = self.output_height, self.output_width s_h2, s_h4 = int(s_h/2), int(s_h/4) s_w2, s_w4 = int(s_w/2), int(s_w/4) # yb = tf.expand_dims(tf.expand_dims(y, 1),2) yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim]) z = concat([z, y], 1) h0 = tf.nn.relu( self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin'))) h0 = concat([h0, y], 1) h1 = tf.nn.relu(self.g_bn1( linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin'))) h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2]) h1 = conv_cond_concat(h1, yb) h2 = tf.nn.relu(self.g_bn2(deconv2d(h1, [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2'))) h2 = conv_cond_concat(h2, yb) return tf.nn.sigmoid( deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))
判别网络
判别网络输入分别为真实图像和类别向量y, 生成图像G和类别向量y. 作用是判别输入图像是真实图像(real)还是生成图像(fake).
判别网络代码为:
def discriminator(self, image, y=None, reuse=False): with tf.variable_scope("discriminator") as scope: if reuse: scope.reuse_variables() if not self.y_dim: h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv')) h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name='d_h1_conv'))) h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv'))) h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, name='d_h3_conv'))) h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h4_lin') return tf.nn.sigmoid(h4), h4 else: yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim]) x = conv_cond_concat(image, yb) h0 = lrelu(conv2d(x, self.c_dim + self.y_dim, name='d_h0_conv')) h0 = conv_cond_concat(h0, yb) h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim + self.y_dim, name='d_h1_conv'))) h1 = tf.reshape(h1, [self.batch_size, -1]) h1 = concat([h1, y], 1) h2 = lrelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin'))) h2 = concat([h2, y], 1) h3 = linear(h2, 1, 'd_h3_lin') return tf.nn.sigmoid(h3), h3
目标函数
self.G = self.generator(self.z, self.y)self.D, self.D_logits = self.discriminator(inputs, self.y, reuse=False)self.sampler = self.sampler(self.z, self.y)self.D_, self.D_logits_ = self.discriminator(self.G, self.y, reuse=True)self.d_sum = histogram_summary("d", self.D)self.d__sum = histogram_summary("d_", self.D_)self.G_sum = image_summary("G", self.G)def sigmoid_cross_entropy_with_logits(x, y): try: return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y) except: return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, targets=y)self.d_loss_real = tf.reduce_mean( sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D)))self.d_loss_fake = tf.reduce_mean( sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_)))self.g_loss = tf.reduce_mean( sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_)))
阅读全文
0 0
- DCGAN代码分析
- [Tensorflow]3.DCGAN代码及实验结果分析
- DCGAN 源码分析(一)
- DCGAN实验
- Tensorflow-DCGAN
- 论文:DCGAN
- DCGAN guidelines
- 生成对抗网络DCGAN+Tensorflow代码学习笔记(一)----main.py
- 生成对抗网络DCGAN+Tensorflow代码学习笔记(二)----utils.py
- 生成对抗网络DCGAN+Tensorflow代码学习笔记(三)----ops.py
- GAN——DCGAN
- DCGAN及其TensorFlow源码
- samson-wang/dcgan.caffe
- DCGAN生成手写体数据
- DCGAN生成彩色图片
- 暑期学习 DCGAN 笔记
- 学习笔记GAN002:DCGAN
- DCGAN的学习
- 去除重复字符
- sizeof与数组和指针
- 【自定义View】5.仿探探的卡片滑动效果
- [LeetCode] binary-tree-preorder-traversal
- GIL理解
- DCGAN代码分析
- 虚拟机类加载机制
- 刻意学习笔记以及一周回顾
- Gradle基本使用(5):文件操作
- 无题
- 关于对Spring中AOP的Advice(通知、增强)的属性Around的理解
- mycat2.0内测之配置和启动(随时更新)
- VMware vSphere Web Services SDK编程指南(十)- 10.7 创建及管理 Datastores
- 大数据预科班16