DeepLearning&Tensorflow学习笔记4__mnist数据集DCGAN

来源:互联网 发布:ipcms录像软件 编辑:程序博客网 时间:2024/06/13 07:17

1.Introduction

利用mnist数据集进行训练DCGAN网络,生成数字图像。

2.Source code

#encoding:utf-8""" Deep Convolutional Generative Adversarial Network (DCGAN).Using deep convolutional generative adversarial networks (DCGAN) to generatedigit images from a noise distribution.References:    - Unsupervised representation learning with deep convolutional generative    adversarial networks. A Radford, L Metz, S Chintala. arXiv:1511.06434.Links:    - [DCGAN Paper](https://arxiv.org/abs/1511.06434).    - [MNIST Dataset](http://yann.lecun.com/exdb/mnist/).Author: Aymeric DamienProject: https://github.com/aymericdamien/TensorFlow-Examples/"""from __future__ import division, print_function, absolute_importimport scipy.miscimport matplotlib.pyplot as pltimport numpy as npimport tensorflow as tfimport PIL.Image as Image# Import MNIST datafrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("/tmp/data/", one_hot=True)# Training Paramsnum_steps = 100  #20000batch_size = 32# Network Paramsimage_dim = 784 # 28*28 pixels * 1 channelgen_hidden_dim = 256disc_hidden_dim = 256noise_dim = 200 # Noise data pointslog_dir = "mnist_logs"# Generator Network# Input: Noise, Output: Imagedef generator(x, reuse=False):    with tf.variable_scope('Generator', reuse=reuse):        # TensorFlow Layers automatically create variables and calculate their        # shape, based on the input.        x = tf.layers.dense(x, units=6 * 6 * 128)  #全连接层  输出维度为units=4608        x = tf.nn.tanh(x)   #计算x的正切值        # Reshape to a 4-D array of images: (batch, height, width, channels)        # New shape: (batch, 6, 6, 128)        x = tf.reshape(x, shape=[-1, 6, 6, 128])        # Deconvolution, image shape: (batch, 14, 14, 64)        x = tf.layers.conv2d_transpose(x, 64, 4, strides=2)        # Deconvolution, image shape: (batch, 28, 28, 1)        x = tf.layers.conv2d_transpose(x, 1, 2, strides=2)        # Apply sigmoid to clip values between 0 and 1        x = tf.nn.sigmoid(x)        return x# Discriminator Network# Input: Image, Output: Prediction Real/Fake Imagedef discriminator(x, reuse=False):  # shape:[None, 28, 28, 1]    with tf.variable_scope('Discriminator', reuse=reuse):        # Typical convolutional neural network to classify images.        x = tf.layers.conv2d(x, 64, 5)  # shape:[None, 24, 24, 64]        x = tf.nn.tanh(x)   # shape:[None, 24, 24, 1]        x = tf.layers.average_pooling2d(x, 2, 2)   # shape:[None, 12, 12, 64]        x = tf.layers.conv2d(x, 128, 5)  # shape:[None, 8, 8, 128]        x = tf.nn.tanh(x)        x = tf.layers.average_pooling2d(x, 2, 2)  # shape:[None, 4, 4, 128]        x = tf.contrib.layers.flatten(x)  # shape:[None, 4096]   4*4*128=4096        x = tf.layers.dense(x, 1024)  #shape:  [None,1024]        x = tf.nn.tanh(x)    #shape: [None,1024]        # Output 2 classes: Real and Fake images        x = tf.layers.dense(x, 2)    #shape: [None,2]    return x# Build Networks# Network Inputsnoise_input = tf.placeholder(tf.float32, shape=[None, noise_dim])real_image_input = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])# Build Generator Networkgen_sample = generator(noise_input)   #shape: [None, 28, 28, 1]# Build 2 Discriminator Networks (one from noise input, one from generated samples)disc_real = discriminator(real_image_input)    #shape:[None,2]disc_fake = discriminator(gen_sample, reuse=True)   #shape: [None,2]disc_concat = tf.concat([disc_real, disc_fake], axis=0)   #shpae:  [2*None,2]# Build the stacked generator/discriminatorstacked_gan = discriminator(gen_sample, reuse=True)# Build Targets (real or fake images)disc_target = tf.placeholder(tf.int32, shape=[None])gen_target = tf.placeholder(tf.int32, shape=[None])# Build Lossdisc_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(    logits=disc_concat, labels=disc_target))gen_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(    logits=stacked_gan, labels=gen_target))# Build Optimizersoptimizer_gen = tf.train.AdamOptimizer(learning_rate=0.001)optimizer_disc = tf.train.AdamOptimizer(learning_rate=0.001)# Training Variables for each optimizer# By default in TensorFlow, all variables are updated by each optimizer, so we# need to precise for each one of them the specific variables to update.# Generator Network Variablesgen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator')# Discriminator Network Variablesdisc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator')# Create training operationstrain_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars)train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars)# Initialize the variables (i.e. assign their default value)init = tf.global_variables_initializer()# Start trainingwith tf.Session() as sess:    train_writer = tf.summary.FileWriter(log_dir + '/train', sess.graph)    # Run the initializer    sess.run(init)    for i in range(1, num_steps+1):        # Prepare Input Data        # Get the next batch of MNIST data (only images are needed, not labels)        batch_x, _ = mnist.train.next_batch(batch_size)        batch_x = np.reshape(batch_x, newshape=[-1, 28, 28, 1])        # Generate noise to feed to the generator        z = np.random.uniform(-1., 1., size=[batch_size, noise_dim])        # Prepare Targets (Real image: 1, Fake image: 0)        # The first half of data fed to the generator are real images,        # the other half are fake images (coming from the generator).        batch_disc_y = np.concatenate(            [np.ones([batch_size]), np.zeros([batch_size])], axis=0)        # Generator tries to fool the discriminator, thus targets are 1.        batch_gen_y = np.ones([batch_size])        # Training        feed_dict = {real_image_input: batch_x, noise_input: z,                     disc_target: batch_disc_y, gen_target: batch_gen_y}        _, _, gl, dl = sess.run([train_gen, train_disc, gen_loss, disc_loss],                                feed_dict=feed_dict)        if i % 100 == 0 or i == 1:            print('Step %i: Generator Loss: %f, Discriminator Loss: %f' % (i, gl, dl))    # Generate images from noise, using the generator network.    f, a = plt.subplots(4, 10, figsize=(10, 4))    for i in range(10):        # Noise input.        z = np.random.uniform(-1., 1., size=[4, noise_dim])        g = sess.run(gen_sample, feed_dict={noise_input: z})        print('g.size: ')        fig_count=0;        for j in range(4):            # Generate image from noise. Extend to 3 channels for matplot figure.            img = np.reshape(np.repeat(g[j][:, :, np.newaxis], 3, axis=2),newshape=(28, 28, 3))            a[j][i].imshow(img)            #############            print("save image")            scipy.misc.imsave('./gen_samp/'+str(i)+str(j)+'.jpg', img)            #scipy.misc.imsave('restmp.jpg', img)    f.show()    plt.draw()    plt.waitforbuttonpress()
阅读全文
0 0
原创粉丝点击