tensorflow 使用正则化

来源:互联网 发布:最短路径的算法 编辑:程序博客网 时间:2024/06/10 17:45

                                                         Tensorflow 使用正则化T

import tensorflow.contrib.layers as layers

def easier_network(x, reg):    """ A network based on tf.contrib.learn, with input `x`. """    with tf.variable_scope('EasyNet'):        out = layers.flatten(x)        out = layers.fully_connected(out,                 num_outputs=200,                weights_initializer = layers.xavier_initializer(uniform=True),                weights_regularizer = layers.l2_regularizer(scale=reg),                activation_fn = tf.nn.tanh)        out = layers.fully_connected(out,                 num_outputs=200,                weights_initializer = layers.xavier_initializer(uniform=True),                weights_regularizer = layers.l2_regularizer(scale=reg),                activation_fn = tf.nn.tanh)        out = layers.fully_connected(out,                 num_outputs=10, # Because there are ten digits!                weights_initializer = layers.xavier_initializer(uniform=True),                weights_regularizer = layers.l2_regularizer(scale=reg),                activation_fn = None)        return out 


def main(_):    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)    x = tf.placeholder(tf.float32, [None, 784])    y_ = tf.placeholder(tf.float32, [None, 10])    # Make a network with regularization    y_conv = easier_network(x, FLAGS.regu)    weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'EasyNet')     print("")    for w in weights:        shp = w.get_shape().as_list()        print("- {} shape:{} size:{}".format(w.name, shp, np.prod(shp)))    print("")    reg_ws = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, 'EasyNet')    for w in reg_ws:        shp = w.get_shape().as_list()        print("- {} shape:{} size:{}".format(w.name, shp, np.prod(shp)))    print("")    # Make the loss function `loss_fn` with regularization.    cross_entropy = tf.reduce_mean(        tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))    loss_fn = cross_entropy + tf.reduce_sum(reg_ws)    train_step = tf.train.AdamOptimizer(1e-4).minimize(loss_fn)

tf.GraphKeys.REGULARIZATION_LOSSES得到在图中正则化的损失

regularizer=tf.contrib.layers.apply_regularization(tf.contrib.layers.l2_regularizer(weight_decay),gen_vars+d_vars)这样也可以



阅读全文
0 0
原创粉丝点击