循环神经网络的训练(2)

来源:互联网 发布:衡水中学 知乎 编辑:程序博客网 时间:2024/06/05 19:07

权重梯度的计算

现在,我们终于来到了BPTT算法的最后一步:计算每个权重的梯度。

首先,我们计算误差函数E对权重矩阵W的梯度EW

上图展示了我们到目前为止,在前两步中已经计算得到的量,包括每个时刻t 循环层的输出值st,以及误差项δt

回忆一下我们在文章零基础入门深度学习(3) - 神经网络和反向传播算法介绍的全连接网络的权重梯度计算算法:只要知道了任意一个时刻的误差项δt,以及上一个时刻循环层的输出值st1,就可以按照下面的公式求出权重矩阵在t时刻的梯度WtE

WtE=δt1st11δt2st11..δtnst11δt1st12δt2st12δtnst12.........δt1st1nδt2st1nδtnst1n(5)

式5中,δti表示t时刻误差项向量的第i个分量;st1i表示t-1时刻循环层第i个神经元的输出值。

我们下面可以简单推导一下式5

我们知道:

nett=nett1nett2..nettn==Uxt+Wst1Uxt+w11w21..wn1w12w22wn2.........w1nw2nwnnst11st12..st1nUxt+w11st11+w12st12...w1nst1nw21st11+w22st12...w2nst1n..wn1st11+wn2st12...wnnst1n(44)(45)(46)

因为对W求导与Uxt无关,我们不再考虑。现在,我们考虑对权重项wji求导。通过观察上式我们可以看到wji只与nettj有关,所以:

Ewji==Enettjnettjwjiδtjst1i(47)(48)

按照上面的规律就可以生成式5里面的矩阵。

我们已经求得了权重矩阵W在t时刻的梯度WtE,最终的梯度WE是各个时刻的梯度之和

WE==i=1tWiEδt1st11δt2st11..δtnst11δt1st12δt2st12δtnst12.........δt1st1nδt2st1nδtnst1n+...+δ11s01δ12s01..δ1ns01δ11s02δ12s02δ1ns02.........δ11s0nδ12s0nδ1ns0n(6)(49)(50)

式6就是计算循环层权重矩阵W的梯度的公式。

----------数学公式超高能预警----------

前面已经介绍了WE的计算方法,看上去还是比较直观的。然而,读者也许会困惑,为什么最终的梯度是各个时刻的梯度之和呢?我们前面只是直接用了这个结论,实际上这里面是有道理的,只是这个数学推导比较绕脑子。感兴趣的同学可以仔细阅读接下来这一段,它用到了矩阵对矩阵求导、张量与向量相乘运算的一些法则。

我们还是从这个式子开始:

nett=Uxt+Wf(nett1)

因为Uxt与W完全无关,我们把它看做常量。现在,考虑第一个式子加号右边的部分,因为W和f(nett1)都是W的函数,因此我们要用到大学里面都学过的导数乘法运算:

(uv)=uv+uv

因此,上面第一个式子写成:

nettW=WWf(nett1)+Wf(nett1)W

我们最终需要计算的是WE

WE===EWEnettnettWδTtWWf(nett1)+δTtWf(nett1)W(7)(51)(52)(53)

我们先计算式7加号左边的部分。WW矩阵对矩阵求导,其结果是一个四维张量(tensor),如下所示:

WW===w11Ww21W..wn1Ww12Ww22Wwn2W.........w1nWw2nWwnnWw11w11w11w21..w11wn1w11w12w11w22w11wn2.........w111nw112nw11nn..w12w11w12w21..w12wn1w12w12w12w22w12wn2.........w121nw122nw12nn...10..0000.........000..00..0100.........000...(54)(55)(56)

接下来,我们知道st1=f(nett1),它是一个列向量。我们让上面的四维张量与这个向量相乘,得到了一个三维张量,再左乘行向量δTt,最终得到一个矩阵:

δTtWWf(nett1)======δTtWWst1δTt10..0000.........000..00..0100.........000...st11st12..st1nδTtst110..0..st120..0...[δt1δt2...δtn]st110..0..st120..0...δt1st11δt2st11..δtnst11δt1st12δt2st12δtnst12.........δt1st1nδt2st1nδtnst1nWtE(57)(58)(59)(60)(61)(62)

接下来,我们计算式7加号右边的部分:

δTtWf(nett1)W====δTtWf(nett1)nett1nett1WδTtWf(nett1)nett1WδTtnettnett1nett1WδTt1nett1W(63)(64)(65)(66)

于是,我们得到了如下递推公式:

WE======EWEnettnettWWtE+δTt1nett1WWtE+Wt1E+δTt2nett2WWtE+Wt1E+...+W1Ek=1tWkE(67)(68)(69)(70)(71)(72)

这样,我们就证明了:最终的梯度WE是各个时刻的梯度之和。

----------数学公式超高能预警解除----------

同权重矩阵W类似,我们可以得到权重矩阵U的计算方法。

UtE=δt1xt1δt2xt1..δtnxt1δt1xt2δt2xt2δtnxt2.........δt1xtmδt2xtmδtnxtm(8)

式8是误差函数在t时刻对权重矩阵U的梯度。和权重矩阵W一样,最终的梯度也是各个时刻的梯度之和:

UE=i=1tUiE

具体的证明这里就不再赘述了,感兴趣的读者可以练习推导一下。

RNN的梯度爆炸和消失问题

不幸的是,实践中前面介绍的几种RNNs并不能很好的处理较长的序列。一个主要的原因是,RNN在训练中很容易发生梯度爆炸梯度消失,这导致训练时梯度不能在较长序列中一直传递下去,从而使RNN无法捕捉到长距离的影响。

为什么RNN会产生梯度爆炸和消失问题呢?我们接下来将详细分析一下原因。我们根据式3可得:

δTk=δTkδTti=kt1Wdiag[f(neti)]δTti=kt1Wdiag[f(neti)]δTt(βWβf)tk(73)(74)(75)

上式的β定义为矩阵的模的上界。因为上式是一个指数函数,如果t-k很大的话(也就是向前看很远的时候),会导致对应的误差项的值增长或缩小的非常快,这样就会导致相应的梯度爆炸梯度消失问题(取决于β大于1还是小于1)。

通常来说,梯度爆炸更容易处理一些。因为梯度爆炸的时候,我们的程序会收到NaN错误。我们也可以设置一个梯度阈值,当梯度超过这个阈值的时候可以直接截取。

梯度消失更难检测,而且也更难处理一些。总的来说,我们有三种方法应对梯度消失问题:

  1. 合理的初始化权重值。初始化权重,使每个神经元尽可能不要取极大或极小值,以躲开梯度消失的区域。
  2. 使用relu代替sigmoid和tanh作为激活函数。原理请参考上一篇文章零基础入门深度学习(4) - 卷积神经网络的激活函数一节。
  3. 使用其他结构的RNNs,比如长短时记忆网络(LTSM)和Gated Recurrent Unit(GRU),这是最流行的做法。我们将在以后的文章中介绍这两种网络。
原创粉丝点击