Tensorflow: RNNCell

来源:互联网 发布:js获取对象的第一个值 编辑:程序博客网 时间:2024/06/08 13:00

这里记录Tensorflow一些常见的RNNCell, 所有的RNNCell都在tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl这个模块中.

  • tf.contrib.rnn.BasicLSTMCell

    基础的LSTM Cell, LSTM中最常见的Cell结构.

    这里的Cell限制: 输入变量x与中间状态h, C中的特征数量相同(或理解为含有的神经元相同), 为参数num_units定义.

    对于t时刻的cell状态, 接受的输入有两个, 一个是输入input(t), 另一个是中间变量context(t-1), 输出为output(t), 通过cell的结构可以看出, output(t)即是新的中间变量context(t). 因此x, h具有相同的shape: [num_units].

    因此, 对于cell的中某个权值W, 与num_units长度的x相乘, 也于num_units长度的h相乘, 获得的结果也是num_units长度的, 因此对于x的权值Wx的shape为(num_units, num_units), 同理对于h的权值Wh的shape也为(num_units, num_units), 将两者叠加在一起, 就为单个激活函数前的连接权值W, shape为(2*num_units, num_units). 同时, 相应的偏置向量长度为(num_units, ).

    又已知基础的BasicLSTMCell中有四个需要这样计算的激活函数前的网络, 因此有4套Wb参数, 再将这些参数组合起来, 得到最终的Wb的大小:
    W的shape为(2*num_units, 4*num_units);
    b的shape为(4*num_units, ).

    以上是对BasicLSTMCell的解释.

'''参数:    num_units: cell中网络的神经元个数, 同时也是输入和中间变量和输出长度.    forget_bias: float, 添加到遗忘门中的偏置, 为了减少在开始训练时遗忘的规模。    input_size: input数据的shape, 废弃使用; 一般input_size为[batch_size, max_time, size], batch_size是每次批量传入的样本数量, max_time表示在时间轴上推进的次数(语句分析代表最长的那句话是多长?), size代表输入的特征数量, 即embedding_size, 也即为num_units;    state_is_tuple: bool, 定义调用时接受和返回的中间变量的格式. True返回两个tuple, 分别代表中间变量C和中间变量h, shape为[batch_size, num_units], False则是将两个tuple在列方向上合并的结构, shape为[batch_size, 2*num_units];    activation: 输出门的激活函数;    reuse: bool, 是否复用Cell中的变量.'''
  • tf.contrib.rnn.DropoutWrapper

    在给定的cell上的对输入和输出添加dropout操作. 这里的dropout指的是两层LSTM之间, 同一时刻t的cell之间传递时进行dropout操作.

'''参数:    cell: (must)需要添加Dropout操作的RNNCell;    input_keep_prob: 输入进行dropout操作的保留概率;    output_keep_prob: 输出进行dropout操作的保留概率;    state_keep_prob: 对中间状态进行dropout操作的保留概率(不推荐设置);    variational_recurrent: bool, True则在每次调用__call__方法时, 使用同样的dropout方式, 同时要求指明了input_size参数;    input_size: 输入的shape;    dtype: 输入, 中间变量和输出的数据类型;    seed: 随机种子'''lstm_cell = tf.contrib.rnn.BasicLSTMCell(size, forget_bias=0.0)lstm_cell = tf.contrib.rnn.DropoutWrapper(lstm_cell, output_keep_prob=config.keep_prob)
  • tf.contrib.rnn.MultiRNNCell

    堆叠多个RNNCell, 形成层级, 成为一个整体的多层级RNNCell.

'''参数:    cells: (must)list of RNNCells, 按先后顺序排列;    state_is_tuple: bool, True则调用__call__方法时传入和输出的中间状态为n个tuple, n为cells的长度, 即RNNCell的个数, False则将这些中间状态在列方向上拼接起来.'''

方法

  • cell.zero_state()

经过DropoutWrapper, MultiRNNCell生成的cell调用这个方法时, 真正调用的是传入的RNNCell的zero_state方法.

'''作用: 返回0值的中间变量Tensors.参数:    batch_size: 一个batch样本的数量;    dtype: 输出的Tensor的数据类型.输出:    1. RNNCell, DropoutWrapper: 一个LSTMStateTuple, 这个对象相当于一个tuple, 长度为2, 含有中间状态C和h的值, 分别为一个Tensor, 这两个Tensor的shape都为[batch_size, num_units];    2. MultiRNNCell: 返回一个含有中间状态的tuple, 长度为cells的数量, 如包含两个DropoutWrapper, 则该tuple的长度为2, 其中的每个元素为LSTMStateTuple, LSTMStateTuple具体见上条.'''

执行RNNCell的方法

  1. cell.call()

    即对创建的对象传入参数进行调用, 对象的call方法中实现了RNNCell的执行方法, 输入当前时间的输入和上个时间的中间状态(初始时为全0), 得到这个时间的输出和中间状态(输出即为中间状态h).

'''参数:    inputs: 一个时间点t的输入变量, 即inputs[:, time_step, :];    state: 中间状态Tensor, 初始时传入zero_state()方法生成的初始状态, 中间传入上一步执行时产生的状态    scope: 在过程中创建的变量使用在此scope之下.输出:    output: t时刻结果Tensor, shape=[batch_size, num_units], 如果cell为MultiRNNCell, 输出为最后一层cell的输出;    state: t时刻更新后的中间状态, 形式以及结果大小同zero_state()方法相同, 即如果cell为MultiRNNCell, 与output不同, 会按顺序记录每一层cell的中间状态在最后输出的tuple中.'''
  1. tf.nn.dynamic_run()

    与直接调用call方法每次传入一个时间点t的数据不同, 此方法一次传入按时间展开的所有数据, 即input的shape为[batch_size, max_time, size], 返回的结果也是所有时间点t的总输出, shape为[batch_size, max_time, size], 中间状态state与之前相同, 为最后一个时间输出的state, 如果cell为多层, 这个state包含每一层的结果, 为一个tuple.

'''参数:    cell: (must)RNNCell;    inputs: (must)包含每个时间点t的总输入Tensor, 当time_major=Fasle时, shape=[batch_size, max_time, size], 当time_major=True时, shape=[max_time, batch_size, size]:    sequence_length: list of int, 长度为batch_size, 指定batch中的每个样本在时间上展开的长度;    initial_state: 初始状态, 传入zero_state()方法生成的初始状态;    dtype: initial_state和outputTensor的数据类型;    parallel_iterations: 并行同时迭代的数量(不明这里的迭代指的是什么?);    swap_memory: (好像是控制前向传播生成的Tensor在后向传播使用时, GPU和CPU之间的数据传递关系);    time_major: bool, 控制inputs的shape的参数;    scope: 在过程中创建的变量使用在此scope之下, 默认为'run'.输出:    output: batch中所有样本在最后一层(对应每个样本的sequence意义下)的输出;        time_major=False: [batch_size, max_time, size];        time_major=True: [max_time, batch_size, size];    state: 与__call__()方法在最后一个时间点的state相同.'''
原创粉丝点击