tensorflow实战3-利用seq2seq实现一个聊天机器人
来源:互联网 发布:mac彩妆专柜几月打折 编辑:程序博客网 时间:2024/05/29 18:34
具体项目代码详见github:https://github.com/jacksonsshi/chat_rnn
具体介绍网络结构与训练这块
1、seq2seq代码
import tensorflow as tfimport numpy as npimport sysclass Seq2Seq(object): def __init__(self, xseq_len, yseq_len, xvocab_size, yvocab_size, emb_dim, num_layers, ckpt_path, lr=0.0001, epochs=10000, model_name='seq2seq_model'): # attach these arguments to self self.xseq_len = xseq_len self.yseq_len = yseq_len self.ckpt_path = ckpt_path self.epochs = epochs self.model_name = model_name # build thy graph # attach any part of the graph that needs to be exposed, to the self def __graph__(): # placeholders tf.reset_default_graph() # encoder inputs : list of indices of length xseq_len self.enc_ip = [ tf.placeholder(shape=[None,], dtype=tf.int64, name='ei_{}'.format(t)) for t in range(xseq_len) ] # labels that represent the real outputs self.labels = [ tf.placeholder(shape=[None,], dtype=tf.int64, name='ei_{}'.format(t)) for t in range(yseq_len) ] # decoder inputs : 'GO' + [ y1, y2, ... y_t-1 ] self.dec_ip = [ tf.zeros_like(self.enc_ip[0], dtype=tf.int64, name='GO') ] + self.labels[:-1] # Basic LSTM cell wrapped in Dropout Wrapper self.keep_prob = tf.placeholder(tf.float32) # define the basic cell basic_cell = tf.contrib.rnn.core_rnn_cell.DropoutWrapper( tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(emb_dim, state_is_tuple=True), output_keep_prob=self.keep_prob) # stack cells together : n layered model stacked_lstm = tf.contrib.rnn.core_rnn_cell.MultiRNNCell([basic_cell]*num_layers, state_is_tuple=True) # for parameter sharing between training model # and testing model with tf.variable_scope('decoder') as scope: # build the seq2seq model # inputs : encoder, decoder inputs, LSTM cell type, vocabulary sizes, embedding dimensions self.decode_outputs, self.decode_states = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(self.enc_ip,self.dec_ip, stacked_lstm, xvocab_size, yvocab_size, emb_dim) # share parameters scope.reuse_variables() # testing model, where output of previous timestep is fed as input # to the next timestep self.decode_outputs_test, self.decode_states_test = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq( self.enc_ip, self.dec_ip, stacked_lstm, xvocab_size, yvocab_size,emb_dim, feed_previous=True) # now, for training, # build loss function # weighted loss # TODO : add parameter hint loss_weights = [ tf.ones_like(label, dtype=tf.float32) for label in self.labels ] self.loss = tf.contrib.legacy_seq2seq.sequence_loss(self.decode_outputs, self.labels, loss_weights, yvocab_size) # train op to minimize the loss self.train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.loss) sys.stdout.write('<log> Building Graph ') # build comput graph __graph__() sys.stdout.write('</log>') ''' Training and Evaluation ''' # get the feed dictionary def get_feed(self, X, Y, keep_prob): feed_dict = {self.enc_ip[t]: X[t] for t in range(self.xseq_len)} feed_dict.update({self.labels[t]: Y[t] for t in range(self.yseq_len)}) feed_dict[self.keep_prob] = keep_prob # dropout prob return feed_dict # run one batch for training def train_batch(self, sess, train_batch_gen): # get batches batchX, batchY = train_batch_gen.__next__() # build feed feed_dict = self.get_feed(batchX, batchY, keep_prob=0.5) _, loss_v = sess.run([self.train_op, self.loss], feed_dict) return loss_v def eval_step(self, sess, eval_batch_gen): # get batches batchX, batchY = eval_batch_gen.__next__() # build feed feed_dict = self.get_feed(batchX, batchY, keep_prob=1.) loss_v, dec_op_v = sess.run([self.loss, self.decode_outputs_test], feed_dict) # dec_op_v is a list; also need to transpose 0,1 indices # (interchange batch_size and timesteps dimensions dec_op_v = np.array(dec_op_v).transpose([1,0,2]) return loss_v, dec_op_v, batchX, batchY # evaluate 'num_batches' batches def eval_batches(self, sess, eval_batch_gen, num_batches): losses = [] for i in range(num_batches): loss_v, dec_op_v, batchX, batchY = self.eval_step(sess, eval_batch_gen) losses.append(loss_v) return np.mean(losses) # finally the train function that # runs the train_op in a session # evaluates on valid set periodically # prints statistics def train(self, train_set, valid_set, sess=None ): # we need to save the model periodically saver = tf.train.Saver() # if no session is given if not sess: # create a session sess = tf.Session() # init all variables sess.run(tf.global_variables_initializer()) sys.stdout.write('\n<log> Training started </log>\n') # run M epochs for i in range(self.epochs): try: self.train_batch(sess, train_set) if i % 1000 == 0 and i != 0: # TODO : make this tunable by the user # save model to disk saver.save(sess, self.ckpt_path + self.model_name + '.ckpt', global_step=i) # evaluate to get validation loss val_loss = self.eval_batches(sess, valid_set, 16) # TODO : and this # print stats print('\nModel saved to disk at iteration #{}'.format(i)) print('val loss : {0:.6f}'.format(val_loss)) sys.stdout.flush() except KeyboardInterrupt: # this will most definitely happen, so handle it print('Interrupted by user at iteration {}'.format(i)) self.session = sess return sess def restore_last_session(self): saver = tf.train.Saver() # create a session sess = tf.Session() # get checkpoint state ckpt = tf.train.get_checkpoint_state(self.ckpt_path) # restore session if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) # return to user return sess # prediction def predict(self, sess, X): feed_dict = {self.enc_ip[t]: X[t] for t in range(self.xseq_len)} feed_dict[self.keep_prob] = 1. dec_op_v = sess.run(self.decode_outputs_test, feed_dict) # dec_op_v is a list; also need to transpose 0,1 indices # (interchange batch_size and timesteps dimensions dec_op_v = np.array(dec_op_v).transpose([1,0,2]) # return the index of item with highest probability return np.argmax(dec_op_v, axis=2)
阅读全文
0 0
- tensorflow实战3-利用seq2seq实现一个聊天机器人
- 使用seq2seq模型实现一个聊天机器人
- ChatGirl 一个基于 TensorFlow Seq2Seq 模型的聊天机器人[中文文档]
- RNN聊天机器人与Beam Search [Tensorflow Seq2Seq]
- 利用tensorflow制作一个简单的聊天机器人
- Tensorflow搞一个聊天机器人
- 我用 tensorflow 实现的“一个神经聊天模型”:一个基于深度学习的聊天机器人
- 我用 tensorflow 实现的“一个神经聊天模型”:一个基于深度学习的聊天机器人
- 利用webQQ实现聊天机器人。。
- TensorFlow实现seq2seq
- TensorFlow实现seq2seq
- 构建聊天机器人:检索、seq2seq、RL、SeqGAN
- 基于Seq2seq的中文聊天机器人
- Seq2Seq Chatbot 聊天机器人:基于Torch的一个Demo搭建 手札
- TensorFlow 聊天机器人
- tensorflow聊天机器人后续
- 4.利用socket实现聊天机器人
- 用tensorflow实现seq2seq模型
- Leetcode 78. Subsets
- 【头条】新时代的新格局 从《猎场》看中国HR市场之变
- 编译型和解释型、静态类型和动态类型、强类型和弱类型语言
- Hexo+Yilia搭建github Pages个人博客
- SpringBoot的入门搭建(问题集)
- tensorflow实战3-利用seq2seq实现一个聊天机器人
- MOOC_人工智能原理学习笔记1
- Spring-Resource
- Java注解Annotation知识点
- SpringBoot非官方教程 | 第十九篇: 验证表单信息
- 01_Zookeeper_简介
- 初入IT行业
- hexo中npm WARN checkPermissions Missing write access
- C语言数组与指针一本道来