tensorflow(一):tf.contrib.seq2seq.GreedyEmbeddingHelper

来源:互联网 发布:数据融合主要技术 编辑:程序博客网 时间:2024/05/16 06:17

简介

最近在用tensorflow搞seq2seq,遇到了不少问题。首先就是tf.contrib.seq2seqtf.contrib.legacy_seq2seq到底用哪个?查最新版api可以发现tf.contrib.legacy_seq2seq已经被抛弃,这时你会想,选tf.contrib.seq2seq不就好了。然而,悲剧的是github、csdn上的例子全是tf.contrib.legacy_seq2seq的例子,而且运行 tensorflow/models下tf.contrib.legacy_seq2seq的例子会报错can’t pickle _thread.lock objects。本着迎难而上的准则,开始探索tf.contrib.seq2seq,顺便记录我踩过的坑。为了书写简单,在接下来的介绍中,若不加前缀,则默认指tf.contrib.seq2seq,例如GreedyEmbeddingHelpertf.contrib.seq2seq.GreedyEmbeddingHelper

系统环境

>>> import sys>>> import tensorflow as tf>>> print(sys.version)3.6.0 |Anaconda 4.3.1 (64-bit)| (default, Dec 23 2016, 12:22:00) \n[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]>>> print(tf.__version__)1.3.0

GreedyEmbeddingHelper

本节主要记录我在使用GreedyEmbeddingHelper踩过的坑。

介绍

介绍GreedyEmbeddingHelper要从Helper开始,因为所有“…..Helper”都来自于它。Helper是seq2seq中decoder采样的接口,且其实例对象会被BasicDecoder调用。简单而言就是,开发者把decoder采样的过程抽象出来,方便后来的人使用(大神们,牛逼!)。以Helper为基础,tensorflow中延伸了很多类,结构如下所示

继承关系

其中红线表示继承关系。以TrainingHelper为代表的类控制训练过程,包括Scheduled和非Schedule两种方式。以GreedyEmbeddingHelper为代表的类用于贪心编码,一般用于预测。

报错

在使用GreedyEmbeddingHelper的过程中遇到的问题是:

Traceback (most recent call last):  File "tutorial#2.py", line 217, in <module>    model = Model(vocab_size)  File "tutorial#2.py", line 172, in __init__    output_time_major=False)  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 286, in dynamic_decode    swap_memory=swap_memory)  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2775, in while_loop    result = context.BuildLoop(cond, body, loop_vars, shape_invariants)  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2604, in BuildLoop    pred, body, original_loop_vars, loop_vars, shape_invariants)  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2554, in _BuildLoop    body_result = body(*packed_vars_for_body)  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 234, in body    decoder_finished) = decoder.step(time, inputs, state)  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py", line 139, in step    cell_outputs, cell_state = self._cell(inputs, state)  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 180, in __call__    return super(RNNCell, self).__call__(inputs, state)  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 450, in __call__    outputs = self.call(inputs, *args, **kwargs)  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 401, in call    concat = _linear([inputs, h], 4 * self._num_units, True)  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1027, in _linear    "but saw %s" % (shape, shape[1]))ValueError: linear expects shape[1] to be provided for shape (3, ?), but saw ?

原因

在Debug中,错误定位在172行,对应代码

pred_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(    decoder=decoder, maximum_iterations=30,    impute_finished=False,    output_time_major=False)

刚开始一直查tensorflow的API,是不是函数定义有问题,查了很久发现定义并没有错误。这个时候,我开始怀疑decoder是不是有问题,顺着这个思路我开始检查decoder

decoder = tf.contrib.seq2seq.BasicDecoder(    cell=self.decoder_cell,    helper=pred_helper,    initial_state=encoder_final_state)

还是先查api,确定定义没有问题,然后通过排除法猜测是pred_helper出了问题,定义代码如下:

self.decoder_inputs = tf.placeholder(    shape=(None, None), dtype=tf.int32, name='decoder_inputs')pred_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(    self.decoder_inputs,    start_tokens=tf.fill([batch_size], SOS_ID),    end_token=EOS_ID)

熟悉API的同学一眼就能发现,GreedyEmbeddingHelper第一个参数是embedding并不是decoder的输入,所以导致tensor的shape一直不对。其实从seq2seq的理论角度,这里显然也不应该是decoder_inputs,因为预测时目标是不知道的。通过这个Bug,收获良多,哈哈…..

阅读全文
0 0