RNN 的训练

来源:互联网 发布:微商吸粉软件 编辑:程序博客网 时间:2024/05/30 13:41

RNN 的训练

针对那些使用(Tensorflow, Torch, Caffe..) 开源库来组建 RNN 已经非常娴熟的朋友们,大家都是搞基术的,没有手刻过源代码,剖析过训练算术公式,都不好意思和同僚们打招呼不是?

那这篇Blog 就是专为那些已经会 forward-pass , backward-pass, 和 SGD 但不太理解 BPTT 的朋友准备的。

注意: 在本文中所有 RNN 均指代 Vanilla Recurrent Nueron Network. 虽然 LSTM 已经是公认的 RNN 默认 function, 但是为了方便初学者的理解, 我们还是先来看看最最基础的 RNN.

(LSTM 是瑞士科学家 Jürgen Schmidhuber 于 1997 发布的 Long Short-Term Memory RNN 核心函数的简称)

(一大波公式逼近中….)

让我们快速回顾下 RNN 的基本结构及前向 inference 公式。这应该非常简单,让我们先来看一下它的结构:
这里写图片描述

我们可以暂时扮演一下万能的神明,把时间拧成一根轴,让所有时间段的 RNN 状态在纸上物理展开,是不是非常棒 :)
这里写图片描述
配方如下
这里写图片描述
还有 loss.
这里写图片描述
为了降低 loss 至 0, 我们依旧用万能的牛顿导数来做梯度下降。与传统的梯度下降略有不同,我们的误差来自于“过去的时间”,名字 Backward Propagation Through Time 不就来了么。得益于我们之前“展开了”时间,所以问题从对时间上的反向梯度传播,转变成了在物理结构上的反向传播, 是不是非常棒?因为这实在太简单了。

我们来计算一下在经过 3个时间段后,各个权重的更新。
这里写图片描述
(vanilla RNN 一层最少拥有3个 Tensor block, Wi 连接 input, Wo 连接 output, Wh 连接前后 hidden state)

我们还是应用链式法则进行推导。 RNN 与传统的 DNN 的更新方式有所不同在于 RNN 将网络以时间展开后,偏导的依赖需要累加,并没有什么 magic。

下面是 Python 代码,我从其他地方超来的…

def bptt(self, x, y):    T = len(y)    # Perform forward propagation    o, s = self.forward_propagation(x)    # We accumulate the gradients in these variables    dLdU = np.zeros(self.U.shape)    dLdV = np.zeros(self.V.shape)    dLdW = np.zeros(self.W.shape)    delta_o = o    delta_o[np.arange(len(y)), y] -= 1.    # For each output backwards...    for t in np.arange(T)[::-1]:        dLdV += np.outer(delta_o[t], s[t].T)        # Initial delta calculation: dL/dz        delta_t = self.V.T.dot(delta_o[t]) * (1 - (s[t] ** 2))        # Backpropagation through time (for at most self.bptt_truncate steps)        for bptt_step in np.arange(max(0, t-self.bptt_truncate), t+1)[::-1]:            # print "Backpropagation step t=%d bptt step=%d " % (t, bptt_step)            # Add to gradients at each previous step            dLdW += np.outer(delta_t, s[bptt_step-1])                          dLdU[:,x[bptt_step]] += delta_t            # Update delta for next step dL/dz at t-1            delta_t = self.W.T.dot(delta_t) * (1 - s[bptt_step-1] ** 2)    return [dLdU, dLdV, dLdW]

关于传统的 RNN 模型,它存在梯度消失与爆炸问题,想想也很简单,当我们展开的模型达到一定长度,比如1000,那么根据链式法则的乘法,1.01^1000 = 20959.155, 0.99^1000 = 0.000043,RNN的训练变成不可能,而1000个字母汉字或者毫秒的音频,又或者像素点的训练长度在实作中是十分合理的。故此我们需要更好的模型 LSTM(1997)或 GRU(2014) 来解决此问题,2014年的paper 是不是非常新? 这就是机器学习的乐趣所在了,踏着时代的 edge。

0 0
原创粉丝点击