Tensorflow读取数据2-tfrecord

来源:互联网 发布:西部数码域名管理 编辑:程序博客网 时间:2024/05/16 05:11

原文地址: http://blog.csdn.net/u010911921/article/details/70991194

上篇博客谈到了Tensorflow从文件中读取数据,当时采用的是CIFAR-10中的二进制数据,这次记录一下官网推荐的比较通用和高效的数据文件类型的读取——TFRecord文件,这是tensorflow指定的标准格式。

1.TFRecords

TFRecords本质上是一种二进制文件,他的优点是可以更好的利用内存空间,缺点是生成过程比较耗费时间,特别是数据量比较大的情况下。文件包含了一个tf.train.Example的缓冲协议(protocol buffer)其中协议块中包含了字段Features.当用程序获得数据以后,就可以将其填充到Example的协议缓冲区(protocol buffer)中,然后在将协议缓冲区序列化为字符串,最后通过tf.python_io.TFRecordWriter将字符串写入文件。

当从TFRecords文件中读取数据时,可以利用tf.TFRecordReadertf.parse_single_example解码器,将Example缓冲协议中的内容解析为Tensor张量

2.notMNIST 数据集

在实验中采用的数据集合时notMNIST数据集,这个数据集合是由一些各种形态的字母组成的数据集合,总共由a~j10个字母组成,下图是a对应的一些图片:

另外需要注意的是,下载的数据集中有几张图片有损坏,所以处理的时候注意跳过。

3.生成TFRecords文件

为了生成TFRecords文件首先是从数据集中,将图片路径放置到一个image_list,样本的标签放置到一个label_list中。

#!/usr/bin/env python3# --*-- encoding:utf-8 --*--import tensorflow as tfimport numpy as npimport osimport matplotlib.pyplot as pltimport skimage.io as iodef get_file(file_dir):    """    get full image directory and correspond labels    :param file_dir:     :return:     """    images =[]    temp =[]    for root ,sub_folders,files in os.walk(file_dir):        #image directories        for name in files:            images.append(os.path.join(root,name))        #get 10 sub-folder names        for name in sub_folders:            temp.append(os.path.join(root,name))    labels =[]    for one_folder in temp:        n_img = len(os.listdir(one_folder))        letter = one_folder.split('/')[-1]        if letter =='A':            labels = np.append(labels,n_img*[1])        elif letter =="B":            labels = np.append(labels,n_img*[2])        elif letter =='C':            labels = np.append(labels,n_img*[3])        elif letter =="D":            labels = np.append(labels,n_img*[4])        elif letter =="E":            labels = np.append(labels,n_img*[5])        elif letter =="F":            labels = np.append(labels,n_img*[6])        elif letter =="G":            labels = np.append(labels,n_img*[7])        elif letter =="H":            labels = np.append(labels,n_img*[8])        elif letter =="I":            labels =np.append(labels,n_img*[9])        else:            labels = np.append(labels,n_img*[10])    #shuffle    temp = np.array([images,labels])    temp = temp.transpose()    np.random.shuffle(temp)    image_list = list(temp[:,0])    label_list = list(temp[:,1])    label_list = [int(float(i)) for i in label_list]    return image_list,label_list

当取得image_listlabel_list以后,读取图片数据,然后利用tf.train.Exampletf.train.Features 这两个函数来构建一个example然后将其序列化到文件中。基本上就是一个Example中包含FeaturesFeatures中包含Feature字典,Feature字典是由float_listbytes_Listint64_list等构成。

#将label转换成int64类型,为了构建tf.train.Featuredef int64_feature(value):    if not isinstance(value,list):        value = [value]    return tf.train.Feature(int64_list = tf.train.Int64List(value=value))#将image转换成bytes类型,同样也是为了构建tf.train.Featuredef bytes_feature(value):    return tf.train.Feature(bytes_list= tf.train.BytesList(value=[value]))def convert_to_tfrecord(images,labels,save_dir,name):    """    convert all images and labels to one tfrecord file    :param images:     :param labels:     :param save_dir:     :param name:     :return:     """    filename = os.path.join(save_dir,name+".tfrecords")    n_samples = len(labels)    if np.shape(images)[0] != n_samples:        raise ValueError('Image size %d does not '                         'match label size %d'%(images.shape[0],n_samples))    #wait some time    writer = tf.python_io.TFRecordWriter(filename)    print("\n Transform start....")    for i in np.arange(0,n_samples):        try:            image = io.imread(images[i])            image_raw = image.tostring()            label= int(labels[i])            example = tf.train.Example(features =tf.train.Features(feature={'label':int64_feature(label),                                                                             "image_raw":bytes_feature(image_raw)}))            writer.write(example.SerializeToString())        except IOError as e:            print("could not read :",images[i])            print("error:%s"%e)            print('Skip it')    writer.close()    print("Transform done!")

这样就完成了TFRecord的生成,但是这个过程会花费较长的时间。

4.TFRecords解码

读取一个文件还是采用上一篇博客中的queue的形式来读取,首先是生成一个文件名层的队列,然后利用tf.TFRecordReader()产生的reader来读取,然后将其读取到的内容,用tf.parse_single_example函数将labelimage_raw读取以及分离出来,为后续操作做准备

def read_and_decode(tfrecords_file,batch_size):    filename_queue = tf.train.string_input_producer([tfrecords_file])    reader = tf.TFRecordReader()    _,serialized_example = reader.read(filename_queue)    img_features = tf.parse_single_example(serialized_example,                                           features={"label":tf.FixedLenFeature([],tf.int64),                                                     "image_raw":tf.FixedLenFeature([],tf.string),})    image = tf.decode_raw(img_features['image_raw'],tf.uint8)    ################################################################    #    #put dataaugmentation here    ################################################################    image = tf.reshape(image,[28,28])    label = tf.cast(img_features['label'],tf.int32)    image_batch, label_batch = tf.train.batch([image,label],                                              batch_size = batch_size,                                              num_threads = 64,                                              capacity=2000)    return image_batch,tf.reshape(label_batch,[batch_size])

解码以后的后续过程和采用queue处理二进制文件相似。

全部代码下载地址:https://github.com/ZhichengHuang/LearnTensorflowCode/blob/master/TFRecords/TFRecord_input.py

参考资料

  1. https://www.tensorflow.org/programmers_guide/reading_data
  2. http://stackoverflow.com/questions/33849617/how-do-i-convert-a-directory-of-jpeg-images-to-tfrecords-file-in-tensorflow
  3. https://github.com/kevin28520
0 0
原创粉丝点击