TensorFlow 队列与多线程

来源:互联网 发布:普联软件 怎么样 编辑:程序博客网 时间:2024/05/17 04:22

正如TensorFlow中的其他组件一样,队列就是TensorFlow图中的节点。这是一种有状态的节点,就像变量一样:其他节点可以修改它的内容。具体来说,其他节点可以把新元素插入到队列后端,也可以把队列前端的元素删除。为了感受一下队列,让我们来看一个非常简单的例子:

# coding:utf-8import tensorflow as tf#创建一个先入先出队列,队列最多可以保存两个元素,并制定类行为整型q=tf.FIFOQueue(2, "int32")#使用 enqueue_many 函数来初始化队列中的元素。和变量初始化类似,在使用队列之前需要明确的调用这些初始化过程init = q.enqueue_many(([0, 10],))#通过 dequeue 取出队列中第一个元素x = q.dequeue()y = x+1#通过 enqueue 将y加入队列q_inc = q.enqueue([y])with tf.Session() as sess:    init.run()    for _ in range(6):        v, _ = sess.run([x, q_inc])        print v #输出结果为:#0#10#1#11#2#12

TensorFlow提供了 FIFOQueue 和 RandomShuffleQueue 两种队列。 FIFOQueue 实现的是先进先出的队列,RandomShuffleQueue 会将队列中元素打乱,出队列得到的是当前队列中随机选择的一个。

在 TensorFlow 中,队列不仅仅是一种数据结构,更是“异步张量取值”的一个重要机制。比如多个线程可以同时向一个队列中写元素,或者同时读取一个队列中的元素。TensorFlow提供了两个类来帮助多线程的实现:tf.Coordinator和 tf.QueueRunner。从设计上这两个类必须被一起使用。

Coordinator类用来帮助多个线程协同工作,多个线程同步终止。 其主要方法有:
should_stop():如果线程应该停止则返回True。
request_stop(<exception>): 请求该线程停止。
join(<list of threads>):等待被指定的线程终止。


首先创建一个 Coordinator 对象,然后建立一些使用Coordinator对象的线程。这些线程通常一直循环运行,一直到should_stop()返回True时停止。 任何线程都可以决定计算什么时候应该停止。它只需要调用request_stop(),同时其他线程的should_stop()将会返回True,然后都停下来。

接下来我们通过一个简单的例子来看一下 Coordinator 对象的使用:

# coding:utf-8import tensorflow as tfimport numpy as npimport threadingimport time#线程中运行的程序,这个程序每隔1秒判断是否需要停止并打印自己的IDdef MyLoop(coord, worker_id):    while not coord.should_stop():        #随即停止线程        if np.random.rand() < 0.1 :            print "stopinig form id: %d\n" % worker_id,            coord.request_stop()        else:            print "Workinig on id: %d\n" % worker_id,        #暂停1妙        time.sleep(1)#声明一个 tf.train.Coordinator 类来协同多个线程coord = tf.train.Coordinator()#创建5个线程threads = [threading.Thread(target=MyLoop, args=(coord,i)) for i in xrange(5)]#启动所有的线程for t in threads: t.start()#等待所有线程退出coord.join(threads)


tf.QueueRunner 主要是用于启动多个线程来操作同一个队列,启动这些线程可以通过 tf.Coordinator 来统一管理

# coding:utf-8#启动5个线程来执行队列的入队操作,其中每一个线程都是将随机数写入队列。于是在每次运行出队操作时,就可以得到一个随机数。import tensorflow as tf#声明一个先进先出的队列,队列最多100个元素,类型为实数queue = tf.FIFOQueue(100, "float")#定义队列的入队操作enqueue_op = queue.enqueue([tf.random_normal([1])])#使用 tf.train.QueueRunner 来创建多个线程运行队列的入队操作#tf.train.QueueRunner 第一个参数给出了被操作的队列#[enqueue_op] × 5 表示需要启动5个线程,每个此案城中运行的是enqueue_op操作qr = tf.train.QueueRunner(queue, [enqueue_op] * 5)#将 QueueRunner 加入TensorFlow 计算图上制定的集合tf.train.add_queue_runner(qr)#定义出队操作out_tensor = queue.dequeue()with tf.Session() as sess:    #使用Coordinator 来协同启动的线程    coord = tf.train.Coordinator()    threads = tf.train.start_queue_runners(sess=sess, coord=coord)    #获取队列中的取值    for i in range(3):print sess.run(out_tensor)[0]    #停止所有线程    coord.request_stop()    coord.join(threads)






原创粉丝点击