文章标题
来源:互联网 发布:管家婆软件如何操作 编辑:程序博客网 时间:2024/06/08 05:37
GAN + MNIST
# -*- coding: utf-8 -*-# @Time : 2017/11/2 18:02# @File : MNISTGAN.py# @Author : Zhiwei Zhong# @Function :from __future__ import divisionimport tensorflow as tfimport numpy as npfrom tensorflow.examples.tutorials.mnist import input_dataimport matplotlib.pyplot as plttf.set_random_seed(1)np.random.seed(1)BATCH_SIZE = 64LR_G = 0.001LR_D = 0.001MEAN = [5, 5]SIGMA = [[1, 0], [0, 1]]TRAIN_STEP = 2000000PRINT_STEP = 5000EPOCHS = 30mnist = input_data.read_data_sets('./mnist', one_hot=True)print(mnist.train.images.shape)def fake_data(): return np.random.normal(0, 1, BATCH_SIZE * 100).reshape(BATCH_SIZE, 100) # return np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))def noise(): return np.random.normal(0, 1, 5 * 784).reshape(5, 784)with tf.variable_scope("GEN"): G_in = tf.placeholder(tf.float32, [None, 100], name="G_in") G_l1 = tf.layers.dense(G_in, 128, name="G_Layer1") # leak_relu 系数为0.01 G_l1 = tf.maximum(0.01*G_l1, G_l1) G_l1 = tf.layers.dropout(G_l1, rate=0.2) # 随机失活 # G_l2 = tf.layers.dense(G_l1, 128, tf.nn.sigmoid, name="G_Layer2") G_out = tf.layers.dense(G_l1, 28*28, tf.nn.sigmoid, name="G_out")with tf.variable_scope("DISC"): D_in = tf.placeholder(tf.float32, [None, 28*28], name="real_data") D_l1 = tf.layers.dense(D_in, 128, name="D_Layer1") D_l1 = tf.maximum(D_l1, 0.01*D_l1) # D_l1 = tf.layers.dropout(D_l1, 0.2) # D_l2 = tf.layers.dense(D_l1, 64, tf.nn.tanh, name="D_Layer2") D_Real_Out = tf.layers.dense(D_l1, 1, tf.nn.sigmoid, name="D_Out") D_l3 = tf.layers.dense(G_out, 128, name="D_Layer1", reuse=True) D_l3 = tf.maximum(D_l3, 0.01*D_l3) # D_l4 = tf.layers.dense(D_l3, 64, tf.nn.relu, name="D_Layer2", reuse=True) D_Fake_Out = tf.layers.dense(D_l3, 1, tf.nn.sigmoid, name="D_Out", reuse=True)D_Loss = -tf.reduce_mean(tf.log(D_Real_Out) + tf.log(1 - D_Fake_Out))G_Loss = tf.reduce_mean(tf.log(1 - D_Fake_Out))train_D = tf.train.AdamOptimizer(LR_D).minimize( D_Loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="DISC"))train_G = tf.train.AdamOptimizer(LR_G).minimize( G_Loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="GEN"))sess = tf.Session()sess.run(tf.global_variables_initializer())saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="GEN"))# f, a = plt.subplots(1, 5, figsize=(5, 2))# plt.ion()samples = []for e in range(EPOCHS): for _ in range(mnist.train.num_examples // BATCH_SIZE): # 向下取整 b_x, by = mnist.train.next_batch(BATCH_SIZE) b_x = b_x.reshape((BATCH_SIZE, 784)) # data_x = b_x * 2 - 1 # 缩放到-1 1之间 data_z = fake_data() # b_x[: 5] = sess.run(G_out, {G_in: data_z})[: 5] data_x = b_x sess.run(train_D, feed_dict={G_in: data_z, D_in: data_x}) sess.run(train_G, feed_dict={G_in: data_z}) saver.save(sess, "./checkpoints/generator.ckpt") DiscFake, DiscReal, DLoss, GLoss = sess.run([D_Fake_Out, D_Real_Out, D_Loss, G_Loss], {G_in: data_z, D_in: data_x}) print("EPOCH:{}, DISC_FAKE:{},DISC_REAL:{}, D_LOSS:{}, G_LOSS:{}".format(e, np.mean(DiscFake), np.mean(DiscReal), np.mean(DLoss),np.mean(GLoss))) data_z = fake_data() pic = sess.run(G_out, {G_in: data_z})[: 5] """try: for i in range(5): a[i].clear() a[i].imshow(np.reshape(pic[i], (28, 28)), cmap='gray') a[i].set_xticks(()) a[i].set_yticks(()) plt.draw() plt.pause(0.01) except: pass""" samples.append(pic)import pickle# 将sample的生成数据记录下来with open('train_samples.pkl', 'wb') as f: pickle.dump(samples, f)with open('train_samples.pkl', 'rb') as f: samples = pickle.load(f)def view_samples(epoch, samples): """ epoch代表第几次迭代的图像 samples为我们的采样结果 """ fig, axes = plt.subplots(figsize=(7, 7), nrows=5, ncols=1) plt.ion() for i in range(5): # for j in range(5): axes[i].imshow(np.reshape(samples[epoch][i], (28, 28)), cmap='gray') #plt.draw() plt.ioff() plt.show()# 生成新的图片saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="GEN"))with tf.Session() as sess: saver.restore(sess, tf.train.latest_checkpoint('checkpoints')) sample_noise = np.random.uniform(-1, 1, size=(25, 100)) gen_samples = sess.run(G_out, {G_in: sample_noise})view_samples(10, gen_samples)
阅读全文
0 0
- 文章标题文章标题文章标题文章标题文章标题文章标题文章标题文章标题文章标题文章标题文章标题文章标题文章标题文章标题文章标题文章标题文章标题
- 文章标题
- 文章标题
- 文章标题
- 文章标题 文章标题 文章标题 文章标题
- 文章标题
- 文章标题
- 文章标题
- 文章标题
- 文章标题
- 文章标题
- 文章标题
- 文章标题
- 文章标题
- 文章标题
- 文章标题
- 文章标题
- 文章标题
- 最全最好用的Android Studio插件整理
- Git查看、删除、重命名远程分支和tag
- 8.接口与抽象类
- 为什么有些手机连接代理后上不了--Fiddler工具安装介绍
- android 发布版本 自动更新 注意事项
- 文章标题
- Spring Cloud推荐开源项目(AG-ADMIN)
- machine learning
- 实用的shell命令
- Linux学习笔记(2)--Linux安装
- GitHub 上最火的 Java 框架
- 趣图:现实和理想的差距
- mfc c++ 多线程AfxBeginThread 例子( 一)
- OpenGL场景截取后存储(BMP图片)