赵哲焕 Clock work RNN(CW-RNN)

来源:互联网 发布:在线淘宝网 编辑:程序博客网 时间:2024/06/05 15:41

CW-RNNICML2014上提出的一篇论文,与LSTM模型目的是相同的,就是为了解决经典的SRN对于长距离信息丢失的问题。但是与LSTM(基于三个门进行过滤和调节)的思想完全不同,CW-RNN利用的思想非常简单。下面介绍一下CW-RNN

提出我们要解决的问题:我们要做的事情是序列标注问题,输入时一个序列,输出是对应的标签序列。如下:

Input   =       (X1,  X2,   … ,    Xt-1,          Xt,    …)

Output =       (Y1,  Y2,   …,     Yt-1,          Yt,    …)

首先,介绍一下SRN(就是经典的RNN),因为CW-RNN是在SRN基础上进行的简单改进。

1

1就是SRN的结构图,相对于多层感知,区别在于多了红色连线的边对隐含层节点进行了全连接。就靠这张图可能不容易理解,图2对红色边进行了进一步地展开,以便大家理解SRN模型。

2

我们可以看到,红色边是有时序性的。它们是从上一个时刻(t-1)的隐含层的节点到这一刻(t)的隐含层节点的全连接。

如果我们进一步在时序上进行展开,就会得到图3所示的结构。

3


                                          1


   

                                           2

公式1和公式2给出了SRN的前馈计算步骤。

回顾了SRN模型,那CW-RNNSRN的基础上进行了什么样的改进呢?我们可以对比一下图4和图1

4

对比中我们可以发现他们存在区别(也就是CW-RNN的改进):

1, 把隐含层节点分成了若干个模块(在图4中分成了3个模块,是为了方便说明,实际中的模块个数可以自定义),而且每个模块都分配了一个时钟周期(Ti),便于独立管理。

2, 隐含层之间的连接,一个模块内部是全连接,但是模块之间是有方向的。模块之间的连接是从高时钟频率的模块指向低时钟频率的模块。

那么这些改进,具体是如何实现的呢?首先,它把公式一中的两个参数(WxWh)根据模块的个数分成了相应的模块。如公式3和公式4.

                                                                3

                                                                 4

运算的时候,会选择部分模块参与运算,不参与运算的模块就置0.具体操作通过公式5实现。其中,本文中。这个也可以根据自己的实际情况进行改进。

                      5

这样实现了分块管理,那模块之间连接的方向(高时钟频率模块指向低时钟频率模块)是如何实现的呢?这个也非常简单,利用公式6就可以实现。公式6强制性地把Wh参数矩阵设置成上三角阵。那为什么设置成上三角就可以完成方向的控制了呢?看下面的例子就容易理解了。

                                    6

让我们看一个具体的例子来进一步理解他们运行的过程吧。看图5

5

如图5所示,当我们要处理序列中第6t=6)个元素的时候,通过t与每个模块的时钟周期进行MOD(求余数)计算后可以得到只有前两个模块会参与运算。所以WhWx矩阵除了上面两行之外,其他元素的值都是0。经过计算之后,得到的ht也只有前两个模块有值。所以,我们也可以把CW-RNN过程看成是通过一些人工的干预,选择不同的隐含层节点进行工作。

现在回答上面遗留下的问题,我们如何通过上三角矩阵实现模块之间的单项传播的呢?

我们先不考虑WxXt之间的计算(因为这部分不涉及到隐含层模块之间的传播问题)。我们看看隐含层第二个模块的计算过程,它是通过Wh的第二行和ht-1向量进行内积得到的。现在Wh的第二行第一个模块对应的值是0(上三角嘛),所以它与ht-1内积运算的过程中,ht-1的第一个模块中的值很自然就被忽略掉了(乘以0)。所以ht的第二个模块的值是ht-1的第二个模块之后所有模块的加权和。也可以看成是ht-1的大于等于第二个模块的所有模块传向ht的第二个模块。如果Wh不是上三角阵,那么ht-1的第一个模块就不会被忽略,导致ht-1的所有模块的加权和传到了ht的第二个模块。所以上三角矩阵可以实现上一个时刻的高时钟周期的隐含层模块传向当前时刻的低时钟周期的隐含层模块。

CW-RNN的工作过程基本就解释完了,那为什么这样就可以提升效果呢?我们分析一下这样分模块管理的好处。

先给出每个模块的时钟周期。本文时钟周期计算方法:


   

 



1. t为不同的时刻值。当t=16时,表示当前输入为序列中第十六个元素。红色星型符号表示在当前时刻,参与运算的模块。是否参与运算通过公式5得到。

1中给出了每个时刻参与运算的模块。我们可以分析得到如下结论:

1.      第一个模块每次都会参与运算。

2.      随着时间的推移,参与运算的模块数会越来越多。或者说,高时钟频率的模块只有在后面的时刻参会参与到运算中。

这样我们可以推测一下,近距离的信息需要经常更新,所以由低时钟频率的模块主要处理,而那些长距离的信息(t比较大的,这时要考虑前面的信息,距离比较远),就由那些高时钟周期的模块来处理。这样模块就有了明确的分工,长距离的信息也得到了比较好的处理

说了这么多CW-RNN,那它的性能到底如何呢?尤其是与已经成名已久,在各个领域都得到认可的LSTM比较呢?我们从两个方面分析一下吧。从模型的原理和实验结果。

从模型的原理上分析的话,还需要回顾一下LSTM模型

LSTM是由igate(输入门),fgate(忘记门),ogate(输出门)进行调节和过滤得到的结果。这些门都是每个元素的值在01之间的实数构成的向量。它会与对应的向量进行内积进而得到过滤(忘记)的过程。当对应的门值接近零的时候,对应的信息就会大大减弱(忘记),相反如果接近1的话,就会大部分保留。

在这里我们考虑比较极端的情况,假设每个门中元素的值要么是0要么是1。我们首先看一下c~t(通过公式可以看出它计算得到的就是SRN的隐含层值)。我们首先考虑ct中第一部分用红框框起来的部分。用输入门过滤后就会使部分隐含层被过滤掉。做到这里,是不是觉得有些类似CW-RNN。都是通过一些方法使部分隐含层的值为0,或者破坏掉部分隐含层的值。只不过CW-RNN比较多的是人工干预,而LSTM是通过计算自动选择部分隐含层节点。LSTM在过滤部分隐含层节点的基础上,还考虑到了上一层细胞状态(ct-1准隐含层值,经过输出门的调节得到隐含层值),通过忘记门(fgate)选择部分上次细胞状态添加进来,最后用输出门(ogate)再进一步进行调节。

从公式的分析可以得到如下结论:

LSTM的功能要比CW-RNN强大很多,而且实现方法也是优雅很多。CW-RNN只是每次选择部分隐含层进行工作,而这只是LSTM功能(还包括忘记门和输出门的调节)的子集。而且CW-RNN有太多人工成分,分模块和时钟频率等等,而LSTM都是自动学到,比较智能。

再从实验中分析,论文1中说,CW-RNN要比LSTM好,但是它实验设置是两个模型训练的时间是一样的。也就是说迭代次数相同。但是LSTM的模型要比CW-RNN复杂很多,一个复杂的模型想要训练好,那么需要的时间肯定也会更长,所以我觉得这种实验设置得到的结果没有可比性。而且再看论文2中的实验效果,CW–RNN并没有比经典的SRN好多少。而LSTM在业界,很多领域都已经公认要比SRN好很多。

但是有一点需要承认,CW-RNN的时间复杂度要比SRN还要低,因为它每次更行参数都是选择部分更新,而不是全部,相比之下,LSTM在时间上肯定要比CW-RNN慢很多,看它的那些公式,那些需要训练的参数就不难得到这种结论了。

综上,CW-RNN的性能有待进一步的考验,目前看来,并没有论文1中说的那么神奇。

原创粉丝点击