Tensorflow读取数据2-tfrecord
来源:互联网 发布:西部数码域名管理 编辑:程序博客网 时间:2024/05/16 05:11
原文地址: http://blog.csdn.net/u010911921/article/details/70991194
上篇博客谈到了Tensorflow从文件中读取数据,当时采用的是CIFAR-10中的二进制数据,这次记录一下官网推荐的比较通用和高效的数据文件类型的读取——TFRecord文件,这是tensorflow指定的标准格式。
1.TFRecords
TFRecords本质上是一种二进制文件,他的优点是可以更好的利用内存空间,缺点是生成过程比较耗费时间,特别是数据量比较大的情况下。文件包含了一个tf.train.Example
的缓冲协议(protocol buffer)其中协议块中包含了字段Features
.当用程序获得数据以后,就可以将其填充到Example
的协议缓冲区(protocol buffer)中,然后在将协议缓冲区序列化为字符串,最后通过tf.python_io.TFRecordWriter
将字符串写入文件。
当从TFRecords文件中读取数据时,可以利用tf.TFRecordReader
和tf.parse_single_example
解码器,将Example
缓冲协议中的内容解析为Tensor
张量
2.notMNIST 数据集
在实验中采用的数据集合时notMNIST数据集,这个数据集合是由一些各种形态的字母组成的数据集合,总共由a~j
10个字母组成,下图是a
对应的一些图片:
另外需要注意的是,下载的数据集中有几张图片有损坏,所以处理的时候注意跳过。
3.生成TFRecords文件
为了生成TFRecords文件首先是从数据集中,将图片路径放置到一个image_list
,样本的标签放置到一个label_list
中。
#!/usr/bin/env python3# --*-- encoding:utf-8 --*--import tensorflow as tfimport numpy as npimport osimport matplotlib.pyplot as pltimport skimage.io as iodef get_file(file_dir): """ get full image directory and correspond labels :param file_dir: :return: """ images =[] temp =[] for root ,sub_folders,files in os.walk(file_dir): #image directories for name in files: images.append(os.path.join(root,name)) #get 10 sub-folder names for name in sub_folders: temp.append(os.path.join(root,name)) labels =[] for one_folder in temp: n_img = len(os.listdir(one_folder)) letter = one_folder.split('/')[-1] if letter =='A': labels = np.append(labels,n_img*[1]) elif letter =="B": labels = np.append(labels,n_img*[2]) elif letter =='C': labels = np.append(labels,n_img*[3]) elif letter =="D": labels = np.append(labels,n_img*[4]) elif letter =="E": labels = np.append(labels,n_img*[5]) elif letter =="F": labels = np.append(labels,n_img*[6]) elif letter =="G": labels = np.append(labels,n_img*[7]) elif letter =="H": labels = np.append(labels,n_img*[8]) elif letter =="I": labels =np.append(labels,n_img*[9]) else: labels = np.append(labels,n_img*[10]) #shuffle temp = np.array([images,labels]) temp = temp.transpose() np.random.shuffle(temp) image_list = list(temp[:,0]) label_list = list(temp[:,1]) label_list = [int(float(i)) for i in label_list] return image_list,label_list
当取得image_list
和label_list
以后,读取图片数据,然后利用tf.train.Example
和tf.train.Features
这两个函数来构建一个example
然后将其序列化到文件中。基本上就是一个Example
中包含Features
,Features
中包含Feature
字典,Feature
字典是由float_list
、bytes_List
或int64_list
等构成。
#将label转换成int64类型,为了构建tf.train.Featuredef int64_feature(value): if not isinstance(value,list): value = [value] return tf.train.Feature(int64_list = tf.train.Int64List(value=value))#将image转换成bytes类型,同样也是为了构建tf.train.Featuredef bytes_feature(value): return tf.train.Feature(bytes_list= tf.train.BytesList(value=[value]))def convert_to_tfrecord(images,labels,save_dir,name): """ convert all images and labels to one tfrecord file :param images: :param labels: :param save_dir: :param name: :return: """ filename = os.path.join(save_dir,name+".tfrecords") n_samples = len(labels) if np.shape(images)[0] != n_samples: raise ValueError('Image size %d does not ' 'match label size %d'%(images.shape[0],n_samples)) #wait some time writer = tf.python_io.TFRecordWriter(filename) print("\n Transform start....") for i in np.arange(0,n_samples): try: image = io.imread(images[i]) image_raw = image.tostring() label= int(labels[i]) example = tf.train.Example(features =tf.train.Features(feature={'label':int64_feature(label), "image_raw":bytes_feature(image_raw)})) writer.write(example.SerializeToString()) except IOError as e: print("could not read :",images[i]) print("error:%s"%e) print('Skip it') writer.close() print("Transform done!")
这样就完成了TFRecord的生成,但是这个过程会花费较长的时间。
4.TFRecords解码
读取一个文件还是采用上一篇博客中的queue的形式来读取,首先是生成一个文件名层的队列,然后利用tf.TFRecordReader()
产生的reader
来读取,然后将其读取到的内容,用tf.parse_single_example
函数将label
和image_raw
读取以及分离出来,为后续操作做准备
def read_and_decode(tfrecords_file,batch_size): filename_queue = tf.train.string_input_producer([tfrecords_file]) reader = tf.TFRecordReader() _,serialized_example = reader.read(filename_queue) img_features = tf.parse_single_example(serialized_example, features={"label":tf.FixedLenFeature([],tf.int64), "image_raw":tf.FixedLenFeature([],tf.string),}) image = tf.decode_raw(img_features['image_raw'],tf.uint8) ################################################################ # #put dataaugmentation here ################################################################ image = tf.reshape(image,[28,28]) label = tf.cast(img_features['label'],tf.int32) image_batch, label_batch = tf.train.batch([image,label], batch_size = batch_size, num_threads = 64, capacity=2000) return image_batch,tf.reshape(label_batch,[batch_size])
解码以后的后续过程和采用queue处理二进制文件相似。
全部代码下载地址:https://github.com/ZhichengHuang/LearnTensorflowCode/blob/master/TFRecords/TFRecord_input.py
参考资料
- https://www.tensorflow.org/programmers_guide/reading_data
- http://stackoverflow.com/questions/33849617/how-do-i-convert-a-directory-of-jpeg-images-to-tfrecords-file-in-tensorflow
- https://github.com/kevin28520
- Tensorflow读取数据2-tfrecord
- tensorflow读取数据-tfrecord格式
- tensorflow读取数据-tfrecord格式
- TensorFlow高效读取数据——TFRecord
- Tensorflow中使用tfrecord方式读取数据
- 第五课 Tensorflow TFRecord读取数据
- Tensorflow-tfrecord数据
- Tensorflow使用TFRecord构建自己的数据集并读取
- TensorFlow学习记录-- 7.TensorFlow高效读取数据之tfrecord详细解读
- TensorFlow 学习(二) 制作自己的TFRecord数据集,读取,显示及代码详解
- TensorFlow学习笔记(二十四)自制TFRecord数据集 读取、显示及代码详解
- TensorFlow 制作自己的TFRecord数据集
- TensorFlow 制作自己的TFRecord数据集
- tensorflow制作数据集之TFRecord
- TensorFlow TFRecord
- tensorflow中的TFRecord格式文件的写入和读取
- Tensorflow 训练自己的数据集(二)(TFRecord)
- TFRecord —— tensorflow 下的统一数据存储格式
- ASP.NET CORE基础教程(一)-启动文件Startup
- 高性能MYSQL读书笔记1
- Cocos2d-X官方Demo---1.ActionManager
- 【python】解析中英文进阶
- JS跨域解决方案
- Tensorflow读取数据2-tfrecord
- Nessus漏洞扫描工具安装
- 二叉树的层次遍历
- kaggle入门篇二【Titanic】
- 【bzoj1677】[Usaco2005 Jan]Sumsets 求和
- 五.SpringMVC+MyBatis搭建安全与性能
- 116. Populating Next Right Pointers in Each Node
- IntelliJ IDEA 注册
- Command "python setup.py egg_info" failed with error code 1 in /tmp/pip-build-o2julgbe/xgboost/