RNN代码解读之char-RNN with TensorFlow(model.py)

来源:互联网 发布:js正则表达式校验邮箱 编辑:程序博客网 时间:2024/06/05 08:34

此工程解读链接(建议按顺序阅读):
RNN代码解读之char-RNN with TensorFlow(model.py)
RNN代码解读之char-RNN with TensorFlow(train.py)
RNN代码解读之char-RNN with TensorFlow(util.py)
RNN代码解读之char-RNN with TensorFlow(sample.py)

最近一直在学习RNN的相关知识,个人认为相比于CNN各种模型在detection/classification/segmentation等方面超人的表现,RNN还有很长的一段路要走,毕竟现在的nlp模型单从output质量上来看只是差强人意,要和人相比还有一段距离。CNN+RNN的任务比如image caption更是有很多有待研究和提高的地方。

关于对CNN和RNN相关内容的学习和探讨,我将会在近期更新对一些经典论文的解读以及自己的看法,届时欢迎大家给予指导。

当然,CS231n中有一句名言“Don’t think too hard, just cross your fingers.” 想法还是要落地才可以看到成果,那么我们今天就一起来看一下大牛Adrew Karparthy的char-RNN模型,AK使用lua基于torch写的,git上已经有人及时的复现了TensorFlow with Python版本(https://github.com/sherjilozair/char-rnn-tensorflow)。

网上已经有很多相关的解析了,但大部分只是针对model进行解释,这对于整体模型的宏观理解以及TensorFlow的学习都是很不利的。因此,这里我会给出自己对所有代码的理解,若有错误欢迎及时指正。

这一个版本的代码共分为四个模块:model.py,train.py, util.py以及sample.py,我们将按照这个顺序,分四篇博文对四个模块进行梳理。我在代码中对所有我认为重要的地方都写了注释,有的部分甚至每一行都有明确的注释,但难免有的基本方法会让人产生疑惑。面对这种问题,我强烈建议大家一边debug一步一步的执行看结果,一边百度或者google。这样梳理一遍代码一定会全身舒畅,豁然开朗,感觉打开了新世界的大门,对于RNN模型的TensorFlow实现也会更有把握。

当然理解这一个工程并不是我们的终极目的,针对后面跟新的paper中提到的有创新的方法,我们也会再此模型的基础上进一步实现,走上我们的科研之路。

废话说太多了,下面我们先开始看最重点的model.py
注意:这里注释解释的只是训练过程中的理解,在infer过程中batch=1,sequence=1,大体理解没有差别,但是具体思想还需要大家到时候再推敲推敲。此外,此class中的sample方法这一节不讨论,到第四节sample.py的时候一并讨论。

#-*-coding:utf-8-*-import tensorflow as tffrom tensorflow.python.ops import rnn_cellfrom tensorflow.python.ops import seq2seqimport numpy as npclass Model():    def __init__(self, args, infer=False):        self.args = args        #在测试状态下(inference)才用如下选项        if infer:            args.batch_size = 1            args.seq_length = 1        #几种备选的rnn类型        if args.model == 'rnn':            cell_fn = rnn_cell.BasicRNNCell        elif args.model == 'gru':            cell_fn = rnn_cell.GRUCell        elif args.model == 'lstm':            cell_fn = rnn_cell.BasicLSTMCell        else:            raise Exception("model type not supported: {}".format(args.model))        #固定格式是例:cell = rnn_cell.GRUCelll(rnn_size)        #rnn_size指的是每个rnn单元中的神经元个数(虽然RNN途中只有一个圆圈代表,但这个圆圈代表了rnn_size个神经元)        #这里state_is_tuple根据官网解释,每个cell返回的h和c状态是储存在一个list里还是两个tuple里,官网建议设置为true        cell = cell_fn(args.rnn_size, state_is_tuple=True)        #固定格式,有几层rnn        self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True)        #input_data&target(标签)格式:[batch_size, seq_length]        self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])        self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])        #cell的初始状态设为0,因为在前面设置cell时,cell_size已经设置好了,因此这里只需给出batch_size即可        #(一个batch内有batch_size个sequence的输入)        self.initial_state = cell.zero_state(args.batch_size, tf.float32)        #rnnlm = recurrent neural network language model        #variable_scope就是变量的作用域        with tf.variable_scope('rnnlm'):            #softmax层的参数            softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])            softmax_b = tf.get_variable("softmax_b", [args.vocab_size])            with tf.device("/cpu:0"):                #推荐使用tf.get_variable而不是tf.variable                #embedding矩阵是将输入转换到了cell_size,因此这样的大小设置                embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])                #关于tf.nn.embedding_lookup(embedding, self.input_data):                #   调用tf.nn.embedding_lookup,索引与train_dataset对应的向量,相当于用train_dataset作为一个id,去检索矩阵中与这个id对应的embedding                #将第三个参数,在第1维度,切成seq_length长短的片段                #embeddinglookup得到的look_up尺寸是[batch_size, seq_length, rnn_size],这里是[50,50,128]                look_up = tf.nn.embedding_lookup(embedding, self.input_data)                #将上面的[50,50,128]切开,得到50个[50,1,128]的inputs                inputs = tf.split(1, args.seq_length, look_up)                #之后将 1 squeeze掉,50个[50,128]                inputs = [tf.squeeze(input_, [1]) for input_ in inputs]        #在infer的时候方便查看        def loop(prev, _):            prev = tf.matmul(prev, softmax_w) + softmax_b            prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))            return tf.nn.embedding_lookup(embedding, prev_symbol)        #seq2seq.rnn_decoder基于schedule sampling实现,相当于一个黑盒子,可以直接调用        #得到的两个参数shape均为50个50*128的张量,和输入是一样的        outputs, last_state = seq2seq.rnn_decoder(inputs,                                                  self.initial_state, cell,                                                  loop_function=loop if infer else None,                                                  scope='rnnlm')        #将outputsreshape在一起,形成[2500,128]的张量        output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size])        #logits和probs的大小都是[2500,65]([2500,128]*[128,65])        self.logits = tf.matmul(output, softmax_w) + softmax_b        self.probs = tf.nn.softmax(self.logits)        #得到length为2500的loss(即每一个batch的sequence中的每一个单词输入,都会最终产生一个loss,50*50=2500)        loss = seq2seq.sequence_loss_by_example([self.logits],                [tf.reshape(self.targets, [-1])],                [tf.ones([args.batch_size * args.seq_length])],                args.vocab_size)        #得到一个batch的cost后面用于求梯度        self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length        #将state转换一下,便于下一次继续训练        self.final_state = last_state        #因为学习率不需要BPTT更新,因此trainable=False        #具体的learning_rate是由train.py中args参数传过来的,这里只是初始化设了一个0        self.lr = tf.Variable(0.0, trainable=False)        #返回了包括前面的softmax_w/softmax_b/embedding等所有变量        tvars = tf.trainable_variables()        #求grads要使用clip避免梯度爆炸,这里设置的阈值是5(见args)        grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),                args.grad_clip)        #使用adam优化方法        optimizer = tf.train.AdamOptimizer(self.lr)        #参考tensorflow手册,        # 将计算出的梯度应用到变量上,是函数minimize()的第二部分,返回一个应用指定的梯度的操作Operation,对global_step做自增操作        self.train_op = optimizer.apply_gradients(zip(grads, tvars))

以上就是对于model.py的代码分析,总体来说就是“模型定义+参数设置+优化”的思路,如果有哪里出错还望大家多多指教啦~!

参考资料:
http://blog.csdn.net/mydear_11000/article/details/52776295
https://github.com/sherjilozair/char-rnn-tensorflow
http://www.tensorfly.cn/tfdoc/api_docs/python/constant_op.html#truncated_normal

0 0
原创粉丝点击