文章标题

来源:互联网 发布:管家婆软件如何操作 编辑:程序博客网 时间:2024/06/08 05:37

GAN + MNIST

# -*- coding: utf-8 -*-# @Time     : 2017/11/2 18:02# @File     : MNISTGAN.py# @Author   : Zhiwei Zhong# @Function :from __future__ import divisionimport tensorflow as tfimport numpy as npfrom tensorflow.examples.tutorials.mnist import input_dataimport matplotlib.pyplot as plttf.set_random_seed(1)np.random.seed(1)BATCH_SIZE = 64LR_G = 0.001LR_D = 0.001MEAN = [5, 5]SIGMA = [[1, 0], [0, 1]]TRAIN_STEP = 2000000PRINT_STEP = 5000EPOCHS = 30mnist = input_data.read_data_sets('./mnist', one_hot=True)print(mnist.train.images.shape)def fake_data():    return np.random.normal(0, 1, BATCH_SIZE * 100).reshape(BATCH_SIZE, 100)    # return np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))def noise():    return np.random.normal(0, 1, 5 * 784).reshape(5, 784)with tf.variable_scope("GEN"):    G_in = tf.placeholder(tf.float32, [None, 100], name="G_in")    G_l1 = tf.layers.dense(G_in, 128, name="G_Layer1")    # leak_relu  系数为0.01    G_l1 = tf.maximum(0.01*G_l1, G_l1)    G_l1 = tf.layers.dropout(G_l1, rate=0.2)        # 随机失活    # G_l2 = tf.layers.dense(G_l1, 128, tf.nn.sigmoid, name="G_Layer2")    G_out = tf.layers.dense(G_l1, 28*28, tf.nn.sigmoid, name="G_out")with tf.variable_scope("DISC"):    D_in = tf.placeholder(tf.float32, [None, 28*28], name="real_data")    D_l1 = tf.layers.dense(D_in, 128, name="D_Layer1")    D_l1 = tf.maximum(D_l1, 0.01*D_l1)    # D_l1 = tf.layers.dropout(D_l1, 0.2)    # D_l2 = tf.layers.dense(D_l1, 64, tf.nn.tanh, name="D_Layer2")    D_Real_Out = tf.layers.dense(D_l1, 1, tf.nn.sigmoid, name="D_Out")    D_l3 = tf.layers.dense(G_out, 128, name="D_Layer1", reuse=True)    D_l3 = tf.maximum(D_l3, 0.01*D_l3)    # D_l4 = tf.layers.dense(D_l3, 64, tf.nn.relu, name="D_Layer2", reuse=True)    D_Fake_Out = tf.layers.dense(D_l3, 1, tf.nn.sigmoid, name="D_Out", reuse=True)D_Loss = -tf.reduce_mean(tf.log(D_Real_Out) + tf.log(1 - D_Fake_Out))G_Loss = tf.reduce_mean(tf.log(1 - D_Fake_Out))train_D = tf.train.AdamOptimizer(LR_D).minimize(    D_Loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="DISC"))train_G = tf.train.AdamOptimizer(LR_G).minimize(    G_Loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="GEN"))sess = tf.Session()sess.run(tf.global_variables_initializer())saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="GEN"))# f, a = plt.subplots(1, 5, figsize=(5, 2))# plt.ion()samples = []for e in range(EPOCHS):    for _ in range(mnist.train.num_examples // BATCH_SIZE):         # 向下取整        b_x, by = mnist.train.next_batch(BATCH_SIZE)        b_x = b_x.reshape((BATCH_SIZE, 784))        # data_x = b_x * 2  - 1  # 缩放到-1 1之间        data_z = fake_data()        # b_x[: 5] = sess.run(G_out, {G_in: data_z})[: 5]        data_x = b_x        sess.run(train_D, feed_dict={G_in: data_z, D_in: data_x})        sess.run(train_G, feed_dict={G_in: data_z})    saver.save(sess, "./checkpoints/generator.ckpt")    DiscFake, DiscReal, DLoss, GLoss = sess.run([D_Fake_Out, D_Real_Out, D_Loss, G_Loss], {G_in: data_z, D_in: data_x})    print("EPOCH:{}, DISC_FAKE:{},DISC_REAL:{}, D_LOSS:{}, G_LOSS:{}".format(e, np.mean(DiscFake),                            np.mean(DiscReal), np.mean(DLoss),np.mean(GLoss)))    data_z = fake_data()    pic = sess.run(G_out, {G_in: data_z})[: 5]    """try:        for i in range(5):            a[i].clear()            a[i].imshow(np.reshape(pic[i], (28, 28)), cmap='gray')            a[i].set_xticks(())            a[i].set_yticks(())            plt.draw()            plt.pause(0.01)    except:        pass"""    samples.append(pic)import pickle# 将sample的生成数据记录下来with open('train_samples.pkl', 'wb') as f:    pickle.dump(samples, f)with open('train_samples.pkl', 'rb') as f:    samples = pickle.load(f)def view_samples(epoch, samples):    """    epoch代表第几次迭代的图像    samples为我们的采样结果    """    fig, axes = plt.subplots(figsize=(7, 7), nrows=5, ncols=1)    plt.ion()    for i in range(5):        # for j in range(5):            axes[i].imshow(np.reshape(samples[epoch][i], (28, 28)), cmap='gray')    #plt.draw()    plt.ioff()    plt.show()# 生成新的图片saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="GEN"))with tf.Session() as sess:    saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))    sample_noise = np.random.uniform(-1, 1, size=(25, 100))    gen_samples = sess.run(G_out, {G_in: sample_noise})view_samples(10, gen_samples)
原创粉丝点击