Image-to-Image Translation with Conditional Adversarial Networks
来源:互联网 发布:mac双系统win8 编辑:程序博客网 时间:2024/06/16 17:49
https://arxiv.org/pdf/1611.07004.pdf
github tensorflow实现代码:
https://github.com/yenchenlin/pix2pix-tensorflow
背景知识:
U-Net: Convolutional Networks for Biomedical
Image Segmentation
生成网络:
生成网络的目的是,将输入高分辨率图像,映射得到输出高分辨率图像.对于生成网络结构,许多之前的方法都为encoder-decoder结构,如下图所示.首先将图像downsample,得到一个特征向量,之后逆过程,upsample,得到输出图像.
文献认为,输入图像和输出图像有相同的底层结构,仅仅在表层外貌上不同.例如在image colorizaton中,输入和输出图像具有相同的局部边缘结构信息.因此,文章采用了U-Net网络结构,如下图所示.U-Net网络将downsample得到的底层特征串联到upsample中,即将layer i的特征向量与layer n-i的特征向量串联,n为U-Net网络总的层数.
具体结构对比
encoder-decoder结构为:
encoder:
C64-C128-C256-C512-C512-C512-C512-C512
decoder:
CD512-CD512-CD512-C512-C512-C256-C128-C64
U-Net decoder:
CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
Dropout-ReLU layer,dropout为0.5.decoder的最后一层之后接一个卷积层,output channels为3(彩色图像),之后再接一个tanh激活函数层.在第一个卷积层C64没有使用batch norm,encoder 中所有的激活函数relu为leaky,slope值为0.2,decoder中激活函数为relu.
U-Net结构与encoder是相同的,除了在encoder的每个层 i 以及decoder的 n-i 中使用skip connections. skip connections 串联从第i层到第n-i层的激活值.遮盖边了decoder中的channels的数量.
U-Net decoder:
CD512-CD1024-CD1024-C1024-C1024-C512
-C256-C128
生成网络代码如下:
def generator(self, image, y=None): with tf.variable_scope("generator") as scope: s = self.output_size s2, s4, s8, s16, s32, s64, s128 = int(s/2), int(s/4), int(s/8), int(s/16), int(s/32), int(s/64), int(s/128) # image is (256 x 256 x input_c_dim) e1 = conv2d(image, self.gf_dim, name='g_e1_conv') # e1 is (128 x 128 x self.gf_dim) e2 = self.g_bn_e2(conv2d(lrelu(e1), self.gf_dim*2, name='g_e2_conv')) # e2 is (64 x 64 x self.gf_dim*2) e3 = self.g_bn_e3(conv2d(lrelu(e2), self.gf_dim*4, name='g_e3_conv')) # e3 is (32 x 32 x self.gf_dim*4) e4 = self.g_bn_e4(conv2d(lrelu(e3), self.gf_dim*8, name='g_e4_conv')) # e4 is (16 x 16 x self.gf_dim*8) e5 = self.g_bn_e5(conv2d(lrelu(e4), self.gf_dim*8, name='g_e5_conv')) # e5 is (8 x 8 x self.gf_dim*8) e6 = self.g_bn_e6(conv2d(lrelu(e5), self.gf_dim*8, name='g_e6_conv')) # e6 is (4 x 4 x self.gf_dim*8) e7 = self.g_bn_e7(conv2d(lrelu(e6), self.gf_dim*8, name='g_e7_conv')) # e7 is (2 x 2 x self.gf_dim*8) e8 = self.g_bn_e8(conv2d(lrelu(e7), self.gf_dim*8, name='g_e8_conv')) # e8 is (1 x 1 x self.gf_dim*8) self.d1, self.d1_w, self.d1_b = deconv2d(tf.nn.relu(e8), [self.batch_size, s128, s128, self.gf_dim*8], name='g_d1', with_w=True) d1 = tf.nn.dropout(self.g_bn_d1(self.d1), 0.5) d1 = tf.concat([d1, e7], 3) # d1 is (2 x 2 x self.gf_dim*8*2) self.d2, self.d2_w, self.d2_b = deconv2d(tf.nn.relu(d1), [self.batch_size, s64, s64, self.gf_dim*8], name='g_d2', with_w=True) d2 = tf.nn.dropout(self.g_bn_d2(self.d2), 0.5) d2 = tf.concat([d2, e6], 3) # d2 is (4 x 4 x self.gf_dim*8*2) self.d3, self.d3_w, self.d3_b = deconv2d(tf.nn.relu(d2), [self.batch_size, s32, s32, self.gf_dim*8], name='g_d3', with_w=True) d3 = tf.nn.dropout(self.g_bn_d3(self.d3), 0.5) d3 = tf.concat([d3, e5], 3) # d3 is (8 x 8 x self.gf_dim*8*2) self.d4, self.d4_w, self.d4_b = deconv2d(tf.nn.relu(d3), [self.batch_size, s16, s16, self.gf_dim*8], name='g_d4', with_w=True) d4 = self.g_bn_d4(self.d4) d4 = tf.concat([d4, e4], 3) # d4 is (16 x 16 x self.gf_dim*8*2) self.d5, self.d5_w, self.d5_b = deconv2d(tf.nn.relu(d4), [self.batch_size, s8, s8, self.gf_dim*4], name='g_d5', with_w=True) d5 = self.g_bn_d5(self.d5) d5 = tf.concat([d5, e3], 3) # d5 is (32 x 32 x self.gf_dim*4*2) self.d6, self.d6_w, self.d6_b = deconv2d(tf.nn.relu(d5), [self.batch_size, s4, s4, self.gf_dim*2], name='g_d6', with_w=True) d6 = self.g_bn_d6(self.d6) d6 = tf.concat([d6, e2], 3) # d6 is (64 x 64 x self.gf_dim*2*2) self.d7, self.d7_w, self.d7_b = deconv2d(tf.nn.relu(d6), [self.batch_size, s2, s2, self.gf_dim], name='g_d7', with_w=True) d7 = self.g_bn_d7(self.d7) d7 = tf.concat([d7, e1], 3) # d7 is (128 x 128 x self.gf_dim*1*2) self.d8, self.d8_w, self.d8_b = deconv2d(tf.nn.relu(d7), [self.batch_size, s, s, self.output_c_dim], name='g_d8', with_w=True) # d8 is (256 x 256 x output_c_dim) return tf.nn.tanh(self.d8)
判别网络
pathGAN
L1约束偏向于校正低频信息,为了获得更好的高频信息,有必要添加注意力到局部图像块结构.因此,文本设计了一个判别网络结构,用于惩罚块结构.这个判别网络用于判断图像的每个块是real or fake.我们将这个判别网络判断图像的所有块,并将结果平均,得到最后的结果.
作者对N的大小进行了实验对比:
由图可以,只有L1约束,以及块大小为时,生成图像纹理信息较少,块大小为时,生成图像包含了更多的纹理等细节,但是出现块效应(tiling artifacts),而块效果与效果相似,细节信息丰富,且无明显块效应,但块生成图像得到更低的FCN score:
因此如无特别说明,本文所有实验取块大小为.
70*70判别网络结构如下:
C64-C128-C256-C512
网络的最后一个之后,加一个卷积层,使output 维度为1,并加以个sigmoid回归层.C64层没有使用BatchNorm,所有的relu为leaky,slope值为0.2.
256*256判别网络结构如下:
C64-C128-C256-C512-C512-C512
判别网络代码如下:
def discriminator(self, image, y=None, reuse=False): with tf.variable_scope("discriminator") as scope: # image is 256 x 256 x (input_c_dim + output_c_dim) if reuse: tf.get_variable_scope().reuse_variables() else: assert tf.get_variable_scope().reuse == False h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv')) # h0 is (128 x 128 x self.df_dim) h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name='d_h1_conv'))) # h1 is (64 x 64 x self.df_dim*2) h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv'))) # h2 is (32x 32 x self.df_dim*4) h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, d_h=1, d_w=1, name='d_h3_conv'))) # h3 is (16 x 16 x self.df_dim*8) h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h3_lin') return tf.nn.sigmoid(h4), h4
损失函数
之前的条件对抗网络方法,证明了在对抗网络的基础上,引入传统算是函数,如文献29引入L2约束,可以取得更好的效果.判别网络的保持不变,但对抗网络的任务不只是欺骗判别网络,更是要生成与真实输出图像尽量相似的图像.基于L2约束的思想,本文在原有生成网络目标函数的基础上,引入L1约束,这是因为L1准则使得生成图像更清晰,从而使得生成图像与输入图像尽量相似,公式如下:
原GAN目标函数为:
L1约束:
总的目标函数为:
损失函数部分代码如下:
self.real_data = tf.placeholder(tf.float32, [self.batch_size, self.image_size, self.image_size, self.input_c_dim + self.output_c_dim], name='real_A_and_B_images') self.real_B = self.real_data[:, :, :, self.input_c_dim:self.input_c_dim + self.output_c_dim] self.real_A = self.real_data[:, :, :, :self.input_c_dim] self.fake_B = self.generator(self.real_A) self.real_AB = tf.concat([self.real_A, self.real_B], 3) self.fake_AB = tf.concat([self.real_A, self.fake_B], 3) self.D, self.D_logits = self.discriminator(self.real_AB, reuse=False) self.D_, self.D_logits_ = self.discriminator(self.fake_AB, reuse=True) self.fake_B_sample = self.sampler(self.real_A) self.d_sum = tf.summary.histogram("d", self.D) self.d__sum = tf.summary.histogram("d_", self.D_) self.fake_B_sum = tf.summary.image("fake_B", self.fake_B) self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits, labels=tf.ones_like(self.D))) self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.zeros_like(self.D_))) self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.ones_like(self.D_))) \ + self.L1_lambda * tf.reduce_mean(tf.abs(self.real_B - self.fake_B))
- 每日论文image-to-Image Translation with Conditional Adversarial Networks
- 《Image-to-Image Translation with Conditional Adversarial Networks》论文笔记
- Image-to-Image Translation with Conditional Adversarial Networks
- Image-to-Image Translation with Conditional Adversarial Networks论文学习
- Image-to-Image Translation with Conditional Adversarial Networks
- Image-to-Image Translation with Conditional Adversarial Networks笔记
- Image-to-Image Translation with Conditional Adversarial Networks 论文翻译
- 『论文阅读』Image-to-Image Translation with Conditional Adversarial Networks
- ICCV2017论文“Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks” 阅读笔记
- Unsupervised Image-to-Image Translation Networks---VAE+GAN+Cycle
- Improving Neural Machine Translation with Conditional Sequence Generative Adversarial Nets
- Conditional adversarial networks
- GENERATIVE ADVERSARIAL NETWORKS FOR IMAGE STEGANOGRAPHY
- 《zi2zi: Master Chinese Calligraphy with Conditional Adversarial Networks》论文笔记
- CONDITIONAL IMAGE SYNTHESIS WITH AUXILIARY CLASSIFIER GANS
- Fast Image Processing with Fully-Convolutional Networks
- An introduction to Generative Adversarial Networks (with code in TensorFlow)
- An introduction to Generative Adversarial Networks (with code in TensorFlow)
- Android WebView系列文章1-关于WebView基本使用介绍
- 怎样制作六足机器人,跟着视频DIY
- 基于视觉的视频分类入门
- java读取excel表
- 对于Session的一点理解
- Image-to-Image Translation with Conditional Adversarial Networks
- HDU 6069 Counting Divisors
- getopt()/getopt_long()/getopt_long_only()
- 爽爆了!最适合码农的夏季饮品盘点
- OpenCV学习之直方图均衡化
- zookeeper(3) zookeeper集群搭建
- MQTT再学习 -- 搭建MQTT服务器及测试
- 收集整理的非常有用的PHP函数
- WPS PPT制作