seq2seq代码部分解析

来源:互联网 发布:什么营销软件最好 编辑:程序博客网 时间:2024/06/11 10:28
encoder_outputs, encoder_state = rnn.rnn(        encoder_cell, encoder_inputs, dtype=dtype)
top_states = [array_ops.reshape(e, [-1, 1, cell.output_size])                  for e in encoder_outputs]

encoder_outputs的维度取决于input的维度,与batch_size无关.若输入为1*1,则有多少个hidden_unit,应该就有多少输出,输出应为1*hidden_unit.

例如:若encoder_inputs=[input_size,batch_size],则输出encoder_outputs的维度应该为[input_size,hidden_units]

原创粉丝点击