tf.contrib.seq2seq.sequence_loss example:seqence loss 实例代码
来源:互联网 发布:淘宝购买充气娃娃图 编辑:程序博客网 时间:2024/06/05 05:20
tf.contrib.seq2seq.sequence_loss example:seqence loss 实例代码
#!/usr/bin/env python# -*- coding: utf-8 -*-import tensorflow as tfimport numpy as npparams=np.random.normal(loc=0.0,scale=1.0,size=[10,10])encoder_inputs=tf.placeholder(dtype=tf.int32,shape=[10,10])decoder_inputs=tf.placeholder(dtype=tf.int32,shape=[10,10])logits=tf.placeholder(dtype=tf.float32,shape=[10,10,10])targets=tf.placeholder(dtype=tf.int32,shape=[10,10])weights=tf.placeholder(dtype=tf.float32,shape=[10,10])train_encoder_inputs=np.ones(shape=[10,10],dtype=np.int32)train_decoder_inputs=np.ones(shape=[10,10],dtype=np.int32)train_weights=np.ones(shape=[10,10],dtype=np.float32)num_encoder_symbols=10num_decoder_symbols=10embedding_size=10cell=tf.nn.rnn_cell.BasicLSTMCell(10)def seq2seq(encoder_inputs,decoder_inputs,cell,num_encoder_symbols,num_decoder_symbols,embedding_size): encoder_inputs = tf.unstack(encoder_inputs, axis=0) decoder_inputs = tf.unstack(decoder_inputs, axis=0) results,states=tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq( encoder_inputs, decoder_inputs, cell, num_encoder_symbols, num_decoder_symbols, embedding_size, output_projection=None, feed_previous=False, dtype=None, scope=None) return resultsdef get_loss(logits,targets,weights): loss=tf.contrib.seq2seq.sequence_loss( logits, targets=targets, weights=weights ) return lossresults=seq2seq(encoder_inputs,decoder_inputs,cell,num_encoder_symbols,num_decoder_symbols,embedding_size)logits=tf.stack(results,axis=0)print(logits)loss=get_loss(logits,targets,weights)with tf.Session() as sess: sess.run(tf.global_variables_initializer()) results_value=sess.run(results,feed_dict={encoder_inputs:train_encoder_inputs,decoder_inputs:train_decoder_inputs}) print(type(results_value[0])) print(len(results_value)) cost = sess.run(loss, feed_dict={encoder_inputs: train_encoder_inputs, targets: train_decoder_inputs, weights:train_weights,decoder_inputs:train_decoder_inputs}) print(cost)
更多资源,代码,教程:
http://www.tensorflownews.com/
阅读全文
0 0
- tf.contrib.seq2seq.sequence_loss example:seqence loss 实例代码
- tensorflow(一):tf.contrib.seq2seq.GreedyEmbeddingHelper
- tf.contrib.legacy_seq2seq.basic_rnn_seq2seq 函数 example 最简单实现
- tf.contrib
- seq2seq 实例
- tf.contrib.slim
- tf.contrib.slim
- tf.contrib.layers.xavier_initializer
- tf.contrib.layers.embed_sequence
- Grunt Contrib Watch Example
- tf.contrib.rnn.BasicLSTMCell, tf.contrib.rnn.MultiRNNCell深度解析
- tf.contrib.learn.preprocessing.VocabularyProcessor
- tf.contrib.learn快速入门
- tf.nn.seq2seq.sequence_loss_by_example函数用法
- seq2seq代码部分解析
- tensorflow seq2seq代码学习
- #tensorflow学习笔记#tf.contrib.framework.get_or_create_global_step
- tensorflow之tf.contrib.learn Quickstart
- xUtils下载断点续传
- 设计模式之中介者模式
- 3DSlicer19: DirectoryStructure
- 服务器端口
- 全局变量与单例模式
- tf.contrib.seq2seq.sequence_loss example:seqence loss 实例代码
- 使用typings让VSCode拓展代码提示
- Cogs 1901. [国家集训队2011]数颜色 bzoj2120(分块 有修改的分块)
- 算法--递归
- python 列表和元组的区别
- HDU6090 Rikka with Graph 找规律签到
- Cocos2d-x v3.0 lambda表达式的使用
- 在Ubuntu上禁用IPv6
- 求一个无序数组的中位数