Tensorflow读取数据1
来源:互联网 发布:淘宝类目销售比列 编辑:程序博客网 时间:2024/05/18 20:07
原文地址: http://blog.csdn.net/u010911921/article/details/70577697
这段一直在用Tensorflow来做深度学习上的相关工作,然后对Tensorflow读取数据的方式进行实现。特地总结一下。首先是读取二进制图片数据,这里采用的是CIFAR-10的二进制数据
## 1.CIFAR-10数据集CIFAR-10数据集合是包含60000张`32*32*3`的图片,其中每个类包含6000张图片,总共10类。在这60000张图片中50000张是训练集合,10000张是测试集合。其中二进制的图片保存的格式如下所示:
2.Tensorflow读取数据
从Tensorflow的官网可以看到从文件中读取数据的流程主要是一下步骤:
- The list of filenames
- (Optional) filename shuffling
- (Optional) epoch limit
- Filename queue
- A Reader for the file format
- A decoder for a record read by the reader
- (Optional) preprocessing
- Example queue
按照这样一个流程,首选应该将CIFAR-10的训练集和测试集合,生成文件名列表,然后在讲这个文件名列表传递给tf.train.string_input_producer
函数创建一个用于保存文件名称的FIFO的队列,最后用tensor flow产生的reader
从队列中读取数据。当reader
读到数据就需要用tf.decode_raw
函数对读取到的二进制数进行解码。
结束了上述操作,下面就需要采用另一个queue去batch together examples来为训练和测试提供数据。采用tf.train.shuffle_batch
将上面生成的image
和label
传入函数即可完成。
3.开始训练
当tf.train.shuffle_batch
生成batch以后就开始利用tf.train.start_queue_runners
函数启动队列,然后开始整个计算图,官网给的建议是如下形式:
init_op = tf.global_variables_initializer()with tf.Session as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess= sess,coord = coord) try: while not coord.should_stop(): #run training steps or whatever sess.run(train_op) except tf.errors.OutOfRangeError: print('Done training --epoch limit reached') finally: # when done,ask the threads to stop coord.request_stop() coord.join(threads)
4.代码实现
在神经网络的训练中由于每训练k步以后就会对网络进行一次测试,所以需要在上述步骤中,增加动态选择文件名称队列这样一个过程,可以由tf.QueueBase.from_list
函数进行实现,然后reader
从返回的文件名称队列中读取数据。
整个过程的实现如下所示:
#!/usr/bin/env python3# --*-- encoding:utf-8 --*--import tensorflow as tfimport numpy as npimport osdef read_cifar10(data_dir,is_traing,batch_size,shuffle): """ :param data_dir:数据保存路径 :param is_traing:True从训练集获取数据,False从测试集获取数据 :param batch_size: batch_size的大小 :param shuffle: bool,是否进行shuffle操作 :return: """ img_width = 32 img_height = 32 img_depth = 3 label_bytes = 1 img_bytes = img_height * img_width *img_depth with tf.name_scope("input") as scope: #训练集合的文件列表 train_filenames = [os.path.join(data_dir, 'data_batch_%d.bin'%ii) for ii in np.arange(1,6)] #测试集合的文件列表 val_filenames = [os.path.join(data_dir,'test_batch.bin')] #训练集和测试集合的文件名称队列 train_queue = tf.train.string_input_producer(train_filenames) val_queue = tf.train.string_input_producer(val_filenames) #挑选文件队列,实现training的过程中测试 queue_select = tf.cond(is_traing, lambda :tf.constant(0), lambda :tf.constant(1) ) queue = tf.QueueBase.from_list(queue_select,[train_queue,val_queue]) #从队列中读取固定长度的数据 reader = tf.FixedLengthRecordReader(label_bytes+img_bytes) key,value = reader.read(queue) recode_bytes = tf.decode_raw(value,tf.uint8) #获取label label = tf.slice(recode_bytes,[0],[label_bytes]) label = tf.cast(label,tf.int32) #获取image image_raw = tf.slice(recode_bytes,[label_bytes],[img_bytes]) image_raw = tf.reshape(image_raw,[img_depth, img_height, img_width]) image = tf.transpose(image_raw,[1,2,0]) image = tf.cast(image,tf.float32) #对每一张图片进行标准化操作,可选操作此处可以进行对图片的各种操作 image = tf.image.per_image_standardization(image) if shuffle: images, label_batch= tf.train.shuffle_batch([image,label], batch_size=batch_size, num_threads=16, capacity=512+3*batch_size, min_after_dequeue=512, allow_smaller_final_batch=True) else: images, label_batch = tf.train.batch([image, label], batch_size=batch_size, num_threads=16, capacity=512 + 3*batch_size, allow_smaller_final_batch=True) label_batch = tf.cast(label_batch,tf.int32) return images,label_batch
整个过程是采用VGG-16的网络模型进行训练的,在迭代16000次,tensorboard展示的结果如图所示:
code下载地址https://github.com/ZhichengHuang/LearnTensorflowCode
参考资料:
https://www.tensorflow.org/versions/r1.1/programmers_guide/reading_data
http://stackoverflow.com/questions/41162955/tensorflow-queues-switching-between-train-and-validation-data
- Tensorflow读取数据1
- tensorflow爬坑行:数据读取
- Tensorflow图片数据读取
- Tensorflow读取数据
- Tensorflow读取数据
- tensorflow读取文件数据
- Tensorflow数据读取方式
- Tensorflow图片数据读取
- tensorflow图片数据读取
- Tensorflow数据读取方法
- TensorFlow读取tfrecords数据
- tensorflow读取数据
- TensorFlow数据读取
- tensorflow读取数据
- tensorflow 数据读取笔记
- TensorFlow数据读取
- TensorFlow读取数据
- tensorflow 数据读取
- mybatis中<where>标签、<set>标签、<trim>标签、<sql>标签、<foreach>标签的使用
- Activiti数据库配置
- 单例模式
- 【Shell】快速追踪哪些文件包含某个关键词
- Linux ALSA声卡驱动之二:声卡的创建
- Tensorflow读取数据1
- java 求阶乘
- 欢迎使用CSDN-markdown编辑器
- STM32 boot跳转到APP的Jump_Address()分析
- cssiot_李_TCP建立流程_讲稿
- sun.misc.BASE64Encoder找不到jar包的解决方法
- 聚类︱python实现 六大 分群质量评估指标(兰德系数、互信息、轮廓系数)
- React-Native 工程添加推送功能 (iOS 篇)
- ZOJ3878(Convert QWERTY to Dvorak)