tensorflow数据集制作/文件队列读取方式
来源:互联网 发布:windows桌面增强小工具 编辑:程序博客网 时间:2024/06/10 20:09
3种数据读取方式
TensorFlow程序读取数据一共有3种方法:
供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据。
从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据。
预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。
以上3种方式官网中均有介绍
1. input = tf.placeholder(tf.float32) feed方式,先定义一个占位符,需要sess.run()的时候把数据传进去。
3.预加载,如下,讲数据保存在常量中,这个方法对于大数据不合适,内存资源不够。
training_data = ...training_labels = ...with tf.Session(): input_data = tf.constant(training_data) input_labels = tf.constant(training_labels)
文件读取方式
使用原因:例如图片分类场景中,我们要使用自己的数据集,数据集比较大,需要动态的去添加数据,这样可以使用占位符,也可以利用文件读取的方式,这种方式更方便。下面我们介绍文件读取方式。(这里特别补充一个问题,队列是用来解决gpu空闲和内存问题的。所有的图片同时读到内存中是非常大的,内存可能承受不了,有了队列之后,每一次都从队列加载到内存队列中,这样就可以添加比较大的图片了。http://geek.csdn.net/news/detail/201552 详解TENSORFLOW读取机制)
官网中给出的一般步骤一共典型的文件读取管线会包含下面这些步骤:`文件名列表`可配置的 文件名乱序(shuffling)`可配置的 最大训练迭代数(epoch limit)`文件名队列`针对输入文件格式的阅读器`纪录解析器`可配置的预处理器`样本队列
这里介绍使用tensorflow的二进制格式来处理,一是因为这个二进制文件操作方便,而且网上有比较通用的处理流程,2是利用图片生成方便,200M的二进制文件可以很快速的生成,但是如果是csv文件的话,生成速度特别慢,而且200M的excel也打不开的。。(亲测)所以还是用二进制文件吧!
一下假设有了二进制文件,先别急,后面我会告诉大家如何去制作一个属于自己的数据集!
具体的队列描述,参考官网即可,我也没太看懂。
http://www.tensorfly.cn/tfdoc/how_tos/reading_data.html#AUTOGENERATED-preloaded-data
TensorFlow提供了两个类来帮助多线程的实现:tf.Coordinator和 tf.QueueRunner。从设计上这两个类必须被一起使用。Coordinator类可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常。QueueRunner类用来协调多个工作线程同时将多个张量推入同一个队列中。QueueRunner类会创建一组线程, 这些线程可以重复的执行Enquene操作, 他们使用同一个Coordinator来处理线程同步终止。
basePath = '/home/user/xxxxx'classes = {'c1','c2'}#生成数据集def create_record(): writer = tf.python_io.TFRecordWriter("train.tfrecords") for index, name in enumerate(classes): class_path = basePath +"/"+ name+"/" for img_name in os.listdir(class_path): img_path = class_path + img_name img = Image.open(img_path) img = img.resize((320, 240)) img_raw = img.tobytes() #将图片转化为原生bytes #print index,img_raw example = tf.train.Example( features=tf.train.Features( feature={ "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) } ) ) writer.write(example.SerializeToString()) writer.close()#读取二进制数据img, label = read_and_decode("../train.tfrecords")#分块处理,这里img_batch,就可以当做输入看待了,以后每次sess.run()相关操作都会取出一部分,此处相当于你自己写一个队列操作去feed x的数据。(个人理解)img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=4, capacity=2000, min_after_dequeue=1000)#官网推荐处理模板# Create the graph, etc.init_op = tf.initialize_all_variables()# Create a session for running operations in the Graph.sess = tf.Session()# Initialize the variables (like the epoch counter).sess.run(init_op)# Start input enqueue threads.coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)try: while not coord.should_stop(): # Run training steps or whatever sess.run(train_op)except tf.errors.OutOfRangeError: print 'Done training -- epoch limit reached'finally: # When done, ask the threads to stop. coord.request_stop()# Wait for threads to finish.coord.join(threads)sess.close()
制作自己的数据集
思路:TFRecords文件。
获取这种格式的文件方式为,首先将一般的数据格式填入Example protocol buffer中,再将 protocol buffer序列化为一个字符串,然后使用tf.python_io.TFRecordWriter类的相关方法将字符串写入一个TFRecords文件中。
#制作二进制数据def create_record(): writer = tf.python_io.TFRecordWriter("train.tfrecords") for index, name in enumerate(classes): class_path = basePath +"/"+ name+"/" for img_name in os.listdir(class_path): img_path = class_path + img_name img = Image.open(img_path) img = img.resize((320, 240)) img_raw = img.tobytes() #将图片转化为原生bytes #print index,img_raw example = tf.train.Example( features=tf.train.Features( feature={ "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) } ) ) writer.write(example.SerializeToString()) writer.close()
小问题
利用CNN分类的时候,出现了资源耗尽问题,网络铺的太大,显卡内存就用光了。mnist中使用28*28,我的图片320*240。主要原因来自于最后全连接层,参数过多。需要单独处理了!
- tensorflow数据集制作/文件队列读取方式
- Tensorflow数据读取方式
- tensorflow读取文件数据
- Tensorflow从文件读取数据
- TensorFlow读取二进制文件数据到队列
- tensorflow读取数据到队列当中
- tensorflow读取数据到队列当中
- tensorflow读取数据到队列当中
- 利用Tensorflow的队列多线程读取数据
- Tensorflow中使用tfrecord方式读取数据
- Tensorflow数据读取有三种方式(next_batch)
- tensorflow文件数据读取机制剖析
- TensorFlow 学习(二) 制作自己的TFRecord数据集,读取,显示及代码详解
- tensorflow mnist实战笔记(二)制作和读取自己的数据集
- 【tensorflow】文件队列的两种创建和加载方式
- 云端TensorFlow读取数据IO的高效方式
- 云端TensorFlow读取数据IO的高效方式
- Tensorflow读取数据的4种方式(8)---《深度学习》
- 了解
- 安装错误:Installation failed with message:INSTALL_CANCELED_BY_USER
- MyEclipse 2015优化技巧
- C++时间复杂度与空间复杂度
- Git 版本管理工具
- tensorflow数据集制作/文件队列读取方式
- 关于flexible的一些归纳
- Spring源码学习第二节
- ReactNative 自定义单选对话框 SingleChoiceDialog
- 采用SASS编写CSS
- 为什么要控制控制 pcb 阻抗
- (38)骨网格物体骼Actor
- 青蛙跳
- 虚拟机字节码执行引擎1