TensorFlow基础5:TFRecords文件的存储与读取讲解及代码实现
来源:互联网 发布:网络系统管理 一建 编辑:程序博客网 时间:2024/06/06 07:49
上篇文章中我梳理一下在TensorFlow中几种不同类型数据读取的流程,但是没有具体说到TFRecords这种文件类型,这篇文章就来具体梳理这一文件格式。
TFRecords是TensorFlow中的设计的一种内置的文件格式,它是一种二进制文件,优点有如下几种:
- 统一不同输入文件的框架
- 它是更好的利用内存,更方便复制和移动(TFRecord压缩的二进制文件, protocal buffer序列化)
- 是用于将二进制数据和标签(训练的类别标签)数据存储在同一个文件中
一、TFRecords存储
在将其他数据存储为TFRecords文件的时候,需要经过两个步骤:
- 建立TFRecord存储器
- 构造每个样本的Example模块
1、建立TFRecord存储器
tf.python_io.TFRecordWriter(path)
- 写入tfrecords文件
- path : TFRecords文件的路径
- return : 写文件
- 方法:
- write(record):向文件中写入一个字符串记录(即一个样本)
- close() : 关闭文件写入器
注:此处的字符串为一个序列化的Example,通过Example.SerializeToString()
来实现,它的作用是将Example中的map压缩为二进制,节约大量空间。
2、构造每个样本的Example协议块
message Example { Features features = 1;};message Features { map<string, Feature> feature = 1;};message Feature { oneof kind { BytesList bytes_list = 1; FloatList float_list = 2; Int64List int64_list = 3; }};
上面这段代码即为Example协议块的规则,详解如下:
(1)tf.train.Example(features = None)
- 写入tfrecords文件
- features : tf.train.Features类型的特征实例
- return : example协议格式块
(2)tf.train.Features(feature = None)
- 构造每个样本的信息键值对
- feature : 字典数据,key为要保存的名字,value为tf.train.Feature实例
- return : Features类型
(3)tf.train.Feature(**options)
options可以选择如下三种格式数据:
bytes_list = tf.train.BytesList(value = [Bytes])
int64_list = tf.train.Int64List(value = [Value])
float_list = tf.trian.FloatList(value = [Value])
(4)将图片数据转化为TFRecords的例子:
对每一个样本,都做如下的处理:
example = tf.train.Example(feature = tf.train.Features(feature = { "image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image(bytes)])) "label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label(int)])) }))
二、TFRecords读取方法
1.流程:
和文件阅读器的流程基本相同,只是中间多了一步解析过程
2.解析TFRecords的example协议内存块:
(1)tf.parse_single_example(serialized,features=None,name= None
- 解析一个单一的Example原型
- serialized : 标量字符串的Tensor,一个序列化的Example,文件经过文件阅读器之后的value
- features :字典数据,key为读取的名字,value为FixedLenFeature
- return : 一个键值对组成的字典,键为读取的名字
(2)tf.FixedLenFeature(shape,dtype)
- shape : 输入数据的形状,一般不指定,为空列表
- dtype : 输入数据类型,与存储进文件的类型要一致,类型只能是float32,int 64, string
- return : Tensor (即使有零的部分也存储)
(3)上面(1)中features中的value还可以为tf.VarLenFeature()
,但是这种方式用的比较少,它返回的是SparseTensor数据,这是一种只存储非零部分的数据格式,了解即可。
三、代码实现
1.将CSV文件转化为TFRecords文件
import tensorflow as tfimport numpy as npimport pandas as pdtrain_frame = pd.read_csv("train.csv")print(train_frame.head())train_labels_frame = train_frame.pop(item="label")train_values = train_frame.valuestrain_labels = train_labels_frame.valuesprint("values shape: ", train_values.shape)print("labels shape:", train_labels.shape)writer = tf.python_io.TFRecordWriter("csv_train.tfrecords")for i in range(train_values.shape[0]): image_raw = train_values[i].tostring() example = tf.train.Example( features=tf.train.Features( feature={ "image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])), "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[train_labels[i]])) } ) ) writer.write(record=example.SerializeToString())writer.close()
2.将图片文件转化为TFRecords文件
import matplotlib.pyplot as pltimport matplotlib.image as mpimgimport numpy as npimport tensorflow as tfimport pandas as pddef get_label_from_filename(filename): return 1filenames = tf.train.match_filenames_once('.\data\*.png')writer = tf.python_io.TFRecordWriter('png_train.tfrecords')for filename in filenames: img=mpimg.imread(filename) print("{} shape is {}".format(filename, img.shape)) img_raw = img.tostring() label = get_label_from_filename(filename) example = tf.train.Example( features=tf.train.Features( feature={ "image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])), "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])) } ) ) writer.write(record=example.SerializeToString())writer.close()
3.将二进制文件转化为TFRecords文件
"""读取二进制文件转换成张量,写进TFRecords,同时读取TFRcords"""#命令行参数FLAGS = tf.app.flags.FLAGS #获取值tf.app.flags.DEFINE_string("tfrecord_dir","./tmp/cifar10.tfrecords","写入图片数据文件的文件名")#读取二进制转换文件class CifarRead(object): """ 读取二进制文件转换成张量,写进TFRecords,同时读取TFRcords """ def __init__(self,file_list): """ 初始化图片参数 :param file_list:图片的路径名称列表 """ #文件列表 self.file_list = file_list #图片大小,二进制文件字节数 self.height = 32 self.width = 32 self.channel = 3 self.label_bytes = 1 self.image_bytes = self.height * self.width * self.channel self.bytes = self.label_bytes + self.image_bytes def read_and_decode(self): """ 解析二进制文件到张量 :return: 批处理的image,label张量 """ #1.构造文件队列 file_queue = tf.train.string_input_producer(self.file_list) #2.阅读器读取内容 reader = tf.FixedLengthRecordReader(self.bytes) key ,value = reader.read(file_queue) #key为文件名,value为元组 print(value) #3.进行解码,处理格式 label_image = tf.decode_raw(value,tf.uint8) print(label_image) #处理格式,image,label #进行切片处理,标签值 #tf.cast()函数是转换数据格式,此处是将label二进制数据转换成int32格式 label = tf.cast(tf.slice(label_image,[0],[self.label_bytes]),tf.int32) #处理图片数据 image = tf.slice(label_image,[self.label_bytes],[self.image_bytes]) print(image) #处理图片的形状,提供给批处理 #因为image的形状已经固定,此处形状用动态形状来改变 image_tensor = tf.reshape(image,[self.height,self.width,self.channel]) print(image_tensor) #批处理图片数据 image_batch,label_batch = tf.train.batch([image_tensor,label],batch_size=10,num_threads=1,capacity=10) return image_batch,label_batch def write_to_tfrecords(self,image_batch,label_batch): """ 将文件写入到TFRecords文件中 :param image_batch: :param label_batch: :return: """ #建立TFRecords文件存储器 writer = tf.python_io.TFRecordWriter(FLAGS.tfrecord_dir) #传进去命令行参数 #循环取出每个样本的值,构造example协议块 for i in range(10): #取出图片的值, #写进去的是值,而不是tensor类型, # 写入example需要bytes文件格式,将tensor转化为bytes用tostring()来转化 image = image_batch[i].eval().tostring() #取出标签值,写入example中需要使用int形式,所以需要强制转换int label = int(label_batch[i].eval()[0]) #构造每个样本的example协议块 example = tf.train.Example(features = tf.train.Features(feature = { "image":tf.train.Feature(bytes_list = tf.train.BytesList(value = [image])), "label":tf.train.Feature(int64_list = tf.train.Int64List(value = [label])) })) #写进去序列化后的值 writer.write(example.SerializeToString()) #此处其实是将其压缩成一个二进制数据 writer.close() return None def read_from_tfrecords(self): """ 从TFRecords文件当中读取图片数据(解析example) :param self: :return: image_batch,label_batch """ #1.构造文件队列 file_queue = tf.train.string_input_producer([FLAGS.tfrecord_dir]) #参数为文件名列表 #2.构造阅读器 reader = tf.TFRecordReader() key,value = reader.read(file_queue) #3.解析协议块,返回的值是字典 feature = tf.parse_single_example(value,features={ "image":tf.FixedLenFeature([],tf.string), "label":tf.FixedLenFeature([],tf.int64) }) #feature["image"],feature["label"] #处理标签数据 ,cast()只能在int和float之间进行转换 label = tf.cast(feature["label"],tf.int32) #将数据类型int64 转换为int32 #处理图片数据,由于是一个string,要进行解码, #将字节转换为数字向量表示,字节为一字符串类型的张量 #如果之前用了tostring(),那么必须要用decode_raw()转换为最初的int类型 # decode_raw()可以将数据从string,bytes转换为int,float类型的 image = tf.decode_raw(feature["image"],tf.uint8) #转换图片的形状,此处需要用动态形状进行转换 image_tensor = tf.reshape(image,[self.height,self.width,self.channel]) #4.批处理 image_batch,label_batch = tf.train.batch([image_tensor,label],batch_size=10,num_threads=1,capacity=10) return image_batch,label_batchif __name__ == '__main__': # 找到文件路径,名字,构造路径+文件名的列表,"A.csv"... # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表 filename = os.listdir('./data/cifar10/cifar-10-batches-bin/') #加上路径 file_list = [os.path.join('./data/cifar10/cifar-10-batches-bin/', file) for file in filename if file[-3:] == "bin"] #初始化参数 cr = CifarRead(file_list) #读取二进制文件 # image_batch,label_batch = cr.read_and_decode() #从已经存储的TFRecords文件中解析出原始数据 image_batch, label_batch = cr.read_from_tfrecords() with tf.Session() as sess: #线程协调器 coord = tf.train.Coordinator() #开启线程 threads = tf.train.start_queue_runners(sess,coord=coord) print(sess.run([image_batch,label_batch])) # print("存进TFRecords文件") # cr.write_to_tfrecords(image_batch,label_batch) # print("存进文件完毕") #回收线程 coord.request_stop() coord.join(threads)
注:
上段代码分为两个部分:
- 第一部分是被注释掉的几行代码,表示的是将二进制文件转化为张量,并经过Example协议存储到TFRecords文件当中;
- 第二部分是从已经存储好数据信息的TFRecords文件中,经过解析,转化为最初的二进制文件。
- TensorFlow基础5:TFRecords文件的存储与读取讲解及代码实现
- TensorFlow基础4:四种类型数据的读取流程及API讲解和代码实现
- 机器学习: TensorFlow 的数据读取与TFRecords 格式
- TensorFlow读取tfrecords数据
- tensorflow数据读取之tfrecords
- Tensorflow分批量读取tfrecords
- TFRecords 文件的生成和读取
- 7.2 TensorFlow笔记(基础篇): 生成TFRecords文件
- [TFRecord格式数据]利用TFRecords存储与读取带标签的图片
- tensorflow中tfrecords文件的save和read
- tensorflow将CSV文件转为TFrecords文件
- TensorFlow .tfrecords训练文件生成、使用
- 文件的存储与读取
- tensorflow:使用tfrecords时的注意事项
- tensorflow系列(4)tfrecords的使用
- tensorflow中tfrecords格式的读写
- TensorFLow 不同大小图片的TFrecords存取
- 用tensorflow训练自己的图片集-用TFRecords将代码导入神经网络
- phpstrom注册
- HDOJ1040 As Easy As A+B
- 点击事件的event的应用
- MAC下IDEA连接MySQL数据库
- Tensorflow学习笔记:Debugging 调试Tensorflow 程序
- TensorFlow基础5:TFRecords文件的存储与读取讲解及代码实现
- 树莓派
- Cannot start container web: iptables failed: iptables -t nat -A DOCKER -p tcp -d 0/0 --dport 32797
- 结构化机器学习项目Quiz1
- Python爬虫实战:抓取淘宝MM照片
- LCD字体
- 操作系统引论
- HbaseSchool 一套包含了HbaseGo和HbaseTo的Java Hbase数据操作快速开发框架
- OkHttp框架读书总结笔记