TensorFlow 数据读取

来源:互联网 发布:网络博客拉人方法 编辑:程序博客网 时间:2024/06/06 01:12

详解tensorflow数据读取机制

tensorflow读取机制图解

tensorflow使用文件名队列+内存队列双队列的形式读入文件,可以很好地管理epoch。

tensorflow读取数据机制的对应函数

文件名队列,我们使用tf.train.string_input_producer函数。这个函数需要传入一个文件名list,系统会自动将它转为一个文件名队列。

string_input_producer(    string_tensor,    num_epochs=None,    shuffle=True,    seed=None,    capacity=32,    shared_name=None,    name=None,    cancel_op=None)

在tensorflow中,内存队列不需要我们自己建立,我们只需要使用reader对象从文件名队列中读取数据就可以了,具体实现可以参考下面的实战代码。

使用tf.train.start_queue_runners之后,才会启动填充队列的线程。

实战代码

# 导入tensorflowimport tensorflow as tf # 新建一个Sessionwith tf.Session() as sess:    # 我们要读三幅图片A.jpg, B.jpg, C.jpg    filename = ['A.jpg', 'B.jpg', 'C.jpg']    # string_input_producer会产生一个文件名队列    filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=5)    # reader从文件名队列中读数据。对应的方法是reader.read    reader = tf.WholeFileReader()    key, value = reader.read(filename_queue)    # tf.train.string_input_producer定义了一个epoch变量,要对它进行初始化    tf.local_variables_initializer().run()    # 使用start_queue_runners之后,才会开始填充队列    threads = tf.train.start_queue_runners(sess=sess)    i = 0    while True:        i += 1        # 获取图片数据并保存        image_data = sess.run(value)        with open('read/test_%d.jpg' % i, 'wb') as f:            f.write(image_data)

# Dataset API
此前,在TensorFlow中读取数据一般有两种方法:
* 使用placeholder读内存中的数据
* 使用queue读硬盘中的数据(关于这种方式,可以参考我之前的一篇文章:十图详解tensorflow数据读取机制)
## Dataset API的导入

在TensorFlow 1.3中,Dataset API是放在contrib包中的:

tf.contrib.data.Dataset

而在TensorFlow 1.4中,Dataset API已经从contrib包中移除,变成了核心API的一员:

tf.data.Dataset

参考资料

十图详解tensorflow数据读取机制(附代码)
TensorFlow全新的数据读取方式:Dataset API入门教程

原创粉丝点击