tensorflow regularizer(正则化)防止过拟合

来源:互联网 发布:java计算天干地支 编辑:程序博客网 时间:2024/05/01 17:25

Regularizer是防止网络过拟合的一种有效方法。这篇文章主要探讨如何在自己的网络模型中加入正则化,防止过拟合。

首先我们看一下正则化的基本使用方法,这篇博客给出了一个使用的例子:

http://www.cnblogs.com/linyuanzhou/p/6923607.html

#!/usr/bin/env python#-*- coding:utf-8 -*-#############################File Name: tf_regularization.py#Author: Wang #Mail: wang****@hotmail.com#Created Time:2017-08-23 11:53:34############################import tensorflow as tf from tensorflow.contrib import layersmyreg1 = layers.l1_regularizer(0.01)     #创建一个正则化方法, 0.01为系数,相当于给每个参数前乘以0.01,当然这里也可以是l2方法或者sum混合方法with tf.variable_scope('var', initializer = tf.random_normal_initializer(), regularizer = myreg1):    #高能!:参数里面指明了regularizer    weight = tf.get_variable('weight', shape=[8], initializer = tf.ones_initializer())with tf.variable_scope('var2', initializer = tf.random_normal_initializer(), regularizer = myreg1):    weight2 = tf.get_variable('weight', shape=[8], initializer = tf.ones_initializer())regularization_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))        #get_collection 获得list, reduce_sum进行对list求和sess = tf.Session()init = tf.global_variables_initializer()sess.run(init)with sess.as_default():    result = regularization_loss.eval() print result

最后的输出结果为0.16。

那么当我们需要在自己的网络中加入正则化时该怎么做? 继续上代码。

首先创建一个net.py文件,这个是我们自己的网络模型:

#!/usr/bin/env python#-*- coding:utf-8 -*-#############################File Name: net.py#Author: Wang#Mail: wang**@hotmail.com#Created Time:2017-08-23 12:10:48############################import tensorflow as tfimport numpy as npfrom tensorflow.contrib import layersclass mynet:        def __init__(self):        self.myreg1 = layers.l1_regularizer(0.01)        self.inference()    def inference(self):        with tf.variable_scope('var', initializer = tf.random_normal_initializer(), regularizer = self.myreg1):    weight = tf.get_variable('weight', shape = [8], initializer = tf.ones_initializer())


然后是我们训练网络的主程序,这里面需要定义数据和loss,学习方法等:

#!/usr/bin/env python#-*- coding:utf-8 -*-#############################File Name: test.py#Author: Wang#Mail: wang***@hotmail.com#Created Time:2017-08-23 12:10:28############################import tensorflow as tffrom net import mynetsess = tf.Session()mnet = mynet()init = tf.global_variables_initializer()sess.run(init)regularization_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))with sess.as_default():    result = regularization_loss.eval()print result

最后的输出结果为0.08。






原创粉丝点击