Tensorflow: 队列操作

来源:互联网 发布:win 启动mysql命令 编辑:程序博客网 时间:2024/06/05 00:29
import tensorflow as tf

Tensorflow有两种队列 Queue 类, 两者都继承与父类 tf.QueueBase :

  • 先进先出队列 tf.FIFOQueue
  • 随机队列 tf.RandomShuffleQueue

队列与Op一样也是图中的一个节点, 类中的主要方法: 插入元素 queue.enqueue() 和取出元素 queue.dequeue() 返回一个Op, 也为图中的一个节点.

Queue

tf.FIFOQueue

  • init()
'''参数:    capacity: (must)整数, 队列的能容纳的最大元素的数量;    dtypes: (must)dtype list, 列表长度与队列中单个元素包含的Tensor数量相同, 对应每个Tensor的数据类型;    shapes: TensorShape list, 与dtypes类似, 长度与单个元素包含的Tensor数量相同, 对应每个Tensor的形状;    names: string list, 与dtypes和shapes类似, 长度与单个元素包含的Tensor数量相同, 对应每个Tensor的名称;    shared_name: 不为空, 则队列在给定的多个Session之间共享;    name: 此队列的名称;'''

tf.RandomShuffleQueue

  • init()
'''性质: 出列时随机选择元素.参数:    capacity: (must)队列长度;    min_after_dequeue: (must)队列中至少保留的元素的数量;    dtypes: (must)dtype list, 列表长度与队列中单个元素包含的Tensor数量相同, 对应每个Tensor的数据类型;    shapes: TensorShape list, 与dtypes类似, 长度与单个元素包含的Tensor数量相同, 对应每个Tensor的形状;    names: string list, 与dtypes和shapes类似, 长度与单个元素包含的Tensor数量相同, 对应每个Tensor的名称;    seed: 随机种子;    shared_name: 不为空, 则队列在给定的多个Session之间共享;    name: 此队列的名称;'''

队列共同方法

  • enqueue()
'''作用: 入列单个元素.参数:    vals: (must)入列元素; 根据队列的设置, 可能为单个Tensor, 或list/tuple of Tensor, 或根据names构成的[name: Tensor]键值对字典;    name: 这个入列Op的名称.输出: Op, 入列Op.''' 
  • enqueue_many()
'''作用: 入列多个元素.参数:    vals: (must)Tensor的传入形式与enqueue()方法相同, 不同的因为是插入多个元素, 所以此处的Tensor比队列规定的Tensor多一维;    对Tensor的第一位进行分割, 得到的多个Tensor是批量插入的元素;    因此此处的Tensor的第一维的大小表示的是批量插入的元素的数量;    name: 这个入列Op的名称.输出: Op, 批量入列Op.'''queue = tf.FIFOQueue(10, dtypes=tf.float32)en_op_1 = queue.enqueue([1.0])# en_op_2 = queue.enqueue_many([1.0, 1.0, 2.0, 4.0, 5.0])  # 这样写会报错en_op_2 = queue.enqueue_many([[1.0, 1.0, 2.0, 4.0, 5.0]])with tf.Session() as sess:    sess.run(en_op_1)    sess.run(en_op_2)    for i in range(queue.size().eval()):        print(sess.run(queue.dequeue()))'''结果:1.01.01.02.04.05.0'''
  • dequeue()
'''作用: 出列单个元素.参数:    name: 这个出列Op的名称.输出: Tensor, 或tuple of Tensor, 取出的元素.'''
  • dequeue_many()
'''作用: 出列多个元素参数:    n: (must)出列的元素的数量    name: 这个出列Op的名称.输出: tuple of Tensor, 出列的若干元素集合.'''

QueueRunner 队列管理器

QueueRunner会创建若干新的线程, 每个线程负责一个入列操作, 因此一般用以在不干扰执行的情况下, 由这些线程读取庞大的数据.
所有队列管理器被默认加入图的tf.GraphKeys.QUEUE_RUNNERS集合中.

tf.train.QueueRunner

  • init()
'''作用: 创建与入列Op数量相同的线程, 维护一个队列.参数:    queue: (must)队列Queue, 被操作的队列;    enqueue_ops: (must)list of Ops, 入列Ops列表, 每个Op在不同的线程中运行;    close_op: 关闭队列的Op, 未执行的入列Ops保留;    cancel_op: 关闭队列的Op, 并且清空未执行的入列Ops;    queue_closed_exception_types: 指定的错误tuple, 表明当其中的错误出现时, 队列已关闭;    queue_runner_def: QueueRunnerDef协议缓冲区, 从这个缓冲区中重建队列, 与其他变量互斥;    import_scope: string, 使用queue_runner_def时才起作用, 添加到指定的Name Scope中.'''

方法:

  • create_threads()
'''作用: 在给定的Session中为入列Ops创建若干线程, 并决定是否马上启用.参数:    sess: (must)给定的Session;    coord: Coordinator类, 辅助回报Errors和判断停止所有线程的条件;    daemon: bool, Trueze则将这些线程设置为daemon线程;    start: bool, True则创建后马上启动这些线程, False则需要再调用这些线程的start()方法来启动线程.输出: 创建的线程list'''queue = tf.FIFOQueue(10, dtypes=[tf.uint8])counter = tf.Variable(0, dtype=tf.uint8)increment_op = tf.assign_add(counter, tf.constant(1, dtype=tf.uint8))# 建立Op之间的依赖关系with tf.control_dependencies([increment_op]):    enqueue_op = queue.enqueue(counter)qrunner = tf.train.QueueRunner(queue, enqueue_ops=[increment_op, enqueue_op]*1)sess = tf.Session()sess.run(tf.global_variables_initializer())# 启动QueueRunner中的线程enqueue_threads = qrunner.create_threads(sess, start=True)for i in range(5):    print(sess.run(queue.dequeue()))

Coordinator 线程协调

Coordinator类用来帮助多个线程协同工作, 多个线程同步终止, 并且向那个在等待所有工作线程终止的程序报告异常.

tf.train.Coordinator

  • init()
'''作用: 协调多个线程协同工作, 多个线程同步终止, 并且向那个在等待所有工作线程终止的程序报告异常.参数:     clean_stop_exception_types: 由这个Exception list中的Exception发出的request_stop(<exception>)请求, 会被当作request_stop(None)这样的请求.'''

方法:

  • should_stop()
'''作用: 判断所有线程是否应该停止, 如果线程应该停止则返回True.输出: bool, 如果线程应该停止则返回True.'''
  • request_stop()
'''作用: 请求所有线程停止, 当request_stop()被调用之后, 再调用should_stop()就会返回True参数:    ex: Exception, 指明停止线程的异常, 这个异常不要自行定义, 应该使用        try:            ...        except Exception as ex:            coord.request_stop(ex)        这样的形式进行传递.'''
  • join()
'''作用: 等待被指定的线程终止参数:    threads: list of threading.Threads, 需要等待终止的线程;    stop_grace_period_secs: 在调用request_stop()后, 给所有线程指定的秒数, 在这个时间之前停止;    ignore_live_threads: bool, 当为False是, 在经过stop_grace_period_secs时间段后, 如果仍有存活的线程, 则报错.'''

Coordinator常有用法:

queue = tf.FIFOQueue(10, dtypes=[tf.uint8])counter = tf.Variable(0, dtype=tf.uint8)increment_op = tf.assign_add(counter, tf.constant(1, dtype=tf.uint8))# 建立Op之间的依赖关系with tf.control_dependencies([increment_op]):    enqueue_op = queue.enqueue(counter)qrunner = tf.train.QueueRunner(queue, enqueue_ops=[increment_op, enqueue_op]*1)sess = tf.Session()coord = tf.train.Coordinator()sess.run(tf.global_variables_initializer())# 启动QueueRunner中的线程enqueue_threads = qrunner.create_threads(sess, coord=coord, start=True)try:    for i in range(5):        print(sess.run(queue.dequeue()))except Exception as ex:    print("Something wrong.")    coord.request_stop(ex)finally:    coord.request_stop()    coord.join(enqueue_threads)

读取数据的函数:

  • tf.train.string_input_producer()
'''作用: 生成一个FIFO队列, 将若干输入数据文件名传入队列, 生成一个带有QueueRunner管理的FIFOQueue输出文件名的队列, 以供后面读取文件数据的方法使用, 从而实现对文件的管理.参数:    string_tensor: (must)字符串队列或字符串1维Tensor, 使用这些字符串生成队列;    num_epochs: 整数, 如果指定了该参数, 则对string_tensor中的每个string循环num_epochs次数后, 如果再调用, 才报OutOfRange错误;    如果没有指定该参数, 即为None, 则可以无限调用这些strings;    注意: 如果指定了num_epochs, 生成一个local计数器epochs, 函数会将这个变量添加到tf.GraphKeys.LOCAL_VARIABLES中, 因此在run训练Op之前, 需要使用tf.local_variables_initializer()对这个变量进行初始化;    shuffle: bool, 如果为True, 则在每次epoch中, 将所有的字符串随机打乱后再被使用;    seed: 整数, 随机种子, 在shuffle为True时被使用;    capacity: 整数, 队列的长度;    shared_name: 不为空, 则队列在给定的多个Session之间共享;    name: Op的名称;    cancel_op: 关闭队列的Op, 并且清空未执行的入列Ops, 在QueueRunner中使用;输出: 内容为文件名称string的FIFOQueue, 并有相应的QueueRunner对这个队列进行管理, 并将该QueueRunner加入到此图中的QUEUE_RUNNER collection中.'''
  • tf.train.shuffle_batch()
'''作用:    对队列中的样本进行乱序处理后, 输出batch_size大小的Tensor.    主要有三个部分, 且这三个部分作为节点被添加到当前Graph中:    1. 一个shuffle queue, 将输入tensors中的tensor插入;    2. 一个dequeue_many Op, 从队列中生成batch tensors;    3. 一个QueueRunner, 并添加到QUEUE_RUNNER collection中, 来将输入tensors添加到queue中.参数:    tensors: (must)需要入列的tensor list;    batch_size: (must)生成的Tensor大小, 即一个dequeue_many从队列中取出的Tensor数量;    capacity: (must)队列的长度;    min_after_dequeue: (must)队列中最少的Tensor数量;    num_threads: 对tensors中的数据执行enqueue操作的线程数量;    seed: 随机种子, 决定队列中的打乱顺序;    enqueue_many: bool, 输入tensor list中的每个Tensor是否是单个样本数据, True则为单样本数据;    shapes: 每个样本的shape, 默认为tensor list中的数据的shape;    allow_smaller_final_batch: bool, True则允许最后一个batch在队列中的数据不足时仍然输出一个比batch_size小的batch数据;    shared_name: 不为空, 则队列在给定的多个Session之间共享;    name: Op的名称输出: 与输入tensors list长度相同的输出tensors list, 每个tensor的第一位长度为batch_size, 即一个batch数据.'''

对于enqueue_many的详解:

enqueue_many为False时, tensors中的元素表示一个样本, 例如输入tensor的shape为[x, y, z];
enqueue_many为True时, tensors中的元素表示若干样本, 其中的每个tensor的第一维应该是同样大小, 代表输入样本的数量, 例如shape为[*, x, y, z];
无论enqueue_many如何, 输出的tensor list长度与tensors的长度一样, 且一一对应, 若tensors某个tensor的shape如上所述, 对应的输出tensor的shape为[batch_size, x, y, z].

  • tf.train.start_queue_runners()
'''作用: 由于string_input_producer和shuffle_batch方法生成的队列都是由QueueRunner管理, 这些QueueRunner被添加到GraphKeys.QUEUE_RUNNERS中, 因此需要本方法统一启用这些QueueRunner, 启动他们的入列线程.参数:    sess: 执行队列Ops的Session;    coord: Coordinator类对象, 对线程进行管理;    daemon: bool, 这些入列线程是否是守护线程;    start: bool, False则只创建线程, True在创建线程后启动;    collection: 指明创建执行的QueueRunner collection, 默认为GraphKeys.QUEUE_RUNNERS.输出: 生成的线程列表, threads list.'''

队列操作例子

FILE_NAME = "./data/cifar10_data/cifar-10-batches-bin/data_batch_{}.bin"filenames = [FILE_NAME.format(i) for i in range(6)]filename_queue = tf.train.string_input_producer(filenames, num_epochs=3)reader = tf.FixedLengthRecordReader(32*32*3+1)key, value = reader.read(filename_queue)record_bytes = tf.decode_raw(value, tf.uint8)label = tf.cast(tf.slice(record_bytes, [0], [1]), tf.int32)image = tf.reshape(tf.slice(record_bytes, [1], [32*32*3]), [3, 32, 32])  # 数据格式为[depth, height, width]uint8image = tf.transpose(image, [1, 2, 0])  # 将数据由[depth, height, width]转置为[height, width, depth]格式floatimage = tf.cast(uint8image, tf.float32)batch_size = 128min_queue_examples = 50000 * 0.4images, label_batch = tf.train.shuffle_batch([floatimage, label],                                             batch_size=batch_size,                                             capacity=int(min_queue_examples+3*batch_size),                                             min_after_dequeue=int(min_queue_examples),                                             num_threads=16)  # label_batch's shape: (128, 1)labels = tf.reshape(label_batch, [batch_size])  # labels' shape: (128,)global_init_op = tf.global_variables_initializer()local_init_op = tf.local_variables_initializer()  # 对string_input_producer()方法中的epochs变量进行初始化sess = tf.Session()sess.run(local_init_op)sess.run(global_init_op)tf.train.start_queue_runners(sess=sess)
原创粉丝点击