[NLP论文阅读]LightRNN: Memory and Computation-Efficient Recurrent Neural Networks

来源:互联网 发布:淘宝退货运费险怎么退 编辑:程序博客网 时间:2024/06/05 08:20

原文链接:LightRNN: Memory and Computation-Efficient Recurrent Neural Networks

引言

RNN已经在多个自然语言处理任务中取得了最先进的表现,例如语言建模和机器翻译。然而,随着词表大小的变大,RNN模型会变得很大(可能会超出GPU设备的内存)并且RNN模型的训练会变得不高效。在本工作中,我们提出了一种解决这一问题的新方法。核心思想是使用一个2-Component(2C) shared embedding(二部共享嵌入)来进行词表示。我们词表中的每个单词都分配到一个表中,表格的每一行和一个行向量关联,每一列和一个列向量关联。那么一个单词可以根据自己在表格中的位置,由2个component表示,即一个行向量和一个列向量。同一行的所有单词共享一个行向量,同一列的所有单词共享一个列向量,因此表示一个有|V|个不同词汇的词表只需要2|V|个向量,这比现有的方法少了很多。基于2-Component(2C) shared embedding,我们设计了一个新的RNN算法并且在语言建模任务上进行了评估。实验结果表明,我们的算法在不牺牲精度的情况下,显著降低了模型的规模的并加速了训练过程。

LightRNN模型

LightRNN算法的一个关键技术创新就是2-Component(2C) shared embedding。
如下图所示:
二部共享嵌入举例
我们把每一个单词都分配进一个单词表中,表格中的第i行对应的行向量为xri,第j列对应的列向量为xcj。那么表格中第i行第j列的单词就可以被表示为两部分:xrixcj。对于输出词向量,也是一样的情况。也就是说这里有输入和输出两套词向量。
LightRNN的具体结构如下:
LightRNN结构
有了2-Component(2C) shared embedding,我们可以通过将vanilla RNN(不知道怎么翻译好)模型的基本单元加倍的方法来构建LightRNN模型。使用n和m来表示行/列输入向量以及隐藏状态向量的维数。为了计算出wt的概率分布,我们需要使用列向量xct1ϵRn,行向量xrtϵRn以及隐藏状态向量hct1ϵRm。列向量和行向量是通过输入向量矩阵Xc,XrϵRn×|V|获得,隐藏状态hct1,hrtϵRm的计算方式如下:
隐藏状态h计算公式
其中,f是非线性的激活函数,例如sigmoid函数。
模型的输入是从行向量矩阵Xr和列向量矩阵Xc中找到单词wt1对应行列向量xrt1xct1作为输入。要计算下一个词是 wt 的概率,先根据前文计算下一个词的行向量是 wrt 的概率分布,在根据前文和 wrt 的概率分布中最大值索引对应的行向量xrr(wt1)来计算下一个词的列向量是xcc(wt1)的概率,行向量和列向量的概率乘积就是下一个词是 wt的概率。
行列向量概率分布计算公式

损失函数

语言模型的目标是最小化预测词的负对数似然函数,这就相当于优化目标概率分布和LightRNN模型的预测的交叉熵(cross-entropy)。
给定一个有T个单词的文本,负对数似然函数可以写成:

NLL=t=1TlogP(wt)=t=1Tlog(Pr(wt)Pc(wt))

NLL=w=i|V|NLLw

NNLw是对于单词w的负对数似然。
这里写图片描述
Sw是单词w在语料库中所有位置的集合;
lrlc分别被称为行损失和列损失。

模型训练

训练主要分为三步:
1. 冷启动(cold start),随机初始化单词在表格中位置;
2. 固定单词在表格中位置,训练LightRNN的输入/输出embedding vectors。收敛条件可以是训练时间、语言模型的困惑度(perplexity)达到阈值等等。
3. 固定步骤2学习到的embedding vectors,通过调整单词在表格中的位置来最小化损失函数。接着重复步骤2。

在步骤3中调整单词位置以减小的损失函数的方法,文章中使用的是最小费用最大流算法(MCMF),对于该方法不是很了解,只读懂了文中对该方法举得例子:
这里写图片描述
上图b中,w1,w2,w3,w4的原始位置分别是11,12,21,22,在这种情况下,运用MCMF算法将位置调整为21,11,12,22可以使得损失函数降低。

实验结果

论文中展示一张训练好的词表,展示出来的结果很有意思(但感觉是经过人工挑选出比较好结果)
这里写图片描述

最后

文中提到了一些future work:
1. 将LightRNN应用到更大的语料库;
2. 包括词的表示不再局限于2-Component(2C) shared embedding,将要去研究k-Component shared embedding;
3. 将LightRNN应用到其他自然语言处理任务,例如机器翻译和问答。

文末提到,作者会在未来对代码进行开源,然而。。。

0 0
原创粉丝点击