Tensorflow数据读取方式
来源:互联网 发布:python 缺省参数 编辑:程序博客网 时间:2024/06/05 07:59
Tensorflow数据读取方式
关于tensorflow(简称TF)数据读取方式,官方给出了三种:
供给数据(Feeding):在TF程序运行的每一步,让python代码来供给数据。
从文件读取数据:在TF图的起始,让每一个管线从文件中读取数据。
预加载数据 :在TF图中定义常量或者变量来保存数据(使用数据量较小的情况)。
的
一、供给数据
TF的数据供给机制允许在TF运算图中将数据注入发到任一张量(tensor)。
通过run(),eval()函数输入到feed_dict()中,如:
with tf.Session() as sess: sess.run(init) ...... train_accuracy = accuracy.eval(feed_dict={x_input: batch_x, y_labels: batch_y, keep_prob: 1.0}) ...... train_step.run(feed_dict={x_input: batch_x, y_labels: batch_y, keep_prob: 0.5}) ......
上述代码中,x_input,y_labels为张量。虽然可以使用常量和变量来代替张量,但是在TF中,最好还是使用 op节点
x_input = tf.placeholder(tf.float32, [None, 32,32,3],name='Mul')y_labels = tf.placeholder(tf.float32,[None,62])
上述代码声明了x_input,y_labels张量,但是张量未被初始化,也不包含数据。
二、从文件读取数据
Kaggle比赛中最常见的数据格式是CSV文件,以读取CSV文件为例进行说明。直接上代码:
(代码来至http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html#Feeding)
filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])reader = tf.TextLineReader()key, value = reader.read(filename_queue)# Default values, in case of empty columns. Also specifies the type of the# decoded result.record_defaults = [[1], [1], [1], [1], [1]]col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults)features = tf.concat(0, [col1, col2, col3, col4])with tf.Session() as sess: # Start populating the filename queue. coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(1200): # Retrieve a single instance: example, label = sess.run([features, col5]) coord.request_stop() coord.join(threads)
首先将文件名列表交给tf.train.string_input_producer函数来生成一个先进先出的队列。
文件列表的表达方式为[“file0.csv”, “file1.csv”],也可以用[(“file%d” % i) for i in range(2)] 或者使用tf.train.match_filenames_once函数来生成。
使用TF读取CSV文件需要使用tf.TextLineReader和tf.decode_csv两个函数,每次read操作都会从文件中读取一行内容,而decode_csv会将这一行内容转为张量列表,如果输入的参数有缺失,record_default参数会根据张量的类型来设置默认值。
在调用run和eval去执行read时,需要使用tf.train_start_queue_runners来将文件名填充到队列。否则,read操作会被阻塞,直到文件名队列中有值为止。(从二进制文件中读取固定长度记录,可以使用tf.FixedLengthRecordReader的tf.decode_raw操作,tf.decode_raw可以将一个字符串转换成Uint8的张量)
三、预加载数据
当数据量较小时,一般选择直接将数据加载到内存中,然后再分Batch的输入到网络中。
这边举个简单的例子:使用python读取图片数据
不同类别的数据存储在四个不同的文件夹下。使用如下代码进行读取。
def load_data(data_dir): directories = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))] labels = [] images = [] for d in directories: label_dir = os.path.join(data_dir, d) file_names = [os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith(".jpg")] for f in file_names: img = skimage.data.imread(f) img299 = skimage.transform.resize(img,(299,299)) images.append(img299) labels.append(int(d)) return images, labels
调用上述代码得到的是list数据,需要调用如下代码转变成array类型。
images299 = [image for image in images]images_x = np.array(images299)labels_x= np.array(labels)
在数据量较大时,预加载数据就不现实了,因为太耗内存。所以这时就是使用上诉三种方法中的第二种:从文件读取数据。
如果要读取图片数据,可以将其转换成TF中的标准支持格式tfrecords,它是一种二进制文件,能够很好的利用内存,且方便复制和移动。
直接上代码(主要参考:http://blog.csdn.net/u012759136/article/details/52232266)
import osimport tensorflow as tf from PIL import Imagecwd = os.getcwd()writer = tf.python_io.TFRecordWriter("train.tfrecords")classes = ['0','1','2','3']for index, name in enumerate(classes): class_path = cwd + '\\' + name + '\\' for img_name in os.listdir(class_path): img_path = class_path + img_name img = Image.open(img_path) img = img.resize((100, 100)) img_raw = img.tobytes() example = tf.train.Example(features=tf.train.Features(feature={ "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) })) writer.write(example.SerializeToString())writer.close()
- Tensorflow数据读取方式
- Tensorflow中使用tfrecord方式读取数据
- Tensorflow数据读取有三种方式(next_batch)
- tensorflow数据集制作/文件队列读取方式
- 云端TensorFlow读取数据IO的高效方式
- 云端TensorFlow读取数据IO的高效方式
- Tensorflow读取数据的4种方式(8)---《深度学习》
- TensorFlow全新的数据读取方式:Dataset API入门教程(转)
- TensorFlow全新的数据读取方式:Dataset API入门教程
- TensorFlow全新的数据读取方式:Dataset API入门教程
- TensorFlow全新的数据读取方式:Dataset API入门教程
- TensorFlow全新的数据读取方式:Dataset API入门教程
- TensorFlow全新的数据读取方式:Dataset API入门教程
- TensorFlow全新的数据读取方式:Dataset API入门教程
- TensorFlow基础3:数据读取的三种方式
- TensorFlow全新的数据读取方式:Dataset API入门教程
- tensorflow爬坑行:数据读取
- Tensorflow图片数据读取
- C类型与数值存储
- 山东第八届acm大赛F题quadratic equation,山东理工oj 3898
- 05-树7 堆中的路径 (25分)
- 关于解决SSM-shiro的Spring注入问题
- R语言缺失值处理
- Tensorflow数据读取方式
- android TV开发之RecyclerView的使用以及自动加载
- Unity脚本类的继承关系
- python爬虫(10)身边的翻译专家——获取有道翻译结果
- 从C#到TypeScript
- 守护进程
- opengl随机颜色
- Android内存泄漏原因
- faster r-cnn使用Pascal VOC2007+2012联合训练