Tensorflow: 队列操作
来源:互联网 发布:win 启动mysql命令 编辑:程序博客网 时间:2024/06/05 00:29
import tensorflow as tf
Tensorflow有两种队列 Queue
类, 两者都继承与父类 tf.QueueBase
:
- 先进先出队列
tf.FIFOQueue
- 随机队列
tf.RandomShuffleQueue
队列与Op一样也是图中的一个节点, 类中的主要方法: 插入元素 queue.enqueue()
和取出元素 queue.dequeue()
返回一个Op, 也为图中的一个节点.
Queue
tf.FIFOQueue
- init()
'''参数: capacity: (must)整数, 队列的能容纳的最大元素的数量; dtypes: (must)dtype list, 列表长度与队列中单个元素包含的Tensor数量相同, 对应每个Tensor的数据类型; shapes: TensorShape list, 与dtypes类似, 长度与单个元素包含的Tensor数量相同, 对应每个Tensor的形状; names: string list, 与dtypes和shapes类似, 长度与单个元素包含的Tensor数量相同, 对应每个Tensor的名称; shared_name: 不为空, 则队列在给定的多个Session之间共享; name: 此队列的名称;'''
tf.RandomShuffleQueue
- init()
'''性质: 出列时随机选择元素.参数: capacity: (must)队列长度; min_after_dequeue: (must)队列中至少保留的元素的数量; dtypes: (must)dtype list, 列表长度与队列中单个元素包含的Tensor数量相同, 对应每个Tensor的数据类型; shapes: TensorShape list, 与dtypes类似, 长度与单个元素包含的Tensor数量相同, 对应每个Tensor的形状; names: string list, 与dtypes和shapes类似, 长度与单个元素包含的Tensor数量相同, 对应每个Tensor的名称; seed: 随机种子; shared_name: 不为空, 则队列在给定的多个Session之间共享; name: 此队列的名称;'''
队列共同方法
- enqueue()
'''作用: 入列单个元素.参数: vals: (must)入列元素; 根据队列的设置, 可能为单个Tensor, 或list/tuple of Tensor, 或根据names构成的[name: Tensor]键值对字典; name: 这个入列Op的名称.输出: Op, 入列Op.'''
- enqueue_many()
'''作用: 入列多个元素.参数: vals: (must)Tensor的传入形式与enqueue()方法相同, 不同的因为是插入多个元素, 所以此处的Tensor比队列规定的Tensor多一维; 对Tensor的第一位进行分割, 得到的多个Tensor是批量插入的元素; 因此此处的Tensor的第一维的大小表示的是批量插入的元素的数量; name: 这个入列Op的名称.输出: Op, 批量入列Op.'''queue = tf.FIFOQueue(10, dtypes=tf.float32)en_op_1 = queue.enqueue([1.0])# en_op_2 = queue.enqueue_many([1.0, 1.0, 2.0, 4.0, 5.0]) # 这样写会报错en_op_2 = queue.enqueue_many([[1.0, 1.0, 2.0, 4.0, 5.0]])with tf.Session() as sess: sess.run(en_op_1) sess.run(en_op_2) for i in range(queue.size().eval()): print(sess.run(queue.dequeue()))'''结果:1.01.01.02.04.05.0'''
- dequeue()
'''作用: 出列单个元素.参数: name: 这个出列Op的名称.输出: Tensor, 或tuple of Tensor, 取出的元素.'''
- dequeue_many()
'''作用: 出列多个元素参数: n: (must)出列的元素的数量 name: 这个出列Op的名称.输出: tuple of Tensor, 出列的若干元素集合.'''
QueueRunner 队列管理器
QueueRunner会创建若干新的线程, 每个线程负责一个入列操作, 因此一般用以在不干扰执行的情况下, 由这些线程读取庞大的数据.
所有队列管理器被默认加入图的tf.GraphKeys.QUEUE_RUNNERS集合中.
tf.train.QueueRunner
- init()
'''作用: 创建与入列Op数量相同的线程, 维护一个队列.参数: queue: (must)队列Queue, 被操作的队列; enqueue_ops: (must)list of Ops, 入列Ops列表, 每个Op在不同的线程中运行; close_op: 关闭队列的Op, 未执行的入列Ops保留; cancel_op: 关闭队列的Op, 并且清空未执行的入列Ops; queue_closed_exception_types: 指定的错误tuple, 表明当其中的错误出现时, 队列已关闭; queue_runner_def: QueueRunnerDef协议缓冲区, 从这个缓冲区中重建队列, 与其他变量互斥; import_scope: string, 使用queue_runner_def时才起作用, 添加到指定的Name Scope中.'''
方法:
- create_threads()
'''作用: 在给定的Session中为入列Ops创建若干线程, 并决定是否马上启用.参数: sess: (must)给定的Session; coord: Coordinator类, 辅助回报Errors和判断停止所有线程的条件; daemon: bool, Trueze则将这些线程设置为daemon线程; start: bool, True则创建后马上启动这些线程, False则需要再调用这些线程的start()方法来启动线程.输出: 创建的线程list'''queue = tf.FIFOQueue(10, dtypes=[tf.uint8])counter = tf.Variable(0, dtype=tf.uint8)increment_op = tf.assign_add(counter, tf.constant(1, dtype=tf.uint8))# 建立Op之间的依赖关系with tf.control_dependencies([increment_op]): enqueue_op = queue.enqueue(counter)qrunner = tf.train.QueueRunner(queue, enqueue_ops=[increment_op, enqueue_op]*1)sess = tf.Session()sess.run(tf.global_variables_initializer())# 启动QueueRunner中的线程enqueue_threads = qrunner.create_threads(sess, start=True)for i in range(5): print(sess.run(queue.dequeue()))
Coordinator 线程协调
Coordinator类用来帮助多个线程协同工作, 多个线程同步终止, 并且向那个在等待所有工作线程终止的程序报告异常.
tf.train.Coordinator
- init()
'''作用: 协调多个线程协同工作, 多个线程同步终止, 并且向那个在等待所有工作线程终止的程序报告异常.参数: clean_stop_exception_types: 由这个Exception list中的Exception发出的request_stop(<exception>)请求, 会被当作request_stop(None)这样的请求.'''
方法:
- should_stop()
'''作用: 判断所有线程是否应该停止, 如果线程应该停止则返回True.输出: bool, 如果线程应该停止则返回True.'''
- request_stop()
'''作用: 请求所有线程停止, 当request_stop()被调用之后, 再调用should_stop()就会返回True参数: ex: Exception, 指明停止线程的异常, 这个异常不要自行定义, 应该使用 try: ... except Exception as ex: coord.request_stop(ex) 这样的形式进行传递.'''
- join()
'''作用: 等待被指定的线程终止参数: threads: list of threading.Threads, 需要等待终止的线程; stop_grace_period_secs: 在调用request_stop()后, 给所有线程指定的秒数, 在这个时间之前停止; ignore_live_threads: bool, 当为False是, 在经过stop_grace_period_secs时间段后, 如果仍有存活的线程, 则报错.'''
Coordinator常有用法:
queue = tf.FIFOQueue(10, dtypes=[tf.uint8])counter = tf.Variable(0, dtype=tf.uint8)increment_op = tf.assign_add(counter, tf.constant(1, dtype=tf.uint8))# 建立Op之间的依赖关系with tf.control_dependencies([increment_op]): enqueue_op = queue.enqueue(counter)qrunner = tf.train.QueueRunner(queue, enqueue_ops=[increment_op, enqueue_op]*1)sess = tf.Session()coord = tf.train.Coordinator()sess.run(tf.global_variables_initializer())# 启动QueueRunner中的线程enqueue_threads = qrunner.create_threads(sess, coord=coord, start=True)try: for i in range(5): print(sess.run(queue.dequeue()))except Exception as ex: print("Something wrong.") coord.request_stop(ex)finally: coord.request_stop() coord.join(enqueue_threads)
读取数据的函数:
- tf.train.string_input_producer()
'''作用: 生成一个FIFO队列, 将若干输入数据文件名传入队列, 生成一个带有QueueRunner管理的FIFOQueue输出文件名的队列, 以供后面读取文件数据的方法使用, 从而实现对文件的管理.参数: string_tensor: (must)字符串队列或字符串1维Tensor, 使用这些字符串生成队列; num_epochs: 整数, 如果指定了该参数, 则对string_tensor中的每个string循环num_epochs次数后, 如果再调用, 才报OutOfRange错误; 如果没有指定该参数, 即为None, 则可以无限调用这些strings; 注意: 如果指定了num_epochs, 生成一个local计数器epochs, 函数会将这个变量添加到tf.GraphKeys.LOCAL_VARIABLES中, 因此在run训练Op之前, 需要使用tf.local_variables_initializer()对这个变量进行初始化; shuffle: bool, 如果为True, 则在每次epoch中, 将所有的字符串随机打乱后再被使用; seed: 整数, 随机种子, 在shuffle为True时被使用; capacity: 整数, 队列的长度; shared_name: 不为空, 则队列在给定的多个Session之间共享; name: Op的名称; cancel_op: 关闭队列的Op, 并且清空未执行的入列Ops, 在QueueRunner中使用;输出: 内容为文件名称string的FIFOQueue, 并有相应的QueueRunner对这个队列进行管理, 并将该QueueRunner加入到此图中的QUEUE_RUNNER collection中.'''
- tf.train.shuffle_batch()
'''作用: 对队列中的样本进行乱序处理后, 输出batch_size大小的Tensor. 主要有三个部分, 且这三个部分作为节点被添加到当前Graph中: 1. 一个shuffle queue, 将输入tensors中的tensor插入; 2. 一个dequeue_many Op, 从队列中生成batch tensors; 3. 一个QueueRunner, 并添加到QUEUE_RUNNER collection中, 来将输入tensors添加到queue中.参数: tensors: (must)需要入列的tensor list; batch_size: (must)生成的Tensor大小, 即一个dequeue_many从队列中取出的Tensor数量; capacity: (must)队列的长度; min_after_dequeue: (must)队列中最少的Tensor数量; num_threads: 对tensors中的数据执行enqueue操作的线程数量; seed: 随机种子, 决定队列中的打乱顺序; enqueue_many: bool, 输入tensor list中的每个Tensor是否是单个样本数据, True则为单样本数据; shapes: 每个样本的shape, 默认为tensor list中的数据的shape; allow_smaller_final_batch: bool, True则允许最后一个batch在队列中的数据不足时仍然输出一个比batch_size小的batch数据; shared_name: 不为空, 则队列在给定的多个Session之间共享; name: Op的名称输出: 与输入tensors list长度相同的输出tensors list, 每个tensor的第一位长度为batch_size, 即一个batch数据.'''
对于enqueue_many的详解:
enqueue_many为False时, tensors中的元素表示一个样本, 例如输入tensor的shape为[x, y, z];
enqueue_many为True时, tensors中的元素表示若干样本, 其中的每个tensor的第一维应该是同样大小, 代表输入样本的数量, 例如shape为[*, x, y, z];
无论enqueue_many如何, 输出的tensor list长度与tensors的长度一样, 且一一对应, 若tensors某个tensor的shape如上所述, 对应的输出tensor的shape为[batch_size, x, y, z].
- tf.train.start_queue_runners()
'''作用: 由于string_input_producer和shuffle_batch方法生成的队列都是由QueueRunner管理, 这些QueueRunner被添加到GraphKeys.QUEUE_RUNNERS中, 因此需要本方法统一启用这些QueueRunner, 启动他们的入列线程.参数: sess: 执行队列Ops的Session; coord: Coordinator类对象, 对线程进行管理; daemon: bool, 这些入列线程是否是守护线程; start: bool, False则只创建线程, True在创建线程后启动; collection: 指明创建执行的QueueRunner collection, 默认为GraphKeys.QUEUE_RUNNERS.输出: 生成的线程列表, threads list.'''
队列操作例子
FILE_NAME = "./data/cifar10_data/cifar-10-batches-bin/data_batch_{}.bin"filenames = [FILE_NAME.format(i) for i in range(6)]filename_queue = tf.train.string_input_producer(filenames, num_epochs=3)reader = tf.FixedLengthRecordReader(32*32*3+1)key, value = reader.read(filename_queue)record_bytes = tf.decode_raw(value, tf.uint8)label = tf.cast(tf.slice(record_bytes, [0], [1]), tf.int32)image = tf.reshape(tf.slice(record_bytes, [1], [32*32*3]), [3, 32, 32]) # 数据格式为[depth, height, width]uint8image = tf.transpose(image, [1, 2, 0]) # 将数据由[depth, height, width]转置为[height, width, depth]格式floatimage = tf.cast(uint8image, tf.float32)batch_size = 128min_queue_examples = 50000 * 0.4images, label_batch = tf.train.shuffle_batch([floatimage, label], batch_size=batch_size, capacity=int(min_queue_examples+3*batch_size), min_after_dequeue=int(min_queue_examples), num_threads=16) # label_batch's shape: (128, 1)labels = tf.reshape(label_batch, [batch_size]) # labels' shape: (128,)global_init_op = tf.global_variables_initializer()local_init_op = tf.local_variables_initializer() # 对string_input_producer()方法中的epochs变量进行初始化sess = tf.Session()sess.run(local_init_op)sess.run(global_init_op)tf.train.start_queue_runners(sess=sess)
阅读全文
0 0
- Tensorflow: 队列操作
- 【TensorFlow动手玩】队列
- TensorFlow 队列与多线程
- tensorflow-队列与多线程
- TensorFlow中的队列
- 学习笔记TF049:TensorFlow 模型存储加载、队列线程、加载数据、自定义操作
- Tensorflow实战学习(四十九)【模型存储加载,队列线程,加载数据,自定义操作】
- 队列操作
- 队列操作
- 队列操作
- 队列操作
- 队列操作
- 队列操作
- 队列操作
- 队列操作
- 操作队列
- 队列操作
- 队列操作
- 欢迎使用CSDN-markdown编辑器
- 图像三大特征
- UML类图实例
- 抽象类和接口
- PHP 7 新特性你知道多少?
- Tensorflow: 队列操作
- Error:Execution failed for task ':app:validateDebugSigning'. > Keystore file F:\myAndroid3\android_s
- 删除Maven仓库无用的版本
- ArrayList,Vector,HashMap,HashSet,HashTable之间的区别与联系
- MAVEN搭建多模块企业级项目
- codeforces round 22 B.The Golden Age(枚举)
- 【工作记录0022】C#(.NET)调用Java开发的WebService(wsdl),客户端传递非string类型参数(int,double,bool等),而服务端无法获取到参数值的解决方案
- 记录-如何测试服务器是否支持ipv6
- 读 Effective java