TensorFlow数据处理方法

来源:互联网 发布:select to SQL 编辑:程序博客网 时间:2024/06/01 21:20

学习深度学习一年多了,一个感觉是实验结果的好坏在很大程度上取决于数据;数据对于深度学习算法十分关键,数据集的大小影响着模型的精度和泛化能力,好的数据处理技巧锦上添花,而合适的数据输入输出方法使Tensor“流动”得更加顺畅更好的发挥机器的性能,为模型的训练节约时间。许多情况下,对于数据的处理花的时间往往比模型的修改花的时间多,因此本文专门针对数据处理(图像类)进行一次梳理归纳,利人利己。

TensorFlow有三种数据读取方式:
1. 预先加载数据
2. 使用python将数据feedTensor
3. 从文件读取数据

预先加载数据

第一种方式直接把数据写在代码里进行运算,这种方式在一些简单的演示算法中很常见

import tensorflow as tfa = tf.constant(3.0)b = tf.constant(4.0)c = a + bwith tf.Session() as sess:    print(sess.run(c))

feed

第二种方法是利用tf.placeholder提供一个数据输入的接口,在启动计算图时将数据通过这个接口输入计算图

import tensorfow as tftrain_images = ...train_labels = ...X = tf.placeholder([], dtypes=tf.float32)Y = tf.placeholder([], dtypes=tf.uint8)train_op = ...with tf.Session() as sess:    sess.run(train_op, feed_dict={X: train_images, Y: train_labels})

从文件读取

第三种方法从文件中读取,涉及到数据的转换和读取两个方面,数据的转换又有各种格式可以选择,这里简单列举几个常用的数据存储与读取方法,最后介绍TensorFlow标准存储格式TFRecord的转换和读取方法。

1、 .pkl
.pkl文件是一种特殊的串行化存储的二进制格式文件,可以存储大部分常见的Python对象,使用起来十分方便

import pickledef data_to_file(image_data, label):    with open('somedata.pkl', 'wb') as f:        pickle.dump([image_data, label], f)def file_to_data(pkl_file):    with open(pkl_file, 'rb') as f:        image_data, label = pickle.load(f)

实际应用案例可参考上一篇博客

2、 TFRecord
如上所述,TFRecord是TensorFlow的标准存储格式,尽管这种数据格式的转换方式不是很直观,不是一两行代码就能搞定的,但是在使用时TensorFlow设计了一套高效的API来专门处理这种文件,配合TensorFlow图像处理的API使其在数据处理方面就显得更有优势了。下面的代码简单的展示了怎样将一张图片转换成.tfrecord文件,以及从文件解析出图片。

"""Converts image data to TFRecords file format with Example protos."""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionfrom PIL import Imageimport tensorflow as tf# Input must be type int or long.def _int64_feature(value):    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))# Input must be type bytesdef _bytes_feature(value):    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def convert_to(data, name):    filename = name + '.tfrecords'    print('Writing', filename)    writer = tf.python_io.TFRecordWriter(filename)    # when there is one picture    example = tf.train.Example(features=tf.train.Features(feature={        'label': _int64_feature(data[1]),        'image_raw': _bytes_feature(data[0])}))    writer.write(example.SerializeToString())    # # TODO    # # when there is many pictures    # for index in range(num_examples):    #     ....    writer.close()def main_1():    # Get the data.    images = Image.open('image_0006.jpg')    images.resize((224,224))    image_raw = images.tobytes()    labels = 0    data_sets = [image_raw, labels]    # TODO: for large scale image dataset,     # a better way is reading while saving    # Convert to Examples and write the result to TFRecords.    convert_to(data_sets, 'test')def main_2():    # Get the data.    img_file = tf.read_file('image_0006.jpg')    images = tf.image.decode_jpeg(img_file)    images = tf.image.resize_images(images, [224,224])    with tf.Session() as sess:        image_raw = sess.run(tf.cast(images, tf.uint8))    image_raw = image_raw.tobytes()    labels = 0    data_sets = [image_raw, labels]    convert_to(data_sets, 'test')if __name__ == '__main__':    main_2()
from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionfrom PIL import Imageimport tensorflow as tfdef read_and_decode(filename):    filename_queue = tf.train.string_input_producer([filename])    reader = tf.TFRecordReader()    _, serialized_example = reader.read(filename_queue)    features = tf.parse_single_example(serialized_example,                                       features={                                           'label': tf.FixedLenFeature([], tf.int64),                                           'img_raw' : tf.FixedLenFeature([], tf.string),                                       })    image = tf.image.decode_jpeg(features['img_raw'], channels=3)    image = tf.image.convert_image_dtype(image, dtype=tf.float32)    label = tf.cast(features['label'], tf.int32)    return image, label

说明:
1、代码中的example是样本的意思哦
2、本代码只展示了将一张图转化为tfrecord格式
3、图片的编解码、裁剪、缩放、旋转等操作,TensorFlow都有自己的函数可以代替第三方库的功能,根据习惯自己选择。

References:

  1. tensorflow/g3doc/how_tos/reading_data/index.md
  2. http://blog.csdn.net/u012759136/article/details/52232266
0 0
原创粉丝点击