TensorFlow入门(十-I)tfrecord 固定维度数据读写

来源:互联网 发布:阿里云1mbps实际网速 编辑:程序博客网 时间:2024/06/05 03:35

本例代码:https://github.com/yongyehuang/Tensorflow-Tutorial/tree/master/python/the_use_of_tfrecord

关于 tfrecord 的使用,分别介绍 tfrecord 进行三种不同类型数据的处理方法。
- 维度固定的 numpy 矩阵
- 可变长度的 序列 数据
- 图片数据

在 tf1.3 及以后版本中,推出了新的 Dataset API, 之前赶实验还没研究,可能以后都不太会用下面的方式写了。这些代码都是之前写好的,因为注释中都写得比较清楚了,所以直接上代码。

tfrecord_1_numpy_writer.py

# -*- coding:utf-8 -*- import tensorflow as tfimport numpy as npfrom tqdm import tqdm'''tfrecord 写入数据.将固定shape的矩阵写入 tfrecord 文件。这种形式的数据写入 tfrecord 是最简单的。refer: http://blog.csdn.net/qq_16949707/article/details/53483493'''# **1.创建文件,可以创建多个文件,在读取的时候只需要提供所有文件名列表就行了writer1 = tf.python_io.TFRecordWriter('../data/test1.tfrecord')writer2 = tf.python_io.TFRecordWriter('../data/test2.tfrecord')"""有一点需要注意的就是我们需要把矩阵转为数组形式才能写入就是需要经过下面的 reshape 操作在读取的时候再 reshape 回原始的 shape 就可以了"""X = np.arange(0, 100).reshape([50, -1]).astype(np.float32)y = np.arange(50)for i in tqdm(xrange(len(X))):  # **2.对于每个样本    if i >= len(y) / 2:        writer = writer2    else:        writer = writer1    X_sample = X[i].tolist()    y_sample = y[i]    # **3.定义数据类型,按照这里固定的形式写,有float_list(好像只有32位), int64_list, bytes_list.    example = tf.train.Example(        features=tf.train.Features(            feature={'X': tf.train.Feature(float_list=tf.train.FloatList(value=X_sample)),                     'y': tf.train.Feature(int64_list=tf.train.Int64List(value=[y_sample]))}))    # **4.序列化数据并写入文件中    serialized = example.SerializeToString()    writer.write(serialized)print('Finished.')writer1.close()writer2.close()

tfrecord_1_numpy_reader.py

# -*- coding:utf-8 -*- import tensorflow as tf'''read data从 tfrecord 文件中读取数据,对应数据的格式为固定shape的数据。'''# **1.把所有的 tfrecord 文件名列表写入队列中filename_queue = tf.train.string_input_producer(['../data/test1.tfrecord', '../data/test2.tfrecord'], num_epochs=None,                                                shuffle=True)# **2.创建一个读取器reader = tf.TFRecordReader()_, serialized_example = reader.read(filename_queue)# **3.根据你写入的格式对应说明读取的格式features = tf.parse_single_example(serialized_example,                                   features={                                       'X': tf.FixedLenFeature([2], tf.float32),  # 注意如果不是标量,需要说明数组长度                                       'y': tf.FixedLenFeature([], tf.int64)}     # 而标量就不用说明                                   )X_out = features['X']y_out = features['y']print(X_out)print(y_out)# **4.通过 tf.train.shuffle_batch 或者 tf.train.batch 函数读取数据"""在shuffle_batch 函数中,有几个参数的作用如下:capacity: 队列的容量,容量越大的话,shuffle 得就更加均匀,但是占用内存也会更多num_threads: 读取进程数,进程越多,读取速度相对会快些,根据个人配置决定min_after_dequeue: 保证队列中最少的数据量。   假设我们设定了队列的容量C,在我们取走部分数据m以后,队列中只剩下了 (C-m) 个数据。然后队列会不断补充数据进来,   如果后勤供应(CPU性能,线程数量)补充速度慢的话,那么下一次取数据的时候,可能才补充了一点点,如果补充完后的数据个数少于   min_after_dequeue 的话,不能取走数据,得继续等它补充超过 min_after_dequeue 个样本以后才让取走数据。   这样做保证了队列中混着足够多的数据,从而才能保证 shuffle 取值更加随机。   但是,min_after_dequeue 不能设置太大,否则补充时间很长,读取速度会很慢。"""X_batch, y_batch = tf.train.shuffle_batch([X_out, y_out], batch_size=2,                                          capacity=200, min_after_dequeue=100, num_threads=2)sess = tf.Session()init = tf.global_variables_initializer()sess.run(init)# **5.启动队列进行数据读取# 下面的 coord 是个线程协调器,把启动队列的时候加上线程协调器。# 这样,在数据读取完毕以后,调用协调器把线程全部都关了。coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)y_outputs = list()for i in xrange(5):    _X_batch, _y_batch = sess.run([X_batch, y_batch])    print('** batch %d' % i)    print('_X_batch:', _X_batch)    print('_y_batch:', _y_batch)    y_outputs.extend(_y_batch.tolist())print(y_outputs)# **6.最后记得把队列关掉coord.request_stop()coord.join(threads)
原创粉丝点击