【TensorFlow动手玩】常用集合: Variable, Summary, 自定义
来源:互联网 发布:windows pro是啥 编辑:程序博客网 时间:2024/05/14 06:39
集合
tensorflow用集合colletion
组织不同类别的对象。tf.GraphKeys
中包含了所有默认集合的名称。
collection
提供了一种“零存整取”的思路:在任意位置,任意层次都可以创造对象,存入相应collection
中;创造完成后,统一从一个collection
中取出一类变量,施加相应操作。
例如,
tf.Optimizer
只优化tf.GraphKeys.TRAINABLE_VARIABLES
中的变量。
本文介绍几个常用集合
- Variable
集合:模型参数
- Summary
集合:监测
- 自定义集合
Variable
Variable
被收集在名为tf.GraphKeys.VARIABLES
的colletion
中
定义
Tensorflow使用Variable
类表达、更新、存储模型参数。
Variable
是在可变更的,具有保持性的内存句柄,存储着Tensor
。必须使用Tensor
进行初始化。
k = tf.Variable(tf.random_normal([]), name='k')
创建的Variable
被添加到默认的collection
中。
初始化
在整个session
运行之前,图中的全部Variable
必须被初始化。
sess = tf.Session()init = tf.initialize_all_variables() sess.run(init)
在执行完初始化之后,Variable
中的值生成完毕,不会再变化。
特别强调:Variable
的值在sess.run(init)之后就确定了;Tensor
的值要在sess.run(x)之后才确定。
获取
和Tensor
, Operation
一样,Variable
也是全局的。
可以通过tf.all_variables()查看所有tf.GraphKeys.VARIABLES
中的对象:
# example for y = k*xx = tf.constant(1.0, shape=[]) # 0D tensork = tf.Variable(tf.constant(0.5, shape=[]) )y = tf.mul(x, k)v = tf.all_variables()
也可以用通用方法直接访问collection
:
v = tf.get_collection(tf.GraphKeys.VARIABLES)
各类Variable
另外,tensorflow还维护另外几个collection
:
ExponentialMovingAverage
对象会生成此类变量 tf.local_variables() LOCAL_VARIABLES 在all_variables()
之外,需要用tf.init_local_variables()初始化 tf.model_variables() MODEL_VARIABLES Summary
Summary
被收集在名为tf.GraphKeys.SUMMARIES
的colletion
中
定义
Summary
是对网络中Tensor
取值进行监测的一种Operation
。这些操作在图中是“外围”操作,不影响数据流本身。
用例
我们模仿常见的训练过程,创建一个最简单的用例。
# 迭代的计数器global_step = tf.Variable(0, trainable=False)# 迭代的+1操作increment_op = tf.assign_add(global_step, tf.constant(1))# 实例应用中,+1操作往往在`tf.train.Optimizer.apply_gradients`内部完成。# 创建一个根据计数器衰减的Tensorlr = tf.train.exponential_decay(0.1, global_step, decay_steps=1, decay_rate=0.9, staircase=False)# 把Tensor添加到观测中tf.scalar_summary('learning_rate', lr)# 并获取所有监测的操作`sum_opts`sum_ops = tf.merge_all_summaries()# 初始化sesssess = tf.Session()init = tf.initialize_all_variables()sess.run(init) # 在这里global_step被赋初值# 指定监测结果输出目录summary_writer = tf.train.SummaryWriter('/tmp/log/', sess.graph)# 启动迭代for step in range(0, 10): s_val = sess.run(sum_ops) # 获取serialized监测结果:bytes类型的字符串 summary_writer.add_summary(s_val, global_step=step) # 写入文件 sess.run(increment_op) # 计数器+1
调用tf.scalar_summary系列函数时,就会向默认的collection
中添加一个Operation
。
再次回顾“零存整取”原则:创建网络的各个层次都可以添加监测;在添加完所有监测,初始化sess之前,统一用tf.merge_all_summaries获取。
查看
SummaryWriter文件中存储的是序列化的结果,需要借助TensorBoard才能查看。
在命令行中运行tensorboard,传入存储SummaryWriter文件的目录:
tensorboard --logdir /tmp/log
完成后会提示:
You can navigate to http://127.0.1.1:6006
可以直接使用服务器本地浏览器访问这个地址(本机6006端口),或者使用远程浏览器访问服务器ip地址的6006端口。
自定义
除了默认的集合,我们也可以自己创造collection
组织对象。网络损失就是一类适宜对象。
tensorflow中的Loss提供了许多创建损失Tensor
的方式。
x1 = tf.constant(1.0)l1 = tf.nn.l2_loss(x1)x2 = tf.constant([2.5, -0.3])l2 = tf.nn.l2_loss(x2)
创建损失不会自动添加到集合中,需要手工指定一个collection
:
tf.add_to_collection("losses", l1)tf.add_to_collection("losses", l2)
创建完成后,可以统一获取所有损失,losses
是个Tensor
类型的list:
losses = tf.get_collection('losses')
另一种常见操作把所有损失累加起来得到一个Tensor
:
loss_total = tf.add_n(losses)
执行操作可以得到损失取值:
sess = tf.Session()init = tf.initialize_all_variables()sess.run(init)losses_val = sess.run(losses)loss_total_val = sess.run(loss_total)
实际上,如果使用TF-Slim包的losses系列函数创建损失,会自动添加到名为”losses”的collection
中。
- 【TensorFlow动手玩】常用集合: Variable, Summary, 自定义
- 【TensorFlow动手玩】队列
- 【TensorFlow动手玩】基本概念: Tensor, Operation, Graph
- 【TensorFlow动手玩】数据导入2
- 【TensorFlow动手玩】数据导入1
- Tensorflow variable
- tensorflow-Variable
- Oracle SET System variable Summary
- 【翻译】动手动脑玩转Web游戏之三:人物动起来、敌人出现、自定义视角
- tensorflow入门之Variable
- Tensorflow学习:Variable变量
- TensorFlow之Variable 使用方法
- tensorflow: variable初始化
- TensorFlow Variable 使用方法
- Tensorflow之Variable
- Tensorflow-get_variable、Variable
- TensorFlow Session&&Variable&&PlaceHolder
- TensorFlow--tf.Variable
- 查找东西是最浪费时间的 找搜索引擎。
- 基于iTextSharp的HTML转PDF,包含图片的转换
- 一、java io 概述
- 4、Java入门—多态
- Linux:FHS标准
- 【TensorFlow动手玩】常用集合: Variable, Summary, 自定义
- oracle查看表、字段属性和说明sql
- C/C++网络通讯编程(一)
- 记录andorid打印输出看不见的问题的探索
- 解决listview和 gridview 单行显示的方法
- *Leetcode 368. Largest Divisible Subset
- emulate touch event from adb
- C++中的协程
- 使用SWIG Python动态绑定C++对象