TensorFlow基础4:四种类型数据的读取流程及API讲解和代码实现

来源:互联网 发布:js图片视频轮播 编辑:程序博客网 时间:2024/06/05 17:11

在上篇文章中梳理了数据读取的三种方式,但是在实际项目当中,由于数据量一般会比较大,所以更多的会使用第三种方法(即直接从文件中读取)。但是对于不同的文件类型,需要不同的文件处理API,有时候比较容易弄混淆,接下来就来梳理一下。

一.文件读取流程

这里写图片描述

如上图所示,展示了文件读取的大致流程。
最左边的A、B、C是存储于磁盘中文件,经过打乱文件之后(这里是默认的乱序读取,只是文件的顺序乱,但是文件内容不受影响),进入到文件队列中(Filename Queue)。文件队列当中的文件经过阅读器(Reader)处理,存储到内存当中。接下来对文件进行解码(Decode),解码之后进入样本队列当中进行批处理,此时经过批处理之后就可以用于模型训练了。

现在举例,对于读取CSV文件,大致要经历一下几步:
1. 找到文件,并构造文件的列表(一阶张量)
2. 构造文件队列
3. 读取文件内容
4. 解码CSV并读取内容
5. 开启会话运行,得出训练结果

二.文件读取的API

1.文件队列构造

tf.train.string_inout_producer(string_tensor,num_epochs,shuffle=True)

  • 将输出字符串(例如文件名)输入到管道队列
  • string_tensor:含有文件名的一阶张量,需要指定文件路径
  • num_epochs:将全部数据循环的次数
  • return:具有输出字符串的队列

2.文件阅读器

此时需要根据文件的格式,选择对应的文件阅读器

(1) 文本文件:tf.TextLineReader()

  • 读取文本文件,逗号分隔值(CSV)格式,默认按行读取
  • return:读取器实例

(2)二进制文件:tf.FixedLengthRecordReader(record_bytes)

  • 读取每个记录是固定数量字节的二进制文件
  • record_bytes:整型,指定每次读取的字节数
  • return:读取器实例

(3)图片文件:tf.WholeReader()

  • 将文件的全部内容作为值输出,即一次读取一整个文件
  • return:读取器实例

(4)TFRecords文件:tf.TFRecordReader()

  • 读取 TFRecords文件
  • return:读取器实例
注:这几种文件格式都有一个共同的读取方法:read(file_queue)
  • 从队列中指定内容数量
  • file_name : 文件队列
  • ruturn : 返回一个Tensor元组(key,value)
    • key : 文件名
    • value : 每次读取的值(一行文本、一张图片或指定字节的值)

3.文件内容解码器

由于从文件中读取的是字符串,需要函数去解析这些字符串,最后变换成张量
(1)CSV文件:
tf.decode_csv(records,record_defaults=None,field_delim=None,name=None)

  • 将CSV文件转换成张量,需要tf.TextLineReader()搭配使用
  • records : tensor型字符串,每个字符串是CSV中的记录行(即value值)
  • record_defaults : 此参数决定了所得张量的类型,并设置一个值,如果在输入字符串中缺少则使用默认值,如[[1],[1]] 或者[[“None”],[“None”]]
  • field_dim : 默认分隔符“ ,”

(2)二进制文件:
tf.decode_raw(bytes,out_type,little_endian=None,name=None)

  • 将字节转换为一个数字向量表示,字节为以字符串类型的张量
  • 与函数tf.FixedLengthRecordReader搭配使用
  • 将二进制转换为uint8格式

(3)图像文件:

  • 1)tf.image.decode_jpeg(contens)

    • 将JPEG编码的图像解码为uint8张量
    • return : uint8张量,3-D形状[height,width,channels]
  • 2) tf.image.decode_png(contents)

    • 将PNG编码的图像解码为uint8或者uint16编码
    • return : 张量类型,3-D形状[height,width,channels]

(4)TFRecords文件:
TFRecords文件是TensorFlow中的统一格式,它的存储和读取方式较为复杂,我会在下篇文章中单独来梳理这部分的内容。

4.批处理数据

对数据进行批处理需要在会话开启之前进行
(1)tf.train.batch(tensors,batch_size,num_threads=1,capacity=32,name=None)

  • 读取指定大小(个数)的张量
  • tensor : 包含张量的列表
  • batch_size : 从队列中读取的批处理数据大小
  • num_threads : 进入队列的线程数
  • capacity : 整数,批处理队列中元素的最大数量
  • teturn : tensors

(2)tf.train.shuffle_batch(tensors,batch_size,capacity,min_after_dequeue,num_threads=1,capacity=32,name=None)

  • 乱序读取指定大小(数量)的张量
  • min_after_dequeue : 留下队列里的张量个数,能够保持随机打乱

三.示例代码

1.CSV文件读取案例

def csvread(filelist):    """    CSV文件读取    :param filelist: 文件的列表(1阶张量)    :return:None    """    #2.构造文件的队列    file_queue = tf.train.string_input_producer(filelist)    #3.读取文件内容tf.decode_csv()    #构造阅读器    reader = tf.TextLineReader()    #读队列文件内容,一行    key,value = reader.read(file_queue)    #4、解码csv文件    #指定每一行格式的默认值,类型,[[1],[2.0],[1]]    records = [["None"],["None"]]    example,label = tf.decode_csv(value,record_defaults=records)    #批处理读取数据    example_batch,label_batch = tf.train.batch([example,label],batch_size=20,num_threads=1,capacity=100)    #5、会话运行结果    with tf.Session() as sess:        #开启线程协调器        coord = tf.train.Coordinator()        #创建子线程去进行操作,返回线程列表        threads = tf.train.start_queue_runners(sess,coord = coord)        #打印        print(sess.run([example_batch,label_batch]))        #回收        coord.request_stop()   #强制请求线程停止        coord.join(threads)    #等待线程终止回收    return Noneif __name__ == '__main__':    #列出文件目录,构造路径+文件名的列表,"A.csv"...    # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表    filename = os.listdir('./data/csvdata')    #加上路径    file_list = [os.path.join('./data/csvdata', file) for file in filename]    csvread(file_list)

2.图片文件读取案例

./data/dog文件中存储了100张 *.jpg格式的狗的图片

def picread(file_list):    """    读取狗图片并转换成张量    :param file_list:    :return:    """    #1、构造文件的队列    file_queue = tf.train.string_input_producer(file_list)    #2、生成图片读取器,读取队列内容    reader = tf.WholeFileReader()   #返回读取器实例    key ,value = reader.read(file_queue)    print(key,value)    #3.进行图片的解码    image = tf.image.decode_jpeg(value)    print(image)    #4.处理图片的大小    image_resize = tf.image.resize_images(image,[256,256])    print(image_resize)    #设置静态形状   ,动态形状也可以    image_resize.set_shape([256,256,3])    print(image_resize)    #5.进行批处理                  #此处image_siez必须指定形状,而且要为列表    image_batch = tf.train.batch([image_resize],batch_size=100,num_threads=1,capacity=100)    print(image_batch)    return image_batchif __name__ == '__main__':    # 找到文件路径,名字,构造路径+文件名的列表,"A.csv"...    # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表    filename = os.listdir('./data/dog')    #加上路径    file_list = [os.path.join('./data/dog', file) for file in filename]    image_batch = picread(file_list)    with tf.Session() as sess:        #定义线程协调器        coord = tf.train.Coordinator()        #开启线程        threads = tf.train.start_queue_runners(sess,coord=coord)        print(sess.run(image_batch))        #回收线程        coord.request_stop()        coord.join(threads)

3.二进制文件读取案例

此案例中数据是使用的下载好的二进制的cifar10数据

#读取二进制转换文件class CifarRead(object):    """    读取二进制文件转换成张量,写进TFRecords,同时读取TFRcords    """    def __init__(self,file_list):        """        初始化图片参数        :param file_list:图片的路径名称列表        """        #文件列表        self.file_list = file_list        #图片大小,二进制文件字节数        self.height = 32        self.width = 32        self.channel = 3        self.label_bytes = 1        self.image_bytes = self.height * self.width * self.channel        self.bytes = self.label_bytes + self.image_bytes    def read_and_decode(self):        """        解析二进制文件到张量        :return: 批处理的image,label张量        """        #1.构造文件队列        file_queue = tf.train.string_input_producer(self.file_list)        #2.阅读器读取内容        reader = tf.FixedLengthRecordReader(self.bytes)        key ,value = reader.read(file_queue)    #key为文件名,value为元组        print(value)        #3.进行解码,处理格式        label_image = tf.decode_raw(value,tf.uint8)        print(label_image)        #处理格式,image,label        #进行切片处理,标签值        #tf.cast()函数是转换数据格式,此处是将label二进制数据转换成int32格式        label = tf.cast(tf.slice(label_image,[0],[self.label_bytes]),tf.int32)        #处理图片数据        image = tf.slice(label_image,[self.label_bytes],[self.image_bytes])        print(image)        #处理图片的形状,提供给批处理        #因为image的形状已经固定,此处形状用动态形状来改变        image_tensor = tf.reshape(image,[self.height,self.width,self.channel])        print(image_tensor)        #批处理图片数据        image_batch,label_batch = tf.train.batch([image_tensor,label],batch_size=10,num_threads=1,capacity=10)        return image_batch,label_batchif __name__ == '__main__':    # 找到文件路径,名字,构造路径+文件名的列表,"A.csv"...    # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表    filename = os.listdir('./data/cifar10/cifar-10-batches-bin/')    #加上路径    file_list = [os.path.join('./data/cifar10/cifar-10-batches-bin/', file) for file in filename if file[-3:] == "bin"]    #初始化参数    cr = CifarRead(file_list)    image_batch,label_batch = cr.read_and_decode()    with tf.Session() as sess:        #线程协调器        coord = tf.train.Coordinator()        #开启线程        threads = tf.train.start_queue_runners(sess,coord=coord)        print(sess.run([image_batch,label_batch]))        #回收线程        coord.request_stop()        coord.join(threads)

TFRecords文件是TensorFlow中的统一格式,它的存储和读取方式较上面三种格式要稍微复杂一些,我会在下篇文章中单独来梳理这部分的内容。

阅读全文
0 0
原创粉丝点击