TensorFlow RNN 相关类与方法

来源:互联网 发布:centos 搜索文件内容 编辑:程序博客网 时间:2024/06/06 05:56

RNN又包括LSTM, GRU等.

1. RNN

1.1 cell

  • tensorflow.python.layers.base.Layer
    类, 表示网络中的一层.

  • tensorflow.python.ops.rnn_cell_impl.RNNCell(base_layer.Layer)
    抽象类, 表示RNN中的一层. 叫作cell更形象.

  • RNNCell#zero_state(self, batch_size, dtype)
    返回 zero-filled state tensor(s).

1.2 MultiRNNCell

  • tf.nn.rnn_cell.MultiRNNCell(RNNCell)
    类.
    RNN cell composed sequentially of multiple simple cells.
    __init__(self, cells, state_is_tuple=True)
    构造函数.
    cells: list of RNNCells that will be composed in this order.

1.3 dynamic_rnn

  • tf.nn.dynamic_rnn(cell, inputs, ...)
    Creates a recurrent neural network specified by RNNCell cell. Performs fully dynamic unrolling of inputs.
    返回的是(outputs, state), output_tensor 的 shape是(BATCH_SIZE, TIME_STEPS, NUM_UNITS). 像预测 sin x 之类的问题, 关注的是 最后一个时刻的输出结果, 那么可以:
"""    before: output_rnn.shape = (5,10,8), namely (BATCH_SIZE, TIME_STEPS, NUM_UNITS)    after:  output_rnn.shape = (5,8) , namely (BATCH_SIZE , NUM_UNITS)"""    output_rnn=output_rnn[:,-1,:]
  • tf.nn.rnn_cell.DropoutWrapper(RNNCell)
    类. 在 cell 的输入与输出中添加 dropout 操作.
    __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0, state_keep_prob=1.0, ...)
    构造函数.
    keep_prob 即 1- dropout_prob.

  • tf.nn.dropout(x, keep_prob, ...)
    x : a tensor.

2. tensorflow 处理input与output

涉及到 TIME_STEP 的地方, tf中一般这么处理:

X = tf.placeholder(tf.float32, shape=[None, TIME_STEPS, INPUT_SIZE])Y = tf.placeholder(tf.float32, shape=[None, OUTPUT_SIZE])def lstm_model(input_tensor,out_tensor):    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(NUM_UNITS, state_is_tuple=True)    # 需要将tensor转成2维进行计算,计算后的结果作为隐藏层的输入    input_tensor=tf.reshape(input_tensor, [-1, INPUT_SIZE])      w_in=tf.get_variable('weight_input',shape=[INPUT_SIZE,NUM_UNITS],initializer=tf.truncated_normal_initializer())    # 注意 shape=[NUM_UNITS,] 与 shape=[NUM_UNITS,1] 的严重不同!    b_in=tf.get_variable('bias_input',shape=[NUM_UNITS, ],initializer=tf.zeros_initializer())    input_rnn = tf.matmul(input_tensor, w_in) + b_in    # 将tensor转成3维,作为lstm cell的输入    input_rnn = tf.reshape(input_rnn, [-1, TIME_STEPS, NUM_UNITS])      # state 的大小一般也是 BATCH_SIZE 的大小       init_state = cell.zero_state(BATCH_SIZE, dtype=tf.float32)    # output_rnn.shape=(BATCH_SIZE, TIME_STEPS, NUM_UNITS)    output_rnn, final_states = tf.nn.dynamic_rnn(lstm_cell , input_rnn,initial_state=init_state, dtype=tf.float32)    # 在本问题中只关注最后一个时刻的输出结果,output_rnn.shape=(BATCH_SIZE , NUM_UNITS)    output_rnn=output_rnn[:,-1,:]    # 创建一个全连接层,因为是回归问题, 输出的维度为1,None指的是不使用激活函数    predictions = tf.contrib.layers.fully_connected(output_rnn, 1, None)

3. LSTM

  • tensorflow.python.ops.rnn_cell_impl.BasicLSTMCell(RNNCell)
    类. 通常通过tf.nn.rnn_cell.BasicLSTMCell引用.
    __init__(self, num_units, forget_bias=1.0, state_is_tuple=True, activation=None, reuse=None)
    构造函数.

4.GRU

tensorflow.python.ops.rnn_cell_impl.GRUCell(RNNCell)
类.