Training RNNs as Fast as CNNs

来源:互联网 发布:淘宝发货后买家退款 编辑:程序博客网 时间:2024/05/21 17:44

摘要

RNN的并行性比较差,主要因为它在计算state的时候不能并行,比如要计算输出h(t),它必须依赖于前一步的输出h(t-1),这个是并行化的瓶颈。
在这篇论文提出一种可选择的RNN结构,它的递归单元可以和卷积层一样快,是cud优化的LSTM的5-10倍。我们通过一序列的实验包括分类,qa,语言模型,翻译以及语音识别来证明这种卷积单元确实是高效的。
论文源码用 PyTorch 和 CNTK实现过 https://github.com/taolei87/sru

简介

最近深度学习的发展主要归功于模型的容量和计算能力的提升上,模型经常通过更深以及更宽的结构来引入更多的超参来实现。不断增长的模型以及参数个数使得计算量会急剧上升。
比如,需要训练一个较好效果的翻译或者语音识别模型,需要花费几天的时间在训练上面。
很显然,计算性能已经变成现在研究的一个很大的瓶颈。现在卷积网络或者attention模型在利用GPU加速上面并行性做的很好,但是rnn在这方面做的不够。之前也有一些工作是在优化LSTM的计算速度上面,但是
和CNN比还是有10倍的速度落后。

在本文我们介绍一种 Simple Recurrent Unit (SRU),比传统的RNN有显著的速度提升。它是通过简化state的计算来实现- 不是每一步的计算都需要依赖上一步的输出的。
简单的说就是 复杂的计算,比如矩阵计算比如 forget gate, input gate等计算都不依赖于上一步的输出h(t),只依赖于当前step的输出x(t)
状态的更新c(t)的更新必须依赖于上一步的状态c(t-1),但是他的计算都是element-wise,很简单的计算。
和cuDNN LSTM 和conv2d 类似,我们也做了 cuda 维度的优化。

模型

SRU 实现

比较流行的RNN结构比如LSTM 和GRU 都是通过gates的机制来控制信息流,下面开始介绍实现:
首先是状态层更新做了一些简化:

ct=ftct1+itx1t 
=ftct1+(1ft)x1t 
ft,it分别是和lstm一致的 forget gate 和 input gate,
是一个sigmoid gate,x1t是对输入的x处理变换,
这里我们只是用一个简单的线性变换,Wxt
it的计算方式直接用1ft来简单化.

在lstm中x1t,it的计算不仅和xt有关,而且和上一步的输出ht1有关。

然后输出状态ct输入一个激活函数g(.)来生成新的输出ht=g(ct)

然后对输出的ht做额外的处理:

第一,我们在引入skip connection layer

h1t=rtht+(1rt)xt 
=rtg(ct)+(1rt)xt 
rt 代表 reset gate。

第二,我们实现 variational dropout 机制作为标准dropout正则的补充
普通的 droput 是作用在 输出ht之后
它是作用在输入xt上面的

速度优化

普通的RNN的gate,比如: forget gate
ft=σ(Wfxt+Rfht1+bf) 
必须依赖上一步的输出ht1,所以ht1它破坏了并行性,
我们把这种连接也去除了

我们整体介绍下各个步骤:
x1t=Wxt  ————-(3)
ft=σ(Wfxt+bf)  ——-(4)
rt=σ(Wrxt+br)  ——-(5)
ct=ftct1+(1ft)x1t  ——-(6)
ht=rtg(ct)+(1rt)xt  ——-(7)

下3-5 完全可以并行计算的,6-7虽然不能并行,但是他们的计算是很快的,
因为他们都是些element-wise的计算。

CUDA级别优化

简单的SRU优化大概可以相对LSTM做到5倍的速度提升,
下面我们介绍CUDA级别的优化,
1, 所有step的矩阵相乘都可以一起计算,这样可以提高GPU的使用
2. 所有element-wise计算都是可以融合成一个核函数,如果不这么做,比如 + 或者 sigmoid 运算,是不同的函数调用,这个带来额外的底层核函数调用以及数据加载,整体会有额外的开销。

相关工作

网络结构:
RCNN Lei et al., 2015,2016
Quasi-RNN Bradbury et al., 2017 这个是本编主要参考的

CUDA 优化:
cuDNN SLTM Appleyard elt al., 2016

实验

在 文本分类,问答,语言模型,机器翻译,语音识别上面都做了实验,做到了和CNN同样的速度,效果也不错

原创粉丝点击