lstm的数学推导

来源:互联网 发布:apache 安装 编辑:程序博客网 时间:2024/04/27 01:18

本文是根据以下三篇文章整理的LSTM推导过程,公式都源于文章,只是一些比较概念性的东西,要coding的话还要自己去吃透以下文章。

 

 

前向传播:

1、计算三个gate(in, out, forget)的输入和cell的输入:

zinj(t)=mwinjmym(t1)+v=1SjwinjcvjScvj(t1),(1)(1)zinj(t)=∑mwinjmym(t−1)+∑v=1SjwinjcjvScjv(t−1),

zφj(t)=mwφjmym(t1)+v=1SjwφjcvjScvj(t1),(2)(2)zφj(t)=∑mwφjmym(t−1)+∑v=1SjwφjcjvScjv(t−1),

zoutj(t)=mwoutjmym(t1)+v=1SjwoutjcvjScvj(t1),(3)(3)zoutj(t)=∑mwoutjmym(t−1)+∑v=1SjwoutjcjvScjv(t−1),

zctj(t)=mwctjmym(t1)+v=1SjwctjcvjScvj(t1),(4)(4)zcjt(t)=∑mwcjtmym(t−1)+∑v=1SjwcjtcjvScjv(t−1),

2、计算上述各个gate和cell的激活值:

yinj(t)=finj(zinj(t)),(5)(5)yinj(t)=finj(zinj(t)),

yφj(t)=fφj(zφj(t)),(6)(6)yφj(t)=fφj(zφj(t)),

youtj(t)=foutj(zoutj(t)),(7)(7)youtj(t)=foutj(zoutj(t)),

Scvj(0)=0,Scvj(t)=yφj(t)Scvj(t1)+yinj(t)g(zcvj(t)),(8)(8)Scjv(0)=0,Scjv(t)=yφj(t)Scjv(t−1)+yinj(t)g(zcjv(t)),

ycvj(t)=youtjScvj(t),(9)(9)ycjv(t)=youtjScjv(t),

3、假定该网络为一个标准的三层结构(如下图所示),即一个输入层,一个隐层和一个输出层。则对于一个输出单元,我们可以按下述的方式计算它的输入和激活值。其中m为所有与该输出单元连接的单元(包括输入层的和隐层的)。


 

 

zk(t)=mwkmym(t),(10)(10)zk(t)=∑mwkmym(t),

yk(t)=fk(zk(t)),(11)(11)yk(t)=fk(zk(t)),

4、计算当前时间点对应状态对input gate和、forget gate以及cell的偏导数。这里跟CNN不一样,CNN前向只是求值,没有传递梯度。但对于lstm,由于内部状态的改变依赖前一时间点的状态,因此内部状态的参数也会把错误传递到网络下一层,因此前向也涉及到梯度传递。

dSjvin,m(t)=Scvj(t)winjm=trScvj(t1)winjmyφj(t)+g(zcvj(t))finj(zinj(t))ym(t1),(12)(12)dSin,mjv(t)=∂Scjv(t)∂winjm=tr∂Scjv(t−1)∂winjmyφj(t)+g(zcjv(t))f′inj(zinj(t))ym(t−1),

dSjvφm(t)=Scvj(t)wφjm=trScvj(t1)wφjmyφj(t)+Scvj(t1)fφj(zφj(t))ym(t1),(13)(13)dSφmjv(t)=∂Scjv(t)∂wφjm=tr∂Scjv(t−1)∂wφjmyφj(t)+Scjv(t−1)f′φj(zφj(t))ym(t−1),

dSjvcm(t)=Scvj(t)wcvjm=trScvj(t1)wcvjmyφj(t)+g(zcvj(t))yinj(t)ym(t1),(14)(14)dScmjv(t)=∂Scjv(t)∂wcjvm=tr∂Scjv(t−1)∂wcjvmyφj(t)+g′(zcjv(t))yinj(t)ym(t−1),


后向传播:
1、对于每个输出单元(output unit),我们可以计算它的 输出错误如下,其中tk(t)tk(t)为前向计算的输出,yk(t)yk(t)为真实值。

ek(t)=tk(t)yk(t),(15)(15)ek(t)=tk(t)−yk(t),

2、接下来计算每个输出单元的残差,这里的计算和CNN是一样的,就是对该层网络求导。

δk(t)=fk(zk)ek(t)(16)(16)δk(t)=f′k(zk)ek(t)

3、输出output gate的残差计算方式和output unit类似。(output unit只针对每一个小单元的权重,而output gate针对的是所有output unit连接到输出层的权重)

δoutj(t)=foutj(zoutj(t))(Sjv=1h(Scvj(t))kwkcvjδk(t)),(17)(17)δoutj(t)=f′outj(zoutj(t))(∑v=1Sjh(Scjv(t))∑kwkcjvδk(t)),

4、第2和第3条针对的是外部残差,内部残差(包括input gate, forget gate和cell)计算方式如下:

eScvj(t)=youtj(t)h(Scvj(t))(kwkcvjδk(t)),(18)(18)eScjv(t)=youtj(t)h′(Scjv(t))(∑kwkcjvδk(t)),

5、最后,根据残差更新各个参数(weight),注意外部和内部的表达式不一样,具体推导见原文。

1).output unit:

Δwkm(t)=αδk(t)ym(t1),(19)(19)Δwkm(t)=αδk(t)ym(t−1),

2).output gate:

Δwout,m(t)=αδout(t)ym(t1),(20)(20)Δwout,m(t)=αδout(t)ym(t−1),

3).input gate:

Δwin,m(t)=αSjv=1eScvj(t)dSjvin,m(t),(21)(21)Δwin,m(t)=α∑v=1SjeScjv(t)dSin,mjv(t),

4).forget gate:

Δwφm(t)=αSjv=1eScvj(t)dSjvφm(t),(22)(22)Δwφm(t)=α∑v=1SjeScjv(t)dSφmjv(t),

5).cell:

Δwcvjm(t)=αeScvj(t)dSjvcm(t),(23)(23)Δwcjvm(t)=αeScjv(t)dScmjv(t),

 

0 0