TensorFlow读取tfrecords数据

来源:互联网 发布:mysql join 默认值 编辑:程序博客网 时间:2024/06/05 00:09

因为要用到TensorFlow,自然少不了数据的读取,这里我自己写了一个tfrecords的数据的读取函数

"""Created on Wed Jun 28 13:56:35 2017@author:liao"""import tensorflow as tfimport numpy as npfrom PIL import Imageimport pprint as ppimport os#cwd = os.getcwd()#file_path=[cwd+os.sep+'emotion-tfrecord'+os.sep+'emotion_image_data.tfrecords']#fileNameQue = tf.train.string_input_producer(file_path)#从文件名列表,生成文件名队列#reader = tf.TFRecordReader()#用tfrecords专用的reader去读#key,value = reader.read(fileNameQue)#features = tf.parse_single_example(value,features={ 'label': tf.FixedLenFeature([], tf.int64),#                                           'img' : tf.FixedLenFeature([], tf.string),})##img = tf.decode_raw(features["img"], tf.uint8)#image_raw=tf.reshape(img,[48,48])#如果后面要接卷积神经网络的话,就化成这样的形式#labels = tf.cast(features["label"], tf.int32)###使用样本队列,让本队列自动打包,发送#num_preprocess_threads = 16#image, label = tf.train.shuffle_batch([image_raw,labels], batch_size = batch_size, num_threads=num_preprocess_threads, min_after_dequeue = 1, capacity = 10)###init = tf.global_variables_initializer()#注意这里一定要有这初始化#with tf.Session() as sess:##    sess.run(init)#    #线程协调器,控制打包读取的线程#    coord = tf.train.Coordinator()#    #让读取队列开始工作#    threads = tf.train.start_queue_runners(coord=coord)#    #从队列读取图像和标签#    images,labels=sess.run([image,label])#    print(images,labels)###    coord.request_stop()#    coord.join(threads)def read_tfrecords(file_path,batch_size):    fileNameQue = tf.train.string_input_producer([file_path])    reader = tf.TFRecordReader()#用tfrecords专用的reader去读    key,value = reader.read(fileNameQue)        features = tf.parse_single_example(value,features={ 'label': tf.FixedLenFeature([], tf.int64),                                           'img' : tf.FixedLenFeature([], tf.string),})    img = tf.decode_raw(features["img"], tf.uint8)    image_raw=tf.reshape(img,[1,48,48])#如果后面要接卷积神经网络的话,就化成这样的形式[depth,height,width]    image_raw = tf.transpose(image_raw, (1,2,0))#     image_raw = tf.cast(image_raw, tf.float32)#    归一化操作    image_raw= tf.image.per_image_standardization(image_raw)#        labels = tf.cast(features["label"], tf.int32)    #使用样本队列,让本队列自动打包,发送    num_preprocess_threads = 16    image, label = tf.train.shuffle_batch([image_raw,labels], batch_size = batch_size, num_threads=num_preprocess_threads, min_after_dequeue = 1, capacity = 1024)    return image, tf.reshape(label, [batch_size])
屏蔽掉的部分是测试程序。有关数TensorFlow的数据的读取过程可以参考点击打开链接