Simple LSTM
来源:互联网 发布:lte是什么网络制式移动 编辑:程序博客网 时间:2024/06/14 06:53
A few weeks ago I released some code on Github to help people understand how LSTM’s work at the implementation level. The forward pass is well explained elsewhere and is straightforward to understand, but I derived the backprop equations myself and the backprop code came without any explanation whatsoever. The goal of this post is to explain the so called backpropagation through time in the context of LSTM’s.
If you feel like anything is confusing, please post a comment below or submit an issue on Github.
Note: this post assumes you understand the forward pass of an LSTM network, as this part is relatively simple. Please read this great intro paper if you are not familiar with this, as it contains a very nice intro to LSTM’s. I follow the same notation as this paper so I recommend reading having the tutorial open in a separate browser tab for easy reference while reading this post.
Introduction
The forward pass of an LSTM node is defined as follows:
By concatenating the
we can re-write parts of the above as follows:
Suppose we have a loss
where
Our ultimate goal in this case is to use gradient descent to minimize the loss
Let’s work through the algebra of computing the loss gradient:
where
where
For notational convenience we introduce the variable
such that
With this in mind, we can re-write our gradient calculation as:
Make sure you understand this last equation. The computation of
Backpropagation through time
This variable
Hence, given activation
Now, we know where the first term on the right hand side
and work our way backwards through the network. Hence the term backpropagation through time. With these intuitions in place, we jump into the code.
Code
We now present the code that performs the backprop pass through a single node at time
top_diff_h
=dL(t)dh(t)=dl(t)dh(t)+dL(t+1)dh(t) top_diff_s
=dL(t+1)ds(t) .
And computes:
self.state.bottom_diff_s
=dL(t)ds(t) self.state.bottom_diff_h
=dL(t)dh(t−1)
whose values will need to be propagated backwards in time. The code also adds derivatives to:
self.param.wi_diff
=dLdWi - …
self.param.bi_diff
=dLdbi - …
since recall that we must sum the derivatives from each time step:
Also, note that we use:
dxc
=dLdxc(t)
where we recall that
def top_diff_is(self, top_diff_h, top_diff_s): # notice that top_diff_s is carried along the constant error carousel ds = self.state.o * top_diff_h + top_diff_s do = self.state.s * top_diff_h di = self.state.g * ds dg = self.state.i * ds df = self.s_prev * ds # diffs w.r.t. vector inside sigma / tanh function di_input = (1. - self.state.i) * self.state.i * di df_input = (1. - self.state.f) * self.state.f * df do_input = (1. - self.state.o) * self.state.o * do dg_input = (1. - self.state.g ** 2) * dg # diffs w.r.t. inputs self.param.wi_diff += np.outer(di_input, self.xc) self.param.wf_diff += np.outer(df_input, self.xc) self.param.wo_diff += np.outer(do_input, self.xc) self.param.wg_diff += np.outer(dg_input, self.xc) self.param.bi_diff += di_input self.param.bf_diff += df_input self.param.bo_diff += do_input self.param.bg_diff += dg_input # compute bottom diff dxc = np.zeros_like(self.xc) dxc += np.dot(self.param.wi.T, di_input) dxc += np.dot(self.param.wf.T, df_input) dxc += np.dot(self.param.wo.T, do_input) dxc += np.dot(self.param.wg.T, dg_input) # save bottom diffs self.state.bottom_diff_s = ds * self.state.f self.state.bottom_diff_x = dxc[:self.param.x_dim] self.state.bottom_diff_h = dxc[self.param.x_dim:]
Details
The forward propagation equations show that modifying
- Simple LSTM
- LSTM
- lstm
- lstm
- lstm
- LSTM
- LSTM
- LSTM
- LSTM
- lstm
- lstm
- LSTM
- LSTM
- LSTM
- LSTM
- LSTM
- LSTM
- LSTM
- 速查笔记(Linux Shell编程<上>)
- 从零开始写Python爬虫 --- 1.3 BS4库的解析器
- Unity的机器学习代理工具
- 第一天
- Activity与Context的继承关系
- Simple LSTM
- 内存泄漏memory leak和内存溢出OOM
- 运行报错Error:Execution failed for task ':app:transformClassesWithDexForDebug'
- VMware Fusion 10序列号
- C++ 数据结构-栈
- js统计字符串出现的频率
- 根据PowerDesigner的模型生成数据库表
- DDL-数据表的创建
- LSTM implementation explained