生成对抗网络DCGAN+Tensorflow代码学习笔记(二)----utils.py
来源:互联网 发布:淘宝店铺标志尺寸 编辑:程序博客网 时间:2024/06/14 00:00
utils.py主要是定义了各种对图像处理的函数,主要负责图像的一些基本操作,获取图像、保存图像、图像翻转,和利用moviepy模块可视化训练过程。相当于其他3个文件的头文件。
"""Some codes from https://github.com/Newmu/dcgan_code"""from __future__ import divisionimport mathimport jsonimport randomimport pprintimport scipy.miscimport numpy as npfrom time import gmtime, strftimefrom six.moves import xrangeimport tensorflow as tfimport tensorflow.contrib.slim as slim#1.首先定义了一个pp = pprint.PrettyPrinter(),以方便打印数据结构信息pp = pprint.PrettyPrinter()#2.定义了get_stddev函数,是三个参数乘积后开平方的倒数,应该是为了随机化用get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])def show_all_variables(): model_vars = tf.trainable_variables() slim.model_analyzer.analyze_vars(model_vars, print_info=True)def get_image(image_path, input_height, input_width, resize_height=64, resize_width=64, crop=True, grayscale=False):#根据图像路径参数读取路径,根据灰度化参数选择是否进行灰度化 image = imread(image_path, grayscale)#对图像参照输入的参数进行裁剪 return transform(image, input_height, input_width, resize_height, resize_width, crop)#存储新图像def save_images(images, size, image_path): return imsave(inverse_transform(images), size, image_path)#判断grayscale参数是否进行范围灰度化,并进行类型转换为np.floatdef imread(path, grayscale = False): if (grayscale): return scipy.misc.imread(path, flatten = True).astype(np.float) else: return scipy.misc.imread(path).astype(np.float)#返回新图像def merge_images(images, size): return inverse_transform(images)# 产生一个大画布,用来保存生成的 batch_size 个图像def merge(images, size): #图像的高、宽 h, w = images.shape[1], images.shape[2] if (images.shape[3] in (3,4)): #图像的通道数 c = images.shape[3] # 循环使得画布特定地方值为某一幅图像的值 img = np.zeros((h * size[0], w * size[1], c)) for idx, image in enumerate(images): i = idx % size[1] j = idx // size[1] img[j * h:j * h + h, i * w:i * w + w, :] = image return img elif images.shape[3]==1: img = np.zeros((h * size[0], w * size[1])) for idx, image in enumerate(images): i = idx % size[1] j = idx // size[1] img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0] return img else: raise ValueError('in merge(images,size) images parameter ' 'must have dimensions: HxW or HxWx3 or HxWx4')#首先将merge()函数返回的图像,用 np.squeeze()函数移除长度为1的轴。# 然后利用scipy.misc.imsave()函数将新图像保存到指定路径中。def imsave(images, size, path): image = np.squeeze(merge(images, size)) return scipy.misc.imsave(path, image)#对图像的H和W与crop的H和W相减,得到取整的值,根据这个值作为下标依据来scipy.misc.resize图像。def center_crop(x, crop_h, crop_w, resize_h=64, resize_w=64): if crop_w is None: crop_w = crop_h h, w = x.shape[:2] j = int(round((h - crop_h)/2.)) i = int(round((w - crop_w)/2.)) return scipy.misc.imresize( x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w])#对输入的图像进行裁剪def transform(image, input_height, input_width, resize_height=64, resize_width=64, crop=True): if crop: cropped_image = center_crop( image, input_height, input_width, resize_height, resize_width) else: cropped_image = scipy.misc.imresize(image, [resize_height, resize_width]) return np.array(cropped_image)/127.5 - 1.#对图像进行翻转后返回新图像def inverse_transform(images): return (images+1.)/2.#获取每一层的权值、偏置值def to_json(output_path, *layers): with open(output_path, "w") as layer_f: lines = "" for w, b, bn in layers: layer_idx = w.name.split('/')[0].split('h')[1] B = b.eval() if "lin/" in w.name: W = w.eval() depth = W.shape[1] else: W = np.rollaxis(w.eval(), 2, 0) depth = W.shape[0] biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]} if bn != None: gamma = bn.gamma.eval() beta = bn.beta.eval() gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]} beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]} else: gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []} beta = {"sy": 1, "sx": 1, "depth": 0, "w": []} if "lin/" in w.name: fs = [] for w in W.T: fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]}) lines += """ var layer_%s = { "layer_type": "fc", "sy": 1, "sx": 1, "out_sx": 1, "out_sy": 1, "stride": 1, "pad": 0, "out_depth": %s, "in_depth": %s, "biases": %s, "gamma": %s, "beta": %s, "filters": %s };""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs) else: fs = [] for w_ in W: fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]}) lines += """ var layer_%s = { "layer_type": "deconv", "sy": 5, "sx": 5, "out_sx": %s, "out_sy": %s, "stride": 2, "pad": 1, "out_depth": %s, "in_depth": %s, "biases": %s, "gamma": %s, "beta": %s, "filters": %s };""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2), W.shape[0], W.shape[3], biases, gamma, beta, fs) layer_f.write(" ".join(lines.replace("'","").split()))#根据图像集的长度和持续的时间做一个除法,然后返回每帧图像。最后视频修剪并制作成GIF动画。def make_gif(images, fname, duration=2, true_image=False): import moviepy.editor as mpy def make_frame(t): try: x = images[int(len(images)/duration*t)] except: x = images[-1] if true_image: return x.astype(np.uint8) else: return ((x+1)/2*255).astype(np.uint8) clip = mpy.VideoClip(make_frame, duration=duration) clip.write_gif(fname, fps = len(images) / duration)#保存图像可视化def visualize(sess, dcgan, config, option): image_frame_dim = int(math.ceil(config.batch_size**.5)) if option == 0: z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim)) samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime())) elif option == 1: values = np.arange(0, 1, 1./config.batch_size) for idx in xrange(dcgan.z_dim): print(" [*] %d" % idx) z_sample = np.random.uniform(-1, 1, size=(config.batch_size , dcgan.z_dim)) for kdx, z in enumerate(z_sample): z[idx] = values[kdx] if config.dataset == "mnist": y = np.random.choice(10, config.batch_size) y_one_hot = np.zeros((config.batch_size, 10)) y_one_hot[np.arange(config.batch_size), y] = 1 samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot}) else: samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_arange_%s.png' % (idx)) elif option == 2: values = np.arange(0, 1, 1./config.batch_size) for idx in [random.randint(0, dcgan.z_dim - 1) for _ in xrange(dcgan.z_dim)]: print(" [*] %d" % idx) z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim)) z_sample = np.tile(z, (config.batch_size, 1)) #z_sample = np.zeros([config.batch_size, dcgan.z_dim]) for kdx, z in enumerate(z_sample): z[idx] = values[kdx] if config.dataset == "mnist": y = np.random.choice(10, config.batch_size) y_one_hot = np.zeros((config.batch_size, 10)) y_one_hot[np.arange(config.batch_size), y] = 1 samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot}) else: samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) try: make_gif(samples, './samples/test_gif_%s.gif' % (idx)) except: save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime())) elif option == 3: values = np.arange(0, 1, 1./config.batch_size) for idx in xrange(dcgan.z_dim): print(" [*] %d" % idx) z_sample = np.zeros([config.batch_size, dcgan.z_dim]) for kdx, z in enumerate(z_sample): z[idx] = values[kdx] samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) make_gif(samples, './samples/test_gif_%s.gif' % (idx)) elif option == 4: image_set = [] values = np.arange(0, 1, 1./config.batch_size) for idx in xrange(dcgan.z_dim): print(" [*] %d" % idx) z_sample = np.zeros([config.batch_size, dcgan.z_dim]) for kdx, z in enumerate(z_sample): z[idx] = values[kdx] image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})) make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx)) new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \ for idx in range(64) + range(63, -1, -1)] make_gif(new_image_set, './samples/test_gif_merged.gif', duration=8)def image_manifold_size(num_images): manifold_h = int(np.floor(np.sqrt(num_images))) manifold_w = int(np.ceil(np.sqrt(num_images))) assert manifold_h * manifold_w == num_images return manifold_h, manifold_w
阅读全文
0 0
- 生成对抗网络DCGAN+Tensorflow代码学习笔记(二)----utils.py
- 生成对抗网络DCGAN+Tensorflow代码学习笔记(一)----main.py
- 生成对抗网络DCGAN+Tensorflow代码学习笔记(三)----ops.py
- TensorFlow/对抗网络DCGAN生成图片
- 生成对抗网络学习笔记5----DCGAN(unsupervised representation learning with deep convolutional generative adv)的实现
- 生成对抗网络学习笔记5----DCGAN(unsupervised representation learning with deep convolutional generative adv)的实现
- 深度卷积对抗生成网络(DCGAN)
- 深度卷积对抗生成网络(DCGAN)
- 深度卷积对抗生成网络(DCGAN)
- 深度卷积对抗生成网络(DCGAN)
- 深度卷积生成对抗网络--DCGAN
- 《白话深度学习与Tensorflow》学习笔记(6)生成式对抗网络(GAN)
- 生成对抗网络的简单介绍(TensorFlow 代码)
- 生成对抗网络简介(包含TensorFlow代码示例)【翻译】
- 生成对抗网络(GAN)原理+tensorflow代码实现
- 生成对抗网络介绍(附TensorFlow代码)
- 学习笔记-对抗生成网络
- 【神经网络与深度学习】生成式对抗网络GAN研究进展(五)——Deep Convolutional Generative Adversarial Nerworks,DCGAN
- 网络远程连接
- python文件的写入write()
- H2数据库入门Demo(一)
- OOP,重写与重载,异常处理机制,多线程,集合框架,IO流 -- Java基础复习
- hdu 1272 小希的迷宫
- 生成对抗网络DCGAN+Tensorflow代码学习笔记(二)----utils.py
- VS项目迁移到linux环境中Makefile相关小问题集锦
- cuda-covnet 深度学习工具的权值转化为txt 方便cpp源码调用
- 吴恩达Coursera机器学习课程笔记-定义分类
- QT中使用webView控件时报错
- 最大似然估计(MLE)和最大后验概率(MAP)
- QQ聊天气泡拖动效果实现
- 不想去读spring庞大源码,欲了解其内部原理来读此文
- RabbitMQ入门教程(四):工作队列(Work Queues)