tensorflow实战3-利用seq2seq实现一个聊天机器人

来源:互联网 发布:mac彩妆专柜几月打折 编辑:程序博客网 时间:2024/05/29 18:34

具体项目代码详见github:https://github.com/jacksonsshi/chat_rnn

具体介绍网络结构与训练这块
1、seq2seq代码

import tensorflow as tfimport numpy as npimport sysclass Seq2Seq(object):    def __init__(self, xseq_len, yseq_len,             xvocab_size, yvocab_size,            emb_dim, num_layers, ckpt_path,            lr=0.0001,             epochs=10000, model_name='seq2seq_model'):        # attach these arguments to self        self.xseq_len = xseq_len        self.yseq_len = yseq_len        self.ckpt_path = ckpt_path        self.epochs = epochs        self.model_name = model_name        # build thy graph        #  attach any part of the graph that needs to be exposed, to the self        def __graph__():            # placeholders            tf.reset_default_graph()            #  encoder inputs : list of indices of length xseq_len            self.enc_ip = [ tf.placeholder(shape=[None,],                             dtype=tf.int64,                             name='ei_{}'.format(t)) for t in range(xseq_len) ]            #  labels that represent the real outputs            self.labels = [ tf.placeholder(shape=[None,],                             dtype=tf.int64,                             name='ei_{}'.format(t)) for t in range(yseq_len) ]            #  decoder inputs : 'GO' + [ y1, y2, ... y_t-1 ]            self.dec_ip = [ tf.zeros_like(self.enc_ip[0], dtype=tf.int64, name='GO') ] + self.labels[:-1]            # Basic LSTM cell wrapped in Dropout Wrapper            self.keep_prob = tf.placeholder(tf.float32)            # define the basic cell            basic_cell = tf.contrib.rnn.core_rnn_cell.DropoutWrapper(                    tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(emb_dim, state_is_tuple=True),                    output_keep_prob=self.keep_prob)            # stack cells together : n layered model            stacked_lstm = tf.contrib.rnn.core_rnn_cell.MultiRNNCell([basic_cell]*num_layers, state_is_tuple=True)            # for parameter sharing between training model            #  and testing model            with tf.variable_scope('decoder') as scope:                # build the seq2seq model                 #  inputs : encoder, decoder inputs, LSTM cell type, vocabulary sizes, embedding dimensions                self.decode_outputs, self.decode_states = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(self.enc_ip,self.dec_ip, stacked_lstm,                                                    xvocab_size, yvocab_size, emb_dim)                # share parameters                scope.reuse_variables()                # testing model, where output of previous timestep is fed as input                 #  to the next timestep                self.decode_outputs_test, self.decode_states_test = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(                    self.enc_ip, self.dec_ip, stacked_lstm, xvocab_size, yvocab_size,emb_dim,                    feed_previous=True)            # now, for training,            #  build loss function            # weighted loss            #  TODO : add parameter hint            loss_weights = [ tf.ones_like(label, dtype=tf.float32) for label in self.labels ]            self.loss = tf.contrib.legacy_seq2seq.sequence_loss(self.decode_outputs, self.labels, loss_weights, yvocab_size)            # train op to minimize the loss            self.train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.loss)        sys.stdout.write('<log> Building Graph ')        # build comput graph        __graph__()        sys.stdout.write('</log>')    '''        Training and Evaluation    '''    # get the feed dictionary    def get_feed(self, X, Y, keep_prob):        feed_dict = {self.enc_ip[t]: X[t] for t in range(self.xseq_len)}        feed_dict.update({self.labels[t]: Y[t] for t in range(self.yseq_len)})        feed_dict[self.keep_prob] = keep_prob # dropout prob        return feed_dict    # run one batch for training    def train_batch(self, sess, train_batch_gen):        # get batches        batchX, batchY = train_batch_gen.__next__()        # build feed        feed_dict = self.get_feed(batchX, batchY, keep_prob=0.5)        _, loss_v = sess.run([self.train_op, self.loss], feed_dict)        return loss_v    def eval_step(self, sess, eval_batch_gen):        # get batches        batchX, batchY = eval_batch_gen.__next__()        # build feed        feed_dict = self.get_feed(batchX, batchY, keep_prob=1.)        loss_v, dec_op_v = sess.run([self.loss, self.decode_outputs_test], feed_dict)        # dec_op_v is a list; also need to transpose 0,1 indices         #  (interchange batch_size and timesteps dimensions        dec_op_v = np.array(dec_op_v).transpose([1,0,2])        return loss_v, dec_op_v, batchX, batchY    # evaluate 'num_batches' batches    def eval_batches(self, sess, eval_batch_gen, num_batches):        losses = []        for i in range(num_batches):            loss_v, dec_op_v, batchX, batchY = self.eval_step(sess, eval_batch_gen)            losses.append(loss_v)        return np.mean(losses)    # finally the train function that    #  runs the train_op in a session    #   evaluates on valid set periodically    #    prints statistics    def train(self, train_set, valid_set, sess=None ):        # we need to save the model periodically        saver = tf.train.Saver()        # if no session is given        if not sess:            # create a session            sess = tf.Session()            # init all variables            sess.run(tf.global_variables_initializer())        sys.stdout.write('\n<log> Training started </log>\n')        # run M epochs        for i in range(self.epochs):            try:                self.train_batch(sess, train_set)                if i  % 1000  == 0 and i != 0: # TODO : make this tunable by the user                    # save model to disk                    saver.save(sess, self.ckpt_path + self.model_name + '.ckpt', global_step=i)                    # evaluate to get validation loss                    val_loss = self.eval_batches(sess, valid_set, 16) # TODO : and this                    # print stats                    print('\nModel saved to disk at iteration #{}'.format(i))                    print('val   loss : {0:.6f}'.format(val_loss))                    sys.stdout.flush()            except KeyboardInterrupt: # this will most definitely happen, so handle it                print('Interrupted by user at iteration {}'.format(i))                self.session = sess                return sess    def restore_last_session(self):        saver = tf.train.Saver()        # create a session        sess = tf.Session()        # get checkpoint state        ckpt = tf.train.get_checkpoint_state(self.ckpt_path)        # restore session        if ckpt and ckpt.model_checkpoint_path:            saver.restore(sess, ckpt.model_checkpoint_path)        # return to user        return sess    # prediction    def predict(self, sess, X):        feed_dict = {self.enc_ip[t]: X[t] for t in range(self.xseq_len)}        feed_dict[self.keep_prob] = 1.        dec_op_v = sess.run(self.decode_outputs_test, feed_dict)        # dec_op_v is a list; also need to transpose 0,1 indices         #  (interchange batch_size and timesteps dimensions        dec_op_v = np.array(dec_op_v).transpose([1,0,2])        # return the index of item with highest probability        return np.argmax(dec_op_v, axis=2)
阅读全文
0 0
原创粉丝点击