TensorFlow基础4:四种类型数据的读取流程及API讲解和代码实现
来源:互联网 发布:js图片视频轮播 编辑:程序博客网 时间:2024/06/05 17:11
在上篇文章中梳理了数据读取的三种方式,但是在实际项目当中,由于数据量一般会比较大,所以更多的会使用第三种方法(即直接从文件中读取)。但是对于不同的文件类型,需要不同的文件处理API,有时候比较容易弄混淆,接下来就来梳理一下。
一.文件读取流程
如上图所示,展示了文件读取的大致流程。
最左边的A、B、C是存储于磁盘中文件,经过打乱文件之后(这里是默认的乱序读取,只是文件的顺序乱,但是文件内容不受影响),进入到文件队列中(Filename Queue)。文件队列当中的文件经过阅读器(Reader)处理,存储到内存当中。接下来对文件进行解码(Decode),解码之后进入样本队列当中进行批处理,此时经过批处理之后就可以用于模型训练了。
现在举例,对于读取CSV文件,大致要经历一下几步:
1. 找到文件,并构造文件的列表(一阶张量)
2. 构造文件队列
3. 读取文件内容
4. 解码CSV并读取内容
5. 开启会话运行,得出训练结果
二.文件读取的API
1.文件队列构造
tf.train.string_inout_producer(string_tensor,num_epochs,shuffle=True)
- 将输出字符串(例如文件名)输入到管道队列
string_tensor
:含有文件名的一阶张量,需要指定文件路径num_epochs
:将全部数据循环的次数return
:具有输出字符串的队列
2.文件阅读器
此时需要根据文件的格式,选择对应的文件阅读器
(1) 文本文件:tf.TextLineReader()
- 读取文本文件,逗号分隔值(CSV)格式,默认按行读取
- return:读取器实例
(2)二进制文件:tf.FixedLengthRecordReader(record_bytes)
- 读取每个记录是固定数量字节的二进制文件
- record_bytes:整型,指定每次读取的字节数
- return:读取器实例
(3)图片文件:tf.WholeReader()
- 将文件的全部内容作为值输出,即一次读取一整个文件
- return:读取器实例
(4)TFRecords文件:tf.TFRecordReader()
- 读取 TFRecords文件
- return:读取器实例
注:这几种文件格式都有一个共同的读取方法:read(file_queue)
- 从队列中指定内容数量
- file_name : 文件队列
- ruturn : 返回一个Tensor元组(key,value)
- key : 文件名
- value : 每次读取的值(一行文本、一张图片或指定字节的值)
3.文件内容解码器
由于从文件中读取的是字符串,需要函数去解析这些字符串,最后变换成张量
(1)CSV文件: tf.decode_csv(records,record_defaults=None,field_delim=None,name=None)
- 将CSV文件转换成张量,需要
tf.TextLineReader()
搭配使用 - records : tensor型字符串,每个字符串是CSV中的记录行(即value值)
- record_defaults : 此参数决定了所得张量的类型,并设置一个值,如果在输入字符串中缺少则使用默认值,如[[1],[1]] 或者[[“None”],[“None”]]
- field_dim : 默认分隔符“ ,”
(2)二进制文件: tf.decode_raw(bytes,out_type,little_endian=None,name=None)
- 将字节转换为一个数字向量表示,字节为以字符串类型的张量
- 与函数tf.FixedLengthRecordReader搭配使用
- 将二进制转换为uint8格式
(3)图像文件:
1)
tf.image.decode_jpeg(contens)
- 将JPEG编码的图像解码为uint8张量
- return : uint8张量,3-D形状[height,width,channels]
2)
tf.image.decode_png(contents)
- 将PNG编码的图像解码为uint8或者uint16编码
- return : 张量类型,3-D形状[height,width,channels]
(4)TFRecords文件:
TFRecords文件是TensorFlow中的统一格式,它的存储和读取方式较为复杂,我会在下篇文章中单独来梳理这部分的内容。
4.批处理数据
对数据进行批处理需要在会话开启之前进行
(1)tf.train.batch(tensors,batch_size,num_threads=1,capacity=32,name=None)
- 读取指定大小(个数)的张量
- tensor : 包含张量的列表
- batch_size : 从队列中读取的批处理数据大小
- num_threads : 进入队列的线程数
- capacity : 整数,批处理队列中元素的最大数量
- teturn : tensors
(2)tf.train.shuffle_batch(tensors,batch_size,capacity,min_after_dequeue,num_threads=1,capacity=32,name=None)
- 乱序读取指定大小(数量)的张量
- min_after_dequeue : 留下队列里的张量个数,能够保持随机打乱
三.示例代码
1.CSV文件读取案例
def csvread(filelist): """ CSV文件读取 :param filelist: 文件的列表(1阶张量) :return:None """ #2.构造文件的队列 file_queue = tf.train.string_input_producer(filelist) #3.读取文件内容tf.decode_csv() #构造阅读器 reader = tf.TextLineReader() #读队列文件内容,一行 key,value = reader.read(file_queue) #4、解码csv文件 #指定每一行格式的默认值,类型,[[1],[2.0],[1]] records = [["None"],["None"]] example,label = tf.decode_csv(value,record_defaults=records) #批处理读取数据 example_batch,label_batch = tf.train.batch([example,label],batch_size=20,num_threads=1,capacity=100) #5、会话运行结果 with tf.Session() as sess: #开启线程协调器 coord = tf.train.Coordinator() #创建子线程去进行操作,返回线程列表 threads = tf.train.start_queue_runners(sess,coord = coord) #打印 print(sess.run([example_batch,label_batch])) #回收 coord.request_stop() #强制请求线程停止 coord.join(threads) #等待线程终止回收 return Noneif __name__ == '__main__': #列出文件目录,构造路径+文件名的列表,"A.csv"... # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表 filename = os.listdir('./data/csvdata') #加上路径 file_list = [os.path.join('./data/csvdata', file) for file in filename] csvread(file_list)
2.图片文件读取案例
./data/dog文件中存储了100张 *.jpg格式的狗的图片
def picread(file_list): """ 读取狗图片并转换成张量 :param file_list: :return: """ #1、构造文件的队列 file_queue = tf.train.string_input_producer(file_list) #2、生成图片读取器,读取队列内容 reader = tf.WholeFileReader() #返回读取器实例 key ,value = reader.read(file_queue) print(key,value) #3.进行图片的解码 image = tf.image.decode_jpeg(value) print(image) #4.处理图片的大小 image_resize = tf.image.resize_images(image,[256,256]) print(image_resize) #设置静态形状 ,动态形状也可以 image_resize.set_shape([256,256,3]) print(image_resize) #5.进行批处理 #此处image_siez必须指定形状,而且要为列表 image_batch = tf.train.batch([image_resize],batch_size=100,num_threads=1,capacity=100) print(image_batch) return image_batchif __name__ == '__main__': # 找到文件路径,名字,构造路径+文件名的列表,"A.csv"... # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表 filename = os.listdir('./data/dog') #加上路径 file_list = [os.path.join('./data/dog', file) for file in filename] image_batch = picread(file_list) with tf.Session() as sess: #定义线程协调器 coord = tf.train.Coordinator() #开启线程 threads = tf.train.start_queue_runners(sess,coord=coord) print(sess.run(image_batch)) #回收线程 coord.request_stop() coord.join(threads)
3.二进制文件读取案例
此案例中数据是使用的下载好的二进制的cifar10数据
#读取二进制转换文件class CifarRead(object): """ 读取二进制文件转换成张量,写进TFRecords,同时读取TFRcords """ def __init__(self,file_list): """ 初始化图片参数 :param file_list:图片的路径名称列表 """ #文件列表 self.file_list = file_list #图片大小,二进制文件字节数 self.height = 32 self.width = 32 self.channel = 3 self.label_bytes = 1 self.image_bytes = self.height * self.width * self.channel self.bytes = self.label_bytes + self.image_bytes def read_and_decode(self): """ 解析二进制文件到张量 :return: 批处理的image,label张量 """ #1.构造文件队列 file_queue = tf.train.string_input_producer(self.file_list) #2.阅读器读取内容 reader = tf.FixedLengthRecordReader(self.bytes) key ,value = reader.read(file_queue) #key为文件名,value为元组 print(value) #3.进行解码,处理格式 label_image = tf.decode_raw(value,tf.uint8) print(label_image) #处理格式,image,label #进行切片处理,标签值 #tf.cast()函数是转换数据格式,此处是将label二进制数据转换成int32格式 label = tf.cast(tf.slice(label_image,[0],[self.label_bytes]),tf.int32) #处理图片数据 image = tf.slice(label_image,[self.label_bytes],[self.image_bytes]) print(image) #处理图片的形状,提供给批处理 #因为image的形状已经固定,此处形状用动态形状来改变 image_tensor = tf.reshape(image,[self.height,self.width,self.channel]) print(image_tensor) #批处理图片数据 image_batch,label_batch = tf.train.batch([image_tensor,label],batch_size=10,num_threads=1,capacity=10) return image_batch,label_batchif __name__ == '__main__': # 找到文件路径,名字,构造路径+文件名的列表,"A.csv"... # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表 filename = os.listdir('./data/cifar10/cifar-10-batches-bin/') #加上路径 file_list = [os.path.join('./data/cifar10/cifar-10-batches-bin/', file) for file in filename if file[-3:] == "bin"] #初始化参数 cr = CifarRead(file_list) image_batch,label_batch = cr.read_and_decode() with tf.Session() as sess: #线程协调器 coord = tf.train.Coordinator() #开启线程 threads = tf.train.start_queue_runners(sess,coord=coord) print(sess.run([image_batch,label_batch])) #回收线程 coord.request_stop() coord.join(threads)
TFRecords文件是TensorFlow中的统一格式,它的存储和读取方式较上面三种格式要稍微复杂一些,我会在下篇文章中单独来梳理这部分的内容。
- TensorFlow基础4:四种类型数据的读取流程及API讲解和代码实现
- TensorFlow基础5:TFRecords文件的存储与读取讲解及代码实现
- API的四种类型
- TensorFlow 读取CSV数据代码实现
- TensorFlow读取CSV数据的实现
- 图像的四种类型及简述
- TensorFlow 学习(二) 制作自己的TFRecord数据集,读取,显示及代码详解
- TensorFlow全新的数据读取方式:Dataset API入门教程(转)
- TensorFlow全新的数据读取方式:Dataset API入门教程
- TensorFlow全新的数据读取方式:Dataset API入门教程
- TensorFlow全新的数据读取方式:Dataset API入门教程
- TensorFlow全新的数据读取方式:Dataset API入门教程
- TensorFlow全新的数据读取方式:Dataset API入门教程
- TensorFlow全新的数据读取方式:Dataset API入门教程
- TensorFlow全新的数据读取方式:Dataset API入门教程
- C++/面试 - 四种类型转换(cast)的关键字 详解 及 代码
- Tensorflow中padding的两种类型SAME和VALID
- Tensorflow中padding的两种类型SAME和VALID
- php将连续回车(换行)、空格正则替换为1个
- Java并发学习(八)-AtomicIntegerArray数组类型类
- Mybatis全部标签
- ES6/ES2015(下)
- 大数据篇_Hadoop入门介绍
- TensorFlow基础4:四种类型数据的读取流程及API讲解和代码实现
- 梯度下降算法实现
- 面向对象-接口 interface关键字
- HDFS NameNode重启优化
- [新手向] c++大数运算 (1)
- 网络优化之MobileNet
- 安卓智能地图开发与实施二十二:展示三维场景
- Unity3D
- ELF函数重定位问题