Tensorflow数据读取方式

来源:互联网 发布:python 缺省参数 编辑:程序博客网 时间:2024/06/05 07:59

Tensorflow数据读取方式
关于tensorflow(简称TF)数据读取方式,官方给出了三种:
供给数据(Feeding):在TF程序运行的每一步,让python代码来供给数据。
从文件读取数据:在TF图的起始,让每一个管线从文件中读取数据。
预加载数据 :在TF图中定义常量或者变量来保存数据(使用数据量较小的情况)。

一、供给数据
TF的数据供给机制允许在TF运算图中将数据注入发到任一张量(tensor)。
通过run(),eval()函数输入到feed_dict()中,如:

with tf.Session() as sess:    sess.run(init)    ......    train_accuracy = accuracy.eval(feed_dict={x_input: batch_x, y_labels: batch_y, keep_prob: 1.0})    ......    train_step.run(feed_dict={x_input: batch_x, y_labels: batch_y, keep_prob: 0.5})    ......

上述代码中,x_input,y_labels为张量。虽然可以使用常量和变量来代替张量,但是在TF中,最好还是使用 op节点

x_input = tf.placeholder(tf.float32, [None, 32,32,3],name='Mul')y_labels = tf.placeholder(tf.float32,[None,62])

上述代码声明了x_input,y_labels张量,但是张量未被初始化,也不包含数据。
二、从文件读取数据
Kaggle比赛中最常见的数据格式是CSV文件,以读取CSV文件为例进行说明。直接上代码:
(代码来至http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html#Feeding)

filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])reader = tf.TextLineReader()key, value = reader.read(filename_queue)# Default values, in case of empty columns. Also specifies the type of the# decoded result.record_defaults = [[1], [1], [1], [1], [1]]col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults)features = tf.concat(0, [col1, col2, col3, col4])with tf.Session() as sess:  # Start populating the filename queue.  coord = tf.train.Coordinator()  threads = tf.train.start_queue_runners(coord=coord)  for i in range(1200):    # Retrieve a single instance:    example, label = sess.run([features, col5])  coord.request_stop()  coord.join(threads)

首先将文件名列表交给tf.train.string_input_producer函数来生成一个先进先出的队列。
文件列表的表达方式为[“file0.csv”, “file1.csv”],也可以用[(“file%d” % i) for i in range(2)] 或者使用tf.train.match_filenames_once函数来生成。
使用TF读取CSV文件需要使用tf.TextLineReader和tf.decode_csv两个函数,每次read操作都会从文件中读取一行内容,而decode_csv会将这一行内容转为张量列表,如果输入的参数有缺失,record_default参数会根据张量的类型来设置默认值。
在调用run和eval去执行read时,需要使用tf.train_start_queue_runners来将文件名填充到队列。否则,read操作会被阻塞,直到文件名队列中有值为止。(从二进制文件中读取固定长度记录,可以使用tf.FixedLengthRecordReader的tf.decode_raw操作,tf.decode_raw可以将一个字符串转换成Uint8的张量)

三、预加载数据
当数据量较小时,一般选择直接将数据加载到内存中,然后再分Batch的输入到网络中。
这边举个简单的例子:使用python读取图片数据

不同类别的数据存储在四个不同的文件夹下。使用如下代码进行读取。

def load_data(data_dir):    directories = [d for d in os.listdir(data_dir)  if os.path.isdir(os.path.join(data_dir, d))]    labels = []    images = []    for d in directories:        label_dir = os.path.join(data_dir, d)        file_names = [os.path.join(label_dir, f)  for f in os.listdir(label_dir) if f.endswith(".jpg")]        for f in file_names:            img = skimage.data.imread(f)            img299 = skimage.transform.resize(img,(299,299))            images.append(img299)            labels.append(int(d))    return images, labels

调用上述代码得到的是list数据,需要调用如下代码转变成array类型。

images299 = [image for image in images]images_x = np.array(images299)labels_x= np.array(labels)

在数据量较大时,预加载数据就不现实了,因为太耗内存。所以这时就是使用上诉三种方法中的第二种:从文件读取数据。
如果要读取图片数据,可以将其转换成TF中的标准支持格式tfrecords,它是一种二进制文件,能够很好的利用内存,且方便复制和移动。
直接上代码(主要参考:http://blog.csdn.net/u012759136/article/details/52232266)

import osimport tensorflow as tf from PIL import Imagecwd = os.getcwd()writer = tf.python_io.TFRecordWriter("train.tfrecords")classes = ['0','1','2','3']for index, name in enumerate(classes):    class_path = cwd + '\\' + name + '\\'    for img_name in os.listdir(class_path):        img_path = class_path + img_name        img = Image.open(img_path)        img = img.resize((100, 100))        img_raw = img.tobytes()         example = tf.train.Example(features=tf.train.Features(feature={            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))        }))        writer.write(example.SerializeToString())writer.close()
0 0