生成对抗网络DCGAN+Tensorflow代码学习笔记(二)----utils.py

来源:互联网 发布:淘宝店铺标志尺寸 编辑:程序博客网 时间:2024/06/14 00:00

utils.py主要是定义了各种对图像处理的函数,主要负责图像的一些基本操作,获取图像、保存图像、图像翻转,和利用moviepy模块可视化训练过程。相当于其他3个文件的头文件。

"""Some codes from https://github.com/Newmu/dcgan_code"""from __future__ import divisionimport mathimport jsonimport randomimport pprintimport scipy.miscimport numpy as npfrom time import gmtime, strftimefrom six.moves import xrangeimport tensorflow as tfimport tensorflow.contrib.slim as slim#1.首先定义了一个pp = pprint.PrettyPrinter(),以方便打印数据结构信息pp = pprint.PrettyPrinter()#2.定义了get_stddev函数,是三个参数乘积后开平方的倒数,应该是为了随机化用get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])def show_all_variables():  model_vars = tf.trainable_variables()  slim.model_analyzer.analyze_vars(model_vars, print_info=True)def get_image(image_path, input_height, input_width,              resize_height=64, resize_width=64,              crop=True, grayscale=False):#根据图像路径参数读取路径,根据灰度化参数选择是否进行灰度化  image = imread(image_path, grayscale)#对图像参照输入的参数进行裁剪  return transform(image, input_height, input_width,                   resize_height, resize_width, crop)#存储新图像def save_images(images, size, image_path):  return imsave(inverse_transform(images), size, image_path)#判断grayscale参数是否进行范围灰度化,并进行类型转换为np.floatdef imread(path, grayscale = False):  if (grayscale):    return scipy.misc.imread(path, flatten = True).astype(np.float)  else:    return scipy.misc.imread(path).astype(np.float)#返回新图像def merge_images(images, size):  return inverse_transform(images)# 产生一个大画布,用来保存生成的 batch_size 个图像def merge(images, size):  #图像的高、宽  h, w = images.shape[1], images.shape[2]  if (images.shape[3] in (3,4)):    #图像的通道数    c = images.shape[3]    # 循环使得画布特定地方值为某一幅图像的值    img = np.zeros((h * size[0], w * size[1], c))    for idx, image in enumerate(images):      i = idx % size[1]      j = idx // size[1]      img[j * h:j * h + h, i * w:i * w + w, :] = image    return img  elif images.shape[3]==1:    img = np.zeros((h * size[0], w * size[1]))    for idx, image in enumerate(images):      i = idx % size[1]      j = idx // size[1]      img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]    return img  else:    raise ValueError('in merge(images,size) images parameter '                     'must have dimensions: HxW or HxWx3 or HxWx4')#首先将merge()函数返回的图像,用 np.squeeze()函数移除长度为1的轴。# 然后利用scipy.misc.imsave()函数将新图像保存到指定路径中。def imsave(images, size, path):  image = np.squeeze(merge(images, size))  return scipy.misc.imsave(path, image)#对图像的H和W与crop的H和W相减,得到取整的值,根据这个值作为下标依据来scipy.misc.resize图像。def center_crop(x, crop_h, crop_w,                resize_h=64, resize_w=64):  if crop_w is None:    crop_w = crop_h  h, w = x.shape[:2]  j = int(round((h - crop_h)/2.))  i = int(round((w - crop_w)/2.))  return scipy.misc.imresize(      x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w])#对输入的图像进行裁剪def transform(image, input_height, input_width,               resize_height=64, resize_width=64, crop=True):  if crop:    cropped_image = center_crop(      image, input_height, input_width,       resize_height, resize_width)  else:    cropped_image = scipy.misc.imresize(image, [resize_height, resize_width])  return np.array(cropped_image)/127.5 - 1.#对图像进行翻转后返回新图像def inverse_transform(images):  return (images+1.)/2.#获取每一层的权值、偏置值def to_json(output_path, *layers):  with open(output_path, "w") as layer_f:    lines = ""    for w, b, bn in layers:      layer_idx = w.name.split('/')[0].split('h')[1]      B = b.eval()      if "lin/" in w.name:        W = w.eval()        depth = W.shape[1]      else:        W = np.rollaxis(w.eval(), 2, 0)        depth = W.shape[0]      biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]}      if bn != None:        gamma = bn.gamma.eval()        beta = bn.beta.eval()        gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]}        beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]}      else:        gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []}        beta = {"sy": 1, "sx": 1, "depth": 0, "w": []}      if "lin/" in w.name:        fs = []        for w in W.T:          fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]})        lines += """          var layer_%s = {            "layer_type": "fc",             "sy": 1, "sx": 1,             "out_sx": 1, "out_sy": 1,            "stride": 1, "pad": 0,            "out_depth": %s, "in_depth": %s,            "biases": %s,            "gamma": %s,            "beta": %s,            "filters": %s          };""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs)      else:        fs = []        for w_ in W:          fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]})        lines += """          var layer_%s = {            "layer_type": "deconv",             "sy": 5, "sx": 5,            "out_sx": %s, "out_sy": %s,            "stride": 2, "pad": 1,            "out_depth": %s, "in_depth": %s,            "biases": %s,            "gamma": %s,            "beta": %s,            "filters": %s          };""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2),               W.shape[0], W.shape[3], biases, gamma, beta, fs)    layer_f.write(" ".join(lines.replace("'","").split()))#根据图像集的长度和持续的时间做一个除法,然后返回每帧图像。最后视频修剪并制作成GIF动画。def make_gif(images, fname, duration=2, true_image=False):  import moviepy.editor as mpy  def make_frame(t):    try:      x = images[int(len(images)/duration*t)]    except:      x = images[-1]    if true_image:      return x.astype(np.uint8)    else:      return ((x+1)/2*255).astype(np.uint8)  clip = mpy.VideoClip(make_frame, duration=duration)  clip.write_gif(fname, fps = len(images) / duration)#保存图像可视化def visualize(sess, dcgan, config, option):  image_frame_dim = int(math.ceil(config.batch_size**.5))  if option == 0:    z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim))    samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})    save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime()))  elif option == 1:    values = np.arange(0, 1, 1./config.batch_size)    for idx in xrange(dcgan.z_dim):      print(" [*] %d" % idx)      z_sample = np.random.uniform(-1, 1, size=(config.batch_size , dcgan.z_dim))      for kdx, z in enumerate(z_sample):        z[idx] = values[kdx]      if config.dataset == "mnist":        y = np.random.choice(10, config.batch_size)        y_one_hot = np.zeros((config.batch_size, 10))        y_one_hot[np.arange(config.batch_size), y] = 1        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot})      else:        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})      save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_arange_%s.png' % (idx))  elif option == 2:    values = np.arange(0, 1, 1./config.batch_size)    for idx in [random.randint(0, dcgan.z_dim - 1) for _ in xrange(dcgan.z_dim)]:      print(" [*] %d" % idx)      z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim))      z_sample = np.tile(z, (config.batch_size, 1))      #z_sample = np.zeros([config.batch_size, dcgan.z_dim])      for kdx, z in enumerate(z_sample):        z[idx] = values[kdx]      if config.dataset == "mnist":        y = np.random.choice(10, config.batch_size)        y_one_hot = np.zeros((config.batch_size, 10))        y_one_hot[np.arange(config.batch_size), y] = 1        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot})      else:        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})      try:        make_gif(samples, './samples/test_gif_%s.gif' % (idx))      except:        save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime()))  elif option == 3:    values = np.arange(0, 1, 1./config.batch_size)    for idx in xrange(dcgan.z_dim):      print(" [*] %d" % idx)      z_sample = np.zeros([config.batch_size, dcgan.z_dim])      for kdx, z in enumerate(z_sample):        z[idx] = values[kdx]      samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})      make_gif(samples, './samples/test_gif_%s.gif' % (idx))  elif option == 4:    image_set = []    values = np.arange(0, 1, 1./config.batch_size)    for idx in xrange(dcgan.z_dim):      print(" [*] %d" % idx)      z_sample = np.zeros([config.batch_size, dcgan.z_dim])      for kdx, z in enumerate(z_sample): z[idx] = values[kdx]      image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}))      make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx))    new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \        for idx in range(64) + range(63, -1, -1)]    make_gif(new_image_set, './samples/test_gif_merged.gif', duration=8)def image_manifold_size(num_images):  manifold_h = int(np.floor(np.sqrt(num_images)))  manifold_w = int(np.ceil(np.sqrt(num_images)))  assert manifold_h * manifold_w == num_images  return manifold_h, manifold_w


阅读全文
0 0