Tensorflow 实战 google 深度学习框架 笔记(一)滑动模型
来源:互联网 发布:知乎 商业摄影师 编辑:程序博客网 时间:2024/05/21 14:58
1、滑动模型是用来干嘛的?
在采用随机梯度下降算法训练神经网络时,使用 tf.train.ExponentialMovingAverage 滑动平均操作的意义在于提高模型在测试数据上的健壮性(robustness)。
2、原理
tensorflow 下的 tf.train.ExponentialMovingAverage 需要提供一个衰减率(decay)。该衰减率用于控制模型更新的速度。该衰减率用于控制模型更新的速度,ExponentialMovingAverage 对每一个(待更新训练学习的)变量(variable)都会维护一个影子变量(shadow variable)。影子变量的初始值就是这个变量的初始值,
shadow_variable=decay×shadow_variable+(1−decay)×variable
由上述公式可知, decay 控制着模型更新的速度,越大越趋于稳定。实际运用中,decay 一般会设置为十分接近 1 的常数(0.99或0.999)。为了使得模型在训练的初始阶段更新得更快,ExponentialMovingAverage 还提供了 num_updates 参数来动态设置 decay 的大小:
decay=min{decay,1+num_updates10+num_updates}
3、实现
#!/usr/bin/env python3# -*- coding: utf-8 -*-"""Created on Sun Aug 20 11:52:00 2017@author: xiaolian"""import tensorflow as tf# 用于计算滑动平均v1 = tf.Variable(0, dtype = tf.float32)# step 模拟神经网络中迭代的轮数, 可以用于动态控制衰减率step = tf.Variable(0, trainable = False)# 定义一个滑动平均的类,初始化时给定了衰减率(0.99)和 控制衰减率的变量 stepema = tf.train.ExponentialMovingAverage(0.99, step)# 定义一个更新变量滑动平均的操作,这里需要给定一个列表,每次执行这个操作时,这个列表中的变量都会被更新maintain_average_op = ema.apply([v1])with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 通过ema.average(v1) 获得滑动平均之后变量的取值,在初始化之后变量 v1 的值 和 v1 的滑动平均都为 0 print(sess.run(v1)) # 更新 v1 到 5 sess.run(tf.assign(v1, 5)) # 更新 v1 的滑动平均值,衰减率为 min(0.99, (1 + step) / (10 + step) = 0.1) = 0.1 # 所以 v1 的滑动平均会被更新为 0.1 x 0 + 0.9 x 5 = 4.5 sess.run(maintain_average_op) print(sess.run(ema.average(v1))) print(sess.run(v1)) sess.run(tf.assign(step, 10000)) sess.run(tf.assign(v1, 10)) sess.run(maintain_average_op) print(sess.run(ema.average(v1))) print(sess.run(v1)) sess.run(maintain_average_op) print(sess.run(ema.average(v1))) print(sess.run(v1))
output:
0.04.55.04.55510.04.6094510.0
阅读全文
1 0
- Tensorflow 实战 google 深度学习框架 笔记(一)滑动模型
- tensorflow07 《TensorFlow实战Google深度学习框架》笔记-04-05滑动平均模型
- TensorFlow实战Google深度学习框架(一)
- Tensorflow实战Google深度学习框架 笔记
- 我是初学者——TensorFlow实战Google深度学习框架(学习笔记一)
- tensorflow10 《TensorFlow实战Google深度学习框架》笔记-05-03模型持久化code
- Tensorflow实战Google深度学习框架-学习笔记
- tensorflow02 《TensorFlow实战Google深度学习框架》笔记-03
- tensorflow03 《TensorFlow实战Google深度学习框架》笔记-04-01
- 关于《TensorFlow 实战Google深度学习框架》
- TensorFlow实战Google深度学习框架
- tensorflow17《TensorFlow实战Google深度学习框架》笔记-08-02 使用循环神经网络实现语言模型 code
- tensorflow26《TensorFlow实战Google深度学习框架》笔记-10-03 分布式TensorFlow code
- tensorflow05 《TensorFlow实战Google深度学习框架》笔记-04-03学习率设置
- tensorflow14《TensorFlow实战Google深度学习框架》笔记-06-03 迁移学习 code
- tensorflow04 《TensorFlow实战Google深度学习框架》笔记-04-02自定义损失函数
- tensorflow06 《TensorFlow实战Google深度学习框架》笔记-04-04正则化
- tensorflow08 《TensorFlow实战Google深度学习框架》笔记-05-01minist数字识别问题code
- POJ 有关动态规划的题目
- 测试实习随笔 (三)
- Hibernate-一对多和多对一
- java关键字
- sanlyShi的前端之路五:add/remove EventListener()
- Tensorflow 实战 google 深度学习框架 笔记(一)滑动模型
- 同步代码块2
- Android菜鸟修行
- Spring面试题
- 洛谷1279 字串距离
- android QQ第三方登陆 回掉信息 头像+昵称(核心代码)
- 轮播
- javabean简介
- SQLServer2008/2012 删除所有表视图存储过程