将图像转为TFRecord文件并读取TFRecord文件

来源:互联网 发布:Linux解压并新建文件夹 编辑:程序博客网 时间:2024/06/06 20:11

1 TFRecord格式介绍

 对于大量的图像数据,TensorFlow提供了一种统一的格式来存储数据——TFRecord。TFRecord文件是以二进制进行存储数据的,适合以串行的方式读取大批量数据,虽然它的内部格式复杂,但是它可以很好地利用内存,方便地复制和移动,更符合TensorFlow执行引擎的方式。 
 TFReocrd文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。

message Example{    Features features = 1;};message Features{    map<string,Feature> feature = 1;};message Feature{    oneof kind{        BytesList bytes_list = 1;        FloatList float_list = 2;        Int64List int64_list = 3;    }};
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

 tf.train.Example中包含了一个从属性名称到取值的字典。其中属性名称为一个字符串,属性的取值可以为字符串(BytesList)、实数列表(FloatList)或者整数列表(Int64List)。

2 图像数据集

本文采用的图像数据集来自stanford car dataset 
将数据集的图片全部放入data文件夹下,label文件(我已改名为label.txt)放在与data文件夹同根目录下。

3 将图像转为TFRecord

# -*- coding = utf-8 -*-from __future__ import absolute_import,division,print_functionimport numpy as npimport tensorflow as tfimport timefrom scipy.misc import imread,imresizefrom os import  walkfrom os.path import join#图片存放位置DATA_DIR = 'data/'#图片信息IMG_HEIGHT = 227IMG_WIDTH = 227IMG_CHANNELS = 3NUM_TRAIN = 7000NUM_VALIDARION = 1144#读取图片def read_images(path):    filenames = next(walk(path))[2]    num_files = len(filenames)    images = np.zeros((num_files,IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS),dtype=np.uint8)    labels = np.zeros((num_files, ), dtype=np.uint8)    f = open('label.txt')    lines = f.readlines()    #遍历所有的图片和label,将图片resize到[227,227,3]    for i,filename in enumerate(filenames):        img = imread(join(path,filename))        img = imresize(img,(IMG_HEIGHT,IMG_WIDTH))        images[i] = img        labels[i] = int(lines[i])    f.close()    return images,labels#生成整数型的属性def _int64_feature(value):    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))#生成字符串型的属性def _bytes_feature(value):    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def convert(images,labels,name):    #获取要转换为TFRecord文件的图片数目    num = images.shape[0]    #输出TFRecord文件的文件名    filename = name+'.tfrecords'    print('Writting',filename)    #创建一个writer来写TFRecord文件    writer = tf.python_io.TFRecordWriter(filename)    for i in range(num):        #将图像矩阵转化为一个字符串        img_raw = images[i].tostring()        #将一个样例转化为Example Protocol Buffer,并将所有需要的信息写入数据结构        example = tf.train.Example(features=tf.train.Features(feature={            'label': _int64_feature(int(labels[i])),            'image_raw': _bytes_feature(img_raw)}))        #将example写入TFRecord文件        writer.write(example.SerializeToString())    writer.close()    print('Writting End')def main(argv):    print('reading images begin')    start_time = time.time()    train_images,train_labels = read_images(DATA_DIR)    duration = time.time() - start_time    print("reading images end , cost %d sec" %duration)    #get validation    validation_images = train_images[:NUM_VALIDARION,:,:,:]    validation_labels = train_labels[:NUM_VALIDARION]    train_images = train_images[NUM_VALIDARION:,:,:,:]    train_labels = train_labels[NUM_VALIDARION:]    #convert to tfrecords    print('convert to tfrecords begin')    start_time = time.time()    convert(train_images,train_labels,'train')    convert(validation_images,validation_labels,'validation')    duration = time.time() - start_time    print('convert to tfrecords end , cost %d sec' %duration)if __name__ == '__main__':    tf.app.run()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89

 本文将数据集中的7000张用于训练,1144张用于验证。

4 读取TFRecord文件

# -*- coding = utf-8 -*-from __future__ import absolute_import,division,print_functionimport numpy as npfrom os.path import joinimport tensorflow as tfimport convert_to_tfrecords#TFRcord文件TRAIN_FILE = 'train.tfrecords'VALIDATION_FILE = 'validation.tfrecords'#图片信息NUM_CLASSES = 196IMG_HEIGHT = convert_to_tfrecords.IMG_HEIGHTIMG_WIDTH = convert_to_tfrecords.IMG_WIDTHIMG_CHANNELS = convert_to_tfrecords.IMG_CHANNELSIMG_PIXELS = IMG_HEIGHT * IMG_WIDTH * IMG_CHANNELSNUM_TRAIN = convert_to_tfrecords.NUM_TRAINNUM_VALIDARION = convert_to_tfrecords.NUM_VALIDARIONdef read_and_decode(filename_queue):    #创建一个reader来读取TFRecord文件中的样例    reader = tf.TFRecordReader()    #从文件中读出一个样例    _,serialized_example = reader.read(filename_queue)    #解析读入的一个样例    features = tf.parse_single_example(serialized_example,features={        'label':tf.FixedLenFeature([],tf.int64),        'image_raw':tf.FixedLenFeature([],tf.string)        })    #将字符串解析成图像对应的像素数组    image = tf.decode_raw(features['image_raw'],tf.uint8)    label = tf.cast(features['label'],tf.int32)    image.set_shape([IMG_PIXELS])    image = tf.reshape(image,[IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS])    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5    return image,label#用于获取一个batch_size的图像和labeldef inputs(data_set,batch_size,num_epochs):    if not num_epochs:        num_epochs = None    if data_set == 'train':        file = TRAIN_FILE    else:        file = VALIDATION_FILE    with tf.name_scope('input') as scope:        filename_queue = tf.train.string_input_producer([file], num_epochs=num_epochs)    image,label = read_and_decode(filename_queue)    #随机获得batch_size大小的图像和label    images,labels = tf.train.shuffle_batch([image, label],         batch_size=batch_size,        num_threads=64,        capacity=1000 + 3 * batch_size,        min_after_dequeue=1000    )    return images,labels
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64

读取一个batch的图像和label只需要调用inputs()函数就行了。

5 结果

结果生成了一个1GB的train.tfrecords和168MB的validation.tfrecords

原创粉丝点击