Tensorflow BN

来源:互联网 发布:golang教程 csdn 编辑:程序博客网 时间:2024/05/21 21:46

莫老师传送门
我把plot删除了,计算cost,发现增加bn后,cost少了很多。
主要复制batch的代码,看的不大懂,比如每个中间的ema 都被存储了吗?
每经过一个层就知道应该用哪个ema对应的权值更新。可能每个ema的名字不一样,他是通过名字识别的。
这里写图片描述

train和test时还不大一样,train要用计算得到的var更新原始var,但是test时只用取得它的原始保存的var
这里写图片描述

if norm:            # Batch Normalize            fc_mean, fc_var = tf.nn.moments(                Wx_plus_b,                axes=[0],   # the dimension you wanna normalize, here [0] for batch                            # for image, you wanna do [0, 1, 2] for [batch, height, width] but not channel            )            scale = tf.Variable(tf.ones([out_size]))            shift = tf.Variable(tf.zeros([out_size]))            epsilon = 0.001            # apply moving average for mean and var when train on batch            ema = tf.train.ExponentialMovingAverage(decay=0.5)            def mean_var_with_update():                ema_apply_op = ema.apply([fc_mean, fc_var])                with tf.control_dependencies([ema_apply_op]):                    return tf.identity(fc_mean), tf.identity(fc_var)            mean, var = mean_var_with_update()            Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b, mean, var, shift, scale, epsilon)            # similar with this two steps:            # Wx_plus_b = (Wx_plus_b - fc_mean) / tf.sqrt(fc_var + 0.001)            # Wx_plus_b = Wx_plus_b * scale + shift

具体的使用请参加tensorflow官网

完整代码

"""visit https://morvanzhou.github.io/tutorials/ for more!Build two networks.1. Without batch normalization2. With batch normalizationRun tests on these two networks."""# 23 Batch Normalizationimport numpy as npimport tensorflow as tfACTIVATION = tf.nn.tanhN_LAYERS = 7N_HIDDEN_UNITS = 30def fix_seed(seed=1):    # reproducible    np.random.seed(seed)    tf.set_random_seed(seed)def built_net(xs, ys, norm):    def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):        # weights and biases (bad initialization for this case)        Weights = tf.Variable(tf.random_normal([in_size, out_size], mean=0., stddev=1.))        biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)        # fully connected product        Wx_plus_b = tf.matmul(inputs, Weights) + biases        # normalize fully connected product        if norm:            # Batch Normalize            fc_mean, fc_var = tf.nn.moments(                Wx_plus_b,                axes=[0],   # the dimension you wanna normalize, here [0] for batch                            # for image, you wanna do [0, 1, 2] for [batch, height, width] but not channel            )            scale = tf.Variable(tf.ones([out_size]))            shift = tf.Variable(tf.zeros([out_size]))            epsilon = 0.001            # apply moving average for mean and var when train on batch            ema = tf.train.ExponentialMovingAverage(decay=0.5)            def mean_var_with_update():                ema_apply_op = ema.apply([fc_mean, fc_var])                with tf.control_dependencies([ema_apply_op]):                    return tf.identity(fc_mean), tf.identity(fc_var)            mean, var = mean_var_with_update()            Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b, mean, var, shift, scale, epsilon)            # similar with this two steps:            # Wx_plus_b = (Wx_plus_b - fc_mean) / tf.sqrt(fc_var + 0.001)            # Wx_plus_b = Wx_plus_b * scale + shift        # activation        if activation_function is None:            outputs = Wx_plus_b        else:            outputs = activation_function(Wx_plus_b)        return outputs    fix_seed(1)    if norm:        # BN for the first input        fc_mean, fc_var = tf.nn.moments(            xs,            axes=[0],        )        scale = tf.Variable(tf.ones([1]))        shift = tf.Variable(tf.zeros([1]))        epsilon = 0.001        # apply moving average for mean and var when train on batch        ema = tf.train.ExponentialMovingAverage(decay=0.5)        def mean_var_with_update():            ema_apply_op = ema.apply([fc_mean, fc_var])            with tf.control_dependencies([ema_apply_op]):                return tf.identity(fc_mean), tf.identity(fc_var)        mean, var = mean_var_with_update()        xs = tf.nn.batch_normalization(xs, mean, var, shift, scale, epsilon)    # record inputs for every layer    layers_inputs = [xs]    # build hidden layers    for l_n in range(N_LAYERS):        layer_input = layers_inputs[l_n]        in_size = layers_inputs[l_n].get_shape()[1].value        output = add_layer(            layer_input,    # input            in_size,        # input size            N_HIDDEN_UNITS, # output size            ACTIVATION,     # activation function            norm,           # normalize before activation        )        layers_inputs.append(output)    # add output for next run    # build output layer    prediction = add_layer(layers_inputs[-1], 30, 1, activation_function=None)    cost = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction), reduction_indices=[1]))    train_op = tf.train.GradientDescentOptimizer(0.001).minimize(cost)    return [train_op, cost, layers_inputs]# make up datafix_seed(1)x_data = np.linspace(-7, 10, 2500)[:, np.newaxis] #[2500,1]np.random.shuffle(x_data) # change positionnoise = np.random.normal(0, 8, x_data.shape) #add noise to y_datay_data = np.square(x_data) - 5 + noisexs = tf.placeholder(tf.float32, [None, 1])  # [num_samples, num_features]ys = tf.placeholder(tf.float32, [None, 1])train_op, cost, layers_inputs = built_net(xs, ys, norm=False)   # without BNtrain_op_norm, cost_norm, layers_inputs_norm = built_net(xs, ys, norm=True) # with BNsess = tf.Session()if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:    init = tf.initialize_all_variables()else:    init = tf.global_variables_initializer()sess.run(init)# record costrecord_step = 30for i in range(250):    # train on batch    sess.run([train_op, train_op_norm], feed_dict={xs: x_data[i*10:i*10+10], ys: y_data[i*10:i*10+10]})    if i % record_step == 0:        # record cost        print(sess.run(cost, feed_dict={xs: x_data, ys: y_data}))        print(sess.run(cost_norm, feed_dict={xs: x_data, ys: y_data}))
0 0
原创粉丝点击