LSTM GRU 得到所有的state 而不是最后一个state

来源:互联网 发布:python 函数定义输出 编辑:程序博客网 时间:2024/06/05 09:06
def custom_net(cell, inputs, init_state, timesteps, time_major=False, scope='custom_net_0'):    # convert to time major format    if not time_major:        inputs_tm = tf.transpose(inputs, [1, 0, -1],name="input_time_major")    # collection of states and outputs    states, outputs = [init_state], []    with tf.variable_scope(scope):        for i in range(timesteps):            if i > 0:                tf.get_variable_scope().reuse_variables()            output, state = cell(inputs_tm[i], states[-1])            outputs.append(output)            states.append(state)return tf.stack(outputs), tf.stack(states[1:])

代码摘自https://github.com/ai-guild/r-net/blob/master/lib/recurrence.py