Tensorflow基础:统一的数据存储格式

来源:互联网 发布:淘宝指数查询软件 编辑:程序博客网 时间:2024/05/23 21:18

在很多图像识别问题中,图像的亮度、对比度等属性都不应该影响最后的识别结果。本文将介绍如何对图像数据进行预处理使训练得到的神经网络模型尽可能小地被无关因素所影响。

TFRecord输入数据格式

来自实际问题的数据往往有很多格式和属性,TFRecord格式可以统一不同的原始数据格式,并更加有效地管理不同的属性。
有些程序中,使用了一个从类别名称到所有数据列表的词典来维护图像和类别的关系。这种方式的可扩展性非常差,当数据来源更加复杂、每一个样例中的信息更加丰富之后,这种方式就很难有效地记录输入数据中的信息了。于是Tensorflow提供了TFRecord的格式来统一存储数据。

TFRecord格式介绍

TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下代码给出了tf.train.Example的定义:

message Example{    Features features = 1};message Features{    map<string, Feature>feature = 1;};message Feature{    oneof kind{    BytesList bytes_list = 1;    FloatList float_list = 2;    Int64List int64_list = 3;}};

TFRecord样例程序

这里将给出具体地样例程序来读写TFRecord文件。下面的程序给出了如何将MNIST输入数据转化为TFRecord的格式:

import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport numpy as npdef _int64_feature(value):    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))def _bytes_feature(value):    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))mnist = input_data.read_data_sets("E:\科研\TensorFlow教程\MNIST_data", dtype=tf.uint8, one_hot=True)images = mnist.train.imageslabels = mnist.train.labelspixels = images.shape[1]num_examples = mnist.train.num_examplesfilename = "./output.tfrecords"writer = tf.python_io.TFRecordWriter(filename)for index in range(num_examples):    image_raw = images[index].tostring()    example = tf.train.Example(features=tf.train.Features(feature={        'pixels': _int64_feature(pixels),        'label': _int64_feature(np.argmax(labels[index])),        'image_raw': _bytes_feature(image_raw)    }))    writer.write(example.SerializeToString())writer.close()

读取TFRecord文件

Tensorflow对从文件列表中读取数据提供了很好的支持。以下程序给出了如何读取TFRecord文件中的数据:

import tensorflow as tfreader = tf.TFRecordReader()filename_queue = tf.train.string_input_producer(['./output.tfrecords'])_, serialized_example = reader.read(filename_queue)features = tf.parse_single_example(    serialized_example,    features={        'image_raw': tf.FixedLenFeature([], tf.string),        'pixels': tf.FixedLenFeature([], tf.int64),        'label': tf.FixedLenFeature([], tf.int64)    })images = tf.decode_raw(features['image_raw'], tf.uint8)labels = tf.cast(features['label'], tf.int32)pixels = tf.cast(features['pixels'], tf.int32)sess = tf.Session() coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)for i in range(10):    image, label, pixel = sess.run([images, labels, pixels])