tensorflow系列(4)tfrecords的使用
来源:互联网 发布:淘宝新品怎么做爆款 编辑:程序博客网 时间:2024/05/17 04:30
利用tfrecords文件格式来高效读取数据吧。
(原文发表在我的博客,欢迎访问
0x00.前言
最近涉及到模型的测试,需要读取数据。通常我们可以直接通过文件读取,自己写一个,方便快捷。但是考虑到要在集群上运行,这个数据文件(csv,约1.4G)在每台服务器上都要有,于是我想事先处理一下数据,看看能不能压缩或是缩小文件大小。于是我想到使用tfrecords这种文件格式。
这篇文章给我一种帮助文档的感觉,所以我没那么想发。考虑到关于tfrecords没有很多文档,网上的文章也是一篇文章模子刻出来的,这里我记录一下过程。
0x01.保存为tfrecords文件
将数据保存为tfrecords文件可以视为这样一个流程:提取features -> 保存为Example结构对象 -> TFRecordWriter写入文件
1.提取features
Features的源码在https://github.com/tensorflow/tensorflow/blob/r1.2/tensorflow/core/example/feature.proto,从源码我们可以看到features总共有3种数据类型,分别是bytes
、float
、int64
。
要想将数据类型提取出features,我们首先要将数据转化为feature(即上面三种数据类型
对应类型,封装成如下三个函数:
def byte_feature(value): return tf.train.Feature( bytes_list = tf.train.BytesList(value=[value.encode()]) )def float_feature(value): return tf.train.Feature( float_list = tf.train.FloatList(value=[value]) )def int_feature(value): return tf.train.Feature( int64_list = tf.train.Int64List(value=[value]) )
其中,byte_feature
中的value.encode()
位置,当我不带.encode()
时会报错TypeError: 'AAB0162' has type str, but expected one of: bytes
。如果输入的是字符串类型,也可以这样:a = b'AAB0162'
。
函数构造好后,我们开始构造features,这里为示例。
features = tf.train.Features(feature={ 'name': byte_feature(n), 'time': float_feature(t), 'data': byte_feature(d) })
2.保存为Example结构对象
example = tf.train.Example(features=features)
3.TFRecordWriter写入文件
# writer对象destination = 'data.tfrecords'writer = tf.python_io.TFRecordWriter(destination)# 写入writer.write(example.SerializeToString())#关闭writer.close()
4.完整代码
# 首先封装好三个函数def byte_feature(value): return tf.train.Feature( bytes_list = tf.train.BytesList(value=[value.encode()]) )def float_feature(value): return tf.train.Feature( float_list = tf.train.FloatList(value=[value]) )def int_feature(value): return tf.train.Feature( int64_list = tf.train.Int64List(value=[value]) )# 定义好文件名source = 'test.csv'destination = 'data.tfrecords'# 创建对象writer = tf.python_io.TFRecordWriter(destination)# 打开文件reader = open(source)# 循环,按行处理for line in reader.readlines(): # 对行的提取 n, t, d = deal(line) # features构造 features = tf.train.Features(feature={ 'name': byte_feature(n), 'time': float_feature(t), 'data': byte_feature(d) }) # 构造example对象 及 写入 example = tf.train.Example(features=features) writer.write(example.SerializeToString())# 关闭指针writer.close()reader.close()
0x02.从tfrecords读取数据
这里分为两部,首先我们从文件中取出features,取出后的类型为Tensor,之后我们使用sess将数据还原。
1.取出features
首先定义一会要用的参数:
# tfrecords文件file = 'data.tfrecords'# 线程数量num_threads = 2num_epochs = 100# 每批次数量batch_size = 10# 样本数量下限min_after_dequeue = 10
首先定义reader:
reader = tf.TFRecordReader()
定义输入部分:
file_queue = tf.train.string_input_producer(file, num_epochs=num_epochs,)
读取:
_, example = reader.read(file_queue)
提取出features,并保存为列表:
features_dict = tf.parse_single_example(example, features={ 'name': tf.FixedLenFeature([], tf.string), 'time': tf.FixedLenFeature([], tf.float32), 'data': tf.FixedLenFeature([], tf.string) })n = features_dict['name']t = features_dict['time']d = features_dict['data']
之后我们将其转化为批次队列:
n, t, d = tf.train.shuffle_batch( [n, t, d], batch_size=batch_size, num_threads=num_threads, capacity = min_after_dequeue + 3 * batch_size, min_after_dequeue = min_after_dequeue )
2.数据还原
定义session,之后将数据还原:
with tf.Session() as sess: # 初始化 tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()).run() tf.train.start_queue_runners(sess=sess) a_val, b_val, c_val = sess.run([n, t, d]) print(a_val, b_val, c_val)
3.完整代码
(这里我将features的提取封装为了函数
def records_to(file, num_threads=2, num_epochs=2, batch_size=2, min_after_dequeue=2): reader = tf.TFRecordReader() file_queue = tf.train.string_input_producer(file, num_epochs=num_epochs,) _, example = reader.read(file_queue) features_dict = tf.parse_single_example(example, features={ 'name': tf.FixedLenFeature([], tf.string), 'time': tf.FixedLenFeature([], tf.float32), 'data': tf.FixedLenFeature([], tf.string) }) # n: Tensor("ParseSingleExample/Squeeze_name:0", shape=(), dtype=string) n = features_dict['name'] t = features_dict['time'] d = features_dict['data'] n, t, d = tf.train.shuffle_batch( [n, t, d], batch_size=batch_size, num_threads=num_threads, capacity = min_after_dequeue + 3 * batch_size, min_after_dequeue = min_after_dequeue ) # 数据格式为Tensor return n, t, ddef train(): n, t, d = records_to(['data.tfrecords']) with tf.Session() as sess: tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()).run() tf.train.start_queue_runners(sess=sess) a_val, b_val, c_val = sess.run([n, t, d]) print(a_val, b_val, c_val)
4.不构造图的读取
这一种方法比较简单,也不需要构造图,可以看作是我们把数据写入tfrecords的逆过程。
record_iterator = tf.python_io.tf_record_iterator(path=file_name)for string_record in record_iterator: example = tf.train.Example() example.ParseFromString(string_record) n = example.features.feature['name'].bytes_list.value[0] t = int(example.features.feature['time'].float_list.value[0]) d = (example.features.feature['data'].bytes_list.value[0]) print(n, t, d)
0x03.参考
TensorFlow学习记录– 7.TensorFlow高效读取数据之tfrecord详细解读
Tfrecords Guide
- tensorflow系列(4)tfrecords的使用
- tensorflow:使用tfrecords时的注意事项
- tensorflow中tfrecords使用介绍
- TensorFlow .tfrecords训练文件生成、使用
- tensorflow中tfrecords格式的读写
- TensorFLow 不同大小图片的TFrecords存取
- Tensorflow 生成tfrecords
- TensorFlow读取tfrecords数据
- Tensorflow 构建 TFrecords
- Tensorflow之构建自己的图片数据集TFrecords
- Tensorflow构建自己的图片数据集TFrecords
- Tensorflow之构建自己的图片数据集TFrecords(精)
- Tensorflow之构建自己的图片数据集TFrecords
- 机器学习: TensorFlow 的数据读取与TFRecords 格式
- Tensorflow构建自己的图片数据集TFrecords
- tensorflow中tfrecords文件的save和read
- tensorflow官网Cifar-10改为自己的TFRecords数据集
- 利用Tensorflow构建自己的图片数据集TFrecords
- matlab快速入门2——数据载入与保存
- python django日志器的使用及配置
- beego使用orm插入大量数据,回滚报错:buffer busy
- ArrayList概念及手写代码
- ThreadLocal原理
- tensorflow系列(4)tfrecords的使用
- git merge,rebase和*(no branch)
- java的Map和Map.Entry
- 837C_Two Seals
- 【HDU
- [caioj]1492: 基于连通性状态压缩的动态规划问题:Pipes
- HDU
- Eclipse如何修改指定项目的JDK版本
- SSM框架搭建配置问题(1)--------spring和Mybatis整合包版本冲突