总结2: Batch Normalization反向传播公式推导及其向量化
来源:互联网 发布:linux下输入ftp命令 编辑:程序博客网 时间:2024/05/29 10:46
1. 背景介绍
上周学习了吴恩达的Deep Learning专项课程的第二门课,其中讲到了Batch Normalization(BN)。但在课程视频中,吴恩达并没有详细地给出Batch Normalization反向传播的公式推导,而是从high level的角度解释了为什么Batch Normalization会work,以及如何在TensorFlow等framework中使用Batch Norm。
我自己首先尝试了一下推导BN反向传播的公式,但是用代码实现后跑的结果都不甚理想,收敛速度比不使用BN还要慢甚至有时候无法收敛,应该是公式推错了。接着我在网上搜索到Google最开始提出BN的论文, 里面给出了反向传播的公式,但是不是以向量化的形式给出的。众所周知,我们实现深度网络应该依赖于向量计算。我看着公式以自己的理解写出了向量的形式,但是实现后结果仍旧不正常。
接着在网上搜索其他人介绍BN的博客文章,绝大多数文章都是前面讲一大堆BN的好处,消除Internal Convariate Shift,加快收敛,减少Dropout的使用,起到部分正则化等等,然后涉及到核心的公式部分时,话锋一转,说BN反向传播部分的推导很简单,就是利用了Chain Rule,接着就给出了与论文中一模一样的公式。看着让人很是头疼。
就这样停滞了大概三四天的时间,但是我实在不甘心仅仅会使用TensorFlow中提供的BN模块,而搞不懂BN的详细推导。终于,我下定决心抽出一整天的时间拿出纸笔一步步的演算,最终心静下来花了大概一个小时算出来,然后代码实现之后跑起来结果就正常了。
2. BN反向传播的详细推导
2.1 单个activation进行batch norm的情况
假设神经网络的第L层是BatchNorm层,其输入数据为
对
其中,
然后对
其中,
以上是Batch Norm层正向传播的严谨的公式,很多文章里都习惯于使用具有broadcasting功能的公式,如用一个矩阵减去一个向量等操作,虽然python里的numpy支持这种运算,但是公式如果也用这种方式写则很不严谨,也对我们的求导造成很大的困扰。
接下来是反向传播部分的推导。因为
现在我们有如下数据:
通过chain rule,我们有:
其中,
因为
下面求
因为
所以:
所以:
因为
So:
以下是我用Scala实现的Batch Norm反向传播方法的向量化版本,值得注意的是,由于向量对向量求导的结果是一个矩阵,多个向量对向量求导的结果是一个三维张量,而Scala中的Breeze数值计算库现并不支持张量运算,所以在我用Scala实现的Batch Normalization的反向传播版本中,不可避免地使用的for循环对
private def backWithBatchNorm(dYCurrent: DenseMatrix[Double], yPrevious: DenseMatrix[Double]): (DenseMatrix[Double], DenseMatrix[Double]) = { val numExamples = dYCurrent.rows val oneVector = DenseVector.ones[Double](numExamples) val dZDelta = dYCurrent *:* this.activationFuncEval(zDelta) val dZNorm = dZDelta *:* (oneVector * beta.t) val dAlpha = dZDelta.t * oneVector / numExamples.toDouble val dBeta = (dZDelta *:* zNorm).t * oneVector / numExamples.toDouble val dZ = DenseMatrix.zeros[Double](z.rows, z.cols) for (j <- 0 until z.cols) { val dZNormJ = dZNorm(::, j) val dZJ = (DenseMatrix.eye[Double](dZNormJ.length) / currentStddevZ(j) - DenseMatrix.ones[Double](dZNormJ.length, dZNormJ.length) / (numExamples.toDouble * currentStddevZ(j)) - (z(::, j) - currentMeanZ(j)) * (z(::, j) - currentMeanZ(j)).t / (numExamples.toDouble * pow(currentStddevZ(j), 3.0))) * dZNormJ dZ(::, j) := dZJ } val dWCurrent = yPrevious.t * dZ / numExamples.toDouble val dYPrevious = dZ * w.t val grads = DenseMatrix.vertcat(dWCurrent, dAlpha.toDenseMatrix, dBeta.toDenseMatrix) (dYPrevious, grads) }
- 总结2: Batch Normalization反向传播公式推导及其向量化
- Batch Normalization 反向传播(backpropagation )公式的推导
- Batch Normalization梯度反向传播推导
- Batch Normalization的前向和反向传播过程
- batch normalization 正向传播与反向传播
- CNN反向传播公式推导
- CNN反向传播公式推导
- Batch Normalization反方向传播求导
- 神经网络反向传播公式的推导
- 经典反向传播算法公式详细推导
- 反向传播算法公式的推导
- 手写,纯享版反向传播算法公式推导
- 神经网络反向传播算法公式推导详解
- 神经网络前向后向传播公式推导
- batch normalization 正向传播与反向传播 both naive and smart way
- 反向传播算法(过程及公式推导)
- 卷积神经网络(CNN)反向传播算法公式详细推导
- 反向传播算法(过程及公式推导)
- 冷静一下,openwrt之总结
- Unity技巧总结02 GUI绘制 Loading遮罩
- Jvisualvm监控远程linux下Tomcat
- 3.比较yield和return
- JAVA 单元测试框架
- 总结2: Batch Normalization反向传播公式推导及其向量化
- (转转)2018校园招聘开发类试题0917
- Selenium遇到的问题4 火狐浏览器用脚本打开,firebug不见了的问题
- 乱码
- 第二周第一节课
- 写几个windows文本处理方面的脚本
- HDU
- 二、极大似然估计
- (三)整合spring cloud云服务架构