tensorflow regularizer(正则化)防止过拟合
来源:互联网 发布:java计算天干地支 编辑:程序博客网 时间:2024/05/01 17:25
Regularizer是防止网络过拟合的一种有效方法。这篇文章主要探讨如何在自己的网络模型中加入正则化,防止过拟合。
首先我们看一下正则化的基本使用方法,这篇博客给出了一个使用的例子:
http://www.cnblogs.com/linyuanzhou/p/6923607.html
#!/usr/bin/env python#-*- coding:utf-8 -*-#############################File Name: tf_regularization.py#Author: Wang #Mail: wang****@hotmail.com#Created Time:2017-08-23 11:53:34############################import tensorflow as tf from tensorflow.contrib import layersmyreg1 = layers.l1_regularizer(0.01) #创建一个正则化方法, 0.01为系数,相当于给每个参数前乘以0.01,当然这里也可以是l2方法或者sum混合方法with tf.variable_scope('var', initializer = tf.random_normal_initializer(), regularizer = myreg1): #高能!:参数里面指明了regularizer weight = tf.get_variable('weight', shape=[8], initializer = tf.ones_initializer())with tf.variable_scope('var2', initializer = tf.random_normal_initializer(), regularizer = myreg1): weight2 = tf.get_variable('weight', shape=[8], initializer = tf.ones_initializer())regularization_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) #get_collection 获得list, reduce_sum进行对list求和sess = tf.Session()init = tf.global_variables_initializer()sess.run(init)with sess.as_default(): result = regularization_loss.eval() print result
最后的输出结果为0.16。
那么当我们需要在自己的网络中加入正则化时该怎么做? 继续上代码。
首先创建一个net.py文件,这个是我们自己的网络模型:
#!/usr/bin/env python#-*- coding:utf-8 -*-#############################File Name: net.py#Author: Wang#Mail: wang**@hotmail.com#Created Time:2017-08-23 12:10:48############################import tensorflow as tfimport numpy as npfrom tensorflow.contrib import layersclass mynet: def __init__(self): self.myreg1 = layers.l1_regularizer(0.01) self.inference() def inference(self): with tf.variable_scope('var', initializer = tf.random_normal_initializer(), regularizer = self.myreg1): weight = tf.get_variable('weight', shape = [8], initializer = tf.ones_initializer())
然后是我们训练网络的主程序,这里面需要定义数据和loss,学习方法等:
#!/usr/bin/env python#-*- coding:utf-8 -*-#############################File Name: test.py#Author: Wang#Mail: wang***@hotmail.com#Created Time:2017-08-23 12:10:28############################import tensorflow as tffrom net import mynetsess = tf.Session()mnet = mynet()init = tf.global_variables_initializer()sess.run(init)regularization_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))with sess.as_default(): result = regularization_loss.eval()print result
最后的输出结果为0.08。
阅读全文
0 0
- tensorflow regularizer(正则化)防止过拟合
- tensorflow中正则化防止过拟合以及Batch Normalization
- 正则化防止过拟合
- 【TensorFlow】正则化(过拟合问题)
- TensorFlow中的Dropout防止过拟合overfiting
- 正则化方法:防止过拟合,提高泛化能力
- 正则化方法:防止过拟合,提高泛化能力
- 正则化方法:防止过拟合,提高泛化能力
- 机器学习--正则化(regularization)防止分类器过拟合
- 机器学习中防止过拟合的正则化
- 正则化方法:防止过拟合,提高泛化能力
- TensorFlow学习---tf.nn.dropout防止过拟合
- 【转载】TensorFlow学习---tf.nn.dropout防止过拟合
- 过拟合与正则化
- 过拟合与正则化
- 过拟合及正则化
- 正则化与过拟合
- 防止过拟合
- Mybatis和spring整合
- iOS国际化
- Tomcat8.5源码分析-StandardContext
- 数据库水平切分
- javaScript高级程序设计 第1章 javaScript简介 思维导图笔记
- tensorflow regularizer(正则化)防止过拟合
- window下MongoDB的配置与安装
- 支付的典型架构
- 万物初始
- HDU
- Matlab转c与c++代码
- Vmware12虚拟机及Oralce安装
- PHP实现上传图片素材获取mediaID
- springCloud的学习篇章