生成对抗网络DCGAN+Tensorflow代码学习笔记(一)----main.py

来源:互联网 发布:mac cosmetics美国官网 编辑:程序博客网 时间:2024/05/19 12:37

深度学习中对图像处理应用最好的模型是CNN,把CNN与GAN结合便产生了DCGAN。本文主要讨论tensorflow版本代码,代码地址为:https://github.com/Newmu/dcgan_code 

本文主要研究main.py,该文件主要是调用定义好的模型,图像处理方法,来进行训练或者测试,是整个程序的入口。

执行main函数之前首先进行flags的解析,TensorFlow底层使用了python-gflags项目,然后封装成tf.app.flags接口,也就是说TensorFlow通过设置flags来传递tf.app.run()所需要的参数,我们可以直接在程序运行前初始化flags,也可以在运行程序的时候设置命令行参数来达到传参的目的。

FLAGS参数:

  1. epoch:迭代次数
  2. learning_rate:Adam学习速率,默认是0.002
  3. beta1:Adam的动量项(Momentum term of Adam),默认为0.5 
  4. train_size:训练图像的个数,默认为np.inf 
  5. batch_size:每次迭代的图像数量
  6. input_height:需要指定输入图像的高
  7. input_width:需要指定输入图像的宽
  8. output_height:需要指定输出图像的高
  9. output_width:需要指定输出图像的宽
  10. dataset:需要指定处理哪个数据集
  11. input_fname_pattern:输入的图片类型,默认为*.jpg 
  12. checkpoint_dir:存放checkpoint的目录名
  13. sample_dir:存放生成图片的目录名
  14. train:True for training, False for testing
  15. crop:True for training, False for testing
  16. visualize:可视化为True,不可视化为False,默认为False
import osimport scipy.miscimport numpy as npfrom model import DCGANfrom utils import pp, visualize, to_json, show_all_variablesimport tensorflow as tfflags = tf.app.flagsflags.DEFINE_integer("epoch", 25, "Epoch to train [25]")flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]")flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]")flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]")flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")flags.DEFINE_integer("generate_test_images", 100, "Number of images to generate during test. [100]")FLAGS = flags.FLAGSdef main(_):# 打印参数数据,然后判断输入图像的输出图像的宽是否指定,如果没有指定,则等于其图像的高。  pp.pprint(flags.FLAGS.__flags)  if FLAGS.input_width is None:    FLAGS.input_width = FLAGS.input_height  if FLAGS.output_width is None:    FLAGS.output_width = FLAGS.output_height#判断checkpoint和sample的文件是否存在,不存在则创建。  if not os.path.exists(FLAGS.checkpoint_dir):    os.makedirs(FLAGS.checkpoint_dir)  if not os.path.exists(FLAGS.sample_dir):    os.makedirs(FLAGS.sample_dir)#tf.ConfigProto一般用在创建session的时候,用来对session进行参数配置  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)# 使用allow_growth option,刚一开始分配少量的GPU容量,然后按需慢慢的增加,由于不会释放内存,所以会导致碎片  run_config = tf.ConfigProto()  run_config.gpu_options.allow_growth=True#运行session,首先判断处理的是哪个数据集,然后对应使用不同参数的DCGAN类,这个类会在model.py文件中定义。  with tf.Session(config=run_config) as sess:    if FLAGS.dataset == 'mnist':      dcgan = DCGAN(          sess,          input_width=FLAGS.input_width,          input_height=FLAGS.input_height,          output_width=FLAGS.output_width,          output_height=FLAGS.output_height,          batch_size=FLAGS.batch_size,          sample_num=FLAGS.batch_size,          y_dim=10,          z_dim=FLAGS.generate_test_images,          dataset_name=FLAGS.dataset,          input_fname_pattern=FLAGS.input_fname_pattern,          crop=FLAGS.crop,          checkpoint_dir=FLAGS.checkpoint_dir,          sample_dir=FLAGS.sample_dir)    else:      dcgan = DCGAN(          sess,          input_width=FLAGS.input_width,          input_height=FLAGS.input_height,          output_width=FLAGS.output_width,          output_height=FLAGS.output_height,          batch_size=FLAGS.batch_size,          sample_num=FLAGS.batch_size,          z_dim=FLAGS.generate_test_images,          dataset_name=FLAGS.dataset,          input_fname_pattern=FLAGS.input_fname_pattern,          crop=FLAGS.crop,          checkpoint_dir=FLAGS.checkpoint_dir,          sample_dir=FLAGS.sample_dir)#show所有与训练相关的变量    show_all_variables()#判断是训练还是测试,如果是训练,则进行训练;如果不是,判断是否有训练好的model,# 然后进行测试,如果没有先训练,则会提示“[!] Train a model first, then run test mode”。    if FLAGS.train:      dcgan.train(FLAGS)    else:      if not dcgan.load(FLAGS.checkpoint_dir)[0]:        raise Exception("[!] Train a model first, then run test mode")          # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],    #                 [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],    #                 [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],    #                 [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],    #                 [dcgan.h4_w, dcgan.h4_b, None])    #进行可视化,visualize(sess, dcgan, FLAGS, OPTION)    # Below is codes for visualization    OPTION = 1    visualize(sess, dcgan, FLAGS, OPTION)if __name__ == '__main__':  tf.app.run()



阅读全文
0 0
原创粉丝点击