【tensorflow学习】BasicLSTMCell 源码分析
来源:互联网 发布:java分布式缓存 编辑:程序博客网 时间:2024/05/29 13:23
BasicLSTMCell类是最基本的LSTM循环神经网络单元。 输入参数如下:
- num_units: LSTM cell层中的单元数
- forget_bias: forget gates中的偏置
- state_is_tuple: 还是设置为True吧, 返回 (c_state , m_state)的二元组
- activation: 状态之间转移的激活函数
- reuse: Python布尔值, 描述是否重用现有作用域中的变量
注意:
- input_size 这个参数不能使用
- state_is_tuple 官方建议设置为True。此时,输入和输出的states为c(cell状态)和h(输出)的二元组
call()将类实例转化为一个可调用的对象,传入输入input和状态state,根据LSTM的计算公式, 返回new_h, 和新的状态new_state. 其中new_state = (new_c, new_h)关于具体的理论详细见这篇论文https://arxiv.org/pdf/1409.2329.pdf
class BasicLSTMCell(RNNCell): """Basic LSTM recurrent network cell. The implementation is based on: http://arxiv.org/abs/1409.2329. We add forget_bias (default: 1) to the biases of the forget gate in order to reduce the scale of forgetting in the beginning of the training. It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline. For advanced models, please use the full LSTMCell that follows. """ def __init__(self, num_units, forget_bias=1.0, input_size=None, state_is_tuple=True, activation=tanh, reuse=None): """Initialize the basic LSTM cell. Args: num_units: int, The number of units in the LSTM cell. forget_bias: float, The bias added to forget gates (see above). input_size: Deprecated and unused. state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. If False, they are concatenated along the column axis. The latter behavior will soon be deprecated. activation: Activation function of the inner states. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) if input_size is not None: logging.warn("%s: The input_size parameter is deprecated.", self) self._num_units = num_units self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple self._activation = activation self._reuse = reuse @property def state_size(self): return (LSTMStateTuple(self._num_units, self._num_units) if self._state_is_tuple else 2 * self._num_units) @property def output_size(self): return self._num_units def __call__(self, inputs, state, scope=None): """Long short-term memory cell (LSTM).""" with _checked_scope(self, scope or "basic_lstm_cell", reuse=self._reuse): # Parameters of gates are concatenated into one multiply for # efficiency. if self._state_is_tuple: c, h = state else: c, h = array_ops.split( value=state, num_or_size_splits=2, axis=1) # 线性计算 concat = [inputs, h]W + b # 线性计算,分配W和b,W的shape为(2*num_units, 4*num_units), b的shape为(4*num_units,),共包含有四套参数, # concat shape(batch_size, 4*num_units) # 注意:只有cell 的input和output的size相等时才可以这样计算,否则要定义两套W,b.每套再包含四套参数 concat = _linear([inputs, h], 4 * self._num_units, True) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split( value=concat, num_or_size_splits=4, axis=1) new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j)) new_h = self._activation(new_c) * sigmoid(o) if self._state_is_tuple: new_state = LSTMStateTuple(new_c, new_h) else: new_state = array_ops.concat([new_c, new_h], 1) return new_h, new_state
阅读全文
0 0
- 【tensorflow学习】BasicLSTMCell 源码分析
- tensorflow教程:LSTMCell和BasicLSTMCell
- BasicLSTMCell
- Tensorflow学习笔记参考—源码分析之最近算法
- TensorFlow学习笔记之源码分析(3)---- retrain.py
- 学习Tensorflow,使用源码安装
- 【tensorflow学习】ptb_reader源码解析
- TensorFlow学习一:源码安装
- TensorFlow学习笔记 - TensorFlow数据结构分析
- TensorFlow学习笔记之四——源码分析之基本操作
- TensorFlow学习笔记之五——源码分析之最近算法
- TensorFlow学习笔记之源码分析(1)----最近算法nearest_neighbor
- TensorFlow学习笔记之源码分析(2)----手写数字识别mnist example
- TensorFlow学习笔记之五——源码分析之最近算法
- 从头实现一个深度学习对话系统--tensorflow Seq-to-Seq API介绍和源码分析
- Tensorflow学习--CNN代码分析
- 【Deep Learning】YOLO_v1 的 TensorFlow 源码分析
- 转: TensorFlow学习一:源码安装
- “共享经济”的风催熟了“信用经济”
- C++对象构造函数失败会直接回收已分配的内存
- B. Sagheer, the Hausmeister(codeforce 417 div2 B, dfs)
- 自定义配置 mpv 播放器
- 记录一些常用的utils方法6
- 【tensorflow学习】BasicLSTMCell 源码分析
- xv6 lazy page allocation
- AM2320 温湿度计 单总线读取数据
- Lintcode——中位数
- git常规使用
- 一个简单的模板引擎
- 用chsh选择shell
- bootstrapTable通过js加载到设置固定表头宽度
- python 获取当前运行的 class 的名字