关于tensorflow 的数据读取线程管理QueueRunner

来源:互联网 发布:大战神张飞进阶数据 编辑:程序博客网 时间:2024/06/05 01:58

转自 http://blog.csdn.net/sunquan_ok/article/details/51832442

TensorFlow的Session对象是可以支持多线程的,因此多个线程可以很方便地使用同一个会话(Session)并且并行地执行操作。然而,在Python程序实现这样的并行运算却并不容易。所有线程都必须能被同步终止,异常必须能被正确捕获并报告,回话终止的时候, 队列必须能被正确地关闭。

所幸TensorFlow提供了两个类来帮助多线程的实现:tf.Coordinator和 tf.QueueRunner。从设计上这两个类必须被一起使用。Coordinator类可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常。QueueRunner类用来协调多个工作线程同时将多个张量推入同一个队列中


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

以上为极客中国上tensorflow的官方文档翻译中,对于线程和队列的介绍。但是可能说的不太清楚。


Coordinator还比较好理解,可以理解为信号量之类的东西。QueueRunner比较难理解,通篇看介绍文档,都没有找到QueueRunner这个代码,后来终于发现一段文字:

创建线程并使用QueueRunner对象来预取

简单来说:使用上面列出的许多tf.train函数添加QueueRunner到你的数据流图中。在你运行任何训练步骤之前,需要调用tf.train.start_queue_runners函数,否则数据流图将一直挂起。tf.train.start_queue_runners 这个函数将会启动输入管道的线程,填充样本到队列中,以便出队操作可以从队列中拿到样本。


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

也就是说,QueueRunner是一个不存在于代码中的东西,而是后台运作的一个概念。由tf.train函数添加。


首先,我们先创建数据流图,这个数据流图由一些流水线的阶段组成,阶段间用队列连接在一起。第一阶段将生成文件名,我们读取这些文件名并且把他们排到文件名队列中。第二阶段从文件中读取数据(使用Reader),产生样本,而且把样本放在一个样本队列中。根据你的设置,实际上也可以拷贝第二阶段的样本,使得他们相互独立,这样就可以从多个文件中并行读取。在第二阶段的最后是一个排队操作,就是入队到队列中去,在下一阶段出队。因为我们是要开始运行这些入队操作的线程,所以我们的训练循环会使得样本队列中的样本不断地出队。




tf.train中要创建这些队列和执行入队操作,就要添加tf.train.QueueRunner到一个使用tf.train.add_queue_runner函数的数据流图中。每个QueueRunner负责一个阶段,处理那些需要在线程中运行的入队操作的列表。一旦数据流图构造成功,tf.train.start_queue_runners函数就会要求数据流图中每个QueueRunner去开始它的线程运行入队操作。


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

代码中根本没有QueueRunner的啊


# Create the graph, etc.init_op = tf.initialize_all_variables()# Create a session for running operations in the Graph.sess = tf.Session()# Initialize the variables (like the epoch counter).sess.run(init_op)# Start input enqueue threads.coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)try:    while not coord.should_stop():        # Run training steps or whatever        sess.run(train_op)except tf.errors.OutOfRangeError:    print 'Done training -- epoch limit reached'finally:    # When done, ask the threads to stop.    coord.request_stop()# Wait for threads to finish.coord.join(threads)sess.close()


0 0
原创粉丝点击