Batch Normalization的前向和反向传播过程
来源:互联网 发布:appserv怎么进入mysql 编辑:程序博客网 时间:2024/05/15 10:22
- 为什么要batch normalization?
- 前向传播
- 反向传播
1.batch normalization的原理
在反向传播的过程中,是一个w不断叠乘的结果,因为在传播过程中w时一个不确定范围的数值。在反向传播的过程中,如果w多数大于1,会造成梯度爆炸,大多数大于0小于1,会梯度弥散。
为了解决这个问题,就有了Batch Normalization的思想。假设:
可以很明显的看出,w的尺度变化不会带来经过了BN之后并不会对反向传播的结果带来影响,解决了梯度弥散、爆炸的问题。
2.前向传播
代码:
sample_mean = np.mean(x,axis=0)sample_var = np.var(x,axis=0)x_normalized = (x - sample_mean) / np.sqrt(sample_var + eps) out = gamma * x_normalized + beta cache = (x_normalized, gamma, beta, sample_mean, sample_var, x, eps)
注意求均值和方差的时候,一定要注意时对样本求均值和方差,不要直接用mean和var公式而不加axis参数。博主最近在做cs231n的作业,一开始由于粗心没仔细看,最后耽误了好多时间,为自己的蠢买单。
2.反向传播
讲到这里,不得不承认博主真的是个很粗心大意的人,所以学习了一下ustc_lijia,这位博主在他的博客http://blog.csdn.net/xiaojiajia007/article/details/54924959里面说的很棒,下面附上他的图解。按照图解一步一步可以清晰的求出反向传播的过程。这里就不重复造轮子了,求不出来的可以直接去他的博客里面看。
代码:
def batchnorm_backward(dout, cache): """ Backward pass for batch normalization. For this implementation, you should write out a computation graph for batch normalization on paper and propagate gradients backward through intermediate nodes. Inputs: - dout: Upstream derivatives, of shape (N, D) - cache: Variable of intermediates from batchnorm_forward. Returns a tuple of: - dx: Gradient with respect to inputs x, of shape (N, D) - dgamma: Gradient with respect to scale parameter gamma, of shape (D,) - dbeta: Gradient with respect to shift parameter beta, of shape (D,) """ x_normalized, gamma, beta, sample_mean, sample_var, x, eps = cache N, D = x.shape dx, dgamma, dbeta = None, None, None dbeta = np.sum(dout,axis=0).reshape(D,) dgamma = np.sum(dout*(x_normalized),axis=0).reshape(D,) dx_normalized = gamma*dout x_denominator = 1/np.sqrt(sample_var+eps) x_numerator = x - sample_mean #print("dbeta:",dbeta) #print("dgamma:",dgamma) dx_norm_numerator = x_denominator*dx_normalized dx_norm_denominator = np.sum(x_numerator*dx_normalized,axis=0) #print("denominator:",dx_norm_denominator) dx_var = -(1/2)*(1/np.sqrt(sample_var + eps))*(1/(sample_var + eps))*dx_norm_denominator #print("var:",dx_var) dx_minus_flow_2 = 2*(x-sample_mean)*(1/N)*np.ones((N,D))*dx_var dx_minus_flow_1 = dx_norm_numerator #print("flow_x2:",dx_minus_flow_2) #print("flow_x1,numerator:",dx_minus_flow_1) dx_flow_1 = dx_minus_flow_2+dx_minus_flow_1 dx_flow_2 = -(1/N)*np.ones((N,D))*np.sum(dx_flow_1,axis=0) dx = dx_flow_1 + dx_flow_2 ########################################################################### # TODO: Implement the backward pass for batch normalization. Store the # # results in the dx, dgamma, and dbeta variables. # ########################################################################### pass ########################################################################### # END OF YOUR CODE # ########################################################################### return dx, dgamma, dbeta
需要注意的是,这里的cache里面是前向传播传进来的数据。我觉得上面那位博主的节点设置的太细致了,总让人想跳步,所以我就自己跳了几步哈哈,大家按照他教的一步步做也ok的。在反向传播的时候注意维度的变化。
http://blog.csdn.net/xiaojiajia007/article/details/54924959,ustc_lijia
阅读全文
0 0
- Batch Normalization的前向和反向传播过程
- batch normalization 正向传播与反向传播
- Batch Normalization 反向传播(backpropagation )公式的推导
- Batch Normalization梯度反向传播推导
- Batch Normalization反方向传播求导
- 总结2: Batch Normalization反向传播公式推导及其向量化
- caffe学习笔记3.2--前向传播和反向传播
- 神经网络中前向传播和反向传播解析
- 深度学习之caffe 前向传播和反向传播
- caffe学习笔记3.2--前向传播和反向传播
- Caffe源码解读:relu_layer前向传播和反向传播
- 前向传播和反向传播(举例说明)
- 神经网络的前向传播和误差反向传播(NN,RNN,LSTM)(一)
- 神经网络的前向传播和误差反向传播(NN,RNN,LSTM)(二)
- 神经网络的前向传播和误差反向传播(NN,RNN,LSTM)(三)
- dl4j的BaseLayer前向与反向传播算法计算过程简介
- 前向传播与反向传播代码
- 神经网络(前向传播和反向传导)
- 2.基本套接字函数
- 在 Linux 中使用 Azure Premium 存储的基本优化指南
- BZOJ3224: Tyvj 1728 普通平衡树(无旋Treap/替罪羊)
- [MFC] RTTI应用总结(一)
- uboot烧写内核和文件系统
- Batch Normalization的前向和反向传播过程
- Servlet从入门开始学习(一)
- vue-Resource(与后端数据交互)
- java基于索引对List进行分批处理
- python中的number数字
- Tablayout+Viewpager+recyclerview简单实现
- Swap Nodes in Pairs
- Quartz中时间参数说明 即Cron表达式
- PMCAFF产品众测 | 对话随手攒CEO聊聊这款产品的设计、推广和改进(活动已结束)