seq2seq

来源:互联网 发布:联想软件商店 编辑:程序博客网 时间:2024/05/30 05:17

seq2seq是通用编码器-解码器框架(encoder-decoder framework),可以用在机器翻译,文本摘要,会话建模,图像描述。

源码

https://github.com/google/seq2seq
tensorflow 中的实现代码如下:
tensorflow/tensorflow/python/ops/seq2seq.py

基本模型

基于论文https://arxiv.org/pdf/1409.3215.pdf,结构如图1.
这里写图片描述
图1
图中编码器接收序列[A, B, C],输出序列输出序列[W, X, Y, Z, < eos > ]
< eos >是end of sentence的缩写,早期seq用在统计机器翻译中,用于标识输入语句的结束。

LSTMCell

encoder_cell = tf.contrib.rnn.LSTMCell(encoder_hidden_units)

tf.contrib.rnn.LSTMCell是一个类,其和tf.contrib.rnn.core_rnn_cell.LSTMCell都定义于。

tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py.
其init方法如下:

init

__init__(    num_units,    input_size=None,    use_peepholes=False,    cell_clip=None,    initializer=None,    num_proj=None,    proj_clip=None,    num_unit_shards=None,    num_proj_shards=None,    forget_bias=1.0,    state_is_tuple=True,    activation=tf.tanh,    reuse=None)

其中,num_units:是LSTM Cell中unit总数。
use_peephole:指示的是是否使用对角连接,其结构基于论文:
https://research.google.com/pubs/archive/43905.pdf
cell_clip:浮点值,如果初始化时被复制,则cell的状态值将被该值削顶,削顶后的值被送到cell的激活函数。
initializer:权重和投影矩阵的初始化函数。

call方法

call(
inputs,
state,
scope=None
)
运行一次LSTM,

zero_state

zero_state(    batch_size,    dtype)

将所有的状态设置成零。

core_rnn_cell_impl.py

class BasicRNNCell(RNNCell):class GRUCell(RNNCell):class BasicLSTMCell(RNNCell):class LSTMCell(RNNCell):class _SlimRNNCell(RNNCell):

可以看出,不论是何种的RNN结构,其都会继承RNNCell这个抽象类并实现该类中定义的方法。

LSTMCell默认实现的是non-peephole结构,即如下论文的结构。

  S. Hochreiter and J. Schmidhuber.  "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.

当peephole设置成True时,结构基于:

https://research.google.com/pubs/archive/43905.pdf

初始化过程就是保存变量的

self._num_units = num_units//输入数据的维度self._use_peepholes = use_peepholesself._cell_clip = cell_clipself._initializer = initializer...

call是比较重要的函数,这里假设peephole设置成了flase。将实现代码改成如下:

 def __call__(self, inputs, state, scope=None):    """Run one step of LSTM.    Args:      inputs: input Tensor, 2D, batch x num_units.      state: if `state_is_tuple` is False, this must be a state Tensor,        `2-D, batch x state_size`.  If `state_is_tuple` is True, this must be a        tuple of state Tensors, both `2-D`, with column sizes `c_state` and        `m_state`.      scope: VariableScope for the created subgraph; defaults to "lstm_cell".    Returns:      A tuple containing:      - A `2-D, [batch x output_dim]`, Tensor representing the output of the        LSTM after reading `inputs` when previous state was `state`.        Here output_dim is:           num_proj if num_proj was set,           num_units otherwise.      - Tensor(s) representing the new state of LSTM after reading `inputs` when        the previous state was `state`.  Same type and shape(s) as `state`.    Raises:      ValueError: If input size cannot be inferred from inputs via        static shape inference.    """

dynamic_rnn

encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(encoder_cell, encoder_inputs_embedded, dtype=tf.float32, time_major=True,)

计算图

这里写图片描述

source code

1-seq2seq.py

import numpy as npimport tensorflow as tfimport helpers import matplotlib.pyplot as pltsess = tf.InteractiveSession()flags = tf.flagsflags.DEFINE_string("save_path", None, "Model output directory")FLAGS = flags.FLAGSPAD = 0EOS = 1def main(_):    vocab_size = 10    input_embedding_size = 20    encoder_hidden_units = 20    decoder_hidden_units = 20    with tf.name_scope('encoder_inputs'):        encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')    with tf.name_scope('decoder_targets'):        decoder_targets = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_targets')    with tf.name_scope('decoder_iputs'):        decoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_inputs')    with tf.name_scope('embeddings'):        embeddings = tf.Variable(tf.random_uniform([vocab_size, input_embedding_size], -0.1, 1.0), dtype=tf.float32)    with tf.name_scope('encoder_inputs_embedded'):        encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)    with tf.name_scope('decoder_inputs_embedded'):        decoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, decoder_inputs)    with tf.name_scope('encoder_cell'):        encoder_cell = tf.contrib.rnn.LSTMCell(encoder_hidden_units)    with tf.name_scope('encoder_dynamic'):        encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(encoder_cell, encoder_inputs_embedded, dtype=tf.float32, time_major=True,)    del encoder_outputs    encoder_final_state    with tf.name_scope('decoder_cell'):        decoder_cell = tf.contrib.rnn.LSTMCell(decoder_hidden_units)    with tf.name_scope('decoder_dynamic'):        decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(            decoder_cell, decoder_inputs_embedded, initial_state=encoder_final_state,            dtype=tf.float32, time_major=True, scope="plain_decoder",)    with tf.name_scope('decoder_logits'):        decoder_logits = tf.contrib.layers.linear(decoder_outputs, vocab_size)    with tf.name_scope('decoder_prediction'):        decoder_prediction = tf.argmax(decoder_logits, 2)    decoder_logits    with tf.name_scope('stepwise_cross_entropy'):        stepwise_cross_entropy =  tf.nn.softmax_cross_entropy_with_logits(            labels=tf.one_hot(decoder_targets, depth=vocab_size, dtype=tf.float32),            logits=decoder_logits,)    with tf.name_scope('loss'):        loss = tf.reduce_mean(stepwise_cross_entropy)    with tf.name_scope('train_op'):        train_op = tf.train.AdamOptimizer().minimize(loss)        sess.run(tf.global_variables_initializer())    sv = tf.train.Supervisor(logdir=FLAGS.save_path)    with sv.managed_session() as session:        batch_ = [[6], [3,4], [9,8,7]]        batch_, batch_length_ = helpers.batch(batch_)        #print('batch_encoded:\n' + str(batch_))        din_, dlen_ = helpers.batch(np.ones(shape=(3,1), dtype=np.int32), max_sequence_length=4)        #print('decoder inputs:\n' + str(din_))        pred_ = sess.run(decoder_prediction, feed_dict={            encoder_inputs:batch_,             decoder_inputs: din_,        })        if FLAGS.save_path:            sv.saver.save(session, FLAGS.save_path, global_step=sv.global_step)        #print('decoder predictions:\n' + str(pred_))    batch_size = 100    batches = helpers.random_sequences(length_from=3, length_to=8, vocab_lower=2, vocab_upper=10, batch_size=batch_size)    print('head of the batch:')#   for seq in next(batches)[:10]:#       print(seq)    def next_feed():        batch = next(batches)#       print('nex_feed batch EOS:{}'.format(EOS))        for seq in batch:            print(seq)        encoder_inputs_, _ = helpers.batch(batch)#       print('encode_input {}'.format(encoder_inputs_))        decoder_targets_, _ = helpers.batch(            [((sequence) + [EOS]) for sequence in batch]        )#       print('decoder_targets_{}'.format(decoder_targets_))        decoder_inputs_, _ = helpers.batch(            [ ([EOS] + (sequence)) for sequence in batch]        )#       print('decode_input {}'.format(decoder_inputs_))        return {            encoder_inputs: encoder_inputs_,            decoder_inputs: decoder_inputs_,            decoder_targets: decoder_targets_,        }    loss_track = []    max_batches = 3001    batches_in_epoch = 1000    try:        for batch in range(max_batches):            fd = next_feed()            _, l = sess.run([train_op, loss], fd)            loss_track.append(l)            if batch == 0 or batch % batches_in_epoch == 0:#               print('batch{}'.format(batch))#               print('minibatch loss: {}'.format(sess.run(loss, fd)))                predict_ = sess.run(decoder_prediction, fd)                for i, (inp, pred) in enumerate(zip(fd[encoder_inputs].T, predict_.T)):                    print(' sample{}:'.format(i+1))                    print('  input   > {}'.format(inp))                    print('  predicted > {}'.format(pred))                    if i >= 2:                        break    except KeyboardInterrupt:        print('training interrupted')    plt.plot(loss_track)#   plt.show()    print('loss {:.4f} after {} examples (batch_size={})'.format(loss_track[-1], len(loss_track)*batch_size, batch_size))if __name__ == "__main__":    tf.app.run()

helpers.py

import numpy as npdef batch(inputs, max_sequence_length=None):    """    Args:        inputs:            list of sentences (integer lists)        max_sequence_length:            integer specifying how large should `max_time` dimension be.            If None, maximum sequence length would be used    Outputs:        inputs_time_major:            input sentences transformed into time-major matrix             (shape [max_time, batch_size]) padded with 0s        sequence_lengths:            batch-sized list of integers specifying amount of active             time steps in each input sequence    """    sequence_lengths = [len(seq) for seq in inputs]    batch_size = len(inputs)    if max_sequence_length is None:        max_sequence_length = max(sequence_lengths)    inputs_batch_major = np.zeros(shape=[batch_size, max_sequence_length], dtype=np.int32) # == PAD    for i, seq in enumerate(inputs):        for j, element in enumerate(seq):            inputs_batch_major[i, j] = element    # [batch_size, max_time] -> [max_time, batch_size]    inputs_time_major = inputs_batch_major.swapaxes(0, 1)    return inputs_time_major, sequence_lengthsdef random_sequences(length_from, length_to,                     vocab_lower, vocab_upper,                     batch_size):    """ Generates batches of random integer sequences,        sequence length in [length_from, length_to],        vocabulary in [vocab_lower, vocab_upper]    """    if length_from > length_to:            raise ValueError('length_from > length_to')    def random_length():        if length_from == length_to:            return length_from        return np.random.randint(length_from, length_to + 1)    while True:        yield [            np.random.randint(low=vocab_lower,                              high=vocab_upper,                              size=random_length()).tolist()            for _ in range(batch_size)        ]

input和target

在next_feed函数里有三个输入量定义:

encoder_inputs: encoder_inputs_,decoder_inputs: decoder_inputs_,decoder_targets: decoder_targets_,

首先
encoder_inputs_产生长度是3到8之间,最小值是2,最大值是9的list,0做为padding标识,1最为eos标识。这里给出一个encoder_inputs的实例:

encoder_inputs = [[**7**, 6, 3, 8, 2, 2, 3, 4, 9, 8, 5, 4, 9, 3, 2, 6, 5, 3, 4, 7, 5, 6, 5, 9, 2, 8, 6, 3, 4, 6, 9, 7, 5, 3, 3, 7, 4, 9, 4, 6, 8, 2, 5, 9, 3, 9, 6, 7, 8, 8, 5, 4, 6, 4, 2, 5, 2, 2, 5, 9, 6, 5, 6, 5, 4, 5, 8, 3, 7, 5, 7, 3, 7, 4, 7, 9, 7, 9, 7, 3, 2, 4, 3, 2, 4, 7, 7, 8, 8, 2, 5, 3, 8, 3, 4, 3, 9, 9, 5, 5],[3, 4, 6, 2, 3, 2, 9, 8, 9, 4, 2, 5, 8, 5, 2, 3, 4, 8, 7, 6, 5, 7, 7, 4, 5, 2, 7, 3, 9, 2, 9, 2, 9, 6, 4, 2, 6, 4, 6, 9, 3, 3, 6, 2, 7, 7, 8, 7, 6, 4, 2, 9, 9, 8, 9, 4, 3, 7, 6, 3, 6, 2, 9, 9, 9, 6, 5, 4, 4, 8, 8, 3, 9, 8, 7, 7, 3, 9, 6, 7, 6, 2, 7, 6, 2, 9, 5, 9, 5, 7, 8, 4, 6, 2, 3, 2, 8, 2, 5, 4],[9, 4, 2, 5, 8, 5, 8, 7, 8, 7, 2, 2, 5, 2, 8, 6, 7, 2, 5, 5, 7, 3, 5, 3, 4, 9, 4, 2, 2, 4, 8, 2, 9, 5, 8, 5, 5, 7, 6, 3, 5, 6, 4, 5, 4, 4, 4, 6, 3, 6, 6, 5, 9, 8, 2, 4, 3, 4, 9, 3, 3, 5, 6, 2, 3, 6, 2, 6, 9, 4, 8, 5, 3, 8, 6, 8, 3, 3, 7, 7, 9, 7, 7, 6, 7, 7, 4, 5, 7, 7, 9, 4, 3, 5, 6, 3, 8, 7, 5, 2],[4, 5, 6, 3, 0, 0, 7, 5, 2, 9, 2, 4, 7, 7, 9, 9, 6, 4, 8, 0, 0, 7, 0, 9, 6, 0, 4, 7, 2, 6, 0, 2, 4, 6, 2, 8, 5, 6, 9, 0, 6, 4, 4, 5, 0, 3, 4, 7, 2, 3, 0, 4, 3, 3, 9, 9, 9, 4, 7, 6, 6, 2, 2, 5, 6, 8, 7, 5, 0, 2, 8, 0, 0, 7, 7, 3, 5, 2, 7, 5, 9, 8, 7, 3, 9, 6, 3, 7, 3, 9, 9, 5, 4, 3, 4, 4, 9, 3, 0, 5],[3, 0, 3, 4, 0, 0, 3, 8, 9, 6, 4, 7, 0, 0, 2, 3, 5, 4, 8, 0, 0, 3, 0, 0, 0, 0, 7, 5, 0, 6, 0, 9, 2, 2, 8, 5, 0, 7, 9, 0, 0, 0, 0, 9, 0, 7, 8, 5, 6, 7, 0, 0, 7, 2, 2, 0, 7, 0, 7, 6, 3, 0, 6, 3, 5, 2, 4, 9, 0, 9, 8, 0, 0, 0, 6, 6, 4, 6, 9, 9, 3, 6, 6, 0, 9, 7, 0, 6, 7, 2, 3, 5, 2, 2, 5, 0, 7, 9, 0, 8],[8, 0, 8, 2, 0, 0, 6, 0, 9, 0, 8, 6, 0, 0, 2, 0, 7, 2, 8, 0, 0, 0, 0, 0, 0, 0, 3, 6, 0, 5, 0, 9, 6, 6, 8, 3, 0, 9, 0, 0, 0, 0, 0, 5, 0, 4, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 8, 0, 0, 0, 8, 3, 0, 8, 0, 3, 0, 0, 8, 0, 0, 0, 7, 6, 3, 0, 8, 5, 3, 0, 0, 0, 0, 0, 0, 0, 5, 5, 4, 9, 8, 9, 8, 0, 0, 3, 0, 0],[3, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 8, 0, 0, 0, 0, 0, 8, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 9, 4, 5, 4, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 7, 0, 0, 0, 2, 8, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 2, 2, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 0, 0, 6, 0, 3, 0, 0, 8, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 9, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 0, 8, 0, 0, 7, 0, 0]]
embeddings = tf.Variable(tf.random_uniform([vocab_size, input_embedding_size], -0.1, 1.0), dtype=tf.float32)

产生10行20列的-0.1~1之间的随机分布,其每一行代表0到9之间的一个数,一个embedding的实例如下:

>>> embeddings.eval()array([[  8.57528567e-01,  -3.72054130e-02,   6.21403337e-01, 8.35353658e-02,   1.53534263e-01,   2.62639582e-01,   2.23163366e-01,   2.58332491e-01,   7.47691810e-01, 2.18696743e-02,   9.51069951e-01,   4.89469558e-01, 4.05145854e-01,   7.03303456e-01,   4.77499455e-01, 9.39792752e-01,   6.68925047e-01,   9.87486601e-01, 9.75592256e-01,   1.45365506e-01],[  8.16872954e-01,   3.05627882e-01, 3.83770078e-01, 3.83601040e-01,   1.96932733e-01,   4.25684005e-01,8.91459048e-01,   7.48355091e-01, 3.97092775e-02,   1.14981048e-01,   6.42851815e-02,   4.97784287e-01, 3.34796220e-01,  -6.49462715e-02,  -6.06450699e-02, 3.05051029e-01,  4.70807880e-01,   5.45011580e-01,  9.60438967e-01,   4.03340489e-01],[  2.18355328e-01,  -2.76850909e-02,   2.97458202e-01, 4.05469924e-01,  -3.91394496e-02,   9.02301550e-01,  5.60466722e-02,   7.92973161e-01,   8.11957419e-01, 6.46622121e-01,   3.68399531e-01,   4.87798229e-02, 9.15218517e-02,   4.87525791e-01,   9.68121052e-01, 1.47761852e-02,   4.30188328e-01,   1.35585159e-01, 6.51796401e-01,   3.13758731e-01],[  2.77882546e-01,   7.78505564e-01,   7.75347471e-01, 6.26467705e-01,   3.10038835e-01,   2.75827736e-01,  5.35759106e-02,   4.61564153e-01,   8.29582095e-01, -7.35464469e-02,   9.96127486e-01,   8.53365734e-02,    2.35948950e-01,   6.37519360e-03,   4.87927884e-01, 2.72394747e-01,   9.69405398e-02,   4.32679445e-01, 3.09049577e-01,   4.30735916e-01],[  2.81157672e-01,   4.05144840e-01,   7.27885902e-01, 3.59490007e-01,   9.77600574e-01,   5.89127421e-01,  5.68094134e-01,   6.05683148e-01,   3.12025309e-01, 3.77709895e-01,   9.63535190e-01,   1.96691453e-01, -1.54233724e-02,   9.55036640e-01,  -3.35017666e-02, 1.98448181e-01,   6.44266680e-02,   2.89128423e-01,    3.11480612e-01,   3.50838304e-01],[  4.90482420e-01,   8.79411757e-01,   4.71502095e-01, 4.49019402e-01,   4.29210216e-01,   7.45444894e-01,  7.96077847e-01,   2.09727436e-02,   8.19468573e-02, 4.89098877e-01,   8.62105563e-02,   9.90414977e-01, 7.82572865e-01,   4.59738284e-01,   2.73251295e-01, 7.01838210e-02,   5.48981905e-01,   5.07066473e-02, -7.69949108e-02,   3.06924343e-01],[  9.54324007e-01,   5.99865317e-01,   7.67913386e-02, 2.71800533e-02,   6.31043375e-01,   4.90301043e-01,  3.13162744e-01,   2.09044039e-01,   4.18955594e-01, -9.16666761e-02,   7.82842994e-01,   7.74771631e-01,    4.82871085e-01,   4.59463924e-01,   2.69717604e-01, 2.75358170e-01,   5.27832091e-01,   3.54466140e-01, 2.80414283e-01,   2.35298544e-01],//从0计,第七个元素,这个值对应于encoder_inputs中的7.[  7.32733130e-01,   7.41863728e-01,   3.31785470e-01, 4.86949235e-01,   2.59555012e-01,   7.44006515e-01,  3.29493582e-01,   1.50963455e-01,   1.70477331e-01, 1.41230315e-01,   4.46048945e-01,   7.82836020e-01, 5.16645372e-01,   5.59585631e-01,   2.88499236e-01, 5.93656972e-02,   6.02837741e-01,   9.35048819e-01, 4.38395411e-01,   9.98193383e-01],        [  9.94453311e-01,   4.23193485e-01,   2.43528277e-01,  4.83180672e-01,   1.69399709e-01,   8.52666676e-01,        6.24188781e-01,   2.93061256e-01,   6.59859061e-01, 9.48151112e-01,   9.17655826e-01,   8.19465399e-01,        -9.57168937e-02,   2.74811313e-02,   7.72869766e-01, -9.81867313e-04,   8.16195846e-01,  -9.82092842e-02,        6.67886853e-01,   6.78441525e-01],        [  8.99791047e-02,   1.56264246e-01,   6.95984125e-01, 8.47990572e-01,   8.40536356e-01,   2.00185269e-01,        4.41039473e-01,   6.13151312e-01,   4.02850598e-01, 4.05698746e-01,   5.45094550e-01,   6.80389225e-01,        -5.77304363e-02,   1.57264650e-01,   5.16325772e-01, 9.07227874e-01,   2.62629867e-01,   8.97993386e-01,        3.52653652e-01,   6.02253318e-01]], dtype=float32)

接下来就是

tf.nn.embedding_lookup

展平再重构,

>>> encoder_inputs_embedded.eval()    array([[>[ 0.73273313,  0.74186373,  0.33178547, ...,  0.93504882,            0.43839541,  0.99819338],            [ 0.95432401,  0.59986532,  0.07679134, ...,  0.35446614,            0.28041428,  0.23529854],            [ 0.27788255,  0.77850556,  0.77534747, ...,  0.43267944,            0.30904958,  0.43073592],            ...,            [ 0.0899791 ,  0.15626425,  0.69598413, ...,  0.89799339,            0.35265365,  0.60225332],            [ 0.49048242,  0.87941176,  0.4715021 , ...,  0.05070665,            -0.07699491,  0.30692434],            [ 0.49048242,  0.87941176,  0.4715021 , ...,  0.05070665,            -0.07699491,  0.30692434]],            [[ 0.27788255,  0.77850556,  0.77534747, ...,  0.43267944,            0.30904958,  0.43073592],            [ 0.28115767,  0.40514484,  0.7278859 , ...,  0.28912842,            0.31148061,  0.3508383 ],            [ 0.95432401,  0.59986532,  0.07679134, ...,  0.35446614,            0.28041428,  0.23529854],            ...,            [ 0.21835533, -0.02768509,  0.2974582 , ...,  0.13558516,            0.6517964 ,  0.31375873],            [ 0.49048242,  0.87941176,  0.4715021 , ...,  0.05070665,            -0.07699491,  0.30692434],            [ 0.28115767,  0.40514484,  0.7278859 , ...,  0.28912842,            0.31148061,  0.3508383 ]],            [[ 0.0899791 ,  0.15626425,  0.69598413, ...,  0.89799339,            0.35265365,  0.60225332],             [ 0.28115767,  0.40514484,  0.7278859 , ...,  0.28912842,            0.31148061,  0.3508383 ],            [ 0.21835533, -0.02768509,  0.2974582 , ...,  0.13558516,            0.6517964 ,  0.31375873],            ...,            [ 0.73273313,  0.74186373,  0.33178547, ...,  0.93504882,            0.43839541,  0.99819338],            [ 0.49048242,  0.87941176,  0.4715021 , ...,  0.05070665,            -0.07699491,  0.30692434],            [ 0.21835533, -0.02768509,  0.2974582 , ...,  0.13558516,            0.6517964 ,  0.31375873]],            ...,            [[ 0.99445331,  0.42319348,  0.24352828, ..., -0.09820928,            0.66788685,  0.67844152],            [ 0.85752857, -0.03720541,  0.62140334, ...,  0.9874866 ,            0.97559226,  0.14536551],            [ 0.99445331,  0.42319348,  0.24352828, ..., -0.09820928,            0.66788685,  0.67844152],            ...,            [ 0.27788255,  0.77850556,  0.77534747, ...,  0.43267944,            0.30904958,  0.43073592],            [ 0.85752857, -0.03720541,  0.62140334, ...,  0.9874866 ,            0.97559226,  0.14536551],            [ 0.85752857, -0.03720541,  0.62140334, ...,  0.9874866 ,            0.97559226,  0.14536551]],[[ 0.27788255,  0.77850556,  0.77534747, ...,  0.43267944,            0.30904958,  0.43073592],            [ 0.85752857, -0.03720541,  0.62140334, ...,  0.9874866 ,            0.97559226,  0.14536551],            [ 0.85752857, -0.03720541,  0.62140334, ...,  0.9874866 ,            0.97559226,  0.14536551],            ...,            [ 0.99445331,  0.42319348,  0.24352828, ..., -0.09820928,            0.66788685,  0.67844152],            [ 0.85752857, -0.03720541,  0.62140334, ...,  0.9874866 ,            0.97559226,  0.14536551],            [ 0.85752857, -0.03720541,  0.62140334, ...,  0.9874866 ,            0.97559226,  0.14536551]],            [[ 0.85752857, -0.03720541,  0.62140334, ...,  0.9874866 ,            0.97559226,  0.14536551],            [ 0.85752857, -0.03720541,  0.62140334, ...,  0.9874866 ,            0.97559226,  0.14536551],            [ 0.85752857, -0.03720541,  0.62140334, ...,  0.9874866 ,            0.97559226,  0.14536551],            ...,            [ 0.73273313,  0.74186373,  0.33178547, ...,  0.93504882,            0.43839541,  0.99819338],            [ 0.85752857, -0.03720541,  0.62140334, ...,  0.9874866 ,            0.97559226,  0.14536551],            [ 0.85752857, -0.03720541,  0.62140334, ...,  0.9874866 ,            0.97559226,  0.14536551]]], dtype=float32)            float32

上面的意思就。

decoder_cell = tf.contrib.rnn.LSTMCell(decoder_hidden_units)

创建LSTMCell。

encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(encoder_cell, encoder_inputs_embedded, dtype=tf.float32, time_major=True,)

time_major表示输入矩阵encoder_inputs_embedded的第一个。

如果 encoder_inputs_embedded 为 (batches, steps, inputs) ==> time_major=False;如果 encoder_inputs_embedded 为 (steps, batches, inputs) ==> time_major=True;

这里写图片描述