【Tensorflow】超大规模数据集解决方案:通过线程来预取(下)
来源:互联网 发布:java lua 编辑:程序博客网 时间:2024/05/21 22:40
环境Tensorflow1.2,python2.7
现在让我们用Tensorflow实现一个具体的Input pipeline,我们使用CoCo2014作为处理对象,网上应该可以下载到CoCo训练集,train2014这个文件。下载链接:
http://msvocds.blob.core.windows.net/coco2014/train2014.zip
一共13.5G,解压完以后大概会有8万多张图,这个数据集算得上超大规模级别了,那么问题来了,这么多图片我们怎么下手呢?难道和以前一样读到内存?如此笨重的数据集,如果仍然用内存暴力解决,那就太耗费时间空间资源了。能否在训练的同时,读数据,预处理数据呢?现在,让我们用队列+多线程去解决这个问题。
一.Beginning of an input pipeline
string_input_producer(string_tensor,num_epochs=None,shuffle=True,seed=None,capacity=32
,shared_name=None,name=None,cancel_op=None)
该函数输入字符串的Tensor或者List,返回一个字符串队列,一共8个输入参数,忽略最后三个不常用的参数,其中
string_tensor:一维的字符串Tensor,注意这个参数不传入Tensor也是可以的,也可以传入字符串的List
num_epochs:控制数量的一个参数,字符串List被放入队列的重复次数
shuffle:很好理解,是否打乱字符串List
seed:shuffle需要的随机数种子,一般可以不指定
capacity:队列的大小
类似的函数还有:
tf.train.range_input_producer
tf.train.slice_input_producer
需要注意的是这里返回的队列是添加了QueueRunner的,也就是我们需要调用线程来操作队列。还有很重要的一点,千万不要以为创建完队列以后,string_tensor的所有值就都入队了,入队也是流程化的,而入队操作通常由分线程来做,任何时刻我们都不关注队列的状态,只关注入队了什么,出队了什么。
二.Batching at the end of an input pipeline
tf.train.batch:
batch(tensors,batch_size,num_threads=1,capacity=32,enqueue_many=False
,shapes=None,dynamic_pad=False,allow_smaller_final_batch=False,shared_name=None,name=None)
该函数创建输入的Tensor中的一些batches,同样这个Tensor也可以是一个List,参数解析:
tensors
:需要注意的是,为了构成pipeline,保持一致性,这个函数也是以队列形式运行,所以Tensor的输入可以和上面描述的类似。
batch_size:一个batch的数量
num_threads:执行操作的线程数量
capacity:该函数运行的队列的长度
enqueue_many:控制是否可以一次入队多个,一般为false
dynamic_pad:
动态填充,填充维度为None的区域,不常用
allow_smaller_final_batch:控制是否允许小于batchsize的batch,如果不允许,则那几个样本会被丢弃
类似函数:
tf.train.batch_join
tf.train.maybe_batch
tf.train.shuffle_batch
同样很重要的一点,batching操作也是在Input pipeline里面,所以他也不是对全部数据来取batch,我们不需要关注当前队列中要多少样本,只需要关注取出了那些样本。
三.示例程序
让我们用一个程序来演示上面的过程。
为了演示需要,从train2014文件夹里面取20张图,然后连着文件夹拷贝到工程路径:
from os import listdirfrom os.path import isfile, joinimport tensorflow as tfdataset_path='train2014'with tf.Session() as sess: filenames = [join(dataset_path, f) for f in listdir(dataset_path) if isfile(join(dataset_path, f))] print 'number of images:', len(filenames) filename_queue = tf.train.string_input_producer(filenames, shuffle=False,num_epochs=1) reader = tf.WholeFileReader() name, img_bytes = reader.read(filename_queue) image = tf.image.decode_jpeg(img_bytes, channels=3) dataname = tf.train.batch([name], 2, dynamic_pad=True) sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) try: while not coord.should_stop(): print sess.run(dataname) except tf.errors.OutOfRangeError: print('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads)
解决的思路:首先产生一个存储路径下所有文件名的List,然后用它作为输入产生字符串队列,这是pipeline的前端。得到了出队的字符串序列,我们可以使用tensorflow中的filereader,将文件内容读取出来,reader.read函数返回{文件名,文件内容}键值对,文件内容即是我们需要处理的对象(这里为了直观,我们使用了文件名作为输出),这个阶段是pipeline的中端。最后得到了文件名序列,再用batch函数提取一个batch,这个作为pipeline的末端。这样一个程序就完成了。
运行以后结果如下:
number of images: 20['train2014/COCO_train2014_000000000109.jpg' 'train2014/COCO_train2014_000000000071.jpg']['train2014/COCO_train2014_000000000092.jpg' 'train2014/COCO_train2014_000000000094.jpg']['train2014/COCO_train2014_000000000064.jpg' 'train2014/COCO_train2014_000000000025.jpg']['train2014/COCO_train2014_000000000072.jpg' 'train2014/COCO_train2014_000000000110.jpg']['train2014/COCO_train2014_000000000086.jpg' 'train2014/COCO_train2014_000000000030.jpg']['train2014/COCO_train2014_000000000113.jpg' 'train2014/COCO_train2014_000000000009.jpg']['train2014/COCO_train2014_000000000061.jpg' 'train2014/COCO_train2014_000000000077.jpg']['train2014/COCO_train2014_000000000081.jpg' 'train2014/COCO_train2014_000000000034.jpg']['train2014/COCO_train2014_000000000078.jpg' 'train2014/COCO_train2014_000000000036.jpg']['train2014/COCO_train2014_000000000089.jpg' 'train2014/COCO_train2014_000000000049.jpg']Done training -- epoch limit reached
这里我没有shuffle数据集,需要shuffle只要把string_input_producer中的shuffle参数改为True。
四.具体项目
这里没有为大家展示一个具体网络怎么调用以上过程来训练。因为能输出batch的数据其实已经达到我们的意图了。如果有需要一个取数据+训练完整程序的同学,请参考github上图像风格迁移的repo,这个工程使用了以上方法,并有完整的训练过程。源码地址如下:
https://github.com/hzy46/fast-neural-style-tensorflow
- 【Tensorflow】超大规模数据集解决方案:通过线程来预取(下)
- 【Tensorflow】超大规模数据集解决方案:通过线程来预取(上)
- tensorflow下数据集MNIST下载源代码
- tensorflow之浅层(输入输出层)神经网络通过softmax分类mnist数据集
- tensorflow下的队列与线程(1)
- Windows下通过Docker安装Tensorflow环境
- Windows下通过Anaconda安装TensorFlow
- Windows下通过Anaconda安装tensorflow
- Windows 7下通过anaconda安装tensorflow
- Windows10下通过anaconda安装tensorflow
- Windows10下通过anaconda安装tensorflow
- Windows 7下通过anaconda安装tensorflow
- Windows 7下通过anaconda安装tensorflow
- windows下通过Anaconda安装tensorflow
- TensorFlow下MNIST数据集下载脚本input_data.py
- TensorFlow 下 mnist 数据集的操作及可视化
- tensorflow下对MNIST数据集进行识别的程序代码
- tensorflow MNIST数据集
- jvm
- putty清除记录 修改host name(or IP address)
- 功能展示——自定义控件Spinner样式实现下拉列表
- ios-回收键盘的方法
- Fzu Problem 2253 Salty Fish(dp)
- 【Tensorflow】超大规模数据集解决方案:通过线程来预取(下)
- 打造最强RecyclerView,Item侧滑菜单,长按拖拽Item,滑动删除Item
- MySQL中的统计信息相关参数介绍
- lua next 用法 table 空的判断
- solr6.0以上安装完,http://localhost:8080/solr/index.html,报404错
- Java String 类
- HDU 1142 最短路径的数量
- HashMap 实现原理
- 《剑指offer》输出最长回文子串