阅读理解任务中的Attention-over-Attention神经网络模型原理及实现
来源:互联网 发布:方正排版印刷软件 编辑:程序博客网 时间:2024/06/06 06:43
本文是“Attention-over-Attention Neural Networks for Reading Comprehension”的阅读笔记。这篇论文所处理的任务是阅读理解里面的完形填空问题。其模型架构是建立在“Text Understanding with the Attention Sum Reader Network”这篇论文至上。该论文首先提出了将Attention用于完形填空任务,本篇论文则在其基础之上添加了一个额外的Attention层,可以免去启发式的算法和一些超参数调整等问题。我们接下来结合两篇论文进行介绍。
数据集
首先介绍一下数据集,目前常用的大规模数据集主要包括CNN/Daliy Mail和Children’s Book Test(CBTest)。前面两个是新闻数据集,将一整篇新闻文档作为完形填空的文本(Document),然后将其新闻摘要中的一句话去掉一个词之后作为查询(Query),去掉的那个词作为答案(Answer)。其中Document中的命名实体会被替换成不同的标识符:@entity1、@entity2、、、等例如,第一行为网页URL(无用),第三行为Document, 第五行为Query, 第七行为answer,并且其中的命名实体均被替换:
CBT数据集是从儿童读物中获取,由于其没有摘要,所以采用,前面连续的21句话作为Document,第22句话作为Query等方式构建。然后其还根据答案的词性分为四个子集:命名实体(NE)、公共名词(CN)、动词、介词。但是由于后面两种答案与文本并没有十分紧密的关系,比如人们常常不需要读文本就可以判断出介词填空等,所以常用的是前面两种。
最终每条数据被构建为如下三元组:
<D, Q, A>
模型
首先我们可以看一下“Text Understanding with the Attention Sum Reader Network”这篇论文所提出的模型架构,如下图所示:
从上图可以看出,模型首先通过嵌入矩阵V得到Document和Query中每个单词的词向量e(w)。接下来分别使用两个encoder网络获得文本中每个单词的向量contextual embedding
和Query的表示向量。这里的encoder使用的是双向GRU循环神经网络。然后使用点积的方式将Query向量和每一个单词的contextual embedding
相乘,得到的结果可以视为每个单词对于该查询的权重,亦可理解为attention。最后使用softmax函数将权重转化为归一化的概率,将概率最大的结果视为该query的答案。
接下来我们再看一下本文提出的模型架构,如下图所示:
模型的前半部分与上面完全一样,差别在于本文提出了一种“Attention over Attention”的机制,也就是获得Document和Query的向量之后,不将Query的所有单词合为一个向量,而是直接以矩阵的形式与Document矩阵相乘,然后分别从行和列两个维度对相乘后的矩阵进行softmax操作得到document的注意力矩阵和query的注意力矩阵。在对query矩阵每一列的元素进行求和当做权重,对document的attention矩阵进行点积即可。
模型的代码实现
其实模型使用tensorflow实现的时候十分简单,直接调用tf.contrib.rnn下面的GRUCell即可,难点在于数据的处理和读取操作。这里我们可以参考github上面的两个实现方案:OlavHN,marshmelloX。第一个使用了TF内置的读取数据的API,代码十分简洁明了,我有时间需要研究一下其实现原理整理出一份博客来。第二个使用的是传统的数据处理方式,也可以参考,此外在github上面应该可以找到CNN等数据集的处理代码结合着一起学习。但是上面两个代码实现都用的是比较老的版本,如果用的是tf1.0及以上的版本可能会出现一些函数的不兼容问题,我参照第一份代码实现进行了一定的修改,可以再1。0的版本上运行。代码后续会放到我的github上面,欢迎查看。在服务器上跑需要四五天的样子,现在还没跑完==下图是结果截图:
四个参数分别代表步数,错误率,准确度,时间。可以看到准确度不是十分稳定,但是基本上达到了论文里面提到的效果。可以看一下我修改过之后的model的代码,特别是模型构建部分还是比较简单的,只用了几行命令就实现了:
import osimport timeimport randomimport numpy as npimport tensorflow as tffrom tensorflow.python.ops import sparse_opsfrom util import softmax, orthogonal_initializerflags = tf.app.flagsFLAGS = flags.FLAGSflags.DEFINE_integer('vocab_size', 119662, 'Vocabulary size')flags.DEFINE_integer('embedding_size', 384, 'Embedding dimension')flags.DEFINE_integer('hidden_size', 256, 'Hidden units')flags.DEFINE_integer('batch_size', 32, 'Batch size')flags.DEFINE_integer('epochs', 2, 'Number of epochs to train/test')flags.DEFINE_boolean('training', True, 'Training or testing a model')flags.DEFINE_string('name', 'lc_model', 'Model name (used for statistics and model path')flags.DEFINE_float('dropout_keep_prob', 0.9, 'Keep prob for embedding dropout')flags.DEFINE_float('l2_reg', 0.0001, 'l2 regularization for embeddings')model_path = 'models/' + FLAGS.nameif not os.path.exists(model_path): os.makedirs(model_path)def read_records(index=0): train_queue = tf.train.string_input_producer(['training.tfrecords'], num_epochs=FLAGS.epochs) validation_queue = tf.train.string_input_producer(['validation.tfrecords'], num_epochs=FLAGS.epochs) test_queue = tf.train.string_input_producer(['test.tfrecords'], num_epochs=FLAGS.epochs) queue = tf.QueueBase.from_list(index, [train_queue, validation_queue, test_queue]) reader = tf.TFRecordReader() _, serialized_example = reader.read(queue) features = tf.parse_single_example( serialized_example, features={ 'document': tf.VarLenFeature(tf.int64), 'query': tf.VarLenFeature(tf.int64), 'answer': tf.FixedLenFeature([], tf.int64) }) document = sparse_ops.serialize_sparse(features['document']) query = sparse_ops.serialize_sparse(features['query']) answer = features['answer'] document_batch_serialized, query_batch_serialized, answer_batch = tf.train.shuffle_batch( [document, query, answer], batch_size=FLAGS.batch_size, capacity=2000, min_after_dequeue=1000) sparse_document_batch = sparse_ops.deserialize_many_sparse(document_batch_serialized, dtype=tf.int64) sparse_query_batch = sparse_ops.deserialize_many_sparse(query_batch_serialized, dtype=tf.int64) document_batch = tf.sparse_tensor_to_dense(sparse_document_batch) document_weights = tf.sparse_to_dense(sparse_document_batch.indices, sparse_document_batch.dense_shape, 1) query_batch = tf.sparse_tensor_to_dense(sparse_query_batch) query_weights = tf.sparse_to_dense(sparse_query_batch.indices, sparse_query_batch.dense_shape, 1) return document_batch, document_weights, query_batch, query_weights, answer_batchdef inference(documents, doc_mask, query, query_mask): embedding = tf.get_variable('embedding', [FLAGS.vocab_size, FLAGS.embedding_size], initializer=tf.random_uniform_initializer(minval=-0.05, maxval=0.05)) regularizer = tf.nn.l2_loss(embedding) doc_emb = tf.nn.dropout(tf.nn.embedding_lookup(embedding, documents), FLAGS.dropout_keep_prob) doc_emb.set_shape([None, None, FLAGS.embedding_size]) query_emb = tf.nn.dropout(tf.nn.embedding_lookup(embedding, query), FLAGS.dropout_keep_prob) query_emb.set_shape([None, None, FLAGS.embedding_size]) with tf.variable_scope('document', initializer=orthogonal_initializer()): fwd_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size) back_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size) doc_len = tf.reduce_sum(doc_mask, reduction_indices=1) h, _ = tf.nn.bidirectional_dynamic_rnn( fwd_cell, back_cell, doc_emb, sequence_length=tf.to_int64(doc_len), dtype=tf.float32) #h_doc = tf.nn.dropout(tf.concat(2, h), FLAGS.dropout_keep_prob) h_doc = tf.concat(h, 2) with tf.variable_scope('query', initializer=orthogonal_initializer()): fwd_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size) back_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size) query_len = tf.reduce_sum(query_mask, reduction_indices=1) h, _ = tf.nn.bidirectional_dynamic_rnn( fwd_cell, back_cell, query_emb, sequence_length=tf.to_int64(query_len), dtype=tf.float32) #h_query = tf.nn.dropout(tf.concat(2, h), FLAGS.dropout_keep_prob) h_query = tf.concat(h, 2) M = tf.matmul(h_doc, h_query, adjoint_b=True) M_mask = tf.to_float(tf.matmul(tf.expand_dims(doc_mask, -1), tf.expand_dims(query_mask, 1))) alpha = softmax(M, 1, M_mask) beta = softmax(M, 2, M_mask) #query_importance = tf.expand_dims(tf.reduce_mean(beta, reduction_indices=1), -1) query_importance = tf.expand_dims(tf.reduce_sum(beta, 1) / tf.to_float(tf.expand_dims(doc_len, -1)), -1) s = tf.squeeze(tf.matmul(alpha, query_importance), [2]) unpacked_s = zip(tf.unstack(s, FLAGS.batch_size), tf.unstack(documents, FLAGS.batch_size)) y_hat = tf.stack([tf.unsorted_segment_sum(attentions, sentence_ids, FLAGS.vocab_size) for (attentions, sentence_ids) in unpacked_s]) return y_hat, regularizerdef train(y_hat, regularizer, document, doc_weight, answer): # Trick while we wait for tf.gather_nd - https://github.com/tensorflow/tensorflow/issues/206 # This unfortunately causes us to expand a sparse tensor into the full vocabulary index = tf.range(0, FLAGS.batch_size) * FLAGS.vocab_size + tf.to_int32(answer) flat = tf.reshape(y_hat, [-1]) relevant = tf.gather(flat, index) # mean cause reg is independent of batch size loss = -tf.reduce_mean(tf.log(relevant)) + FLAGS.l2_reg * regularizer global_step = tf.Variable(0, name="global_step", trainable=False) accuracy = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(y_hat, 1), answer))) optimizer = tf.train.AdamOptimizer() grads_and_vars = optimizer.compute_gradients(loss) capped_grads_and_vars = [(tf.clip_by_value(grad, -5, 5), var) for (grad, var) in grads_and_vars] train_op = optimizer.apply_gradients(capped_grads_and_vars, global_step=global_step) tf.summary.scalar('loss', loss) tf.summary.scalar('accuracy', accuracy) return loss, train_op, global_step, accuracydef main(): dataset = tf.placeholder_with_default(0, []) document_batch, document_weights, query_batch, query_weights, answer_batch = read_records(dataset) y_hat, reg = inference(document_batch, document_weights, query_batch, query_weights) loss, train_op, global_step, accuracy = train(y_hat, reg, document_batch, document_weights, answer_batch) summary_op = tf.summary.merge_all() with tf.Session() as sess: summary_writer = tf.summary.FileWriter(model_path, sess.graph) saver_variables = tf.all_variables() if not FLAGS.training: saver_variables = filter(lambda var: var.name != 'input_producer/limit_epochs/epochs:0', saver_variables) saver_variables = filter(lambda var: var.name != 'smooth_acc:0', saver_variables) saver_variables = filter(lambda var: var.name != 'avg_acc:0', saver_variables) saver = tf.train.Saver(saver_variables) sess.run([ tf.initialize_all_variables(), tf.initialize_local_variables()]) model = tf.train.latest_checkpoint(model_path) if model: print('Restoring ' + model) saver.restore(sess, model) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time.time() accumulated_accuracy = 0 try: if FLAGS.training: while not coord.should_stop(): loss_t, _, step, acc = sess.run([loss, train_op, global_step, accuracy], feed_dict={dataset: 0}) elapsed_time, start_time = time.time() - start_time, time.time() print(step, loss_t, acc, elapsed_time) if step % 100 == 0: summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step) if step % 1000 == 0: saver.save(sess, model_path + '/aoa', global_step=step) else: step = 0 while not coord.should_stop(): acc = sess.run(accuracy, feed_dict={dataset: 2}) step += 1 accumulated_accuracy += (acc - accumulated_accuracy) / step elapsed_time, start_time = time.time() - start_time, time.time() print(accumulated_accuracy, acc, elapsed_time) except tf.errors.OutOfRangeError: print('Done!') finally: coord.request_stop() coord.join(threads) ''' import pickle with open('counter.pickle', 'r') as f: counter = pickle.load(f) word, _ = zip(*counter.most_common()) '''if __name__ == "__main__": main()
- 阅读理解任务中的Attention-over-Attention神经网络模型原理及实现
- Attention-over-Attention Neural Network for Reading Comprehension----神经网络在阅读理解上的应用
- Attention, 神经网络中的注意力机制
- Attention!神经网络中的注意机制到底是什么?
- 一个Hierarchical Attention神经网络的实现
- attention
- Attention
- Attention
- attention
- Attention
- Attention
- Attention
- Attention
- Attention
- seq2seq里的 attention机制 的 原理 及 代码 及 个人理解
- Attention:注意力模型
- Attention Model 理解
- attention机制 深入理解
- [Hackerrank题目选做] Tree Pruning
- Spring整合web项目
- JSp与Servlet跳转路径配置
- 乘法逆元 小结
- .net实现求职招聘网站
- 阅读理解任务中的Attention-over-Attention神经网络模型原理及实现
- 内核printk的用法
- 烧写uboot
- 刘强东表态后 京东宣布全面接入顺丰
- 案例-----拦截有序广播
- H.264 基础
- Netty实现简单聊天室
- 全局变量和局部变量 案例
- Java基础教程0-测试人员为什么要掌握Java基础