TensorFlow学习笔记-ExponentialMovingAverage
来源:互联网 发布:山东广电网络集团上市 编辑:程序博客网 时间:2024/06/03 22:02
作用:使用随机梯度下降算法训练神经网络时,使用滑动平均模型在很多应用中都可以在一定程度上提高最终模型在测试数据上的表现。tensorflow提供的滑动平均模型的接口为:tf.train.ExponentialMovingAverage。
函数的定义如下:
def __init__(self, decay, num_updates=None, zero_debias=False, name="ExponentialMovingAverage")
decay:衰减率,用于控制模型的更新速度,它越大,则模型越趋于稳定,在实际的运行中,它的值非常接近1,例如0.99,0.999.
num_updates:为了使模型在训练前期更新更快,如果提供该值,则decay的更新为:
ExponentialMovingAverage对每一个变量都会维护一个影子变量(shadow_var),它的初始值为对应变量的值,每次运行变量更新时,影子变量的值很更新为:
通过代码为:
v1 = tf.Variable(0,dtype=tf.float32) step = tf.Variable(0,trainable=False) ema = tf.train.ExponentialMovingAverage(0.99,step) # 使用apply创建shadow variables.这是需要给定一个列表,每次执行这个操作时, # 列表中的变量将会被更新 maintain_step = ema.apply([v1]) with tf.Session() as sess: init_ops = tf.global_variables_initializer() sess.run(init_ops) # [0.0, 0.0] print(sess.run([v1,ema.average(v1)])) sess.run(tf.assign(v1,5.0)) # update shadow_var sess.run(maintain_step) # [5.0, 4.5] print(sess.run([v1, ema.average(v1)])) sess.run(tf.assign(step, 10000)) sess.run(tf.assign(v1, 10)) sess.run(maintain_step) # [10.0, 4.5549998] print(sess.run([v1,ema.average(v1)])) sess.run(maintain_step) # [10.0, 4.6094499] print(sess.run([v1,ema.average(v1)]))
关于ExponentialMovingAverage类主要用到的函数有apply,average,average_name,他们的作用分别为:
- apply:创建影子变量,参数必须是一个列表。
- average:获取影子变量的值,一般用在创建模型时。
- average_name:通过影子变量的名称获取影子变量的值,主要用在模型的恢复时。
- variables_to_restore:获取变量名与其对应的影子变量的MAP,例如:
conv/batchnorm/gamma/ExponentialMovingAverage: conv/batchnorm/gamma, conv_4/conv2d_params/ExponentialMovingAverage: conv_4/conv2d_params, global_step: global_step
官网给出的方法为:
方法一
# Create a Saver that loads variables from their saved shadow values.shadow_var0_name = ema.average_name(var0)shadow_var1_name = ema.average_name(var1)saver = tf.train.Saver({shadow_var0_name: var0, shadow_var1_name: var1})saver.restore(...checkpoint filename...)# var0 and var1 now hold the moving average values
方法二
#Returns a map of names to Variables to restore.variables_to_restore = ema.variables_to_restore()saver = tf.train.Saver(variables_to_restore)...saver.restore(...checkpoint filename...)
例如基于上面的例子:
模型保存
# [10.0, 4.6094499]# print(sess.run([v1,ema.average(v1)]))saver = tf.train.Saver())saver.save(sess,'model/mode.ckpt')
模型加载:
# If None, it will default to variables.moving_average_variables() + variables.trainable_variables()# 这里我们只恢复滑动平均值,所以只使用moving_average_variables()saver = tf.train.Saver(ema.variables_to_restore(tf.moving_average_variables()))saver.restore(sess,'model/mode.ckpt')#4.60945,v1对应的影子变量。print(sess.run(v1))
以上为个人对滑动平均模型的理解,如果不对,请指出,一起学习。
阅读全文
0 0
- TensorFlow学习笔记-ExponentialMovingAverage
- tensorflow学习笔记(三十三):ExponentialMovingAverage
- tensorflow ExponentialMovingAverage
- ExponentialMovingAverage 学习笔记(二))
- tensorflow 滑动平均模型 ExponentialMovingAverage
- Tensorflow cifar10_multi_gpu问题:Variable conv1/weights/ExponentialMovingAverage/ does not exist
- Tensorflow滑动平均模型tf.train.ExponentialMovingAverage解析
- TensorFlow学习笔记-1
- TensorFlow学习笔记
- TensorFlow 深度学习笔记
- TensorFlow学习笔记1
- tensorflow-Alexnet学习笔记
- TensorFlow学习笔记
- Tensorflow学习笔记
- Tensorflow学习笔记(1)
- Tensorflow学习笔记(2)
- tensorflow学习笔记
- Tensorflow学习笔记
- cdh上使用spark-thriftserver操作carbondata
- 实现一个函数,打印乘法口诀表,口诀表的行数和列数自己指定。
- dhcp服务器的配置
- ExtJs的列模式column详解
- [2017.11.25]verlauf
- TensorFlow学习笔记-ExponentialMovingAverage
- IAP实现之一内建模式 — cocos2dx
- sublime配置PHP环境
- caffe安装配置
- 深入浅出Zookeeper(一) Zookeeper架构及FastLeaderElection机制
- Play2 for Java(一:简介)
- 蓝桥杯-算法训练 出现次数最多的整数
- linux中数据库的操作命令
- 【深度学习笔记】(二)Hello, Tensorflow!