神经网络中的反向传播的推导和python实现
来源:互联网 发布:淘宝卖家如何退出村淘 编辑:程序博客网 时间:2024/05/29 18:43
事先声明一下,这篇博客的适用人群是对于卷积神经网络的基本结构和每个模块都基本了解的同学。当然,如果各位大神看到我这篇博客有什么不对的地方请大家积极指出哈,我一定好好改正,毕竟学习是一个不断改进的过程。
之前学习反向传播的时候,对于矩阵的求导有些疑问,最近看到cs231n上的assignment1和assignment2上都有对应的习题,我觉得是一个不错的机会来彻底搞清楚。
这里先开一个小差,我自己特别不喜欢推公式,因为推了几篇公式之后,编程实现就是这么几行,感觉超级不爽,终于明白之前最初涉及推公式的时候看到网上有人说这个现象的时候的不爽了,感觉有一种怀才不遇的感觉。。。废话少说,下面开始进入主题。
下面是这篇博客的内容:
- affine backward
- relu backward
- svm loss backward
- softmax loss backward
affine backward
先从最简单的说起,就是最简单的affine layer(全连接层)的反向传播,在这一层中的前向传播如下公式,
整个过程可由如下图所示,这里画出示意图是为了后面求导时的方便起见,
当这层的后一层反向传播到这一层的值为
先求
而后
同理
注意上述的所有的变量都是向量,之后还要考虑矩阵的求导问题,现在我们开始着重讲一下这一点,假设输入
如何得到
可以看到
而理所当然的,
而对于
以上就是affine backward的求解,可能有些跳跃,但是总体思路还是清楚的。
下面是python 的代码:
def affine_backward(dout, cache): """ Computes the backward pass for an affine layer. Inputs: - dout: Upstream derivative, of shape (N, M) - cache: Tuple of: - x: Input data, of shape (N, d_1, ... d_k) - w: Weights, of shape (D, M) Returns a tuple of: - dx: Gradient with respect to x, of shape (N, d1, ..., d_k) - dw: Gradient with respect to w, of shape (D, M) - db: Gradient with respect to b, of shape (M,) """ x, w, b = cache dx, dw, db = None, None, None dx = np.dot(dout, w.T) # (N,D) dx = np.reshape(dx, x.shape) # (N,d1,...,d_k) x_row = x.reshape(x.shape[0], -1) # (N,D) dw = np.dot(x_row.T, dout) # (D,M) db = np.sum(dout, axis=0, keepdims=True) # (1,M) return dx, dw, db
relu_backward
relu的公式如下,
当反向传播的量为
编程实现的时候只需注意一点,
def relu_backward(dout, cache): """ Computes the backward pass for a layer of rectified linear units (ReLUs). Input: - dout: Upstream derivatives, of any shape - cache: Input x, of same shape as dout Returns: - dx: Gradient with respect to x """ dx, x = None, cache dx = dout dx[x <= 0] = 0 return dx
SVM_loss backward
在说softmax_loss之前先说一个比较轻松的话题,就是svm_loss, 虽然在神经网络的实际运用中运用的着实不多,但是还是有必要介绍一下,来扩充一下大家的知识面,这里值得注意的是下文的推导和后面的softmax_loss的推导都没有添加正则项,但是正则项的求导很好求,这里不再赘述,希望大家谅解。
原公式是这样说的,一个训练样本
而所有样本的svm_loss是如下公式:
如今需要求解这个损失函数的
而对于
根据上文的描述,计算svm_loss的python代码如下:
def svm_loss(x, y): """ Computes the loss and gradient using for multiclass SVM classification. Inputs: - x: Input data, of shape (N, C) where x[i, j] is the score for the jth class for the ith input. - y: Vector of labels, of shape (N,) where y[i] is the label for x[i] and 0 <= y[i] < C Returns a tuple of: - loss: Scalar giving the loss - dx: Gradient of the loss with respect to x """ N = x.shape[0] correct_class_scores = x[np.arange(N), y] margins = np.maximum(0, x - correct_class_scores[:, np.newaxis] + 1.0) margins[np.arange(N), y] = 0 loss = np.sum(margins) / N num_pos = np.sum(margins > 0, axis=1) dx = np.zeros_like(x) dx[margins > 0] = 1 dx[np.arange(N), y] -= num_pos dx /= N return loss, dx
softmax_loss backward
上文的svm_loss 已经把样本的准备情况交代清楚了,而后这一节来介绍softmax_loss, 具体的推导细节就不再赘述了,(如果大家想了解的再清楚一点,请参照cs231n lecture3的关于softmax_loss的解释,也可以参照cs229的softmax详解),
和上文类似,经过一系列推导得到的得分向量公式如下:
而
而所有样本的svm_loss是如下公式:
需要求解
而对于
希望大家能够仔细推导, 下面是python的代码表示:
def softmax_loss(x, y): """ Computes the loss and gradient for softmax classification. Inputs: - x: Input data, of shape (N, C) where x[i, j] is the score for the jth class for the ith input. - y: Vector of labels, of shape (N,) where y[i] is the label for x[i] and 0 <= y[i] < C Returns a tuple of: - loss: Scalar giving the loss - dx: Gradient of the loss with respect to x """ probs = np.exp(x - np.max(x, axis=1, keepdims=True)) probs /= np.sum(probs, axis=1, keepdims=True) N = x.shape[0] loss = -np.sum(np.log(probs[np.arange(N), y])) / N dx = probs.copy() dx[np.arange(N), y] -= 1 dx /= N return loss, dx
总结
本片博客介绍了四种在神经网络中比较常用的函数的反向传播的推导和python的代码实现,能够清楚的理清矩阵的求导运算,希望能够对大家在深度学习学习的路上有所帮助,如有不对的地方,希望大家能够多多指出。
- 神经网络中的反向传播的推导和python实现
- 神经网络中的反向传播算法推导
- 神经网络和反向传播算法推导
- 神经网络和反向传播算法推导
- 神经网络和反向传播算法推导
- 神经网络和反向传播算法推导
- 神经网络基础和反向传播推导
- 神经网络反向传播公式的推导
- 神经网络反向传播算法的推导
- 深度学习:神经网络中的前向传播和反向传播算法推导
- 卷积神经网络反向传播推导
- 神经网络中的反向传播法算法推导及matlab代码实现
- 神经网络中的矩阵求导及反向传播推导
- 【机器学习】反向传播神经网络推导
- 1 神经网络反向传播算法推导流程
- 卷积神经网络反向传播理论推导
- 卷积神经网络(CNN)反向传播理论推导
- 神经网络反向传播算法公式推导详解
- [LintCode]1.A + B 问题 位运算
- 拦截器
- 吐血推荐:深入理解Mysql 锁!玩MYSQL必备!
- 网络编程之服务器与客户端的建立
- springBoot(二)springboot配置读取、配置原理及其视图
- 神经网络中的反向传播的推导和python实现
- C 共用体
- JavaWeb简介
- spring,mybatis事务管理配置与@Transactional注解使用[转]
- hdu6154 CaoHaha's staff CCPC网赛1005 找规律+构造
- Python管理端口的操作
- centos 6.5下 安装R语言R-3.3.2失败?
- MOOC清华《面向对象程序设计》第3章编程题第2题:重载下标运算符以统计分段人数
- 线程基础