读tfrecords文件,一个一个读/按批次读

来源:互联网 发布:指纹软件锁 编辑:程序博客网 时间:2024/05/17 02:38
import tensorflow as tfimport matplotlib.pyplot as plttfrecords_file = '/home/lw/workspace/MicrovideoLSTM/tfrecordData/videoframe.tfrecords'filename_queue = tf.train.string_input_producer([tfrecords_file])  # 根据文件名生成一个队列reader = tf.TFRecordReader()                                       # TFRecordReader 用于读取 TFReacord_, serialized_example = reader.read(filename_queue)                # 返回文件名和文件features = tf.parse_single_example(serialized_example,                                   features={                                       'label': tf.FixedLenFeature([], tf.int64),                                       'image_raw' : tf.FixedLenFeature([], tf.string),                                   })                              # 取出包含image和label的feature对象..image = tf.decode_raw(features['image_raw'], tf.uint8)             # 要结合自己的数据大小来选择tf.uint8,tf.int32image = tf.reshape(image, [20,20])   # image = tf.reshape(image, [128, 128, 3]) ]要与具体的图像大小保持一致,取灰度图与彩色图label = tf.cast(features['label'], tf.int32)                       # 读取出标签数据image_batch, label_batch = tf.train.shuffle_batch([image, label],                                            batch_size=10,                                             capacity=2000,                                            min_after_dequeue=1000)# 生成批次with tf.Session() as sess:                                         # 开始一个会话    init_op = tf.global_variables_initializer()    sess.run(init_op)        coord=tf.train.Coordinator()                                   # 设置多线程协调器    threads= tf.train.start_queue_runners(sess = sess, coord=coord)# 开始 Queue Runners (队列运行器)        for i in range(2):        #example, l = sess.run([image,label])              # 在会话中取出image和label              在队列中一个一个取        example, l = sess.run([image_batch,label_batch])   # 在会话中取出image_batch和label_batch  在队列中按批次取,维度不同        print example.shape        print l.shape        print l        # plt.imshow(example)                              # 显示单张图像         plt.imshow(example[i,:,:])                         # 在批次里面显示单张图像        plt.show()    coord.request_stop()    coord.join(threads)

阅读全文
0 0
原创粉丝点击