1 TFRecord格式介绍
对于大量的图像数据,TensorFlow提供了一种统一的格式来存储数据——TFRecord。TFRecord文件是以二进制进行存储数据的,适合以串行的方式读取大批量数据,虽然它的内部格式复杂,但是它可以很好地利用内存,方便地复制和移动,更符合TensorFlow执行引擎的方式。
TFReocrd文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。
message Example{ Features features = 1;};message Features{ map<string,Feature> feature = 1;};message Feature{ oneof kind{ BytesList bytes_list = 1; FloatList float_list = 2; Int64List int64_list = 3; }};
tf.train.Example中包含了一个从属性名称到取值的字典。其中属性名称为一个字符串,属性的取值可以为字符串(BytesList)、实数列表(FloatList)或者整数列表(Int64List)。
2 图像数据集
本文采用的图像数据集来自stanford car dataset
将数据集的图片全部放入data文件夹下,label文件(我已改名为label.txt)放在与data文件夹同根目录下。
3 将图像转为TFRecord
from __future__ import absolute_import,division,print_functionimport numpy as npimport tensorflow as tfimport timefrom scipy.misc import imread,imresizefrom os import walkfrom os.path import joinDATA_DIR = 'data/'IMG_HEIGHT = 227IMG_WIDTH = 227IMG_CHANNELS = 3NUM_TRAIN = 7000NUM_VALIDARION = 1144def read_images(path): filenames = next(walk(path))[2] num_files = len(filenames) images = np.zeros((num_files,IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS),dtype=np.uint8) labels = np.zeros((num_files, ), dtype=np.uint8) f = open('label.txt') lines = f.readlines() for i,filename in enumerate(filenames): img = imread(join(path,filename)) img = imresize(img,(IMG_HEIGHT,IMG_WIDTH)) images[i] = img labels[i] = int(lines[i]) f.close() return images,labelsdef _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def convert(images,labels,name): num = images.shape[0] filename = name+'.tfrecords' print('Writting',filename) writer = tf.python_io.TFRecordWriter(filename) for i in range(num): img_raw = images[i].tostring() example = tf.train.Example(features=tf.train.Features(feature={ 'label': _int64_feature(int(labels[i])), 'image_raw': _bytes_feature(img_raw)})) writer.write(example.SerializeToString()) writer.close() print('Writting End')def main(argv): print('reading images begin') start_time = time.time() train_images,train_labels = read_images(DATA_DIR) duration = time.time() - start_time print("reading images end , cost %d sec" %duration) validation_images = train_images[:NUM_VALIDARION,:,:,:] validation_labels = train_labels[:NUM_VALIDARION] train_images = train_images[NUM_VALIDARION:,:,:,:] train_labels = train_labels[NUM_VALIDARION:] print('convert to tfrecords begin') start_time = time.time() convert(train_images,train_labels,'train') convert(validation_images,validation_labels,'validation') duration = time.time() - start_time print('convert to tfrecords end , cost %d sec' %duration)if __name__ == '__main__': tf.app.run()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
本文将数据集中的7000张用于训练,1144张用于验证。
4 读取TFRecord文件
from __future__ import absolute_import,division,print_functionimport numpy as npfrom os.path import joinimport tensorflow as tfimport convert_to_tfrecordsTRAIN_FILE = 'train.tfrecords'VALIDATION_FILE = 'validation.tfrecords'NUM_CLASSES = 196IMG_HEIGHT = convert_to_tfrecords.IMG_HEIGHTIMG_WIDTH = convert_to_tfrecords.IMG_WIDTHIMG_CHANNELS = convert_to_tfrecords.IMG_CHANNELSIMG_PIXELS = IMG_HEIGHT * IMG_WIDTH * IMG_CHANNELSNUM_TRAIN = convert_to_tfrecords.NUM_TRAINNUM_VALIDARION = convert_to_tfrecords.NUM_VALIDARIONdef read_and_decode(filename_queue): reader = tf.TFRecordReader() _,serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized_example,features={ 'label':tf.FixedLenFeature([],tf.int64), 'image_raw':tf.FixedLenFeature([],tf.string) }) image = tf.decode_raw(features['image_raw'],tf.uint8) label = tf.cast(features['label'],tf.int32) image.set_shape([IMG_PIXELS]) image = tf.reshape(image,[IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS]) image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 return image,labeldef inputs(data_set,batch_size,num_epochs): if not num_epochs: num_epochs = None if data_set == 'train': file = TRAIN_FILE else: file = VALIDATION_FILE with tf.name_scope('input') as scope: filename_queue = tf.train.string_input_producer([file], num_epochs=num_epochs) image,label = read_and_decode(filename_queue) images,labels = tf.train.shuffle_batch([image, label], batch_size=batch_size, num_threads=64, capacity=1000 + 3 * batch_size, min_after_dequeue=1000 ) return images,labels
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
读取一个batch的图像和label只需要调用inputs()函数就行了。
5 结果
结果生成了一个1GB的train.tfrecords和168MB的validation.tfrecords