Tensorflow 构建 TFrecords

来源:互联网 发布:淘宝号的名字可以改吗 编辑:程序博客网 时间:2024/06/05 05:14

构建数据集 TFrecords

这时,我们需要构建 TFrecords

当数据量很大时,一般将数据定义为 TensorFlow 内定标准格式——TFRecords,TFRecords是一种将图像数据和标签放在一起的二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件

TFRecords 文件中的数据都是通过 tf.train.Example Protocol Buffer 的格式存储的。

message Example {      Features features = 1;  };  message Features {      map<string, Feature> feature = 1;  };  message Feature {      oneof kind {      BytesList bytes_list = 1;      FloatList float_list = 2;      Int64List int64_list = 3;  }  };  
#PIL python图像处理标准库$ sudo apt-get install python-imaging$ sudo gedit TFrecords.py

生成 TFrecords

关于相关API定义请官网查询

#导入必要的包import os import tensorflow as tf from PIL import Image  #注意Image,后面会用到import numpy as np#当前工作目录cwd = os.getcwd()+"/dog/"classes={'jiwawa','hashiqi'}#图片最终生成的tfrecordswriter = tf.python_io.TFRecordWriter("dog_train.tfrecords")for index, name in enumerate(classes):    class_path = cwd + name + "/"    for img_name in os.listdir(class_path):        img_path = class_path + img_name        img = Image.open(img_path)        img = img.resize((60, 60))        img_raw = img.tobytes()              ##将图片转化为二进制格式        example = tf.train.Example(features=tf.train.Features(feature={            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))        }))#example对象对label和image数据进行封装        writer.write(example.SerializeToString())  #序列化为字符串writer.close()
$ cd /home/shiyanlou/tensorflow$ python myTFrecords.py

运行 myTFrecords.py 会在当前目录下生成 dog_train.tfrecords,如果没有生成,请检查是否操作有误。

有些时候我们希望检查分类是否有误,或者在之后的网络训练过程中可以监视,输出图片,来观察分类等操作的结果,那么我们就可以session回话中,将tfrecord的图片从流中读取出来,再保存。

$ cd /home/shiyanlou/tensorflow$ sudo gedit myTestTF.py
# -*- coding: UTF-8 -*-  import os import tensorflow as tf from PIL import Image  #注意Image,后面会用到import numpy as np#当前工作目录cwd = os.getcwd()+"/dog/"filename_queue = tf.train.string_input_producer(["dog_train.tfrecords"]) #读入流中reader = tf.TFRecordReader()_, serialized_example = reader.read(filename_queue)   #返回文件名和文件features = tf.parse_single_example(serialized_example,                                   features={                                       'label': tf.FixedLenFeature([], tf.int64),                                       'img_raw' : tf.FixedLenFeature([], tf.string),                                   })  #取出包含image和label的feature对象image = tf.decode_raw(features['img_raw'], tf.uint8)image = tf.reshape(image, [60, 60, 3])label = tf.cast(features['label'], tf.int32)with tf.Session() as sess: #开始一个会话    init_op = tf.initialize_all_variables()    sess.run(init_op)    coord=tf.train.Coordinator()    threads= tf.train.start_queue_runners(coord=coord)    for i in range(20):        example, l = sess.run([image,label])#在会话中取出image和label        img=Image.fromarray(example, 'RGB')#这里Image是之前提到的        img.save(cwd+str(i)+'_''Label_'+str(l)+'.png')#存下图片        print(example, l)    coord.request_stop()    coord.join(threads)
$ python myTestTF.py

我们可以看到如下结果:

接下来打开终端,我们将读取 dog_train.tfrecords 的代码写到 myCNN.py 中:

$ cd /home/shiyanlou/tensorflow$ sudo gedit myCNN.py
def read_and_decode(filename): # 读入dog_train.tfrecords    filename_queue = tf.train.string_input_producer([filename])#生成一个queue队列    reader = tf.TFRecordReader()    _, serialized_example = reader.read(filename_queue)#返回文件名和文件    features = tf.parse_single_example(serialized_example,                                       features={                                           'label': tf.FixedLenFeature([], tf.int64),                                           'img_raw' : tf.FixedLenFeature([], tf.string),                                       })#将image数据和label取出来    img = tf.decode_raw(features['img_raw'], tf.uint8)    img = tf.reshape(img, [60, 60, 3])  #reshape为60*60的3通道图片    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #在流中抛出img张量    label = tf.cast(features['label'], tf.int32) #在流中抛出label张量    return img, label

这样实验所需的数据会在函数调用时被导入。

原创粉丝点击