tf.contrib.legacy_seq2seq.basic_rnn_seq2seq 函数 example 最简单实现

来源:互联网 发布:淘宝家具店招图片 编辑:程序博客网 时间:2024/06/18 16:46

tf.contrib.legacy_seq2seq.basic_rnn_seq2seq 函数 example 最简单实现

函数文档:https://www.tensorflow.org/api_docs/python/tf/contrib/legacy_seq2seq/basic_rnn_seq2seq

import tensorflow as tfimport numpy as npsteps=10batch_size=10input_size=10encoder_inputs = tf.placeholder("float", [None, steps, input_size])decoder_inputs = tf.placeholder("float", [None, steps, input_size])en_input=np.zeros(shape=[steps,batch_size,input_size])de_input=np.zeros(shape=[steps,batch_size,input_size])cell=tf.nn.rnn_cell.BasicLSTMCell(10)def get_result(encoder_inputs,decoder_inputs,cell):    encoder_inputs=tf.unstack(encoder_inputs,axis=1)    decoder_inputs=tf.unstack(decoder_inputs,axis=1)    result=tf.contrib.legacy_seq2seq.basic_rnn_seq2seq(        encoder_inputs,        decoder_inputs,        cell,        dtype=tf.float32,        scope=None    )    return resultresult=get_result(encoder_inputs,decoder_inputs,cell)init=tf.global_variables_initializer()with tf.Session() as sess:    sess.run(init)    result_value=sess.run(result,feed_dict={encoder_inputs:en_input,decoder_inputs:de_input})    print(result_value)

http://www.tensorflownews.com/

原创粉丝点击