tf.contrib.seq2seq.sequence_loss example:seqence loss 实例代码

来源:互联网 发布:淘宝购买充气娃娃图 编辑:程序博客网 时间:2024/06/05 05:20

tf.contrib.seq2seq.sequence_loss example:seqence loss 实例代码

#!/usr/bin/env python# -*- coding: utf-8 -*-import tensorflow as tfimport numpy as npparams=np.random.normal(loc=0.0,scale=1.0,size=[10,10])encoder_inputs=tf.placeholder(dtype=tf.int32,shape=[10,10])decoder_inputs=tf.placeholder(dtype=tf.int32,shape=[10,10])logits=tf.placeholder(dtype=tf.float32,shape=[10,10,10])targets=tf.placeholder(dtype=tf.int32,shape=[10,10])weights=tf.placeholder(dtype=tf.float32,shape=[10,10])train_encoder_inputs=np.ones(shape=[10,10],dtype=np.int32)train_decoder_inputs=np.ones(shape=[10,10],dtype=np.int32)train_weights=np.ones(shape=[10,10],dtype=np.float32)num_encoder_symbols=10num_decoder_symbols=10embedding_size=10cell=tf.nn.rnn_cell.BasicLSTMCell(10)def seq2seq(encoder_inputs,decoder_inputs,cell,num_encoder_symbols,num_decoder_symbols,embedding_size):    encoder_inputs = tf.unstack(encoder_inputs, axis=0)    decoder_inputs = tf.unstack(decoder_inputs, axis=0)    results,states=tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(    encoder_inputs,    decoder_inputs,    cell,    num_encoder_symbols,    num_decoder_symbols,    embedding_size,    output_projection=None,    feed_previous=False,    dtype=None,    scope=None)    return resultsdef get_loss(logits,targets,weights):    loss=tf.contrib.seq2seq.sequence_loss(        logits,        targets=targets,        weights=weights    )    return lossresults=seq2seq(encoder_inputs,decoder_inputs,cell,num_encoder_symbols,num_decoder_symbols,embedding_size)logits=tf.stack(results,axis=0)print(logits)loss=get_loss(logits,targets,weights)with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    results_value=sess.run(results,feed_dict={encoder_inputs:train_encoder_inputs,decoder_inputs:train_decoder_inputs})    print(type(results_value[0]))    print(len(results_value))    cost = sess.run(loss, feed_dict={encoder_inputs: train_encoder_inputs, targets: train_decoder_inputs,                                     weights:train_weights,decoder_inputs:train_decoder_inputs})    print(cost)

更多资源,代码,教程:
http://www.tensorflownews.com/

原创粉丝点击