Keras之LSTM源码阅读笔记
来源:互联网 发布:linux关闭mysql启动 编辑:程序博客网 时间:2024/05/22 21:49
这里目前为止只是博主阅读Keras中LSTM源码的草稿笔记,内容不全,没有清晰的逻辑,只是堆砌个人想法。
参考文献:
1. keras的官方相关文档
2. LSTM原论文
3. keras的RNN源码
1. 接口研究
1.1. Recurrent接口
Recurrent是LSTM的父类(实际是通过SimpleRNN间接继承),定义所有RNNs的统一接口。
1.1.1. implementation:
implementation: one of {0, 1, or 2}.
If set to 0, the RNN will use an implementation that uses fewer, larger matrix products, thus running faster on CPU but consuming more memory. If set to 1, the RNN will use more matrix products, but smaller ones, thus running slower (may actually be faster on GPU) while consuming less memory. If set to 2 (LSTM/GRU only), the RNN will combine the input gate, the forget gate and the output gate into a single matrix, enabling more time-efficient parallelization on the GPU. Note: RNN dropout must be shared for all gates, resulting in a slightly reduced regularization.
博主主要使用GPU加速且不在意内存的代码,所以通常设置implementation=2,源码阅读也主要集中在implementation=2的部分。
1.1.2. weights:
weights: list of Numpy arrays to set as initial weights.
The list should have 3 elements, of shapes: [(input_dim, output_dim), (output_dim, output_dim), (output_dim,)]`.
1.2. LSTM接口
1.2.1. recurrent_activation
Activation function to use for the recurrent step.
注意: 默认值是’hard_sigmoid’,而原论文中用的’sigmoid’(关于hard_sigmoid 和sigmiod的比较请参考What is hard sigmoid in artificial neural networks? Why is it faster than standard sigmoid? Are there any disadvantages over the standard sigmoid?)。
2. kernel VS recurrent_kernel
2.1. kernel
1 . 初始化
self.kernel = self.add_weight(shape=(self.input_dim, self.units * 4),name='kernel',initializer=self.kernel_initializer,regularizer=self.kernel_regularize,constraint=self.kernel_constraint)
2 . 分块意义
self.kernel_i = self.kernel[:, :self.units]self.kernel_f = self.kernel[:, self.units: self.units * 2]self.kernel_c = self.kernel[:, self.units * 2: self.units * 3]self.kernel_o = self.kernel[:, self.units * 3:]
3 . kernel是用于和输入x做乘法的矩阵
2.2. recurrent_kernel
1 . 初始化:
self.recurrent_kernel = self.add_weight( shape=(self.units, self.units * 4), name='recurrent_kernel', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint)
2 . 分块意义
self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units]self.recurrent_kernel_f = self.recurrent_kernel[:, self.units: self.units * 2]self.recurrent_kernel_c = self.recurrent_kernel[:, self.units * 2: self.units * 3]self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:]
3 . recurrent_kernel是用于和前一时刻隐层输出h做乘法的矩阵
3. activation VS recurrent_activation
迭代部分代码如下:
if self.implementation == 2: z = K.dot(inputs * dp_mask[0], self.kernel) z += K.dot(h_tm1 * rec_dp_mask[0], self.recurrent_kernel) if self.use_bias: z = K.bias_add(z, self.bias) z0 = z[:, :self.units] z1 = z[:, self.units: 2 * self.units] z2 = z[:, 2 * self.units: 3 * self.units] z3 = z[:, 3 * self.units:] i = self.recurrent_activation(z0) f = self.recurrent_activation(z1) c = f * c_tm1 + i * self.activation(z2) o = self.recurrent_activation(z3)h = o * self.activation(c)
可见activation 作用于i,f,o的生成,recurrent_activation作用于g的生成以及在c的输出部分做微调。如果要模拟原论文的话,应该设置activation = tanh, recurrent_activation = sigmoid。
4. 猜测的代码运行顺序
- 首先运行__init__部分,先Recurrent部分,然后是LSTM部分,都是简单的赋值。
- 很有可能还会运行Recurrent部分的__call__函数,关于__init__和__call__的区别可以参考Python 魔术方法指南。
注意到这里其实是可以定义初始状态的。通常情况下直接调用父类Layer的__call__函数。 - 然后运行LSTM的build部分,变量初始化。参考Writing your own Keras layers
- 运行Recurrent的call函数。猜测通常的输入inputs只是一个keras的tensor,所以以下代码的运行逻辑是走到最后一个else里面:
if isinstance(inputs, list): initial_state = inputs[1:] inputs = inputs[0] elif initial_state is not None: pass elif self.stateful: initial_state = self.states else: initial_state = self.get_initial_state(inputs)
5 . Recurrent的get_initial_state函数,里面就是返回全0的初始状态。
- Keras之LSTM源码阅读笔记
- keras之lstm
- keras + LSTM
- keras源代码阅读之-Reshape
- 阅读源码遇到的一些TF、keras函数及问题2(--小白笔记)
- keras + lstm 情感分类
- lstm——keras
- Keras实现LSTM
- python keras LSTM 学习
- keras lstm分析imdb
- MySQL源码阅读笔记之代码结构
- Spark源码阅读笔记之BlockStore
- Spark源码阅读笔记之MemoryStore
- Spark源码阅读笔记之DiskStore
- Spark源码阅读笔记之BlockObjectWriter
- Spark源码阅读笔记之MetadataCleaner
- Spark源码阅读笔记之Broadcast(一)
- free modbus 源码阅读笔记之中断
- hql投影查询之—— [Ljava.lang.Object; cannot be cast to cn.bdqn.guanMingSys.entity.Notice
- qduoj 冰清玉洁丶YCB(水题)
- OKGO 最基础post请求使用
- PAT乙级真题及训练集(10)--1041. 考试座位号(15)
- 文件操作
- Keras之LSTM源码阅读笔记
- 实现string类
- redis
- 路由算法
- 删除字符串中的子串
- Android_Activity页面之间数据传递
- Final字段如何改变它们的值
- 识别红线
- codeforces#419 B q