seq2seq 做翻译,预测时 用argmax的原因

来源:互联网 发布:小号男士衣服淘宝店铺 编辑:程序博客网 时间:2024/05/22 12:56

在这里https://github.com/tensorflow/models/blob/master/tutorials/rnn/translate/translate.py#L281-L282

# This is a greedy decoder - outputs are just argmaxes of output_logits.outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]

最初很难理解,其实就是在所有词里选概率最大的

因为这段代码的存在https://github.com/tensorflow/models/blob/master/tutorials/rnn/translate/seq2seq_model.py#L168-L173

      if output_projection is not None:        for b in xrange(len(buckets)):          self.outputs[b] = [              tf.matmul(output, output_projection[0]) + output_projection[1]              for output in self.outputs[b]          ]

里面output的维数就是embedding之后的维数,比如1024
output_projection的维数是[1024,vocabulary_size]

0 0