TensorFolw学习笔记-TFRecord

来源:互联网 发布:阿里云主机多少钱 编辑:程序博客网 时间:2024/05/19 11:47

  关于 tensorflow 读取数据, 官网提供了3中方法
  1 Feeding: 在tensorflow程序运行的每一步, 用python代码在线提供数据。
  2 Reader : 在一个计算图(tf.graph)的开始前,将文件读入到流(queue)中。
  3 在声明tf.variable变量或numpy数组时保存数据。受限于内存大小,适用于数据较小的情况。
  我们在刚学习Tensorflow时,几乎所有的例子都是使用第一种或第三种方法,因为例子中的训练集数据量都比较少,而当训练集比较大时,就会占用较大的内存,效率比较低。此时,第二种方法就会发挥巨大的作用,因此它存储的是二进制文件,PC读取二进制文件是比读取格式文件要快的多。
  在第二种方法中,TensorFlow提供了TFRecord的格式统一管理存储数据。我们首先看一下TFRecord的存储格式:

# tf.train.Examplemessage Example{    Features features = 1;}message Features{    map<string,Features> feature = 1;}message Feature {    oneof kind {        BytesList bytes_list = 1;        FloateList float_list = 2;        Int64List int64_list = 3;    }}

  从定义中可以看出tf.train.Example是以字典的形式存储数据格式,string为字典的key值,字典的属性值有三种类型:bytes、float、int64。接下来通过例子说明如果通过TFRecord保存和读取文件。保存和读取用到函数分别为:tf.python_io.TFRecordWriter和tf.TFRecordReader()。

TFrecode格式保存

  我们以Mnist数据集为例,MNIST 训练集中的shape=(55000, 784),我们要将每一个数据的image_data、label、channel保存到TFRecord中。

 # 生成整数型的属性    def _int64_feature(value):        return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))    # 生成字符串类型的属性,也就是图像的内容    def _string_feature(value):        return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))    mniset = input_data.read_data_sets('./mni_data',dtype=tf.uint8, one_hot=True)    images = mniset.train.images    labels = mniset.train.labels    # (55000, 784)    pixels = images.shape[1]    # 55000    num_examples = mniset.train.num_examples    file_name = 'output.tfrecords'    writer = tf.python_io.TFRecordWriter(file_name)    for index in range(num_examples):        image_raw = images[index].tostring()        example = tf.train.Example(features = tf.train.Features(feature = {            'pixel': _int64_feature(pixels),            'label': _int64_feature(np.argmax(labels[index])),            'image_raw': _string_feature(image_raw)        }))        writer.write(example.SerializeToString())    writer.close()

TFrecode格式读取

    reader = tf.TFRecordReader()    # 创建一个队列来输入列表文件    filename_quene = tf.train.string_input_producer([path])    _, serialized_example = reader.read(filename_quene)    features = tf.parse_single_example(        serialized_example,        features = {            # 解析的结果为Tensor。            'image_raw' : tf.FixedLenFeature([], tf.string),            'pixel' : tf.FixedLenFeature([], tf.int64),            'label' : tf.FixedLenFeature([],tf.int64)        }    )    # 将字符串内容解析为图像的数组结构    images = tf.decode_raw(features['image_raw'],tf.uint8)    labels = tf.cast(features['label'],tf.int32)    pixels = tf.cast(features['pixel'],tf.int32)    with tf.Session() as sess:        # 可以使用Tensorflow的多线程进行读取        coord = tf.train.Coordinator()        threads = tf.train.start_queue_runners(sess=sess,coord=coord)        for i in range(10):            image,label,pixel = sess.run([images,labels,pixels])

  读取TFRecord文件时,使用tf.train.string_input_producer生成一个输入文件队列。这里我们的输入列表文件只有一个[path],而如果当训练数据比较大时,就需要将数据拆分多个TFRecord文件来提高处理效率。例如,Cifar10的例子中,将训练集数据拆分为5个bin文件以提高文件处理效率。
  在Cifar10中,例子使用下面方式获取所有的训练集输入文件列表,而Tensorflow既然让我们将训练数据拆分为多个TFRecord文件,那么它也提供函数tf.train.match_filenames_once,通过正则表达式获取某个目录下的输入文件列表。

filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)               for i in xrange(1, 6)]filenames =tf.train.match_filenames_once('data_batch_×')

  string_input_producer使用输入文件队列创建一个输入队列,队列中的原始元素为文件列表中的所有文件,而输入队列可以作为读取函数的参数,例子中的tf.TFRecordReader()。函数在读取文件时,会首先判断当前是否已有打开的文件可读,如果没有或打开的文件已经读完,则函数会从输入队列中出列一个文件,并从这个文件中读取数据。而文件出列方式的参数默认为shuffle=True,也就是说是随机出列的。注意:string_input_producer还有一个参数num_epochs,并且默认值为None,它的作用是什么呢?它的作用为限制加载初始文件列表的轮数。在默认的情况下:当一个队列的所有文件被处理完后,它会将初始化时提供的文件列表中的文件重新加入队列。如果num_epochs=1,则程序将输入队列中的所有文件处理完后,程序将自动停止,如果继续尝试读取输入队列中的文件,输入队列将会报:OutOfRange错误。因此在测试数据时,需要将num_epochs=1设置为1。
  输入队列读取的文件,需要进行预处理操作,以增加训练集数量,增加图片的使用率。预处理请参考:http://blog.csdn.net/lovelyaiq/article/details/78716325
  string_input_producer生成的输入队列可以被多个文件读取现成操作,而且输入队列会将队列中的文件均匀的非配给不同的线程,不会出现有些文件出现多次,而有的文件没有被处理的情况。而多线程文件获取是通过tf.train.Coordinator()完成。

原创粉丝点击