TensorFlow滑动平均模型

来源:互联网 发布:数据库无法远程连接 编辑:程序博客网 时间:2024/05/29 19:05

我们使用滑动平均模型可以使模型在测试数据上更准确。

在《TensorFlow:实战Google深度学习框架》一书中给了例子:

  shadow_variable = decay*shadow_variable+(1-decay)*variable

       每次使用的衰减率 = min{decay, (1+num_updates)/(10+unm_updates)} 

下面代码中, num_updates = step   

import tensorflow as tf#需要保存滑动平均值的变量v1 = tf.Variable(0, dtype=tf.float32)v2 = tf.Variable(0, dtype=tf.float32)#步数step = tf.Variable(0, trainable=False)#滑动平均模型ema = tf.train.ExponentialMovingAverage(0.99,step)#向模型提供变量averages_op = ema.apply([v1,v2])with tf.Session() as sess:    init_op = tf.initialize_all_variables()    sess.run(init_op)    print sess.run([v1,ema.average(v1),v2,ema.average(v2)])    sess.run(tf.assign(v1,5))    sess.run(tf.assign(v2,8))#想获得影子变量,需要在run一下滑动平均节点    sess.run(averages_op)    print sess.run([v1,ema.average(v1),v2,ema.average(v2)])

运行结果为:

[0.0, 0.0, 0.0, 0.0][5.0, 4.5, 8.0, 7.1999998]


分析:

上述代码中 step = 0

衰减率 = min{decay, (1+0)/(10+0)} = 0.1 

4.5 = 0.1 * 0 +0.9 * 5

7.19998 = 0.1 * 0 + 0.9 * 8 



1 0
原创粉丝点击