debug tensorflow的seq2seq的attention_decoder方法

来源:互联网 发布:网络用语马克啥意思 编辑:程序博客网 时间:2024/05/16 00:33

写这个attention_decoder的testcase来用debug的方式看看注意力机制的实现

import tensorflow as tffrom tensorflow.python.ops import rnnfrom tensorflow.python.ops import rnn_cellfrom tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq as seq2seq_libwith tf.Session() as sess:    batch_size = 16    step1 = 20    step2 = 10    input_size = 50    output_size = 40    gru_hidden = 30    cell_fn = lambda: rnn_cell.GRUCell(gru_hidden)    cell = cell_fn()    inp = [tf.constant(0.8, shape=[batch_size, input_size])] * step1    enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=tf.float32)    attn_states = tf.concat([        tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs    ], 1)    dec_inp = [tf.constant(0.3, shape=[batch_size, output_size])] * step2    dec, mem = seq2seq_lib.attention_decoder(        dec_inp, enc_state, attn_states, cell_fn(), output_size=7)    sess.run([tf.global_variables_initializer()])    res = sess.run(dec)    print(len(res))    print(res[0].shape)    res = sess.run([mem])    print(len(res))    print(res[0].shape)

改编自https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py