tensorflow学习——GAN手写体生成
来源:互联网 发布:端口名称怎么查 编辑:程序博客网 时间:2024/05/29 21:31
import tensorflow as tfimport numpy as npimport pickleimport matplotlib.pyplot as pltfrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('MNIST_data/')#img = mnist.train.images[0]#plt.imshow(img.reshape((28, 28)), cmap='Greys_r')# 定义参数# 真实图像的sizeimg_size = mnist.train.images[0].shape[0]# 传入给generator的噪声sizenoise_size = 100# 生成器隐层参数g_units = 128# 判别器隐层参数d_units = 128# leaky ReLU的参数alpha = 0.01# learning_ratelearning_rate = 0.001# label smoothingsmooth = 0.1# batch_sizebatch_size = 64# 训练迭代轮数epochs = 50# 抽取样本数n_sample = 25class GAN(): def __init__(self, img_size, noise_size): self.real_img = tf.placeholder(tf.float32, [None, img_size], name='real_img') self.noise_img = tf.placeholder(tf.float32, [None, noise_size], name='noise_img') @staticmethod def get_generator( noise_img, n_units, out_dim, reuse=False, alpha=0.01): """ 生成器 noise_img: 生成器的输入 n_units: 隐层单元个数 out_dim: 生成器输出tensor的size,这里应该为32*32=784 alpha: leaky ReLU系数 """ with tf.variable_scope("generator", reuse=reuse): # hidden layer hidden1 = tf.layers.dense(noise_img, n_units) # leaky ReLU hidden1 = tf.maximum(alpha * hidden1, hidden1) # dropout hidden1 = tf.layers.dropout(hidden1, rate=0.2) # logits & outputs logits = tf.layers.dense(hidden1, out_dim) # shape=(?, 784) outputs = tf.tanh(logits) #shape=(?, 784) return logits, outputs @staticmethod def get_discriminator( img, n_units, reuse=False, alpha=0.01): """ 判别器 n_units: 隐层结点数量 alpha: Leaky ReLU系数 """ with tf.variable_scope("discriminator", reuse=reuse): # hidden layer 全连接层 hidden1 = tf.layers.dense(img, n_units) hidden1 = tf.maximum(alpha * hidden1, hidden1) # leaky relu # logits & outputs logits = tf.layers.dense(hidden1, 1) # 输出判别器的判断,真或假 outputs = tf.sigmoid(logits) return logits, outputs def draw_loss(self,losses): fig, ax = plt.subplots(figsize=(20,7)) losses = np.array(losses) plt.plot(losses.T[0], label='Discriminator Total Loss') # 取loss的每一列画曲线 plt.plot(losses.T[1], label='Discriminator Real Loss') plt.plot(losses.T[2], label='Discriminator Fake Loss') plt.plot(losses.T[3], label='Generator') plt.title("Training Losses") plt.legend() def inference(self): # generator shape=(?, 784) g_logits, g_outputs = GAN.get_generator(self.noise_img, g_units, img_size) # discriminator 输出判断的结果,真或假 d_logits_real, d_outputs_real = GAN.get_discriminator(self.real_img, d_units) d_logits_fake, d_outputs_fake = GAN.get_discriminator(g_outputs, d_units, reuse=True) # 识别真实图片, 让判断1的项更soft, 防止过拟合的方式 self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, labels=tf.ones_like(d_logits_real)) * (1 - smooth)) # 识别生成的图片 self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake))) # 总体loss self.d_loss = tf.add(self.d_loss_real, self.d_loss_fake) # generator的loss self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.ones_like(d_logits_fake)) * (1 - smooth)) train_vars = tf.trainable_variables() # 返回的是需要训练的变量列表 # generator中的tensor ,遍历列表,找到以指定单词开头的变量 self.g_vars = [var for var in train_vars if var.name.startswith("generator")] # discriminator中的tensor self.d_vars = [var for var in train_vars if var.name.startswith("discriminator")] # optimizer,,只优化指定的变量,其他保持不变 d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(self.d_loss, var_list=self.d_vars) g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(self.g_loss, var_list=self.g_vars) self.saver = tf.train.Saver(var_list=self.g_vars) return d_train_opt,g_train_opt# @classmethod def train(self, d_train_opt,g_train_opt):# cls().inference() samples = [] # 存储loss losses = [] # 开始训练 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for e in range(epochs): for batch_i in range(mnist.train.num_examples//batch_size): batch = mnist.train.next_batch(batch_size) batch_images = batch[0].reshape((batch_size, 784)) # 对图像像素进行scale,这是因为tanh输出的结果介于(-1,1),real和fake图片共享discriminator的参数 batch_images = batch_images*2 - 1 # generator的输入噪声 batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size)) # Run optimizers _ = sess.run(d_train_opt, feed_dict={self.real_img: batch_images, self.noise_img: batch_noise}) _ = sess.run(g_train_opt, feed_dict={self.noise_img: batch_noise}) # 每一轮结束计算loss train_loss_d = sess.run(self.d_loss, feed_dict = {self.real_img: batch_images, self.noise_img: batch_noise}) # real img loss train_loss_d_real = sess.run(self.d_loss_real, feed_dict = {self.real_img: batch_images, self.noise_img: batch_noise}) # fake img loss train_loss_d_fake = sess.run(self.d_loss_fake, feed_dict = {self.real_img: batch_images, self.noise_img: batch_noise}) # generator loss train_loss_g = sess.run(self.g_loss, feed_dict = {self.noise_img: batch_noise}) print("Epoch {}/{}...".format(e+1, epochs), "Discriminator Loss: {:.4f}(Real: {:.4f} + Fake: {:.4f})...".format(train_loss_d, train_loss_d_real, train_loss_d_fake), "Generator Loss: {:.4f}".format(train_loss_g)) # 记录各类loss值 losses.append((train_loss_d, train_loss_d_real, train_loss_d_fake, train_loss_g)) # 抽取样本后期进行观察 sample_noise = np.random.uniform(-1, 1, size=(n_sample, noise_size)) gen_samples = sess.run(GAN.get_generator(self.noise_img, g_units, img_size, reuse=True), feed_dict={self.noise_img: sample_noise}) samples.append(gen_samples) # 每轮训练后都保存一下使用数据生成的图片 # 存储checkpoints self.saver.save(sess, './checkpoints/generator_1.ckpt') with open('train_samples_1.pkl', 'wb') as f: pickle.dump(samples, f) return losses @staticmethod def view_samples( samples): fig, axes = plt.subplots(figsize=(7,7), nrows=5, ncols=5, sharey=True, sharex=True) for ax, img in zip(axes.flatten(), samples): # 这里samples[epoch][1]代表生成的图像结果,而[0]代表对应的logits# ax.xaxis.set_visible(False)# ax.yaxis.set_visible(False) ax.axis('off') ax.imshow(img.reshape((28,28)), cmap='Greys_r') return fig, axes def draw_samples(self): epoch_idx = [0, 5, 10, 20, 40] # 一共300轮,不要越界 show_imgs = [] with open('train_samples_1.pkl', 'rb') as f: samples = pickle.load(f) for i in epoch_idx: show_imgs.append(samples[i][1]) # 指定图片形状 rows, cols = 10, 25 fig, axes = plt.subplots(figsize=(30,12), nrows=rows, ncols=cols, sharex=True, sharey=True) for sample, ax_row in zip(show_imgs, axes): for img, ax in zip(sample, ax_row): ax.imshow(img.reshape((28,28)), cmap='Greys_r') ax.axis('off') def test(self): with tf.Session() as sess: self.saver.restore(sess, tf.train.latest_checkpoint('checkpoints')) sample_noise = np.random.uniform(-1, 1, size=(25, 100)) logits,gen_samples = sess.run(GAN.get_generator(self.noise_img, g_units, img_size, reuse=True), feed_dict={self.noise_img: sample_noise}) GAN.view_samples( gen_samples) return gen_samplesgan = GAN(img_size, noise_size)d_train_opt,g_train_opt=gan.inference()#losses=gan.train(d_train_opt,g_train_opt)#gan.draw_loss(losses)#plt.savefig("losses_2.jpg") #gan.draw_samples()gen_samples=gan.test()
阅读全文
0 0
- tensorflow学习——GAN手写体生成
- tensorflow学习——DCGAN手写体生成
- 生成对抗网络GAN入门——生成mnist手写体
- CNN学习(三)—Tensorflow 进行MNIST手写体识别
- Tensorflow学习:MINIST手写体
- 深度学习笔记——TensorFlow学习笔记(三)使用TensorFlow实现的神经网络进行MNIST手写体数字识别
- GAN—生成对抗网络
- TensorFlow小试牛刀(2):GAN生成手写数字
- GAN生成对抗网络的TensorFlow实现
- 《白话深度学习与Tensorflow》学习笔记(6)生成式对抗网络(GAN)
- 无监督学习之深度生成模型——生成对抗网络GAN
- GAN学习笔记(一)——初探GAN
- 生成对抗网络GAN学习
- DCGAN例子学习-MNIST 手写体数字生成
- 【tensorflow学习】最简单的GAN 实现
- 【tensorflow学习】最简单的GAN 实现
- 深度学习的三大生成模型:VAE、GAN、GAN
- tensorflow手写体识别实例
- Ext JS 构造函数、私有变量和静态变量
- js的常用功能及属性总结
- Andrew Ng机器学习课程笔记--week9(上)
- golang net包基础解析
- hdu-1370(中国剩余定理余数互质)&&hdu-1573(中国剩余定理余数不互质)
- tensorflow学习——GAN手写体生成
- 高效的枚举元素集合
- 和为S的连续正数序列
- SSL2688 2017年8月14日提高组T2 温度
- 2017.8.14
- Java中运用位运算符的屏蔽技术求得整数的各个位
- 自练题20170725
- 17.8.14B组总结
- Linked List Cycle leetcode java (链表检测环)