用Tensorflow处理自己的数据:制作自己的TFRecords数据集

来源:互联网 发布:淘宝查号网131458. 编辑:程序博客网 时间:2024/05/16 09:59

转载请注明作者和出处: http://blog.csdn.net/wiinter_fdd/article/details/72835939
运行平台: Windows
Python版本: Python3.x
IDE: Spyder

前言

   最近一直在研究深度学习,主要是针对卷积神经网络(CNN),接触过的数据集也有了几个,最经典的就是MNIST, CIFAR10/100, NOTMNIST, CATS_VS_DOGS 这几种,由于这几种是在深度学习入门中最被广泛应用的,所以很多深度学习框架 Tensorflow、keraspytorch都有针对这些数据集专用的数据导入的函数封装,但是一般情况下我们的数据集并不是这种很规范的形式,那么如何把自己的数据集转换成这些框架能够使用的数据形式至关重要,接下来博主将会对现有的较流行的深度学习框架封装自己的数据进行讲解,首先是针对最流行的Tensorflow。

   查阅tensorflow的官方API,在GET STARTED下面的Programmer’s Guide中有一个Reading Data的章节介绍,大体内容就是tensorflow读取数据的方式:
这里写图片描述
可以看到,tensorflow官网给出了三种读取数据的方法:
对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练(tip:使用这种方法时,结合yield 使用更为简洁,大家自己尝试一下吧,我就不赘述了)。但是,如果数据量较大,这样的方法就不适用了,因为太耗内存,所以这时最好使用tensorflow提供的队列queue,也就是第二种方法 从文件读取数据。对于一些特定的读取,比如csv文件格式,官网有相关的描述,在这儿我介绍一种比较通用,高效的读取方法(官网介绍的少),即使用tensorflow内定标准格式——TFRecord.

   那下面就让我们了解一下什么是TFRecord:

1. What is TFRecord?

TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件(等会儿就知道为什么了)… …总而言之,这样的文件格式好处多多,所以让我们用起来吧。这里注意:TFRecord会根据你输入的文件的类,自动给每一类打上同样的标签。

TFRecord文件中的数据是通过tf.train.Example Protocol Buffer的格式存储的,下面是tf.train.Example的定义:

message Example { Features features = 1;};message Features{ map<string,Feature> featrue = 1;};message Feature{    oneof kind{        BytesList bytes_list = 1;        FloatList float_list = 2;        Int64List int64_list = 3;    }};

从上述代码可以看到,ft.train.Example 的数据结构相对简洁。tf.train.Example中包含了一个从属性名称到取值的字典,其中属性名称为一个字符串,属性的取值可以为字符串(BytesList ),实数列表(FloatList )或整数列表(Int64List )。例如我们可以将解码前的图片作为字符串,图像对应的类别标号作为整数列表。

2. How to convert our own data to TFRecord?

终于我们关心的话题来了,怎么转换?这里我们使用Kaggle上面有名猫狗大战的数据集可以通过Dogs vs Cats来下载,为了方便演示,我们利用这个数据集创建了一个新的数据集,取猫狗图片中各100张分别放在data文件夹下面的cats和dogs子文件中,入下图所示。
这里写图片描述

数据准备好以后,我们就要开始读取数据,生成TFRecord了,下面直接上代码,对于代码内容随后会有相应的说明:

# -*- coding: utf-8 -*-import os import tensorflow as tffrom PIL import Imageimport matplotlib.pyplot as pltimport numpy as npcwd = "D://Anaconda3//spyder//Tensorflow_ReadData//data//"classes = {'cats', 'dogs'} #预先自己定义的类别writer = tf.python_io.TFRecordWriter('train.tfrecords') #输出成tfrecord文件def _int64_feature(value):    return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))def _bytes_feature(value):    return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))for index, name in enumerate(classes):    class_path = cwd + name + '//'    for img_name in os.listdir(class_path):        img_path = class_path + img_name    #每个图片的地址        img = Image.open(img_path)        img = img.resize((208, 208))        img_raw = img.tobytes()  #将图片转化为二进制格式        example = tf.train.Example(features = tf.train.Features(feature = {                                                                           "label": _int64_feature(index),                                                                           "img_raw": _bytes_feature(img_raw),                                                                                                                                                     }))        writer.write(example.SerializeToString())  #序列化为字符串writer.close()

以上代码就是将数据读去进来,生成tfrecord文件,看过Tensorflow官方API的同学们应该都可以看懂。

3. How to read data from TFRecords?

上面已经把自己的数据保存成tensorflow可以使用的tfrecord的形式了,那么tensorflow到底如何使用呢?下面继续看代码:

def read_and_decode(filename, batch_size): # read train.tfrecords    filename_queue = tf.train.string_input_producer([filename])# create a queue    reader = tf.TFRecordReader()    _, serialized_example = reader.read(filename_queue)#return file_name and file    features = tf.parse_single_example(serialized_example,                                       features={                                           'label': tf.FixedLenFeature([], tf.int64),                                           'img_raw' : tf.FixedLenFeature([], tf.string),                                       })#return image and label    img = tf.decode_raw(features['img_raw'], tf.uint8)    img = tf.reshape(img, [208, 208, 3])  #reshape image to 208*208*3    label = tf.cast(features['label'], tf.int32) #throw label tensor    img_batch, label_batch = tf.train.shuffle_batch([img, label],                                                    batch_size= batch——size,                                                    num_threads=64,                                                    capacity=2000,                                                    min_after_dequeue=1500,                                                    )    return img_batch, tf.reshape(label_batch,[batch_size])

   以上是我们定义的从tfrecord文件中读取数据的函数,在这里我们使用的tensorflow的队列读取方式。在读取到队列中后,数据输出之前还要作解码的操作从,可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。
   可以看到,这个函数除了tfrecord文件的名这一个参数外,还有batch_size这个参数,利用tf.train.shuffle_batch()这个函数对读取到的数据进行了batch处理,这样更有利于后续的训练。
注意:当数据量加大时,也可以将数据写入多个TFRecord文件。

  我们的数据是读进来,那么到底是不是我们想要的呢,下面就是我们的测试程序。

4. How to show TFRecords’ images?

tfrecords_file = 'D://Anaconda3//spyder//Tensorflow_ReadData//train.tfrecords'BATCH_SIZE = 4image_batch, label_batch = read_and_decode(tfrecords_file,BATCH_SIZE)with tf.Session()  as sess:    i = 0    coord = tf.train.Coordinator()    threads = tf.train.start_queue_runners(coord=coord)    try:        while not coord.should_stop() and i<1:            # just plot one batch size            image, label = sess.run([image_batch, label_batch])            for j in np.arange(4):                print('label: %d' % label[j])                plt.imshow(image[j,:,:,:])                plt.show()            i+=1    except tf.errors.OutOfRangeError:        print('done!')    finally:        coord.request_stop()    coord.join(threads)

这里我们也是用的tensorflow官网推荐的队列管理形式,batch_size这里可以大家任意设定,显示几幅图片都可以,这里博主设置的是4。

这样就可以把任意格式的数据转换成tensorflow推荐的TFRecord的格式的了,是不是随你有很大帮助呢。下面是完整的代码:

# -*- coding: utf-8 -*-import os import tensorflow as tffrom PIL import Imageimport matplotlib.pyplot as pltimport numpy as np#%%cwd = "D://Anaconda3//spyder//Tensorflow_ReadData//data//"classes = {'cats', 'dogs'}writer = tf.python_io.TFRecordWriter('train.tfrecords')def _int64_feature(value):    return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))def _bytes_feature(value):    return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))for index, name in enumerate(classes):    class_path = cwd + name + '//'    for img_name in os.listdir(class_path):        img_path = class_path + img_name    #每个图片的地址        img = Image.open(img_path)        img = img.resize((208, 208))        img_raw = img.tobytes()  #将图片转化为二进制格式        example = tf.train.Example(features = tf.train.Features(feature = {                                                                           "label": _int64_feature(index),                                                                           "img_raw": _bytes_feature(img_raw),                                                                                                                                                     }))        writer.write(example.SerializeToString())  #序列化为字符串writer.close()#%%def read_and_decode(filename, batch_size): # read train.tfrecords    filename_queue = tf.train.string_input_producer([filename])# create a queue    reader = tf.TFRecordReader()    _, serialized_example = reader.read(filename_queue)#return file_name and file    features = tf.parse_single_example(serialized_example,                                       features={                                           'label': tf.FixedLenFeature([], tf.int64),                                           'img_raw' : tf.FixedLenFeature([], tf.string),                                       })#return image and label    img = tf.decode_raw(features['img_raw'], tf.uint8)    img = tf.reshape(img, [208, 208, 3])  #reshape image to 512*80*3#    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #throw img tensor    label = tf.cast(features['label'], tf.int32) #throw label tensor    img_batch, label_batch = tf.train.shuffle_batch([img, label],                                                    batch_size= batch_size,                                                    num_threads=64,                                                    capacity=2000,                                                    min_after_dequeue=1500,                                                    )    return img_batch, tf.reshape(label_batch,[batch_size])#%%tfrecords_file = 'D://Anaconda3//spyder//Tensorflow_ReadData//train.tfrecords'BATCH_SIZE = 4image_batch, label_batch = read_and_decode(tfrecords_file, BATCH_SIZE)with tf.Session()  as sess:    i = 0    coord = tf.train.Coordinator()    threads = tf.train.start_queue_runners(coord=coord)    try:        while not coord.should_stop() and i<1:            # just plot one batch size            image, label = sess.run([image_batch, label_batch])            for j in np.arange(BATCH_SIZE):                print('label: %d' % label[j])                plt.imshow(image[j,:,:,:])                plt.show()            i+=1    except tf.errors.OutOfRangeError:        print('done!')    finally:        coord.request_stop()    coord.join(threads)

运行结果如下:
这里写图片描述

只需要Ctrl+C和V点击运行,就可以得到上面的结果了。

接下来还会讲解到另外两个深度学习框架Keras,pytorch如何将自己的数据转化为框架可以使用的格式,敬请期待吧!!!!

原创粉丝点击