TensorFlow中的队列

来源:互联网 发布:淘宝网店开店的策划书 编辑:程序博客网 时间:2024/06/06 09:21

在上一篇文章中,虽然最终运行结果正确, 但是在运行结果最后报了一个错误:

_1_input_producer: Skipping cancelled enqueue attempt with queue not closed

这主要是主线程已经关闭,但是读取数据入队线程还在执行入队。这篇文章转自《理解TensorFlow的Queue》一文,文章对TF队列讲的很详细,受益匪浅,很有必要转载过来。这里对原文有部分用词做了修改,比如“卡住”改为“阻塞”。

这篇文章来说说TensorFlow里与Queue有关的概念和用法。其实概念只有三个:

  • Queue是TF队列和缓存机制的实现
  • QueueRunner是TF中对操作Queue的线程的封装
  • Coordinator是TF中用来协调线程运行的工具

虽然它们经常同时出现,但这三样东西在TensorFlow里面是可以单独使用的,不妨先分开来看待。

1. Queue

根据实现的方式不同,分成具体的几种类型,例如:

  • tf.FIFOQueue 按入列顺序出列的队列
  • tf.RandomShuffleQueue 随机顺序出列的队列
  • tf.PaddingFIFOQueue 以固定长度批量出列的队列
  • tf.PriorityQueue 带优先级出列的队列
  • … …

这些类型的Queue除了自身的性质不太一样外,创建、使用的方法基本是相同的。

创建函数的参数:

tf.FIFOQueue(capacity, dtypes, shapes=None, names=None ...)

Queue主要包含入列(enqueue)和出列(dequeue)两个操作。enqueue操作返回计算图中的一个Operation节点,dequeue操作返回一个Tensor值。Tensor在创建时同样只是一个定义(或称为“声明”),需要放在Session中运行才能获得真正的数值。下面是一个单独使用Queue的例子:

import tensorflow as tftf.InteractiveSession()q = tf.FIFOQueue(2, "float")init = q.enqueue_many(([0,0],))x = q.dequeue()y = x+1q_inc = q.enqueue([y])init.run()q_inc.run()q_inc.run()q_inc.run()x.eval()  # 返回1x.eval()  # 返回2x.eval()  # 阻塞

注意,如果一次性入列超过Queue Size的数据,enqueue操作会阻塞,直到有数据(被其他线程)从队列取出。对一个已经取空的队列使用dequeue操作也会阻塞,直到有新的数据(从其他线程)写入。

2. QueueRunner

Tensorflow的计算主要在使用CPU/GPU和内存,而数据读取涉及磁盘操作,速度远低于前者操作。因此通常会使用多个线程读取数据,然后使用一个线程消费数据。QueueRunner就是来管理这些读写队列的线程的。

QueueRunner需要与Queue一起使用(这名字已经注定了它和Queue脱不开干系),但并不一定必须使用Coordinator。看下面这个例子:

import tensorflow as tf  import sys  q = tf.FIFOQueue(10, "float")  counter = tf.Variable(0.0)  #计数器# 给计数器加一increment_op = tf.assign_add(counter, 1.0)# 将计数器加入队列enqueue_op = q.enqueue(counter)# 创建QueueRunner# 用多个线程向队列添加数据# 这里实际创建了4个线程,两个增加计数,两个执行入队qr = tf.train.QueueRunner(q, enqueue_ops=[increment_op, enqueue_op] * 2)# 主线程sess = tf.InteractiveSession()tf.global_variables_initializer().run()# 启动入队线程qr.create_threads(sess, start=True)for i in range(20):    print (sess.run(q.dequeue()))

增加计数的进程会不停的后台运行,执行入队的进程会先执行10次(因为队列长度只有10),然后主线程开始消费数据,当一部分数据消费被后,入队的进程又会开始执行。最终主线程消费完20个数据后停止,但其他线程继续运行,程序不会结束。

3. Coordinator

Coordinator是个用来保存线程组运行状态的协调器对象,它和TensorFlow的Queue没有必然关系,是可以单独和Python线程使用的。例如:

import tensorflow as tfimport threading, time# 子线程函数def loop(coord, id):    t = 0    while not coord.should_stop():        print(id)        time.sleep(1)        t += 1        # 只有1号线程调用request_stop方法        if (t >= 2 and id == 1):            coord.request_stop()# 主线程coord = tf.train.Coordinator()# 使用Python API创建10个线程threads = [threading.Thread(target=loop, args=(coord, i)) for i in range(10)]# 启动所有线程,并等待线程结束for t in threads: t.start()coord.join(threads)

将这个程序运行起来,会发现所有的子线程执行完两个周期后都会停止,主线程会等待所有子线程都停止后结束,从而使整个程序结束。由此可见,只要有任何一个线程调用了Coordinator的request_stop方法,所有的线程都可以通过should_stop方法感知并停止当前线程。

将QueueRunner和Coordinator一起使用,实际上就是封装了这个判断操作,从而使任何一个现成出现异常时,能够正常结束整个程序,同时主线程也可以直接调用request_stop方法来停止所有子线程的执行。

4. 在一起

在TensorFlow中用Queue的经典模式有两种,都是配合了QueueRunner和Coordinator一起使用的。

第一种,显式的创建QueueRunner,然后调用它的create_threads方法启动线程。例如下面这段代码:

import tensorflow as tf# 1000个4维输入向量,每个数取值为1-10之间的随机数data = 10 * np.random.randn(1000, 4) + 1# 1000个随机的目标值,值为0或1target = np.random.randint(0, 2, size=1000)# 创建Queue,队列中每一项包含一个输入数据和相应的目标值queue = tf.FIFOQueue(capacity=50, dtypes=[tf.float32, tf.int32], shapes=[[4], []])# 批量入列数据(这是一个Operation)enqueue_op = queue.enqueue_many([data, target])# 出列数据(这是一个Tensor定义)data_sample, label_sample = queue.dequeue()# 创建包含4个线程的QueueRunnerqr = tf.train.QueueRunner(queue, [enqueue_op] * 4)with tf.Session() as sess:    # 创建Coordinator    coord = tf.train.Coordinator()    # 启动QueueRunner管理的线程    enqueue_threads = qr.create_threads(sess, coord=coord, start=True)    # 主线程,消费100个数据    for step in range(100):        if coord.should_stop():            break        data_batch, label_batch = sess.run([data_sample, label_sample])    # 主线程计算完成,停止所有采集数据的进程    coord.request_stop()    coord.join(enqueue_threads)

第二种,使用全局的start_queue_runners方法启动线程。

import tensorflow as tf# 同时打开多个文件,显示创建Queue,同时隐含了QueueRunner的创建filename_queue = tf.train.string_input_producer(["data1.csv","data2.csv"])reader = tf.TextLineReader(skip_header_lines=1)# Tensorflow的Reader对象可以直接接受一个Queue作为输入key, value = reader.read(filename_queue)with tf.Session() as sess:    coord = tf.train.Coordinator()    # 启动计算图中所有的队列线程    threads = tf.train.start_queue_runners(coord=coord)    # 主线程,消费100个数据    for _ in range(100):        features, labels = sess.run([data_batch, label_batch])    # 主线程计算完成,停止所有采集数据的进程    coord.request_stop()    coord.join(threads)

在这个例子中,tf.train.string_input_produecer会将一个隐含的QueueRunner添加到全局图中(类似的操作还有tf.train.shuffle_batch等)。

由于没有显式地返回QueueRunner来用create_threads启动线程,这里使用了tf.train.start_queue_runners方法直接启动tf.GraphKeys.QUEUE_RUNNERS集合中的所有队列线程。

这两种方式在效果上是等效的。

原创粉丝点击