tensorflow中正则化防止过拟合以及Batch Normalization
来源:互联网 发布:iyst是什么预算软件 编辑:程序博客网 时间:2024/05/01 06:45
一、正则化
正则化原理这里就不介绍了,网上资源有很多,详情可以点击这里(https://zhuanlan.zhihu.com/p/29297934)
但是网上大多数关于正则化的教程太乱,而且有很多代码 不能试用。这里说一个可以用的:
首先在权值初始话的时候,对权值进行正则化计算,利用函数:
tf.contrib.layers.l1_regularizer(lambda1)(w)
或 tf.contrib.layers.l2_regularizer(lambda1)(w)
分别是 l1 正则化和 l2 正则化,其实际效果如:
weights = tf.constant([[1., -2.], [-3., 4.]])with tf.Session() as sess: print(sess.run(tf.contrib.layers.l1_regularizer(.5)(weights))) # (1+2+3+4)*.5 ⇒ 5 print(sess.run(tf.contrib.layers.l2_regularizer(.5)(weights))) # (1+4+9+16)*.5*.5 ⇒ 7.5
对于正则化的权重,我们利用tf.add_to_collection()存到一个自定义的集合中,如:
w=tf.get_variable(name,shape,initializer=tf.truncated_normal_initializer(stddev=stddev))tf.add_to_collection("regular_loss",tf.contrib.layers.l2_regularizer(lambda)(w))
其中lambda是惩罚程度,具体只可以自己设定,w的是被正则化的一组权重。regular_loss是正则化权值被存储的位置。当很多组权重被正则化存储到regular_loss中后,我们利用tf.get_collection()函数将这些值提取出来,根据正则化公式,利用tf.add_n()对其求和,即:
regular_loss = tf.add_n(tf.get_collection("regular_loss"))
最后将regular_loss与真正的loss想加,送进优化啊器即可。:
伪代码如下:
x = 1w = tf.get_variable('w',[1],initializer=tf.truncated_normal_initializer(stddev=0.2))tf.add_to_collection("regular_loss",tf.contrib.layers.l2_regularizer(0.001)(w))y = w*xrel_loss = 1/2*(y-label)^2regular_loss = tf.add_n(tf.get_collection("regular_loss"))loss = regular_loss + rel_lossopt = tf.GradientDescentOptimizer(learning_rate).minimize(Loss)
二、Batch Normalization注意事项
tensorflow中使用Batch Normalization层教程很多,参考可见(http://www.jianshu.com/p/0312e04e4e83),
BN层的代码如下:
conv_bn = tf.contrib.layers.batch_norm(conv, momentum, scale=True, epsilon=1e-5,is_training = self.training, scope=names)
scale是指系数λ。
实际在搭建好包含BN层的网络之后,进行训练时需要注意使用如下模式代码:
rmsprop = tf.train.RMSPropOptimizer(learning_rate= self.lr)with tf.control_dependencies(self.update_ops): self.train_rmsprop = rmsprop.minimize(loss)
其中 with tf.control_dependencies(self.update_ops): 是保证在训练的时候,先将滑动计算的均值和方差更新之后,再进行梯度计算并优化,防止使用了初始化的权值和方差。
- tensorflow中正则化防止过拟合以及Batch Normalization
- tensorflow regularizer(正则化)防止过拟合
- 正则化防止过拟合
- 机器学习中防止过拟合的正则化
- 【TensorFlow】正则化(过拟合问题)
- 防止过拟合以及解决过拟合
- 防止过拟合以及解决过拟合
- 防止过拟合以及解决过拟合
- 防止过拟合以及解决过拟合
- 防止过拟合以及解决过拟合
- 加L2正则化防止过拟合前后准确率变化,以及权重初始化
- tensorflow batch normalization
- Batch Normalization Tensorflow代码
- TensorFlow实现Batch Normalization
- TensorFlow batch normalization
- Tensorflow的Batch Normalization
- tensorflow实现batch normalization
- [Tensorflow] Batch Normalization实现
- shape函数
- 2006年4月全国计算机等级考试三级数据库技术笔试试卷
- [C/C++]字符串工具类(去除左右空格、左右换行符)
- 机器学习实战笔记(3.3)-朴素贝叶斯算法(多项式模型的朴素贝叶斯实现)
- Angularjs全选/反选/表单验证
- tensorflow中正则化防止过拟合以及Batch Normalization
- python-@property 属性
- 用AD10画PCB图的定位方法
- Kafka入门经典教程
- 初学java:题目:求s=a+aa+aaa+aaaa+aa...a的值,其中a是一个数字,几个数相加有键盘控制。
- babel使用
- [BZOJ 1028][JSOI 2007] 麻将 模拟+贪心思想
- 不同的加密算法认识和对比
- 1003对齐输出