【Tensorflow】怎样为你的网络预加工和打包训练数据?(一)
来源:互联网 发布:淘宝客服忙吗 编辑:程序博客网 时间:2024/05/17 01:49
面对五花八门的数据集,各种各样的数据存储形式,刚新手入门的我们在处理这些情况的时候是否会手足无措?反正一路走来,我的经验告诉我,deep learning的实验阶段,数据准备和处理过程往往会让你碰一鼻子灰。明明知道如何搭建网络,还是完成不了实验,究其原因,是数据工程经验的不足。
我打算做这个系列,主要是记录针对不同种类,格式的数据的处理方案。数据预处理的首篇,我为大家展示一种常见情形的处理方法
一.问题背景
问题的背景是面对raw image数据集,但是图片按label为文件夹存放。以Office-31数据集为例。
Office-31数据集是一个用于迁移学习算法性能测试的数据集,我已经上传到网上,下载地址在下面:
domain_adaptation_images.part1.rar
domain_adaptation_images.part2.rar
权限不够,上传了两个分卷。解压完以后出现这个文件
下面又是三个文件,这三个就是不同环境下拍摄的图,我们只需要进amazon即可
最后这个文件夹下有各种各样的类,每一个类文件夹,相当于一个label。
进到具体label下面,则出现各种各样的图片。
描述这样一个问题背景是有意义的,因为实际上很多图片数据集都是以这样的形式来存放。
以此为范例,下面来记录一个这个问题的具体解决方案。
二.解决方案
首先说一下需要用到的辅助工具,前一篇讲到的skimage(【Tensorflow】辅助工具篇——scikit-image介绍),cPickle,matplotlib
鉴于这里有三个domain的数据,我们只做amazon这个文件夹下图片的处理
先上代码。
def build_dataset(data_dir, out_dir, weight=100,hight=100): data_dir = os.path.join(data_dir,"images") for _, dirnames, _ in os.walk(data_dir): for dirname in dirnames: index = dirnames.index(dirname) workdir = os.path.join(data_dir, dirname) #images = io.imread_collection(workdir + '/*.jpg') processed_images = io.ImageCollection(workdir + '/*.jpg', load_func=process_image, weight=weight,hight=hight) label = np.full(len(processed_images), fill_value=index, dtype=np.int32) images = io.concatenate_images(processed_images) if index == 0: data = images labels = label else: data = np.vstack((data,images)) labels = np.append(labels,label) if not os.path.exists(out_dir): os.makedirs(out_dir) print "data shape:",data.shape print "label shape:",labels.shape save_pickle(data, out_dir+'/'+'amazon_images.pkl') save_pickle(labels, out_dir+'/'+'amazon_labels.pkl')
解决思路还是比较传统的。首先要遍历文件夹,对于每一个文件夹下面的所有图片,用skimage批量读出来
读取的过程是通过imread_collection函数将所有jgp图片读取出来,返回一个类(注意此时这个类并不是np数组,而是skimage中的ImageCollection类,所以他并不能直接使用,我们要通过concatenate_images函数将多个图片连接起来成为一个np数组)
但是我们没有使用imread_collection函数,而是使用了ImageCollection类的构造函数,直接构造一个ImageCollection类,主要是因为如果图像的大小像素不同会导致连接的时候报错(维度不同),所以我们要先完成图像的预处理,处理完了将所有的图resize到相同的大小。构造ImageCollection类的时候可以load进去一个处理函数,在这里是process_image函数:
def process_image(image, weight, hight): img = io.imread(image) img = transform.resize(img, (weight,hight), mode='reflect') return img
当然process_image函数里面我们还可以添加其他内容(裁剪,填充等)
另外,如果是可以保证原始图像的像素全部相等,那么我们也可以imread_collection读进来以后统一处理。这里我们主要针对的是更复杂的情况。
最后,使用pkl文件来保存。
def save_pickle(data, path): with open(path, 'wb') as f: pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) print ('Saved %s..' %path)
在数据量不大的情况下,pkl是一种常用的保存手段,同时使用gzip来压缩,(我这里为了方便没有用),最常见的mnist就是用的pkl.gz这种后缀。同时cPickle又是pickle的升级版,压缩率好过pickle,大家可以尝试一下。
但是在数据量很大的清况下,我们一般使用hdf5,hdf5在性能方面是好过cPickle很多。这种方法后面会介绍。
当然也可以构造图片预处理的pipeline。这种方法是所有方案的终极版,专门针对超大数据集(ImageNet,CoCo)不可能全部load到内存中使用的,例如用CoCo数据集来做style transfer训练的时候用的就是线程读图片的方式,同时这种方法也是最难去实现的,同样后面也会介绍。
大功告成了!最后看一看结果吧。
同样还是用matplotlib来显示多个图片
import cPicklefrom mpl_toolkits.axes_grid1 import ImageGridimport matplotlib.pyplot as pltdef imshow_grid(images, shape=[2, 8]): """Plot images in a grid of a given shape.""" fig = plt.figure(1) grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05) size = shape[0] * shape[1] for i in range(size): grid[i].axis('off') grid[i].imshow(images[i]) # The AxesGrid object work as a list of axes. plt.show()def load_amazon(): data = cPickle.load(open('prosessed_data/amazon_images.pkl')) labels = cPickle.load(open('prosessed_data/amazon_labels.pkl')) return data,labelsdata,labels = load_amazon()print "show image..."imshow_grid(data[90:106])print labels[90:106]
图片和label,看到是可以对上了,然后我们就可以下一步了。
[0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1]
不要忘了数据的归一化!!比较简单的做法是计算(像素值-127.5)/127.5,这种做法是归一到-1到1之间,也可以算每个通道的均值,然后每个通道分别归一。
归一化相信大家都会,就不赘述了。
三.实验源码
import tensorflow as tfimport tensorflow.contrib.slim as slimimport osfrom skimage import io,transformfrom mpl_toolkits.axes_grid1 import ImageGridimport argparseimport numpy as npimport cPickleimport matplotlib.pyplot as pltparser = argparse.ArgumentParser(description='')parser.add_argument('--dataset', dest='dataset', default='amazon', help='dataset name')def save_pickle(data, path): with open(path, 'wb') as f: cPickle.dump(data, f, cPickle.HIGHEST_PROTOCOL) print ('Saved %s..' %path)def imshow_grid(images, shape=[2, 8]): """Plot images in a grid of a given shape.""" fig = plt.figure(1) grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05) size = shape[0] * shape[1] for i in range(size): grid[i].axis('off') grid[i].imshow(images[i]) # The AxesGrid object work as a list of axes. plt.show()def process_image(image, weight, hight): img = io.imread(image) img = transform.resize(img, (weight,hight), mode='reflect') return imgdef build_dataset(data_dir, out_dir, name,weight=100,hight=100): data_dir = os.path.join(data_dir,"images") for _, dirnames, _ in os.walk(data_dir): for dirname in dirnames: index = dirnames.index(dirname) workdir = os.path.join(data_dir, dirname) #images = io.imread_collection(workdir + '/*.jpg') processed_images = io.ImageCollection(workdir + '/*.jpg', load_func=process_image, weight=weight,hight=hight) label = np.full(len(processed_images), fill_value=index, dtype=np.int32) images = io.concatenate_images(processed_images) if index == 0: data = images labels = label else: data = np.vstack((data,images)) labels = np.append(labels,label) if not os.path.exists(out_dir): os.makedirs(out_dir) print("data shape:") print(data.shape) print("label shape:") print(labels.shape) save_pickle(data, out_dir+'/'+name+'_images.pkl') save_pickle(labels, out_dir+'/'+name+'_labels.pkl')def load_amazon(): images = cPickle.load(open('prosessed_data/amazon/amazon_images.pkl')) labels = cPickle.load(open('prosessed_data/amazon/amazon_labels.pkl')) images = images*2 - 1 print ('finished loading amazon image dataset..!') return images,labelsdef load_dslr(): images = cPickle.load(open('prosessed_data/dslr/dslr_images.pkl')) labels = cPickle.load(open('prosessed_data/dslr/dslr_labels.pkl')) images = images * 2 - 1 print ('finished loading dslr image dataset..!') return images,labelsdef load_webcam(): images = cPickle.load(open('prosessed_data/webcam/webcam_images.pkl')) labels = cPickle.load(open('prosessed_data/webcam/webcam_labels.pkl')) images = images * 2 - 1 print ('finished loading webcam image dataset..!') return images, labelsargs = parser.parse_args()def main(): print "make dataset..." if args.dataset == 'amazon': build_dataset("domain_adaptation_images/amazon","prosessed_data/amazon",args.dataset,weight=64,hight=64) print "read dataset..." images,label = load_amazon() print "show image..." imshow_grid((images[90:106]+1)/2) print label[90:106] elif args.dataset == 'dslr': build_dataset("domain_adaptation_images/dslr", "prosessed_data/dslr",args.dataset, weight=64, hight=64) print "read dataset..." images, label =load_dslr() print "show image..." imshow_grid((images[90:106]+1)/2) print label[90:106] elif args.dataset == 'webcam': build_dataset("domain_adaptation_images/webcam", "prosessed_data/webcam",args.dataset, weight=64, hight=64) print "read dataset..." images, label =load_webcam() print "show image..." imshow_grid((images[90:106]+1)/2) print label[90:106] else: raise Exception("wrong args!!") print "loading successful!"if __name__ == "__main__": main()
- 【Tensorflow】怎样为你的网络预加工和打包训练数据?(一)
- 【Tensorflow】怎样为你的网络预加工和打包训练数据?(二):小数据集的处理方案
- TensorFlow——训练自己的数据(一)数据处理
- Tensorflow + ResNet101 + fasterRcnn 训练自己的模型 数据(一)
- 利用tensorflow训练自己的图片数据(5)——测试训练网络
- Tensorflow 训练自己的数据集(一)(数据直接导入到内存)
- TensorFlow——训练自己的数据——CIFAR10(一)数据准备
- win10 tensorflow faster rcnn训练自己的数据集(一、制作VOC2007数据集)
- 利用tensorflow训练自己的图片数据(3)——建立网络模型
- tensorflow 实战 猫狗大战(一)训练自己的数据
- 用自己的数据训练Faster-RCNN,tensorflow版本(一)
- tensorflow 实战 猫狗大战(一)训练自己的数据
- tensorflow保存网络参数 使用训练好的网络参数进行数据的预测
- TensorFlow学习笔记(四):Tensorflow网络构建和TensorBoard进行训练过程可视化
- 学习TensorFlow,调用预训练好的网络(Alex, VGG, ResNet etc)
- nw.js node-webkit系列(17)怎样打包和分发你的应用
- Tensorflow 训练自己的数据集(二)(TFRecord)
- 使用Tensorflow训练自己的分割数据
- 设计模式:简单工厂
- arduino教程
- ion-tab和ng-click()一起使用,选项不跳转问题解决办法
- SSD的王者 PCIe固态硬盘的未来在哪里
- 利用DelayQueue实现延时消息队列(简易版MQ)
- 【Tensorflow】怎样为你的网络预加工和打包训练数据?(一)
- 浅谈Java中的equals和==
- 65_常用类_Date类的使用_JDk源码分析
- 32位Linux系统虚拟地址映射
- 欢迎使用CSDN-markdown编辑器
- Spdylay
- webjs--获取上传图片的缓存路径展示在页面上
- sync fence 软件接口-------------sw_sync_timeline和sw_sync_pt
- 【Java】Java中Integer和int比较大小出现的错误