TensorFlow学习之CNN-Cifar10代码阅读与详解(一):cifar10数据批量读取

来源:互联网 发布:淘宝自己做衣服卖 编辑:程序博客网 时间:2024/05/18 19:43

本文将详细分析TensorFlow官方文档中关于在Cifar10数据上进行CNN分类的官方代码。官方文档地址如下:
https://www.tensorflow.org/versions/r0.12/tutorials/deep_cnn/index.html#convolutional-neural-networks
其代码组织形式如下:
这里写图片描述

由于该代码文件、参数较多,本文将分步骤,详叙该代码每部分的具体运行情况。

1. Cifar10数据集读取

数据集读取的功能主要由文件cifar10_input.py完成,该部分主要由函数read_cifar10、_generate_image_and_label_batch、以及inputs函数组成。各函数的主要功能如下:

  • ead_cifar10:从输入队列中读取一条记录信息,即一副图像的数据以及其label。

  • generate_image_and_label_batch:根据输入队列,采用多线程方法,从输入队列中批量读取一批数据,每批的中包含的图像数目为batch_size。

  • inputs:创建cifar10文件名队列,并调用read_cifar10函数以及_generate_image_and_label_batch函数,读取一批图像数据。

数据读取部分主要涉及到tensorflow队列和多线程管理的概念,说明如下:

1.1 tensorflow队列—————-
在TensorFlow中,队列和变量类似,都是算图上有状态的节点。所有队列管理器被默认加入图的tf.GraphKeys.QUEUE_RUNNERS集合中。
操作队列的类和方法如下:

  • FIFOQueue(): 创建一个先入先出(FIFO)的队列
  • RandomShuffleQueue():创建一个随机出队的队列
  • enqueue_many(): 初始化队列中的元素
  • enqueue(): 入队
  • dequeue(): 出队

1.2 TensorFlow多线程操作队列
为提高数据读取的速度,一般采用多线程对队列进行操作。TensorFlow提供的多线程协同操作的类—tf.Coordinator,其方法主要有:

  • should_stop(): 确定当前线程是否退出

  • request_stop():通知其他线程退出

  • join(): 等待所有线程终止

TensorFlow提供了队列tf.QueueRunner类处理多个线程操作同一队列,启动的线程由上面提到的tf.Coordinator类统一管理,常用的操作有:

  • QueueRunner():启动线程,第一个参数为线程需要操作的队列,第二个参数为对队列的操作,如enqueue_op,此时的enqueue_op = queue.enqueue()

  • add_queue_runner():在图中的一个集合中加‘QueueRunner’,如果没有指定的合集的话会被添加到tf.GraphKeys.QUEUE_RUNNERS合集

  • start_queue_runners():启动所有被添加到图中的线程

1.3 代码详解
这里写图片描述

这里写图片描述

这里写图片描述
这里写图片描述

这里写图片描述

程序的运行结果如下:
这里写图片描述qi

即在当前目录下,会保存10个读取到的图像文件