tensorflow(一):tf.contrib.seq2seq.GreedyEmbeddingHelper
来源:互联网 发布:数据融合主要技术 编辑:程序博客网 时间:2024/05/16 06:17
简介
最近在用tensorflow搞seq2seq,遇到了不少问题。首先就是tf.contrib.seq2seq
和tf.contrib.legacy_seq2seq
到底用哪个?查最新版api可以发现tf.contrib.legacy_seq2seq
已经被抛弃,这时你会想,选tf.contrib.seq2seq
不就好了。然而,悲剧的是github、csdn上的例子全是tf.contrib.legacy_seq2seq
的例子,而且运行 tensorflow/models下tf.contrib.legacy_seq2seq
的例子会报错can’t pickle _thread.lock objects。本着迎难而上的准则,开始探索tf.contrib.seq2seq
,顺便记录我踩过的坑。为了书写简单,在接下来的介绍中,若不加前缀,则默认指tf.contrib.seq2seq
,例如GreedyEmbeddingHelper
指tf.contrib.seq2seq.GreedyEmbeddingHelper
。
系统环境
>>> import sys>>> import tensorflow as tf>>> print(sys.version)3.6.0 |Anaconda 4.3.1 (64-bit)| (default, Dec 23 2016, 12:22:00) \n[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]>>> print(tf.__version__)1.3.0
GreedyEmbeddingHelper
本节主要记录我在使用GreedyEmbeddingHelper
踩过的坑。
介绍
介绍GreedyEmbeddingHelper
要从Helper
开始,因为所有“…..Helper”都来自于它。Helper
是seq2seq中decoder采样的接口,且其实例对象会被BasicDecoder
调用。简单而言就是,开发者把decoder采样的过程抽象出来,方便后来的人使用(大神们,牛逼!)。以Helper
为基础,tensorflow中延伸了很多类,结构如下所示
其中红线表示继承关系。以TrainingHelper
为代表的类控制训练过程,包括Scheduled和非Schedule两种方式。以GreedyEmbeddingHelper
为代表的类用于贪心编码,一般用于预测。
报错
在使用GreedyEmbeddingHelper
的过程中遇到的问题是:
Traceback (most recent call last): File "tutorial#2.py", line 217, in <module> model = Model(vocab_size) File "tutorial#2.py", line 172, in __init__ output_time_major=False) File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 286, in dynamic_decode swap_memory=swap_memory) File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2775, in while_loop result = context.BuildLoop(cond, body, loop_vars, shape_invariants) File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2604, in BuildLoop pred, body, original_loop_vars, loop_vars, shape_invariants) File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2554, in _BuildLoop body_result = body(*packed_vars_for_body) File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 234, in body decoder_finished) = decoder.step(time, inputs, state) File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py", line 139, in step cell_outputs, cell_state = self._cell(inputs, state) File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 180, in __call__ return super(RNNCell, self).__call__(inputs, state) File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 450, in __call__ outputs = self.call(inputs, *args, **kwargs) File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 401, in call concat = _linear([inputs, h], 4 * self._num_units, True) File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1027, in _linear "but saw %s" % (shape, shape[1]))ValueError: linear expects shape[1] to be provided for shape (3, ?), but saw ?
原因
在Debug中,错误定位在172行,对应代码
pred_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder=decoder, maximum_iterations=30, impute_finished=False, output_time_major=False)
刚开始一直查tensorflow的API,是不是函数定义有问题,查了很久发现定义并没有错误。这个时候,我开始怀疑decoder
是不是有问题,顺着这个思路我开始检查decoder
decoder = tf.contrib.seq2seq.BasicDecoder( cell=self.decoder_cell, helper=pred_helper, initial_state=encoder_final_state)
还是先查api,确定定义没有问题,然后通过排除法猜测是pred_helper
出了问题,定义代码如下:
self.decoder_inputs = tf.placeholder( shape=(None, None), dtype=tf.int32, name='decoder_inputs')pred_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( self.decoder_inputs, start_tokens=tf.fill([batch_size], SOS_ID), end_token=EOS_ID)
熟悉API的同学一眼就能发现,GreedyEmbeddingHelper
第一个参数是embedding
并不是decoder的输入,所以导致tensor的shape一直不对。其实从seq2seq的理论角度,这里显然也不应该是decoder_inputs
,因为预测时目标是不知道的。通过这个Bug,收获良多,哈哈…..
- tensorflow(一):tf.contrib.seq2seq.GreedyEmbeddingHelper
- tf.contrib.seq2seq.sequence_loss example:seqence loss 实例代码
- #tensorflow学习笔记#tf.contrib.framework.get_or_create_global_step
- tensorflow之tf.contrib.learn Quickstart
- TensorFlow-4: tf.contrib.learn 快速入门
- tensorflow教程:tf.contrib.rnn.DropoutWrapper
- 引入tf-seq2seq:TensorFlow中开源序列到序列框架
- tf.contrib
- TensorFlow学习笔记6----tf.contrib.learn Quickstart
- [TensorFlow实战练习]3-高层API-tf.contrib.learn练习
- tensorflow学习笔记(六):TF.contrib.learn大杂烩
- TensorFlow学习笔记12----Creating Estimators in tf.contrib.learn
- tensorflow中tf.contrib.learn.preprocessing.VocabularyProcessor理解
- tensorflow图片归一化之tf.layers.batch_normalization/tf.nn.batch_normalization/tf.contrib.layers.batch_norm
- tensorflow学习——tf.layers.batch_normalization/tf.nn.batch_normalization/tf.contrib.layers.batch_norm
- tensorflow学习笔记十四:TF官方教程学习 tf.contrib.learn Quickstart
- tensorflow学习笔记十五:tensorflow官方文档学习 Logging and Monitoring Basics with tf.contrib.learn
- tf.contrib.slim
- Mabits的PageHelper分页插件的使用和Jsp分页页面展示(带源码)
- 深度学习之数学基础(数值计算)
- spring异常java.lang.IllegalStateException
- 1046. 划拳(15)
- java.lang.OutOfMemoryError异常解决方法
- tensorflow(一):tf.contrib.seq2seq.GreedyEmbeddingHelper
- 51Nod-1621-花钱买车牌
- docker入门1 : 使用docker镜像
- 1047. 编程团体赛(20)
- 人类的特征
- 552. Student Attendance Record II
- sql查询重复记录、删除重复记录方法大全
- 矩阵、向量求导法则(转载)
- PHP中VC6、VC9、TS、NTS版本的区别与用法详解