tensorflow框架输入数据格式

来源:互联网 发布:经传证券炒股软件 编辑:程序博客网 时间:2024/05/19 03:45

关于TensorFlow的输入数据格式已经研究了几天了,官方的格式是TFRecords格式,但是依然很困惑,仍然不知道用于自己的数据怎么搞。

下面总结一下使用自己的图片的数据格式代码。

1. 写入TFRecord文件

首先,将自己的图片统一格式,写入到 TFRecord 文件中。其中images保存图片的路径,label为相应图片的标签。代码如下:

#coding=utf-8import tensorflow as tfimport cv2filename = "train.tfrecords" #要存入的文件writer = tf.python_io.TFRecordWriter(filename)for i in range(len(images)):    image = tf.gfile.FastGFile(images[i], 'rb').read() #image type is string    label = label[i]    image_shape = cv2.imread(images[i]).shape    width = image_shape[1]    height = image_shape[0]    example = tf.train.Example(        features = tf.train.Features(feature = {            'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image])),            'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [label])),            'height': tf.train.Feature(int64_list = tf.train.Int64List(value = [height])),            'width': tf.train.Feature(int64_list = tf.train.Int64List(value = [width])),        })    )    #将Example写入TFRecord文件    writer.write(example.SerializeToString())


以上是将所有图片数据存储到一个 TFRecord文件中,当数据量比较大时,也可以将图片数据写入到多个TFRecord文件中。

2. 读取TFRecord文件

首先读取TFRecord文件,对读取到TFRecords文件进行解码,根据保存的serialize example和feature字典返回feature所对应的值。此时获得的值都是string,

需要进一步解码为所需的数据类型。下面是读取文件的代码。

import tensorflow as tfimport matplotlib.pyplot as pltfilename = "train.tfrecords"filename_queue = tf.train.string_input_producer([filename], shuffle = False)reader = tf.TFRecordReader()# read in serialized example data_, serialized_example = reader.read(filename_queue)features = tf.parse_single_example(    serialized_example,    features = {        'image' : tf.FixedLenFeature([], tf.string),        'label' : tf.FixedLenFeature([], tf.int64),        'height' : tf.FixedLenFeature([], tf.int64),        'width' : tf.FixedLenFeature([], tf.int64)    })image =  tf.image.decode_jpeg(features['image'])image = tf.image.convert_image_dtype(image, dtype=tf.float32)label = tf.cast(features['label'], tf.int64)height = tf.cast(features['height'], tf.int32)width = tf.cast(features['width'], tf.int32)image = tf.reshape(image, [height, width, 3])image = tf.image.resize_images(image, [512, 512])image_batch, label_batch = tf.train.shuffle_batch(    [image, label], batch_size = 10, capacity = 10, min_after_dequeue = 10)with tf.Session() as sess:      tf.initialize_all_variables().run()    coord = tf.train.Coordinator()    threads = tf.train.start_queue_runners(sess = sess, coord=coord)
     for i in range(TRAINING_ROUDNS):        ...        coord.request_stop()    coord.join(threads)    sess.close()




原创粉丝点击