seq2seq里的 attention机制 的 原理 及 代码 及 个人理解
来源:互联网 发布:数据存储四种方式 编辑:程序博客网 时间:2024/06/03 01:42
其中
其中
其中
综合
观察所有输入的东西,可见是 所有encoder的输出 和 decoder的每个state 一起作为输入,搅和在一起,然后target/output就是一个类似score的东西
def attention(self, prev_state, enc_outputs): """ Attention model for Neural Machine Translation :param prev_state: the decoder hidden state at time i-1 :param enc_outputs: the encoder outputs, a length 'T' list. """ e_i = [] c_i = [] for output in enc_outputs: atten_hidden = tf.tanh(tf.add(tf.matmul(prev_state, self.attention_W), tf.matmul(output, self.attention_U))) e_i_j = tf.matmul(atten_hidden, self.attention_V) e_i.append(e_i_j) e_i = tf.concat(e_i, axis=1) alpha_i = tf.nn.softmax(e_i) alpha_i = tf.split(alpha_i, self.num_steps, 1) for alpha_i_j, output in zip(alpha_i, enc_outputs): c_i_j = tf.multiply(alpha_i_j, output) c_i.append(c_i_j) c_i = tf.reshape(tf.concat(c_i, axis=1), [-1, self.num_steps, self.hidden_dim * 2]) c_i = tf.reduce_sum(c_i, 1) return c_i#对应的decode def decode(self, cell, init_state, encoder_outputs, loop_function=None): outputs = [] prev = None state = init_state for i, inp in enumerate(self.decoder_inputs_emb):#decoder_inputs_emb是tf.placeholder #if loop_function is not None and prev is not None: # with tf.variable_scope("loop_function", reuse=True): # inp = loop_function(prev, i) #if i > 0: # tf.get_variable_scope().reuse_variables() c_i = self.attention(state, encoder_outputs) inp = tf.concat([inp, c_i], axis=1) output, state = cell(inp, state)#原本没有attention的是decoder_input和state作为输入 outputs.append(output) if loop_function is not None: prev = output return outputs
代码摘自 https://github.com/pemywei/attention-nmt
阅读全文
0 0
- seq2seq里的 attention机制 的 原理 及 代码 及 个人理解
- 带Attention机制的Seq2Seq框架梳理
- seq2seq以及Attention机制
- 对seq2seq的一些个人理解
- 对seq2seq的一些个人理解
- 关于RNN(Seq2Seq)的一点个人理解与感悟
- 关于RNN(Seq2Seq)的一点个人理解与感悟
- 阅读理解任务中的Attention-over-Attention神经网络模型原理及实现
- Tensorflow 自动文摘: 基于Seq2Seq+Attention模型的Textsum模型
- Tensorflow 自动文摘: 基于Seq2Seq+Attention模型的Textsum模型
- Tensorflow 自动文摘: 基于Seq2Seq+Attention模型的Textsum模型
- 图解RNN、RNN变体、Seq2Seq、Attention机制
- Autoreleasepool的理解及原理
- maven 的基本配置及个人理解
- maven 的基本配置及个人理解
- maven 的基本配置及个人理解
- maven 的基本配置及个人理解
- Java面向对象及个人的理解
- scala安装过程中需要注意的问题
- Sift特征点匹配过程
- ARM原子操作atomic_add详解
- jquery easyui dialog不超出父容器以及随浏览器缩放
- spring+springMVC+hibernate 三大框架整合
- seq2seq里的 attention机制 的 原理 及 代码 及 个人理解
- js随机设置8位密码
- 二进制输出所有的子集
- 欢迎使用CSDN-markdown编辑器
- logback配置日志
- Python练习-0701
- HDU 4283 You Are the One (区间dp)
- Java实现AES加密解密
- 解决qt提示:qt.network.ssl: QSslSocket: cannot call unresolved function DH_free和qt.network.ssl: QSslSocke