Batch Normalization梯度反向传播推导
来源:互联网 发布:淘宝倒卖赚差价的生意 编辑:程序博客网 时间:2024/05/19 16:51
最近在看CS231N的课程,同时也顺带做配套的作业,在Assignment2 中关于Batch Normalization的具体数学过程则困惑了很久,通过参看一些博客自己推导了一遍,供大家参考。
Batch Normalization
首先,关于Batch Normalization的具体实现过程就不在此介绍了,想了解的可以参看论文或者博客。
对于Batch Normalization的前向传播可以参看下图的过程,它主要思路就是将每个Batch的输入根据均值
对于前向传播网络,可以很直观的给出实现代码
def batchnorm_forward(x, gamma, beta, bn_param): """ Input: - x: (N, D)维输入数据 - gamma: (D,)维尺度变化参数 - beta: (D,)维尺度变化参数 - bn_param: Dictionary with the following keys: - mode: 'train' 或者 'test' - eps: 一般取1e-8~1e-4 - momentum: 计算均值、方差的更新参数 - running_mean: (D,)动态变化array存储训练集的均值 - running_var:(D,)动态变化array存储训练集的方差 Returns a tuple of: - out: 输出y_i(N,D)维 - cache: 存储反向传播所需数据 """ mode = bn_param['mode'] eps = bn_param.get('eps', 1e-5) momentum = bn_param.get('momentum', 0.9) N, D = x.shape # 动态变量,存储训练集的均值方差 running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype)) running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype)) out, cache = None, None # TRAIN 对每个batch操作 if mode == 'train': sample_mean = np.mean(x, axis = 0) sample_var = np.var(x, axis = 0) x_hat = (x - sample_mean) / np.sqrt(sample_var + eps) out = gamma * x_hat + beta cache = (x, gamma, beta, x_hat, sample_mean, sample_var, eps) running_mean = momentum * running_mean + (1 - momentum) * sample_mean running_var = momentum * running_var + (1 - momentum) * sample_var # TEST:要用整个训练集的均值、方差 elif mode == 'test': x_hat = (x - running_mean) / np.sqrt(running_var + eps) out = gamma * x_hat + beta else: raise ValueError('Invalid forward batchnorm mode "%s"' % mode) bn_param['running_mean'] = running_mean bn_param['running_var'] = running_var return out, cache
上述代码基于CS231N Assignment2,值得注意的是Batch Normalization对于在训练和测试阶段的计算方法不一样,因为训练阶段的均值和方差是基于一个Batch的数据,而测试阶段是基于整个训练集求得。
梯度反向传播
Batch Normalization最让人头疼的就是理清楚反向传播梯度并写成代码,当然它依然遵循链式求导法则。首先我们基于上图,将变量定义如下:
σ 为一个batch所有样本的方差μ 为样本均值xˆ 为归一化后的样本数据yi 为输入样本xi 经过尺度变化的输出量γ 和β 为尺度变化系数∂L∂y 为已知,并假设x 和y 都为(N,D)维,即有N个维度为D的样本
由于网络正向传播是根据
上面两个式子都涉及到Batch中的N个样本的累加,因为N个样本的
直接求
下面,就可以求
在上面的式子中我写成
下面,我们就根据上面的步骤来完成代码:
def batchnorm_backward(dout, cache): """ Inputs: - dout: 上一层的梯度,维度(N, D),即 dL/dy - cache: 所需的中间变量,来自于前向传播 Returns a tuple of: - dx: (N, D)维的 dL/dx - dgamma: (D,)维的dL/dgamma - dbeta: (D,)维的dL/dbeta """ x, gamma, beta, x_hat, sample_mean, sample_var, eps = cache N = x.shape[0] dgamma = np.sum(dout * x_hat, axis = 0) dbeta = np.sum(dout, axis = 0) dx_hat = dout * gamma dsigma = -0.5 * np.sum(dx_hat * (x - sample_mean), axis=0) * np.power(sample_var + eps, -1.5) dmu = -np.sum(dx_hat / np.sqrt(sample_var + eps), axis=0) - 2 * dsigma*np.sum(x-sample_mean, axis=0)/ N dx = dx_hat /np.sqrt(sample_var + eps) + 2.0 * dsigma * (x - sample_mean) / N + dmu / N return dx, dgamma, dbeta
附:两个有用的博客 这里 和 这里
- Batch Normalization梯度反向传播推导
- Batch Normalization 反向传播(backpropagation )公式的推导
- batch normalization 正向传播与反向传播
- 总结2: Batch Normalization反向传播公式推导及其向量化
- Batch Normalization的前向和反向传播过程
- Batch Normalization 梯度归一化
- Batch Normalization 梯度归一化
- batch normalization 正向传播与反向传播 both naive and smart way
- 梯度下降法和误差反向传播推导
- 反向传播算法推导
- Batch Normalization反方向传播求导
- cs231n-assignment2-Batch Normalization原理推导
- 三、梯度下降与反向传播(含过程推导及证明)
- 激活函数、BP反向传播算法、三种梯度下降、softmax函数及其推导
- BP神经网络,BP推导过程,反向传播算法,误差反向传播,梯度下降,权值阈值更新推导,隐含层权重更新公式
- 反向传播算法的推导
- CNN反向传播公式推导
- CNN反向传播公式推导
- java反射中getDeclaredMethods和getMethods的区别
- Caffe中权值初始化方法
- [李景山php] 深入理解PHP内核[读书笔记]--第二章:用户代码执行--opcode处理函数查找
- 3.2 JS 变量提升&&函数参数
- 563. Binary Tree Tilt
- Batch Normalization梯度反向传播推导
- 2017-05-09 总结
- Android N 多窗口的设计
- 服务的基本用法-1
- ios 蓝牙开发总结
- hihoCoder 1518 : 最大集合
- 熔断器设计模式
- 在 Ubuntu 中手动安装任何版本的 Firefox
- Android Studio安装Genymotion插件