DCGAN代码分析

来源:互联网 发布:行楷字帖谁的好 知乎 编辑:程序博客网 时间:2024/06/06 00:34

生成网络

生成网络输入为噪声向量z,和输出类别向量y,输出为生成图像.

以手写字体图像为例,z为100维的向量,由于手写字体分为10类,因此输出类别向量y为10维的向量,对应的类别的索引值为1,其他值为0.

代码如下:

def generator(self, z, y=None):  with tf.variable_scope("generator") as scope:    if not self.y_dim:      s_h, s_w = self.output_height, self.output_width      s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)      s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)      s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)      s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)      # project `z` and reshape      self.z_, self.h0_w, self.h0_b = linear(          z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin', with_w=True)      self.h0 = tf.reshape(          self.z_, [-1, s_h16, s_w16, self.gf_dim * 8])      h0 = tf.nn.relu(self.g_bn0(self.h0))      self.h1, self.h1_w, self.h1_b = deconv2d(          h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1', with_w=True)      h1 = tf.nn.relu(self.g_bn1(self.h1))      h2, self.h2_w, self.h2_b = deconv2d(          h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2', with_w=True)      h2 = tf.nn.relu(self.g_bn2(h2))      h3, self.h3_w, self.h3_b = deconv2d(          h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3', with_w=True)      h3 = tf.nn.relu(self.g_bn3(h3))      h4, self.h4_w, self.h4_b = deconv2d(          h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4', with_w=True)      return tf.nn.tanh(h4)    else:      s_h, s_w = self.output_height, self.output_width      s_h2, s_h4 = int(s_h/2), int(s_h/4)      s_w2, s_w4 = int(s_w/2), int(s_w/4)      # yb = tf.expand_dims(tf.expand_dims(y, 1),2)      yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])      z = concat([z, y], 1)      h0 = tf.nn.relu(          self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin')))      h0 = concat([h0, y], 1)      h1 = tf.nn.relu(self.g_bn1(          linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin')))      h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2])      h1 = conv_cond_concat(h1, yb)      h2 = tf.nn.relu(self.g_bn2(deconv2d(h1,          [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2')))      h2 = conv_cond_concat(h2, yb)      return tf.nn.sigmoid(          deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))

判别网络

判别网络输入分别为真实图像和类别向量y, 生成图像G和类别向量y. 作用是判别输入图像是真实图像(real)还是生成图像(fake).

判别网络代码为:

def discriminator(self, image, y=None, reuse=False):  with tf.variable_scope("discriminator") as scope:    if reuse:      scope.reuse_variables()    if not self.y_dim:      h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv'))      h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name='d_h1_conv')))      h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv')))      h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, name='d_h3_conv')))      h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h4_lin')      return tf.nn.sigmoid(h4), h4    else:      yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])      x = conv_cond_concat(image, yb)      h0 = lrelu(conv2d(x, self.c_dim + self.y_dim, name='d_h0_conv'))      h0 = conv_cond_concat(h0, yb)      h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim + self.y_dim, name='d_h1_conv')))      h1 = tf.reshape(h1, [self.batch_size, -1])            h1 = concat([h1, y], 1)      h2 = lrelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin')))      h2 = concat([h2, y], 1)      h3 = linear(h2, 1, 'd_h3_lin')      return tf.nn.sigmoid(h3), h3

目标函数

self.G                  = self.generator(self.z, self.y)self.D, self.D_logits   = self.discriminator(inputs, self.y, reuse=False)self.sampler            = self.sampler(self.z, self.y)self.D_, self.D_logits_ = self.discriminator(self.G, self.y, reuse=True)self.d_sum = histogram_summary("d", self.D)self.d__sum = histogram_summary("d_", self.D_)self.G_sum = image_summary("G", self.G)def sigmoid_cross_entropy_with_logits(x, y):  try:    return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)  except:    return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, targets=y)self.d_loss_real = tf.reduce_mean(  sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D)))self.d_loss_fake = tf.reduce_mean(  sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_)))self.g_loss = tf.reduce_mean(  sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_)))
原创粉丝点击