tensorflow1.1/GAN生成对抗网络
来源:互联网 发布:淘宝商城休闲女鞋 编辑:程序博客网 时间:2024/06/06 00:10
环境:tensorflow1.1,python3, matplotlib2.02
生成式对抗网络(GAN)是近年来大热的深度学习模型,以生成图片为例进行说明。假设我们有两个网络,G(Generator)和D(Discriminator)G是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。
在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。这样,G和D构成了一个动态的“博弈过程”。
在最理想的状态下,G可以生成足以“以假乱真”的图片G(z)。对于D来说,它难以判定G生成的图片究竟是不是真实的,因此D(G(z)) = 0.5
#coding:utf-8"""python 3 tensorflow 1.1"""import tensorflow as tfimport matplotlib.pyplot as pltimport numpy as npimport input_datamnist = input_data.read_data_sets('mnist/',one_hot=True)total_epoch = 100batch_size = 100learning_rate = 0.0002n_hidden = 256n_input = 28*28n_noise = 128#Descriminator网络输入图片形状x = tf.placeholder(tf.float32,[None,n_input])#Generator网络输入的是噪声z = tf.placeholder(tf.float32,[None,n_noise])#Generator网络的权重和偏置Generator_param={ 'gw_1':tf.Variable(tf.random_normal([n_noise,n_hidden],stddev=0.1)), 'gb_1':tf.Variable(tf.zeros([n_hidden])), 'gw_2':tf.Variable(tf.random_normal([n_hidden,n_input],stddev=0.1)), 'gb_2':tf.Variable(tf.zeros([n_input]))}#Discriminator网络权重和偏置Discriminator_param={ 'dw_1':tf.Variable(tf.random_normal([n_input,n_hidden],stddev=0.1)), 'db_1':tf.Variable(tf.zeros([n_hidden])), 'dw_2':tf.Variable(tf.random_normal([n_hidden,1],stddev=0.1)), 'db_2':tf.Variable(tf.zeros([1]))}#构建Generator网络def generator(noise_z): hidden = tf.nn.relu(tf.matmul(noise_z,Generator_param['gw_1'])+Generator_param['gb_1']) output = tf.nn.sigmoid(tf.matmul(hidden,Generator_param['gw_2'])+Generator_param['gb_2']) return output#构建Discriminator网络def discriminator(inputs): hidden = tf.nn.relu(tf.matmul(inputs,Discriminator_param['dw_1'])+Discriminator_param['db_1']) output = tf.nn.sigmoid(tf.matmul(hidden,Discriminator_param['dw_2'])+Discriminator_param['db_2']) return output#生成网络根据噪声生成一张图片generator_output = generator(z)#判别网络根据生成网络生成的图片片别其真假概率discriminator_pred = discriminator(generator_output)#判别网络根据真实图片判别其真假概率discriminator_real = discriminator(x)#生成网络lossgenerator_loss = tf.reduce_mean(tf.log(discriminator_pred))#判别网络lossdiscriminator_loss = tf.reduce_mean(tf.log(discriminator_real)+tf.log(1 - discriminator_pred))generator_param_list=[Generator_param['gw_1'],Generator_param['gb_1'],Generator_param['gw_2'],Generator_param['gb_2']]discriminator_param_list=[Discriminator_param['dw_1'],Discriminator_param['db_1'],Discriminator_param['dw_2'],Discriminator_param['db_2']]generator_train = tf.train.AdamOptimizer(learning_rate).minimize(-generator_loss,var_list=generator_param_list)discriminator_train = tf.train.AdamOptimizer(learning_rate).minimize(-discriminator_loss,var_list=discriminator_param_list)with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init) total_batch = int(mnist.train.num_examples/batch_size) generator_c,discriminator_c = 0,0 #开始交互模式 plt.ion() for epoch in range(total_epoch): for i in range(total_batch): batch_x,batch_y = mnist.train.next_batch(batch_size) noise = np.random.normal(size=(batch_size,n_noise)) _,generator_c = sess.run([generator_train,generator_loss],feed_dict={z:noise}) _,discriminator_c = sess.run([discriminator_train,discriminator_loss],feed_dict={x:batch_x,z:noise}) if epoch % 10 ==0: print('epoch: ',int(epoch/10),'--generator_loss: %.4f' %generator_c,'--discriminator_loss: %.4f' %discriminator_c) #图片显示 if epoch % 10 == 0: new_batch = 3 noise = np.random.normal(size=(new_batch,n_noise)) #生成图像 samples = sess.run(generator_output,feed_dict={z:noise}) fig,a = plt.subplots(1,new_batch,figsize=(new_batch*2,2)) for i in range(new_batch): a[i].clear() a[i].set_axis_off() a[i].imshow(np.reshape(samples[i],(28,28))) plt.draw() plt.pause(0.1) plt.ioff() """ if epoch % 10 == 0: new_batch = 10 noise = np.random.normal(size=(new_batch,n_noise)) #生成图像 samples = sess.run(generator_output,feed_dict={z:noise}) fig,a = plt.subplots(1,new_batch,figsize=(new_batch,1)) for i in range(new_batch): a[i].set_axis_off() a[i].imshow(np.reshape(samples[i],(28,28))) plt.savefig('samples/%i.png' %int(epoch/10)) plt.close(fig) """
结果
随着迭代次数增加,生成网络生成的图片越来越接近真实图片
阅读全文
0 0
- tensorflow1.1/GAN生成对抗网络
- 生成对抗网络GAN
- GAN生成对抗网络
- 生成对抗网络-GAN
- Gan 生成对抗网络
- [生成对抗网络] GAN
- 生成对抗网络(GAN)
- 深入浅出生成对抗网络1-GAN入门
- 了解生成对抗网络GAN
- 生成对抗网络(GAN)
- 浅谈GAN生成对抗网络
- pytorch GAN生成对抗网络
- GAN 生成式对抗网络
- 生成对抗网络GAN学习
- GAN—生成对抗网络
- 生成式对抗网络GAN汇总
- 生成式对抗网络GAN汇总 研究进展
- 生成式对抗网络GAN汇总
- Mysql数据库操作常用命令
- C++中的public、protected和private
- HttpServletResponse和HttpServletRequest详解
- HttpServletResponse和HttpServletRequest详解
- Java 多线程之 synchronized 和 volatile 的比较
- tensorflow1.1/GAN生成对抗网络
- Spring Boot 入门教程
- java.io.EOFException
- 数据应用达人之SQL基础教程分享13-存储过程与事务
- Vue, App与我(二)
- redis最通俗介绍和在window下安装
- iOS测试
- 【Android】- bindService 之 leaked ServiceConnection错误
- C API向MySQL插入批量数据的快速方法——关于mysql_autocommit