tensorflow ExponentialMovingAverage
来源:互联网 发布:淘宝机票平台出租 编辑:程序博客网 时间:2024/05/24 07:06
作用
假如我们训练模型迭代了100K,每2K步保存一个snapshot。在evaluation时, 我们可以只使用最后得到的model-100K,也可以通过cross validation选出一最佳的model,如model-98K。 但Googlers发现(https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/train/ExponentialMovingAverage):
When training a model, it is often beneficial to maintain moving averages of the trained parameters. Evaluations that use averaged parameters sometimes produce significantly better results than the final trained values.
大概意思是对模型参数进行平均得到的模型往往比单个模型的结果要好很多,注意,他们使用了significantly better
这个修饰, 虽然前边还有个sometimes
.
于是, tensorflow里就有了tf.train.ExponentialMovingAverage
这个接口。
在training时使用
文档中的例子:
# Create variables.var0 = tf.Variable(...)var1 = tf.Variable(...)# ... use the variables to build a training model...# Create an op that applies the optimizer. This is what we usually# would use as a training op.opt_op = opt.minimize(my_loss, [var0, var1])# Create an ExponentialMovingAverage objectema = tf.train.ExponentialMovingAverage(decay=0.9999)# Create the shadow variables, and add ops to maintain moving averages of var0 and var1.maintain_averages_op = ema.apply([var0, var1])# Create an op that will update the moving averages after each training step. This is what we will use in place of the usual training op.with tf.control_dependencies([opt_op]): training_op = tf.group(maintain_averages_op)#...train the model by running training_op...
这个例子体现出使用MovingAverage的三个要素。
1. 指定decay
参数创建实例: ema = tf.train.ExponentialMovingAverage(decay=0.9999)
2. 对模型变量使用apply
方法: maintain_averages_op = ema.apply([var0, var1])
3. 在优化方法使用梯度更新模型参数后执行MovingAverage:
with tf.control_dependencies([opt_op]): training_op = tf.group(maintain_averages_op)
其中,tf.group
将传入的操作捆绑成一个操作,详细可参考文档。
原理:影子变量与decay
apply
方法会为每个变量(也可以指定特定变量)创建各自的shadow variable
, 即影子变量。之所以叫影子变量,是因为它会全程跟随训练中的模型变量。影子变量会被初始化为模型变量的值,然后,每训练一个step,就更新一次。更新的方式为:
shadow_variable = decay * shadow_variable + (1 - decay) * updated_model_variable
decay
的值一般很接近于1,例如0.999, 0.9999
。很明显, 这个操作会增加模型在训练过程中的稳定性。
还有一点值得指出——创建tf.train.ExponentialMovingAverage
实例时还可以传入num_updates
参数,一般使用global_step
的值:
ema = tf.train.ExponentialMovingAverage(decay=0.9999, num_updates = tf_global_step)
它的作用让decay
变成动态的,训练前期的值小,后期的值大。因为这时真实decay的计算方式为:
decay = min(decay, (1 + num_updates) / (10 + num_updates))
Restore from checkpoint
训练时若使用了ExponentialMovingAverage
,在保存checkpoint时,不仅仅会保存模型参数,优化器参数(如Momentum), 还会保存ExponentialMovingAverage
的shadow variable
。
之前,我们可以直接使用以下代码restore模型参数, 但不会利用ExponentialMovingAverage
的结果:
saver = tf.Saver()saver.restore(sess, save_path)
若要使用ExponentialMovingAverage
保存的参数:
variables_to_restore = ema.variables_to_restore()saver = tf.train.Saver(variables_to_restore)saver.restore(sess, save_path)
当然,还有其他方式可以实现同样的效果。
写了一个完整的可运行demo, tensorflow r1.1.0。
看看它的输出就一目了然了(字符串是变量名,数字是对应的变量值),两者加载的变量值是不同的:
variables in checkpoint: bias/ExponentialMovingAverage 0.664593 bias/Momentum 4.12663 weight [[ 0.01567289] [ 0.17180483]] weight/ExponentialMovingAverage [[ 0.10421171] [ 0.26470858]] weight/Momentum [[ 5.95625305] [ 6.24084663]] bias 0.602739==============================================variables restored not from ExponentialMovingAverage: weight:0 [[ 0.01567289] [ 0.17180483]] bias:0 0.602739==============================================variables restored from ExponentialMovingAverage: weight:0 [[ 0.10421171] [ 0.26470858]] bias:0 0.664593
- tensorflow ExponentialMovingAverage
- TensorFlow学习笔记-ExponentialMovingAverage
- tensorflow学习笔记(三十三):ExponentialMovingAverage
- tensorflow 滑动平均模型 ExponentialMovingAverage
- Tensorflow cifar10_multi_gpu问题:Variable conv1/weights/ExponentialMovingAverage/ does not exist
- Tensorflow滑动平均模型tf.train.ExponentialMovingAverage解析
- tensorflow 下的滑动平均模型 —— tf.train.ExponentialMovingAverage
- Tensorflow中提供tf.train.ExponentialMovingAverage函数实现(滑动平均模型)
- tf.train.ExponentialMovingAverage解析
- tf.train.ExponentialMovingAverage用法
- tf.train.ExponentialMovingAverage的用法
- 指数滑动平均(ExponentialMovingAverage)EMA
- ExponentialMovingAverage 学习笔记(二))
- tensorflow
- TensorFlow
- TensorFlow
- tensorflow
- tensorflow
- 天下没有免费的午餐
- decorators.xml wap项目中用到了
- 要不要换个开发工具?——IntelliJ IDEA
- 我的R学习笔记
- Android中数据库的基本操作
- tensorflow ExponentialMovingAverage
- 【微信小程序】发送消息模板教程
- 一个java文件中可以有多个类
- redis支持的数据类型、操作指令及使用场景
- python入门系列20―——GUI Tkinter入门
- 【Centos7笔记七】用户及文件权限管理
- 尝试后可以成功在Ubuntu安装node.js的方法
- 源码日记——ArrayList
- JS中的prototype