13、TensorFlow 数据读取
来源:互联网 发布:国家电网总部 知乎 编辑:程序博客网 时间:2024/05/16 15:15
- 一使用 placeholder feed_dict 传入数据
- 二使用 TFRecords 统一输入数据的格式
- 0TFRecords 数据格式的优缺点
- 1将数据转换为 tfrecords 文件
- a获得图片的保存路径和标签
- b指定编码函数
- c将图片数据和标签或其它需要需要保存的数据都转成 TFRecods 格式
- 2读取并解码 tfrecords 文件并生成 batch
- a指定想要读取的 tfrecords 文件列表
- b创建一个输入文件名队列来维护输入文件列表
- c读取并解码
- 3将 batch 数据喂入计算图并开始训练验证测试等
- 三参考资料
一、使用 placeholder + feed_dict 传入数据
placeholder 是 Tensorflow 中的占位符,必须要指定将传给该占位符的值的数据类型 dtype ,一般为
tf.float32
形式;然后通过 sess.run() 的可选参数 feed_dict 为给占位符喂入实际的数据.eg: sess.run(***, feed_dict={input: **})
input = tf.placeholder(tf.float32, shape=[2], name="my_input")
- dtype:指定了将传给该占位符的值的数据类型。该参数是必须指定的,因为需要确保不出现类型不匹配的错误
- shape:指定了所要传入的 Tensor 对象的形状,shape 参数的默认值为
None
,表示可接收任意形状的Tensor对象
- name:与任何 op 一样,也可在 tf.placeholder 中指定一个 name 标识符
input1 = tf.placeholder(tf.float32)input2 = tf.placeholder(tf.float32)output = tf.add(input1, input2)with tf.Session() as sess: print sess.run([output], feed_dict={input1:[7.], input2:[2.]})>>> [array([ 9.], dtype=float32)]
Note:在 shape 的一个维度上使用 None 可以方便的使用不同 batch 的大小。在训练时,把数据分成比较小的 batch,但在测试时,可以一次使用全部的数据。但要注意,当数据集比较大时,将大量数据放入一个 batch 可能导致内存溢出
二、使用 TFRecords 统一输入数据的格式
0、TFRecords 数据格式的优缺点
- 优点:
- 可以统一不同的原始数据格式
- 更加有效的管理不同的属性、更好的利用内存、更方便的复制和移动
- 缺点:
- 转换过后 tfrecords 文件会占用较大内存
1、将数据转换为 .tfrecords 文件
a、获得图片的保存路径和标签
# 获得图片的保存路径和标签,以便后面的读取和转换def get_file(file_dir): '''Get full image directory and corresponding labels Args: file_dir: file directory Returns: images: image directories, list, string labels: label, list, int '''
b、指定编码函数
tf.train.Example
的数据结构中包含了一个从属性到取值的字典。
属性名称(feature name)
为一个字符串属性的取值(feature value)
可以为字符串列表(BytesList)、实数列表(FloatList)或者整数列表(Int64List)通过以下函数编码为Example proto
形式的返回值
# Wrapper for inserting int64 features into Example protodef _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))# Wrapper for inserting bytes features into Example protodef _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
c、将图片数据和标签(或其它需要需要保存的数据)都转成 TFRecods 格式
- 指定转换数据格式后的保存路径和文件名称
- 创建一个实例对象 writer,用于后面序列化数据的写入
- 将所有数据按照
tf.train.Example Protocol Buffer
的格式存储
- 取得图片的样本总数
- 循环读取图片和标签的内容:将图片内容转换为字符串型,当有多个标签时,应将多标签内容也转换为字符串型
- 使用编码函数将一个样例的所有数据(图片和标签内容等)转换为
Example Protocol Buffer
- 调用实例对象 writer 的 write 方法将
序列化后的 Example Protocol Buffer
写入 TFRecords 文件- 当所有样本数据都转换完毕时,调用实例对象 writer 的 close 方法结束写入过程
import tensorflow as tfimport numpy as npimport osimport skimage.io as io# 将图片数据和标签(或者其它需要需要保存的数据)都转成 TFRecods 格式的数据def convert_to_tfrecord(images, labels, save_dir, name): '''convert all images and labels to one tfrecord file. Args: images: list of image directories, string type labels: list of labels, int type save_dir: the directory to save tfrecord file, e.g.: '/home/folder1/' name: the name of tfrecord file, string type, e.g.: 'train' Return: no return ''' # 指定数据转换格式后的保存路径和名称 filename = os.path.join(save_dir, name + '.tfrecords') # 创建一个实例对象 writer,用于后面序列化数据的写入 writer = tf.python_io.TFRecordWriter(filename) # 取得图片的样本总数 n_samples = len(labels) print('\nTransform start......') # 将所有数据(包括标签等)按照 tf.train.Example Protocol Buffer 的格式存储 for i in np.arange(n_samples): try: image = io.imread(images[i]) # read a image, returned image type must be array! image_raw = image.tostring() # 将图片矩阵转化为字符串,tobytes同理 label = int(labels[i]) # 当单个label为字符串时,需要将其转换为int型 # 创建tf.train.Example 协议内存块,把标签、图片数据作为特定字段存入(数据类型转换) example = tf.train.Example(features=tf.train.Features(feature={ 'label': _int64_feature(label), 'image_raw': _bytes_feature(image_raw)})) # 调用实例对象 writer 的 write 方法将序列化后的 example 协议内存块写入 TFRecord 文件 writer.write(example.SerializeToString()) # 跳过不能读取的图片 except IOError as e: print('Could not read:', images[i]) print('error: %s' % e) print('Skip it!\n') # 调用实例对象 writer 的 close 方法结束写入过程 writer.close() print('Transform done!')
2、读取并解码 .tfrecords 文件并生成 batch
A typical pipeline for reading records from files has the following stages:
- The list of filenames
- Filename queue
- Optional filename shuffling
- Optional epoch limit
- A Reader for the file format
- A decoder for a record read by the reader
- Optional preprocessing
- Example queue
a、指定想要读取的 .tfrecords 文件列表
# 直接指定文件列表filenames = ['/path/to/train_dataset1.tfrecords', '/path/to/train_dataset2.tfrecords']# 通过 tf.train.match_filenames_once 函数获取文件列表filenames = tf.train.match_filenames_once(os.path.join(FLAGS.data_dir, 'train_*.tfrecords'))# 通过 python 中的 glob 模块获取文件列表filenames = glob.glob(os.path.join(FLAGS.data_dir, 'train_*.tfrecords'))
b、创建一个输入文件名队列来维护输入文件列表
- 通过
tf.train.string_input_producer(filenames, shuffle=True, num_epochs=None)
函数来产生输入文件名队列 - 可参考 十图详解tensorflow数据读取机制 进行理解,如下图所示,当系统检测到了“结束”,就会自动抛出一个异常(OutOfRange)外部捕捉到这个异常后就可以结束程序了,不过个人理解这里A、B、C 应该为
.tfrecords
格式的文件,即类似上面filenames
中的内容
tf.train.string_input_producer( string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None, cancel_op=None)# 参数string_tensor: A 1-D string tensor with the strings to produce, 如上面的filenamesnum_epochs: An integer (optional). If specified, string_input_producer produces each string from string_tensor num_epochs times before generating an OutOfRange error. If not specified, string_input_producer can cycle through the strings in string_tensor an unlimited number of times.shuffle: Boolean. If true, the strings are randomly shuffled within each epoch.capacity: An integer. Sets the queue capacity.# 返回值A queue with the output strings. A QueueRunner for the Queue is added to the current Graph's QUEUE_RUNNER collection.
c、读取并解码
- 创建一个实例对象 reader,用于读取
.tfrecords
中的样例 - 调用实例对象 reader 的 read 方法,读取
文件名队列
中的一个样例,得到文件名和序列化的 Example Protocol Buffer
- 按照字段格式,使用
tf.parse_single_example()
解码器对上述序列化的 Example Protocol Buffer
的一个样例进行解码,返回一个dict
(mapping feature keys to Tensor and SparseTensor values) - 通过
tf.decode_raw()
函数将字符串解析成图像对应的像素数组、tf.cast()
函数转换标签的数据类型 - 图像预处理
- 构造批处理器
tf.train.shuffle_batch
,来产生一个批次的数据,用于神经网络的输入
def read_and_decode(filenames, batch_size, num_epochs=None): '''read and decode tfrecord file, generate (image, label) batches Args: filenames: the directory of tfrecord filenames, list batch_size: number of images in each batch num_epochs: None, cycle through the strings in string_tensor an unlimited number of times Returns: image: 4D tensor - [batch_size, width, height, channel] label: 1D tensor - [batch_size] ''' # Creates a FIFO queue for holding the filenames until the reader needs them filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True) # 创建一个实例对象 reader, 用于读取 TFRecord 中的样例 reader = tf.TFRecordReader() # 调用实例对象 reader 的 read 方法,读取文件名队列中的一个样例,得到文件名和序列化的协议内存块 _, serialized_example = reader.read(filename_queue) # 按照字段格式,解析读入的一个样例(序列化的协议内存块) img_features = tf.parse_single_example( serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'image_raw': tf.FixedLenFeature([], tf.string), }) # 将字符串解析成图像对应的像素数组 Tensor("DecodeRaw:0", shape=(?,), dtype=uint8) # 注意:转成字符串之前是什么类型的数据,那么这里的参数就要填成对应的类型,否则会报错 image = tf.decode_raw(img_features['image_raw'], tf.uint8) # Tensor("Cast:0", shape=(), dtype=int32) label = tf.cast(img_features['label'], tf.int32) ################***** Preprocessing *****#################### # 图像预处理(resize, reshape, crop, flip, distortion, per_image_standardization ......) image.set_shape([FLAGS.height, FLAGS.width, FLAGS.depth]) # 将图片内容转换成多维数组形式 image = tf.image.resize_images(image, [48, 160]) # 统一图片的尺寸 ... ... ... ############***** 构造批处理器,来产生一个批次的数据 *****############## # num_threads:可以指定多个线程同时执行入队操作(数据读取和预处理),通过队列实现多线程处理机制 # capacity: 队列中最多可以存储的样例个数 # min_after_dequeue:限制了出队时队列中元素的最少个数,从而保证随机打乱顺序的作用 image_batch, label_batch = tf.train.shuffle_batch([image, label], batch_size=batch_size, num_threads=16, capacity=min_queue_examples + 3 * batch_size, min_after_dequeue = min_queue_examples) return image_batch, tf.reshape(label_batch, [batch_size])
3、将 batch 数据喂入计算图并开始训练、验证、测试等
filenames = tf.train.match_filenames_once(os.path.join(FLAGS.data_dir, 'train_*.tfrecords'))image_batch, label_batch = read_and_decode(filenames, batch_size=BATCH_SIZE)# tf.train.string_input_producer() 定义了一个局部变量 num_epochs,所以使用前要对其初始化init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())with tf.Session() as sess: sess.run(init) # 声明一个 tf.train.Coordinator() 对象来协同多个线程的工作 coord = tf.train.Coordinator() # 使用 tf.train.start_queue_runners() 之后,才会开始填充队列 threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: # 运行 FLAGS.iteration 个 batch for itr in range(FLAGS.iteration): # just plot one batch size image, label = sess.run([image_batch, label_batch]) plot_images(image, label) except tf.errors.OutOfRangeError: print('Done training -- epoch limit reached') finally: coord.request_stop() # 通知其它线程退出,同时 corrd.should_stop()被设置成 True # 等待所有的线程退出 coord.join(threads)
三、参考资料
1、https://www.tensorflow.org/versions/r1.2/programmers_guide/reading_data
2、tensorflow/examples/how_tos/reading_data
3、十图详解tensorflow数据读取机制(附代码)
阅读全文
0 0
- 13、TensorFlow 数据读取
- tensorflow爬坑行:数据读取
- Tensorflow图片数据读取
- Tensorflow读取数据
- Tensorflow读取数据
- tensorflow读取文件数据
- Tensorflow读取数据1
- Tensorflow数据读取方式
- Tensorflow图片数据读取
- tensorflow图片数据读取
- Tensorflow数据读取方法
- TensorFlow读取tfrecords数据
- tensorflow读取数据
- TensorFlow数据读取
- tensorflow读取数据
- tensorflow 数据读取笔记
- TensorFlow数据读取
- TensorFlow读取数据
- centos7 firewall操作指南
- 自定义控件的演示
- JAVA | 26
- 【转载】code review几处小问题集锦
- linux常用操作记录
- 13、TensorFlow 数据读取
- Ionic3项目webapp 怎么做
- iOS 解决UITableView最后一个cell不显示分割线问题
- 抽象模板-计算程序的执行时间
- html网页渲染的基本过程
- Spring
- 新手安卓开发小知识
- 文章标题
- css3 加载瀑布流 column-coun属性