Batch Normalization的前向和反向传播过程

来源:互联网 发布:appserv怎么进入mysql 编辑:程序博客网 时间:2024/05/15 10:22
  1. 为什么要batch normalization?
  2. 前向传播
  3. 反向传播

1.batch normalization的原理

  在反向传播的过程中,是一个w不断叠乘的结果,因为在传播过程中w时一个不确定范围的数值。在反向传播的过程中,如果w多数大于1,会造成梯度爆炸,大多数大于0小于1,会梯度弥散。
  为了解决这个问题,就有了Batch Normalization的思想。假设:

wunknown=αw[0,1]

BN(wunknownx)=BN(w[0,1]x)
(BN(wunknownx))x=(BN(w[0,1]x))x

  可以很明显的看出,w的尺度变化不会带来经过了BN之后并不会对反向传播的结果带来影响,解决了梯度弥散、爆炸的问题。

2.前向传播

u=1mi=1mxi
var=1mi=1m(xiu)2
xi^=xiuvar+ϵ
yi=γxi^+β

代码:

sample_mean = np.mean(x,axis=0)sample_var = np.var(x,axis=0)x_normalized = (x - sample_mean) / np.sqrt(sample_var + eps)     out = gamma * x_normalized + beta    cache = (x_normalized, gamma, beta, sample_mean, sample_var, x, eps)  

注意求均值和方差的时候,一定要注意时对样本求均值和方差,不要直接用mean和var公式而不加axis参数。博主最近在做cs231n的作业,一开始由于粗心没仔细看,最后耽误了好多时间,为自己的蠢买单。

2.反向传播

  讲到这里,不得不承认博主真的是个很粗心大意的人,所以学习了一下ustc_lijia,这位博主在他的博客http://blog.csdn.net/xiaojiajia007/article/details/54924959里面说的很棒,下面附上他的图解。按照图解一步一步可以清晰的求出反向传播的过程。这里就不重复造轮子了,求不出来的可以直接去他的博客里面看。

这里写图片描述

代码:

def batchnorm_backward(dout, cache):    """    Backward pass for batch normalization.    For this implementation, you should write out a computation graph for    batch normalization on paper and propagate gradients backward through    intermediate nodes.    Inputs:    - dout: Upstream derivatives, of shape (N, D)    - cache: Variable of intermediates from batchnorm_forward.    Returns a tuple of:    - dx: Gradient with respect to inputs x, of shape (N, D)    - dgamma: Gradient with respect to scale parameter gamma, of shape (D,)    - dbeta: Gradient with respect to shift parameter beta, of shape (D,)    """    x_normalized, gamma, beta, sample_mean, sample_var, x, eps = cache    N, D = x.shape        dx, dgamma, dbeta = None, None, None    dbeta = np.sum(dout,axis=0).reshape(D,)    dgamma = np.sum(dout*(x_normalized),axis=0).reshape(D,)    dx_normalized = gamma*dout    x_denominator = 1/np.sqrt(sample_var+eps)    x_numerator = x - sample_mean    #print("dbeta:",dbeta)    #print("dgamma:",dgamma)    dx_norm_numerator = x_denominator*dx_normalized    dx_norm_denominator = np.sum(x_numerator*dx_normalized,axis=0)    #print("denominator:",dx_norm_denominator)    dx_var = -(1/2)*(1/np.sqrt(sample_var + eps))*(1/(sample_var + eps))*dx_norm_denominator    #print("var:",dx_var)    dx_minus_flow_2 = 2*(x-sample_mean)*(1/N)*np.ones((N,D))*dx_var    dx_minus_flow_1 = dx_norm_numerator    #print("flow_x2:",dx_minus_flow_2)    #print("flow_x1,numerator:",dx_minus_flow_1)    dx_flow_1 = dx_minus_flow_2+dx_minus_flow_1    dx_flow_2 = -(1/N)*np.ones((N,D))*np.sum(dx_flow_1,axis=0)    dx = dx_flow_1 + dx_flow_2    ###########################################################################    # TODO: Implement the backward pass for batch normalization. Store the    #    # results in the dx, dgamma, and dbeta variables.                         #    ###########################################################################    pass    ###########################################################################    #                             END OF YOUR CODE                            #    ###########################################################################    return dx, dgamma, dbeta

  需要注意的是,这里的cache里面是前向传播传进来的数据。我觉得上面那位博主的节点设置的太细致了,总让人想跳步,所以我就自己跳了几步哈哈,大家按照他教的一步步做也ok的。在反向传播的时候注意维度的变化。

http://blog.csdn.net/xiaojiajia007/article/details/54924959,ustc_lijia

阅读全文
0 0