tensorflow1.1/GAN生成对抗网络

来源:互联网 发布:淘宝商城休闲女鞋 编辑:程序博客网 时间:2024/06/06 00:10

环境:tensorflow1.1,python3, matplotlib2.02

生成式对抗网络(GAN)是近年来大热的深度学习模型,以生成图片为例进行说明。假设我们有两个网络,G(Generator)和D(Discriminator)G是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。
在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。这样,G和D构成了一个动态的“博弈过程”。
在最理想的状态下,G可以生成足以“以假乱真”的图片G(z)。对于D来说,它难以判定G生成的图片究竟是不是真实的,因此D(G(z)) = 0.5
这里写图片描述

#coding:utf-8"""python 3 tensorflow 1.1"""import tensorflow as tfimport matplotlib.pyplot as pltimport numpy as npimport input_datamnist = input_data.read_data_sets('mnist/',one_hot=True)total_epoch = 100batch_size = 100learning_rate = 0.0002n_hidden = 256n_input = 28*28n_noise = 128#Descriminator网络输入图片形状x = tf.placeholder(tf.float32,[None,n_input])#Generator网络输入的是噪声z = tf.placeholder(tf.float32,[None,n_noise])#Generator网络的权重和偏置Generator_param={    'gw_1':tf.Variable(tf.random_normal([n_noise,n_hidden],stddev=0.1)),    'gb_1':tf.Variable(tf.zeros([n_hidden])),    'gw_2':tf.Variable(tf.random_normal([n_hidden,n_input],stddev=0.1)),    'gb_2':tf.Variable(tf.zeros([n_input]))}#Discriminator网络权重和偏置Discriminator_param={    'dw_1':tf.Variable(tf.random_normal([n_input,n_hidden],stddev=0.1)),    'db_1':tf.Variable(tf.zeros([n_hidden])),    'dw_2':tf.Variable(tf.random_normal([n_hidden,1],stddev=0.1)),    'db_2':tf.Variable(tf.zeros([1]))}#构建Generator网络def generator(noise_z):    hidden = tf.nn.relu(tf.matmul(noise_z,Generator_param['gw_1'])+Generator_param['gb_1'])    output = tf.nn.sigmoid(tf.matmul(hidden,Generator_param['gw_2'])+Generator_param['gb_2'])    return output#构建Discriminator网络def discriminator(inputs):    hidden = tf.nn.relu(tf.matmul(inputs,Discriminator_param['dw_1'])+Discriminator_param['db_1'])    output = tf.nn.sigmoid(tf.matmul(hidden,Discriminator_param['dw_2'])+Discriminator_param['db_2'])    return output#生成网络根据噪声生成一张图片generator_output = generator(z)#判别网络根据生成网络生成的图片片别其真假概率discriminator_pred = discriminator(generator_output)#判别网络根据真实图片判别其真假概率discriminator_real = discriminator(x)#生成网络lossgenerator_loss = tf.reduce_mean(tf.log(discriminator_pred))#判别网络lossdiscriminator_loss = tf.reduce_mean(tf.log(discriminator_real)+tf.log(1 - discriminator_pred))generator_param_list=[Generator_param['gw_1'],Generator_param['gb_1'],Generator_param['gw_2'],Generator_param['gb_2']]discriminator_param_list=[Discriminator_param['dw_1'],Discriminator_param['db_1'],Discriminator_param['dw_2'],Discriminator_param['db_2']]generator_train = tf.train.AdamOptimizer(learning_rate).minimize(-generator_loss,var_list=generator_param_list)discriminator_train = tf.train.AdamOptimizer(learning_rate).minimize(-discriminator_loss,var_list=discriminator_param_list)with tf.Session() as sess:    init = tf.global_variables_initializer()    sess.run(init)    total_batch = int(mnist.train.num_examples/batch_size)    generator_c,discriminator_c = 0,0    #开始交互模式    plt.ion()    for epoch in range(total_epoch):        for i in range(total_batch):            batch_x,batch_y = mnist.train.next_batch(batch_size)            noise = np.random.normal(size=(batch_size,n_noise))            _,generator_c = sess.run([generator_train,generator_loss],feed_dict={z:noise})            _,discriminator_c = sess.run([discriminator_train,discriminator_loss],feed_dict={x:batch_x,z:noise})        if epoch % 10 ==0:            print('epoch: ',int(epoch/10),'--generator_loss: %.4f' %generator_c,'--discriminator_loss: %.4f' %discriminator_c)        #图片显示        if epoch % 10 == 0:            new_batch = 3            noise = np.random.normal(size=(new_batch,n_noise))            #生成图像            samples = sess.run(generator_output,feed_dict={z:noise})            fig,a = plt.subplots(1,new_batch,figsize=(new_batch*2,2))            for i in range(new_batch):                a[i].clear()                a[i].set_axis_off()                a[i].imshow(np.reshape(samples[i],(28,28)))            plt.draw()            plt.pause(0.1)    plt.ioff()    """        if epoch % 10 == 0:            new_batch = 10            noise = np.random.normal(size=(new_batch,n_noise))            #生成图像            samples = sess.run(generator_output,feed_dict={z:noise})            fig,a = plt.subplots(1,new_batch,figsize=(new_batch,1))            for i in range(new_batch):                a[i].set_axis_off()                a[i].imshow(np.reshape(samples[i],(28,28)))            plt.savefig('samples/%i.png' %int(epoch/10))            plt.close(fig)    """

结果

随着迭代次数增加,生成网络生成的图片越来越接近真实图片
这里写图片描述

这里写图片描述

这里写图片描述

这里写图片描述

原创粉丝点击