Tensorflow: RNNCell
来源:互联网 发布:js获取对象的第一个值 编辑:程序博客网 时间:2024/06/08 13:00
这里记录Tensorflow一些常见的RNNCell, 所有的RNNCell都在tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl
这个模块中.
tf.contrib.rnn.BasicLSTMCell
基础的LSTM Cell, LSTM中最常见的Cell结构.
这里的Cell限制: 输入变量
x
与中间状态h
,C
中的特征数量相同(或理解为含有的神经元相同), 为参数num_units
定义.对于t时刻的cell状态, 接受的输入有两个, 一个是输入
input(t)
, 另一个是中间变量context(t-1)
, 输出为output(t)
, 通过cell的结构可以看出,output(t)
即是新的中间变量context(t)
. 因此x
,h
具有相同的shape: [num_units].因此, 对于cell的中某个权值
W
, 与num_units长度的x
相乘, 也于num_units长度的h
相乘, 获得的结果也是num_units长度的, 因此对于x
的权值Wx
的shape为(num_units, num_units)
, 同理对于h
的权值Wh
的shape也为(num_units, num_units)
, 将两者叠加在一起, 就为单个激活函数前的连接权值W
, shape为(2*num_units, num_units)
. 同时, 相应的偏置向量长度为(num_units, )
.又已知基础的BasicLSTMCell中有四个需要这样计算的激活函数前的网络, 因此有4套
W
和b
参数, 再将这些参数组合起来, 得到最终的W
和b
的大小:W
的shape为(2*num_units, 4*num_units);b
的shape为(4*num_units, ).以上是对BasicLSTMCell的解释.
'''参数: num_units: cell中网络的神经元个数, 同时也是输入和中间变量和输出长度. forget_bias: float, 添加到遗忘门中的偏置, 为了减少在开始训练时遗忘的规模。 input_size: input数据的shape, 废弃使用; 一般input_size为[batch_size, max_time, size], batch_size是每次批量传入的样本数量, max_time表示在时间轴上推进的次数(语句分析代表最长的那句话是多长?), size代表输入的特征数量, 即embedding_size, 也即为num_units; state_is_tuple: bool, 定义调用时接受和返回的中间变量的格式. True返回两个tuple, 分别代表中间变量C和中间变量h, shape为[batch_size, num_units], False则是将两个tuple在列方向上合并的结构, shape为[batch_size, 2*num_units]; activation: 输出门的激活函数; reuse: bool, 是否复用Cell中的变量.'''
tf.contrib.rnn.DropoutWrapper
在给定的cell上的对输入和输出添加dropout操作. 这里的dropout指的是两层LSTM之间, 同一时刻
t
的cell之间传递时进行dropout操作.
'''参数: cell: (must)需要添加Dropout操作的RNNCell; input_keep_prob: 输入进行dropout操作的保留概率; output_keep_prob: 输出进行dropout操作的保留概率; state_keep_prob: 对中间状态进行dropout操作的保留概率(不推荐设置); variational_recurrent: bool, True则在每次调用__call__方法时, 使用同样的dropout方式, 同时要求指明了input_size参数; input_size: 输入的shape; dtype: 输入, 中间变量和输出的数据类型; seed: 随机种子'''lstm_cell = tf.contrib.rnn.BasicLSTMCell(size, forget_bias=0.0)lstm_cell = tf.contrib.rnn.DropoutWrapper(lstm_cell, output_keep_prob=config.keep_prob)
tf.contrib.rnn.MultiRNNCell
堆叠多个RNNCell, 形成层级, 成为一个整体的多层级RNNCell.
'''参数: cells: (must)list of RNNCells, 按先后顺序排列; state_is_tuple: bool, True则调用__call__方法时传入和输出的中间状态为n个tuple, n为cells的长度, 即RNNCell的个数, False则将这些中间状态在列方向上拼接起来.'''
方法
- cell.zero_state()
经过DropoutWrapper, MultiRNNCell生成的cell调用这个方法时, 真正调用的是传入的RNNCell的zero_state方法.
'''作用: 返回0值的中间变量Tensors.参数: batch_size: 一个batch样本的数量; dtype: 输出的Tensor的数据类型.输出: 1. RNNCell, DropoutWrapper: 一个LSTMStateTuple, 这个对象相当于一个tuple, 长度为2, 含有中间状态C和h的值, 分别为一个Tensor, 这两个Tensor的shape都为[batch_size, num_units]; 2. MultiRNNCell: 返回一个含有中间状态的tuple, 长度为cells的数量, 如包含两个DropoutWrapper, 则该tuple的长度为2, 其中的每个元素为LSTMStateTuple, LSTMStateTuple具体见上条.'''
执行RNNCell的方法
cell.call()
即对创建的对象传入参数进行调用, 对象的call方法中实现了RNNCell的执行方法, 输入当前时间的输入和上个时间的中间状态(初始时为全0), 得到这个时间的输出和中间状态(输出即为中间状态h).
'''参数: inputs: 一个时间点t的输入变量, 即inputs[:, time_step, :]; state: 中间状态Tensor, 初始时传入zero_state()方法生成的初始状态, 中间传入上一步执行时产生的状态 scope: 在过程中创建的变量使用在此scope之下.输出: output: t时刻结果Tensor, shape=[batch_size, num_units], 如果cell为MultiRNNCell, 输出为最后一层cell的输出; state: t时刻更新后的中间状态, 形式以及结果大小同zero_state()方法相同, 即如果cell为MultiRNNCell, 与output不同, 会按顺序记录每一层cell的中间状态在最后输出的tuple中.'''
tf.nn.dynamic_run()
与直接调用call方法每次传入一个时间点t的数据不同, 此方法一次传入按时间展开的所有数据, 即input的shape为[batch_size, max_time, size], 返回的结果也是所有时间点t的总输出, shape为[batch_size, max_time, size], 中间状态state与之前相同, 为最后一个时间输出的state, 如果cell为多层, 这个state包含每一层的结果, 为一个tuple.
'''参数: cell: (must)RNNCell; inputs: (must)包含每个时间点t的总输入Tensor, 当time_major=Fasle时, shape=[batch_size, max_time, size], 当time_major=True时, shape=[max_time, batch_size, size]: sequence_length: list of int, 长度为batch_size, 指定batch中的每个样本在时间上展开的长度; initial_state: 初始状态, 传入zero_state()方法生成的初始状态; dtype: initial_state和outputTensor的数据类型; parallel_iterations: 并行同时迭代的数量(不明这里的迭代指的是什么?); swap_memory: (好像是控制前向传播生成的Tensor在后向传播使用时, GPU和CPU之间的数据传递关系); time_major: bool, 控制inputs的shape的参数; scope: 在过程中创建的变量使用在此scope之下, 默认为'run'.输出: output: batch中所有样本在最后一层(对应每个样本的sequence意义下)的输出; time_major=False: [batch_size, max_time, size]; time_major=True: [max_time, batch_size, size]; state: 与__call__()方法在最后一个时间点的state相同.'''
- Tensorflow: RNNCell
- tensorflow中RNNcell源码分析以及自定义RNNCell的方法
- Tensorflow RNN源代码解析笔记1:RNNCell的基本实现
- Tensorflow ValueError: Attempt to reuse RNNCell with a different variable scope than its first
- tensorflow
- TensorFlow
- TensorFlow
- tensorflow
- tensorflow
- tensorflow
- Tensorflow
- Tensorflow
- tensorFlow
- tensorflow
- Tensorflow
- TensorFlow
- tensorflow
- Tensorflow
- sklearn-GridSearchCV,CV调节超参使用方法
- svn搭建
- C++指针函数和函数指针
- Merge k sorted lists
- Android高级控件系列三之第三方控件XListView下拉刷新实现代码
- Tensorflow: RNNCell
- 管理Activity生命周期
- 说说QtQuick提供的类型
- 数据结构之单链表
- 最小gcc.exe编译器(C语言)
- Spark SQL下Parquet内幕深度解密
- Android 拨打电话
- 胡语录6.10
- go json数据格式化输出