tensorflow系列(4)tfrecords的使用

来源:互联网 发布:淘宝新品怎么做爆款 编辑:程序博客网 时间:2024/05/17 04:30

利用tfrecords文件格式来高效读取数据吧。

(原文发表在我的博客,欢迎访问

0x00.前言

最近涉及到模型的测试,需要读取数据。通常我们可以直接通过文件读取,自己写一个,方便快捷。但是考虑到要在集群上运行,这个数据文件(csv,约1.4G)在每台服务器上都要有,于是我想事先处理一下数据,看看能不能压缩或是缩小文件大小。于是我想到使用tfrecords这种文件格式。

这篇文章给我一种帮助文档的感觉,所以我没那么想发。考虑到关于tfrecords没有很多文档,网上的文章也是一篇文章模子刻出来的,这里我记录一下过程。

0x01.保存为tfrecords文件

将数据保存为tfrecords文件可以视为这样一个流程:提取features -> 保存为Example结构对象 -> TFRecordWriter写入文件

1.提取features

Features的源码在https://github.com/tensorflow/tensorflow/blob/r1.2/tensorflow/core/example/feature.proto,从源码我们可以看到features总共有3种数据类型,分别是bytesfloatint64

要想将数据类型提取出features,我们首先要将数据转化为feature(即上面三种数据类型

对应类型,封装成如下三个函数:

def byte_feature(value):    return tf.train.Feature(        bytes_list = tf.train.BytesList(value=[value.encode()])    )def float_feature(value):    return tf.train.Feature(        float_list = tf.train.FloatList(value=[value])    )def int_feature(value):    return tf.train.Feature(        int64_list = tf.train.Int64List(value=[value])    )

其中,byte_feature中的value.encode()位置,当我不带.encode()时会报错TypeError: 'AAB0162' has type str, but expected one of: bytes。如果输入的是字符串类型,也可以这样:a = b'AAB0162'

函数构造好后,我们开始构造features,这里为示例。

features = tf.train.Features(feature={            'name': byte_feature(n),            'time': float_feature(t),            'data': byte_feature(d)        })

2.保存为Example结构对象

example = tf.train.Example(features=features)

3.TFRecordWriter写入文件

# writer对象destination = 'data.tfrecords'writer = tf.python_io.TFRecordWriter(destination)# 写入writer.write(example.SerializeToString())#关闭writer.close()

4.完整代码

# 首先封装好三个函数def byte_feature(value):    return tf.train.Feature(        bytes_list = tf.train.BytesList(value=[value.encode()])    )def float_feature(value):    return tf.train.Feature(        float_list = tf.train.FloatList(value=[value])    )def int_feature(value):    return tf.train.Feature(        int64_list = tf.train.Int64List(value=[value])    )# 定义好文件名source = 'test.csv'destination = 'data.tfrecords'# 创建对象writer = tf.python_io.TFRecordWriter(destination)# 打开文件reader = open(source)# 循环,按行处理for line in reader.readlines():    # 对行的提取    n, t, d = deal(line)    # features构造    features = tf.train.Features(feature={        'name': byte_feature(n),        'time': float_feature(t),        'data': byte_feature(d)    })    # 构造example对象 及 写入    example = tf.train.Example(features=features)    writer.write(example.SerializeToString())# 关闭指针writer.close()reader.close()

0x02.从tfrecords读取数据

这里分为两部,首先我们从文件中取出features,取出后的类型为Tensor,之后我们使用sess将数据还原。

1.取出features

首先定义一会要用的参数:

# tfrecords文件file = 'data.tfrecords'# 线程数量num_threads = 2num_epochs = 100# 每批次数量batch_size = 10# 样本数量下限min_after_dequeue = 10

首先定义reader:

reader = tf.TFRecordReader()

定义输入部分:

file_queue = tf.train.string_input_producer(file, num_epochs=num_epochs,)

读取:

_, example = reader.read(file_queue)

提取出features,并保存为列表:

features_dict = tf.parse_single_example(example,     features={        'name': tf.FixedLenFeature([], tf.string),        'time': tf.FixedLenFeature([], tf.float32),        'data': tf.FixedLenFeature([], tf.string)    })n = features_dict['name']t = features_dict['time']d = features_dict['data']

之后我们将其转化为批次队列:

n, t, d = tf.train.shuffle_batch(    [n, t, d],    batch_size=batch_size,    num_threads=num_threads,    capacity = min_after_dequeue + 3 * batch_size,    min_after_dequeue = min_after_dequeue    )

2.数据还原

定义session,之后将数据还原:

with tf.Session() as sess:        # 初始化        tf.group(tf.global_variables_initializer(),            tf.local_variables_initializer()).run()        tf.train.start_queue_runners(sess=sess)        a_val, b_val, c_val = sess.run([n, t, d])        print(a_val, b_val, c_val)

3.完整代码

(这里我将features的提取封装为了函数

def records_to(file, num_threads=2, num_epochs=2, batch_size=2, min_after_dequeue=2):    reader = tf.TFRecordReader()    file_queue = tf.train.string_input_producer(file, num_epochs=num_epochs,)    _, example = reader.read(file_queue)    features_dict = tf.parse_single_example(example,         features={            'name': tf.FixedLenFeature([], tf.string),            'time': tf.FixedLenFeature([], tf.float32),            'data': tf.FixedLenFeature([], tf.string)        })    # n: Tensor("ParseSingleExample/Squeeze_name:0", shape=(), dtype=string)    n = features_dict['name']    t = features_dict['time']    d = features_dict['data']    n, t, d = tf.train.shuffle_batch(        [n, t, d],        batch_size=batch_size,        num_threads=num_threads,        capacity = min_after_dequeue + 3 * batch_size,        min_after_dequeue = min_after_dequeue    )    # 数据格式为Tensor    return n, t, ddef train():    n, t, d = records_to(['data.tfrecords'])    with tf.Session() as sess:        tf.group(tf.global_variables_initializer(),            tf.local_variables_initializer()).run()        tf.train.start_queue_runners(sess=sess)        a_val, b_val, c_val = sess.run([n, t, d])        print(a_val, b_val, c_val)

4.不构造图的读取

这一种方法比较简单,也不需要构造图,可以看作是我们把数据写入tfrecords的逆过程。

record_iterator = tf.python_io.tf_record_iterator(path=file_name)for string_record in record_iterator:    example = tf.train.Example()    example.ParseFromString(string_record)    n = example.features.feature['name'].bytes_list.value[0]    t = int(example.features.feature['time'].float_list.value[0])    d = (example.features.feature['data'].bytes_list.value[0])    print(n, t, d)

0x03.参考

TensorFlow学习记录– 7.TensorFlow高效读取数据之tfrecord详细解读

Tfrecords Guide

原创粉丝点击