TensorFlow基础5:TFRecords文件的存储与读取讲解及代码实现

来源:互联网 发布:网络系统管理 一建 编辑:程序博客网 时间:2024/06/06 07:49

上篇文章中我梳理一下在TensorFlow中几种不同类型数据读取的流程,但是没有具体说到TFRecords这种文件类型,这篇文章就来具体梳理这一文件格式。

TFRecords是TensorFlow中的设计的一种内置的文件格式,它是一种二进制文件,优点有如下几种:

  • 统一不同输入文件的框架
  • 它是更好的利用内存,更方便复制和移动(TFRecord压缩的二进制文件, protocal buffer序列化)
  • 是用于将二进制数据和标签(训练的类别标签)数据存储在同一个文件中

一、TFRecords存储

在将其他数据存储为TFRecords文件的时候,需要经过两个步骤:

  • 建立TFRecord存储器
  • 构造每个样本的Example模块

1、建立TFRecord存储器

tf.python_io.TFRecordWriter(path)

  • 写入tfrecords文件
  • path : TFRecords文件的路径
  • return : 写文件
  • 方法:
    • write(record):向文件中写入一个字符串记录(即一个样本)
    • close() : 关闭文件写入器

注:此处的字符串为一个序列化的Example,通过Example.SerializeToString()来实现,它的作用是将Example中的map压缩为二进制,节约大量空间。

2、构造每个样本的Example协议块

message Example {  Features features = 1;};message Features {  map<string, Feature> feature = 1;};message Feature {  oneof kind {    BytesList bytes_list = 1;    FloatList float_list = 2;    Int64List int64_list = 3;  }};

上面这段代码即为Example协议块的规则,详解如下:
(1)tf.train.Example(features = None)

  • 写入tfrecords文件
  • features : tf.train.Features类型的特征实例
  • return : example协议格式块

(2)tf.train.Features(feature = None)

  • 构造每个样本的信息键值对
  • feature : 字典数据,key为要保存的名字,value为tf.train.Feature实例
  • return : Features类型

(3)tf.train.Feature(**options)
options可以选择如下三种格式数据:

  • bytes_list = tf.train.BytesList(value = [Bytes])
  • int64_list = tf.train.Int64List(value = [Value])
  • float_list = tf.trian.FloatList(value = [Value])

(4)将图片数据转化为TFRecords的例子:
对每一个样本,都做如下的处理:

example = tf.train.Example(feature = tf.train.Features(feature = {                            "image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image(bytes)]))                             "label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label(int)]))    }))

二、TFRecords读取方法

1.流程:

和文件阅读器的流程基本相同,只是中间多了一步解析过程

2.解析TFRecords的example协议内存块:

(1)tf.parse_single_example(serialized,features=None,name= None

  • 解析一个单一的Example原型
  • serialized : 标量字符串的Tensor,一个序列化的Example,文件经过文件阅读器之后的value
  • features :字典数据,key为读取的名字,value为FixedLenFeature
  • return : 一个键值对组成的字典,键为读取的名字

(2)tf.FixedLenFeature(shape,dtype)

  • shape : 输入数据的形状,一般不指定,为空列表
  • dtype : 输入数据类型,与存储进文件的类型要一致,类型只能是float32,int 64, string
  • return : Tensor (即使有零的部分也存储)

(3)上面(1)中features中的value还可以为tf.VarLenFeature(),但是这种方式用的比较少,它返回的是SparseTensor数据,这是一种只存储非零部分的数据格式,了解即可。

三、代码实现

1.将CSV文件转化为TFRecords文件

import tensorflow as tfimport numpy as npimport pandas as pdtrain_frame = pd.read_csv("train.csv")print(train_frame.head())train_labels_frame = train_frame.pop(item="label")train_values = train_frame.valuestrain_labels = train_labels_frame.valuesprint("values shape: ", train_values.shape)print("labels shape:", train_labels.shape)writer = tf.python_io.TFRecordWriter("csv_train.tfrecords")for i in range(train_values.shape[0]):    image_raw = train_values[i].tostring()    example = tf.train.Example(        features=tf.train.Features(            feature={                "image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[train_labels[i]]))            }        )    )    writer.write(record=example.SerializeToString())writer.close()

2.将图片文件转化为TFRecords文件

import matplotlib.pyplot as pltimport matplotlib.image as mpimgimport numpy as npimport tensorflow as tfimport pandas as pddef get_label_from_filename(filename):    return 1filenames = tf.train.match_filenames_once('.\data\*.png')writer = tf.python_io.TFRecordWriter('png_train.tfrecords')for filename in filenames:    img=mpimg.imread(filename)    print("{} shape is {}".format(filename, img.shape))    img_raw = img.tostring()    label = get_label_from_filename(filename)    example = tf.train.Example(        features=tf.train.Features(            feature={                "image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))            }        )    )    writer.write(record=example.SerializeToString())writer.close()

3.将二进制文件转化为TFRecords文件

"""读取二进制文件转换成张量,写进TFRecords,同时读取TFRcords"""#命令行参数FLAGS = tf.app.flags.FLAGS       #获取值tf.app.flags.DEFINE_string("tfrecord_dir","./tmp/cifar10.tfrecords","写入图片数据文件的文件名")#读取二进制转换文件class CifarRead(object):    """    读取二进制文件转换成张量,写进TFRecords,同时读取TFRcords    """    def __init__(self,file_list):        """        初始化图片参数        :param file_list:图片的路径名称列表        """        #文件列表        self.file_list = file_list        #图片大小,二进制文件字节数        self.height = 32        self.width = 32        self.channel = 3        self.label_bytes = 1        self.image_bytes = self.height * self.width * self.channel        self.bytes = self.label_bytes + self.image_bytes    def read_and_decode(self):        """        解析二进制文件到张量        :return: 批处理的image,label张量        """        #1.构造文件队列        file_queue = tf.train.string_input_producer(self.file_list)        #2.阅读器读取内容        reader = tf.FixedLengthRecordReader(self.bytes)        key ,value = reader.read(file_queue)    #key为文件名,value为元组        print(value)        #3.进行解码,处理格式        label_image = tf.decode_raw(value,tf.uint8)        print(label_image)        #处理格式,image,label        #进行切片处理,标签值        #tf.cast()函数是转换数据格式,此处是将label二进制数据转换成int32格式        label = tf.cast(tf.slice(label_image,[0],[self.label_bytes]),tf.int32)        #处理图片数据        image = tf.slice(label_image,[self.label_bytes],[self.image_bytes])        print(image)        #处理图片的形状,提供给批处理        #因为image的形状已经固定,此处形状用动态形状来改变        image_tensor = tf.reshape(image,[self.height,self.width,self.channel])        print(image_tensor)        #批处理图片数据        image_batch,label_batch = tf.train.batch([image_tensor,label],batch_size=10,num_threads=1,capacity=10)        return image_batch,label_batch    def write_to_tfrecords(self,image_batch,label_batch):        """        将文件写入到TFRecords文件中        :param image_batch:        :param label_batch:        :return:        """        #建立TFRecords文件存储器        writer = tf.python_io.TFRecordWriter(FLAGS.tfrecord_dir)      #传进去命令行参数        #循环取出每个样本的值,构造example协议块        for i in range(10):            #取出图片的值,  #写进去的是值,而不是tensor类型,            # 写入example需要bytes文件格式,将tensor转化为bytes用tostring()来转化            image = image_batch[i].eval().tostring()            #取出标签值,写入example中需要使用int形式,所以需要强制转换int            label = int(label_batch[i].eval()[0])            #构造每个样本的example协议块            example = tf.train.Example(features = tf.train.Features(feature = {                "image":tf.train.Feature(bytes_list = tf.train.BytesList(value = [image])),                "label":tf.train.Feature(int64_list = tf.train.Int64List(value = [label]))            }))            #写进去序列化后的值            writer.write(example.SerializeToString())     #此处其实是将其压缩成一个二进制数据        writer.close()        return None    def read_from_tfrecords(self):        """        从TFRecords文件当中读取图片数据(解析example)        :param self:        :return: image_batch,label_batch        """        #1.构造文件队列        file_queue = tf.train.string_input_producer([FLAGS.tfrecord_dir])    #参数为文件名列表        #2.构造阅读器        reader = tf.TFRecordReader()        key,value = reader.read(file_queue)        #3.解析协议块,返回的值是字典        feature = tf.parse_single_example(value,features={            "image":tf.FixedLenFeature([],tf.string),            "label":tf.FixedLenFeature([],tf.int64)        })        #feature["image"],feature["label"]        #处理标签数据    ,cast()只能在int和float之间进行转换        label = tf.cast(feature["label"],tf.int32)    #将数据类型int64 转换为int32        #处理图片数据,由于是一个string,要进行解码,  #将字节转换为数字向量表示,字节为一字符串类型的张量        #如果之前用了tostring(),那么必须要用decode_raw()转换为最初的int类型        # decode_raw()可以将数据从string,bytes转换为int,float类型的        image = tf.decode_raw(feature["image"],tf.uint8)        #转换图片的形状,此处需要用动态形状进行转换        image_tensor = tf.reshape(image,[self.height,self.width,self.channel])        #4.批处理        image_batch,label_batch = tf.train.batch([image_tensor,label],batch_size=10,num_threads=1,capacity=10)        return image_batch,label_batchif __name__ == '__main__':    # 找到文件路径,名字,构造路径+文件名的列表,"A.csv"...    # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表    filename = os.listdir('./data/cifar10/cifar-10-batches-bin/')    #加上路径    file_list = [os.path.join('./data/cifar10/cifar-10-batches-bin/', file) for file in filename if file[-3:] == "bin"]    #初始化参数    cr = CifarRead(file_list)    #读取二进制文件    # image_batch,label_batch = cr.read_and_decode()    #从已经存储的TFRecords文件中解析出原始数据    image_batch, label_batch = cr.read_from_tfrecords()    with tf.Session() as sess:        #线程协调器        coord = tf.train.Coordinator()        #开启线程        threads = tf.train.start_queue_runners(sess,coord=coord)        print(sess.run([image_batch,label_batch]))        # print("存进TFRecords文件")        # cr.write_to_tfrecords(image_batch,label_batch)        # print("存进文件完毕")        #回收线程        coord.request_stop()        coord.join(threads)

注:
上段代码分为两个部分:

  • 第一部分是被注释掉的几行代码,表示的是将二进制文件转化为张量,并经过Example协议存储到TFRecords文件当中;
  • 第二部分是从已经存储好数据信息的TFRecords文件中,经过解析,转化为最初的二进制文件。
原创粉丝点击