【TensorFlow】数据处理(队列和多线程)

来源:互联网 发布:在手机上怎么注册淘宝 编辑:程序博客网 时间:2024/06/12 20:53

项目已上传至 GitHub —— queue-threading

Queue


TensorFlow 提供了两种列队

  • FIFOQueue:先进先出队列
  • RandomShuffleQueue:随机出队队列

修改队列状态的函数主要有

  • enqueue():入队
  • enqueue_many():将多个元素入队
  • dequeue():出队

以下代码示范了怎么使用这些函数(以下所有代码都实现自《TensorFlow:实战Google深度学习框架》)

import tensorflow as tf# 创建一个先进先出队列,指定队列最多保存两个元素,类型为整数q = tf.FIFOQueue(2, 'int32')# 初始化队列中的元素init = q.enqueue_many(([0, 10], ))# 执行出队操作x = q.dequeue()# 将元素+1后入队y = x + 1q_inc = q.enqueue([y])with tf.Session() as sess:    # 运行初始化队列的操作    init.run()    # 执行出队+1入队的操作    for _ in range(5):        v, _ = sess.run([x, q_inc])        print(v)

运行结果如下

$ python q.py0101112

Coordinator


tf.Coordinator 主要用于协同多个进程一起停止,提供了以下函数

  • should_stop()
  • request_stop()
  • join()

request_stop() 函数将 should_stop() 函数的返回值设置为 True,当 should_stop() 函数的返回值为 True 后,线程将同时终止

以下程序示范了如何使用 tf.Coordinator

import tensorflow as tfimport numpy as npimport threadingimport time# 线程中运行的程序,每隔1秒判断是否停止并打印IDdef MyLoop(coord, worker_id):    while not coord.should_stop():        # 随机停止所有线程        if np.random.rand() < 0.1:            print('Stoping from id: %d\n' % (worker_id))            coord.request_stop()        else:            print('Working on id: %d\n' % (worker_id))        time.sleep(1)# 创建Coordinatorcoord = tf.train.Coordinator()# 声明创建5个线程threads = [threading.Thread(target=MyLoop, args=(coord, i)) for i in range(5)]# 启动所有线程for t in threads:    t.start()# 等待所有线程退出coord.join()

运行结果如下

$ python coord.pyWorking on id: 4Working on id: 1Working on id: 4......Working on id: 0Working on id: 1Stoping from id: 4

QueueRunner


tf.QueueRunner 主要用于启动多个线程来操作同一个队列,启动的线程可以通过 tf.Coordinator 来统一管理

以下代码示范了如何通过 tf.QueueRunner 和 tf.Coordinator 来管理多线程队列操作

import tensorflow as tf# 声明一个FIFO队列,最多100个实数元素q = tf.FIFOQueue(100, 'float')# 定义队列的入队操作enqueue_op = q.enqueue([tf.random_normal([1])])# 使用QueueRunner来创建多个线程运行队列的入队操作qr = tf.train.QueueRunner(q, [enqueue_op] * 5)# 将定义过的QueueRunner加入计算图上指定的集合tf.train.add_queue_runner(qr)# 定义出队操作dequeue_op = q.dequeue()with tf.Session() as sess:    # 使用Coordinator来协同启动的线程    coord = tf.train.Coordinator()    # 使用QueueRunner时需要明确调用tf.train.start_queue_runners来启动所有进程    threads = tf.train.start_queue_runners(coord=coord)    # 获取队列中的取值    for _ in range(3):        print(sess.run(dequeue_op)[0])    # 停止所有线程    coord.request_stop()    coord.join(threads)

运行结果如下

$ python qr.py1.1977-1.01813-0.617845
原创粉丝点击