【深度学习】RNN的梯度消失/爆炸与正交初始化

来源:互联网 发布:田馥甄唱功 知乎 编辑:程序博客网 时间:2024/06/08 00:06

在训练较为复杂的RNN类型网络时,有时会采取正交方法初始化(orthogonal initialization)网络参数。本文用简单的例子介绍其中的原因。

本文较大程度参考了这篇博客。

简单例子

RNN具有如下形式:

ht=fh(Wht1+Vxt)

yt=fy(Uht)

我们考虑一个极端简化的版本:没有输入,激活函数为直通,直接输出隐变量。

yt=Wyt1

计算第t步的输出时,需要计算参数矩阵的t次幂:

yt=Wty0

为了计算简便,可以把方阵W进行正交分解:
W=QΛQ1

yt=QΛtQ1y0

其中Q是单位正交矩阵;Λ是对角阵,计算其t次幂只需要把对角线上的特征值进行幂运算即可。

优化网络参数时,使用简单的二范数代价:

E=||ytyt¯¯¯||2

为了更新参数,需要计算代价对于参数的导数(是个标量):

EWi=2(ytyt¯¯¯)TytWi

梯度消失/爆炸

当RNN步数t增加时,yt/Wi会怎样变化呢?

为书写直观假设y是个二维向量。于是W有四个参数,我们用正交分解的形式表示出来。

Q=[w1w2w2w1],Q1=[w1w2w2w1]

Λ=diag(w3,w4)

可以直接写出yt的表达式(善用Matlab的syms功能):

yt=[w21wt3+w22wt4w1w2(wt4wt3)w1w2(wt4wt3)w21wt4+w22wt3]y0

分别写出对四个参数的导数(长度为2的矢量):

ytw1=[2w1wt3w2(wt4wt3)w2(wt4wt3)2w1wt4]y0

ytw2=[2w2wt4w1(wt4wt3)w1(wt4wt3)2w2wt3]y0

ytw3=[tw21wt13w1w2wt13tw1w2wt13tw22wt13]y0

ytw4=[tw22wt14w1w2wt14tw1w2wt14tw21wt14]y0

重点:每一项里都有w3,w4的t或t-1次幂。不考虑细节,这个推导说明:

代价对于梯度的导数参数矩阵特征值λi的t次方。

如果|λi|>1,则步数增加时λt超出浮点范围,发生梯度爆炸,优化无法收敛;
如果|λi|<1,步数增加时λt变为0,发生梯度消失,优化停滞不前。

正交初始化

理想的情况是,特征值绝对值为1。则无论步数增加多少,梯度都在数值计算的精度内。

这样的参数矩阵W单位正交阵

把转移矩阵初始化为单位正交阵,可以避免在训练一开始就发生梯度爆炸/消失现象,称为orthogonal initialization。

其他解决方法

除了正交初始化,在RNN类型网络训练中,还可以使用如下方法解决梯度消失/爆炸问题:
- 使用ReLU激活函数->解决梯度消失
- 对梯度进行剪切(gradient clipping)->解决梯度爆炸
- 引入更复杂的结构,例如LSTM、GRU->解决梯度消失

0 0
原创粉丝点击