阅读理解任务中的Attention-over-Attention神经网络模型原理及实现

来源:互联网 发布:方正排版印刷软件 编辑:程序博客网 时间:2024/06/06 06:43

本文是“Attention-over-Attention Neural Networks for Reading Comprehension”的阅读笔记。这篇论文所处理的任务是阅读理解里面的完形填空问题。其模型架构是建立在“Text Understanding with the Attention Sum Reader Network”这篇论文至上。该论文首先提出了将Attention用于完形填空任务,本篇论文则在其基础之上添加了一个额外的Attention层,可以免去启发式的算法和一些超参数调整等问题。我们接下来结合两篇论文进行介绍。

数据集

首先介绍一下数据集,目前常用的大规模数据集主要包括CNN/Daliy Mail和Children’s Book Test(CBTest)。前面两个是新闻数据集,将一整篇新闻文档作为完形填空的文本(Document),然后将其新闻摘要中的一句话去掉一个词之后作为查询(Query),去掉的那个词作为答案(Answer)。其中Document中的命名实体会被替换成不同的标识符:@entity1、@entity2、、、等例如,第一行为网页URL(无用),第三行为Document, 第五行为Query, 第七行为answer,并且其中的命名实体均被替换:
这里写图片描述
CBT数据集是从儿童读物中获取,由于其没有摘要,所以采用,前面连续的21句话作为Document,第22句话作为Query等方式构建。然后其还根据答案的词性分为四个子集:命名实体(NE)、公共名词(CN)、动词、介词。但是由于后面两种答案与文本并没有十分紧密的关系,比如人们常常不需要读文本就可以判断出介词填空等,所以常用的是前面两种。
最终每条数据被构建为如下三元组:

<D, Q, A>

模型

首先我们可以看一下“Text Understanding with the Attention Sum Reader Network”这篇论文所提出的模型架构,如下图所示:
这里写图片描述
从上图可以看出,模型首先通过嵌入矩阵V得到Document和Query中每个单词的词向量e(w)。接下来分别使用两个encoder网络获得文本中每个单词的向量contextual embedding和Query的表示向量。这里的encoder使用的是双向GRU循环神经网络。然后使用点积的方式将Query向量和每一个单词的contextual embedding相乘,得到的结果可以视为每个单词对于该查询的权重,亦可理解为attention。最后使用softmax函数将权重转化为归一化的概率,将概率最大的结果视为该query的答案。
接下来我们再看一下本文提出的模型架构,如下图所示:
这里写图片描述
模型的前半部分与上面完全一样,差别在于本文提出了一种“Attention over Attention”的机制,也就是获得Document和Query的向量之后,不将Query的所有单词合为一个向量,而是直接以矩阵的形式与Document矩阵相乘,然后分别从行和列两个维度对相乘后的矩阵进行softmax操作得到document的注意力矩阵和query的注意力矩阵。在对query矩阵每一列的元素进行求和当做权重,对document的attention矩阵进行点积即可。

模型的代码实现

其实模型使用tensorflow实现的时候十分简单,直接调用tf.contrib.rnn下面的GRUCell即可,难点在于数据的处理和读取操作。这里我们可以参考github上面的两个实现方案:OlavHN,marshmelloX。第一个使用了TF内置的读取数据的API,代码十分简洁明了,我有时间需要研究一下其实现原理整理出一份博客来。第二个使用的是传统的数据处理方式,也可以参考,此外在github上面应该可以找到CNN等数据集的处理代码结合着一起学习。但是上面两个代码实现都用的是比较老的版本,如果用的是tf1.0及以上的版本可能会出现一些函数的不兼容问题,我参照第一份代码实现进行了一定的修改,可以再1。0的版本上运行。代码后续会放到我的github上面,欢迎查看。在服务器上跑需要四五天的样子,现在还没跑完==下图是结果截图:
这里写图片描述
这里写图片描述
四个参数分别代表步数,错误率,准确度,时间。可以看到准确度不是十分稳定,但是基本上达到了论文里面提到的效果。可以看一下我修改过之后的model的代码,特别是模型构建部分还是比较简单的,只用了几行命令就实现了:

import osimport timeimport randomimport numpy as npimport tensorflow as tffrom tensorflow.python.ops import sparse_opsfrom util import softmax, orthogonal_initializerflags = tf.app.flagsFLAGS = flags.FLAGSflags.DEFINE_integer('vocab_size', 119662, 'Vocabulary size')flags.DEFINE_integer('embedding_size', 384, 'Embedding dimension')flags.DEFINE_integer('hidden_size', 256, 'Hidden units')flags.DEFINE_integer('batch_size', 32, 'Batch size')flags.DEFINE_integer('epochs', 2, 'Number of epochs to train/test')flags.DEFINE_boolean('training', True, 'Training or testing a model')flags.DEFINE_string('name', 'lc_model', 'Model name (used for statistics and model path')flags.DEFINE_float('dropout_keep_prob', 0.9, 'Keep prob for embedding dropout')flags.DEFINE_float('l2_reg', 0.0001, 'l2 regularization for embeddings')model_path = 'models/' + FLAGS.nameif not os.path.exists(model_path):    os.makedirs(model_path)def read_records(index=0):  train_queue = tf.train.string_input_producer(['training.tfrecords'], num_epochs=FLAGS.epochs)  validation_queue = tf.train.string_input_producer(['validation.tfrecords'], num_epochs=FLAGS.epochs)  test_queue = tf.train.string_input_producer(['test.tfrecords'], num_epochs=FLAGS.epochs)  queue = tf.QueueBase.from_list(index, [train_queue, validation_queue, test_queue])  reader = tf.TFRecordReader()  _, serialized_example = reader.read(queue)  features = tf.parse_single_example(      serialized_example,      features={        'document': tf.VarLenFeature(tf.int64),        'query': tf.VarLenFeature(tf.int64),        'answer': tf.FixedLenFeature([], tf.int64)      })  document = sparse_ops.serialize_sparse(features['document'])  query = sparse_ops.serialize_sparse(features['query'])  answer = features['answer']  document_batch_serialized, query_batch_serialized, answer_batch = tf.train.shuffle_batch(      [document, query, answer], batch_size=FLAGS.batch_size,      capacity=2000,      min_after_dequeue=1000)  sparse_document_batch = sparse_ops.deserialize_many_sparse(document_batch_serialized, dtype=tf.int64)  sparse_query_batch = sparse_ops.deserialize_many_sparse(query_batch_serialized, dtype=tf.int64)  document_batch = tf.sparse_tensor_to_dense(sparse_document_batch)  document_weights = tf.sparse_to_dense(sparse_document_batch.indices, sparse_document_batch.dense_shape, 1)  query_batch = tf.sparse_tensor_to_dense(sparse_query_batch)  query_weights = tf.sparse_to_dense(sparse_query_batch.indices, sparse_query_batch.dense_shape, 1)  return document_batch, document_weights, query_batch, query_weights, answer_batchdef inference(documents, doc_mask, query, query_mask):  embedding = tf.get_variable('embedding',              [FLAGS.vocab_size, FLAGS.embedding_size],              initializer=tf.random_uniform_initializer(minval=-0.05, maxval=0.05))  regularizer = tf.nn.l2_loss(embedding)  doc_emb = tf.nn.dropout(tf.nn.embedding_lookup(embedding, documents), FLAGS.dropout_keep_prob)  doc_emb.set_shape([None, None, FLAGS.embedding_size])  query_emb = tf.nn.dropout(tf.nn.embedding_lookup(embedding, query), FLAGS.dropout_keep_prob)  query_emb.set_shape([None, None, FLAGS.embedding_size])  with tf.variable_scope('document', initializer=orthogonal_initializer()):    fwd_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)    back_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)    doc_len = tf.reduce_sum(doc_mask, reduction_indices=1)    h, _ = tf.nn.bidirectional_dynamic_rnn(        fwd_cell, back_cell, doc_emb, sequence_length=tf.to_int64(doc_len), dtype=tf.float32)    #h_doc = tf.nn.dropout(tf.concat(2, h), FLAGS.dropout_keep_prob)    h_doc = tf.concat(h, 2)  with tf.variable_scope('query', initializer=orthogonal_initializer()):    fwd_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)    back_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)    query_len = tf.reduce_sum(query_mask, reduction_indices=1)    h, _ = tf.nn.bidirectional_dynamic_rnn(        fwd_cell, back_cell, query_emb, sequence_length=tf.to_int64(query_len), dtype=tf.float32)    #h_query = tf.nn.dropout(tf.concat(2, h), FLAGS.dropout_keep_prob)    h_query = tf.concat(h, 2)  M = tf.matmul(h_doc, h_query, adjoint_b=True)  M_mask = tf.to_float(tf.matmul(tf.expand_dims(doc_mask, -1), tf.expand_dims(query_mask, 1)))  alpha = softmax(M, 1, M_mask)  beta = softmax(M, 2, M_mask)  #query_importance = tf.expand_dims(tf.reduce_mean(beta, reduction_indices=1), -1)  query_importance = tf.expand_dims(tf.reduce_sum(beta, 1) / tf.to_float(tf.expand_dims(doc_len, -1)), -1)  s = tf.squeeze(tf.matmul(alpha, query_importance), [2])  unpacked_s = zip(tf.unstack(s, FLAGS.batch_size), tf.unstack(documents, FLAGS.batch_size))  y_hat = tf.stack([tf.unsorted_segment_sum(attentions, sentence_ids, FLAGS.vocab_size) for (attentions, sentence_ids) in unpacked_s])  return y_hat, regularizerdef train(y_hat, regularizer, document, doc_weight, answer):  # Trick while we wait for tf.gather_nd - https://github.com/tensorflow/tensorflow/issues/206  # This unfortunately causes us to expand a sparse tensor into the full vocabulary  index = tf.range(0, FLAGS.batch_size) * FLAGS.vocab_size + tf.to_int32(answer)  flat = tf.reshape(y_hat, [-1])  relevant = tf.gather(flat, index)  # mean cause reg is independent of batch size  loss = -tf.reduce_mean(tf.log(relevant)) + FLAGS.l2_reg * regularizer   global_step = tf.Variable(0, name="global_step", trainable=False)  accuracy = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(y_hat, 1), answer)))  optimizer = tf.train.AdamOptimizer()  grads_and_vars = optimizer.compute_gradients(loss)  capped_grads_and_vars = [(tf.clip_by_value(grad, -5, 5), var) for (grad, var) in grads_and_vars]  train_op = optimizer.apply_gradients(capped_grads_and_vars, global_step=global_step)  tf.summary.scalar('loss', loss)  tf.summary.scalar('accuracy', accuracy)  return loss, train_op, global_step, accuracydef main():  dataset = tf.placeholder_with_default(0, [])  document_batch, document_weights, query_batch, query_weights, answer_batch = read_records(dataset)  y_hat, reg = inference(document_batch, document_weights, query_batch, query_weights)  loss, train_op, global_step, accuracy = train(y_hat, reg, document_batch, document_weights, answer_batch)  summary_op = tf.summary.merge_all()  with tf.Session() as sess:    summary_writer = tf.summary.FileWriter(model_path, sess.graph)    saver_variables = tf.all_variables()    if not FLAGS.training:      saver_variables = filter(lambda var: var.name != 'input_producer/limit_epochs/epochs:0', saver_variables)      saver_variables = filter(lambda var: var.name != 'smooth_acc:0', saver_variables)      saver_variables = filter(lambda var: var.name != 'avg_acc:0', saver_variables)    saver = tf.train.Saver(saver_variables)    sess.run([        tf.initialize_all_variables(),        tf.initialize_local_variables()])    model = tf.train.latest_checkpoint(model_path)    if model:      print('Restoring ' + model)      saver.restore(sess, model)    coord = tf.train.Coordinator()    threads = tf.train.start_queue_runners(coord=coord)    start_time = time.time()    accumulated_accuracy = 0    try:      if FLAGS.training:        while not coord.should_stop():          loss_t, _, step, acc = sess.run([loss, train_op, global_step, accuracy], feed_dict={dataset: 0})          elapsed_time, start_time = time.time() - start_time, time.time()          print(step, loss_t, acc, elapsed_time)          if step % 100 == 0:            summary_str = sess.run(summary_op)            summary_writer.add_summary(summary_str, step)          if step % 1000 == 0:            saver.save(sess, model_path + '/aoa', global_step=step)      else:        step = 0        while not coord.should_stop():          acc = sess.run(accuracy, feed_dict={dataset: 2})          step += 1          accumulated_accuracy += (acc - accumulated_accuracy) / step          elapsed_time, start_time = time.time() - start_time, time.time()          print(accumulated_accuracy, acc, elapsed_time)    except tf.errors.OutOfRangeError:      print('Done!')    finally:      coord.request_stop()    coord.join(threads)    '''    import pickle    with open('counter.pickle', 'r') as f:      counter = pickle.load(f)    word, _ = zip(*counter.most_common())    '''if __name__ == "__main__":  main()
原创粉丝点击