lstm做NER

来源:互联网 发布:数据仿真 教材 编辑:程序博客网 时间:2024/06/06 06:47

1、我在网上下载了人民日报语料库199801.txt文件,nerTest将该文档里的全角符号改为半角符号

然后利用nerTest_1文件对该语料库进行预处理:将语料库中的英文、数字、不正规的写法去除

2、将语料库分为traintestvalid数据集(721

3、将语料库中的wordtag分开,然后将所有的wordtag合并排序,并且建立tag_to_idword_to_id

4、然后在nerTest_3中写了lstm模型,使用tf.contrib.rnn.BasicLSTMCell设置为默认LSTM单元。然后如果在is_training状态并且keep_drop小于1,则在前面的lstm_cell之后加一个dropout层(tf.contrib.rnn.DropoutWrapper( ) ),然后用tf.contrib.rnn.MultiRNNCell将这些单元堆叠起来,用cell.zero_state将初始状态置为0。然后将模型save

5、在nerTest_4中导入模型,然后写了预测函数

6、最后在nerTest_5中载入测试集,然后选取测试集的data[‘len’] == True的行(这里全是True),然后去掉里面的标点符号、数学、英文等,用jieba分词,再用之前训练好的模型predict每行的分词之后的数据,得到每个词语以及其词性标注


代码片段如下,就不全贴出来了,也是在github上看别人写的。

全角改半角:
def strQ2B(ustring):    '''    :param ustring: fullWidth file    :return: halfWidth file    '''    rstring = ""    for uchar in ustring:        inside_code = ord(uchar)        if inside_code == 12288:            inside_code = 32        elif (inside_code >= 65281 and inside_code <= 65374):            inside_code -= 65248        else:            pass        rstring += chr(inside_code)    return rstring


注意tf.contrib.rnn.MultiRNNCell将这些单元堆叠起来时,要写:
 def lstm_cell():        return tf.contrib.rnn.BasicLSTMCell(            size, forget_bias=0.0, state_is_tuple=True, reuse=tf.get_variable_scope().reuse)    attn_cell = lstm_cell    # dropout    if is_training and config.keep_prob < 1.0:        def attn_cell():            return tf.contrib.rnn.DropoutWrapper(lstm_cell(), output_keep_prob=config.keep_prob)    cell = tf.contrib.rnn.MultiRNNCell([attn_cell() for _ in range(num_layers)], state_is_tuple=True)

然后就是储存模型

 # CheckPoint State        ckpt = tf.train.get_checkpoint_state(FLAGS.pos_train_dir)        if ckpt:            print("Loading model parameters from %s" % ckpt.model_checkpoint_path)            m.saver.restore(session, tf.train.latest_checkpoint(FLAGS.pos_train_dir))        else:            print("Created model with fresh parameters.")            session.run(tf.global_variables_initializer())

读取模型:
class ModelLoader(object):    def __init__(self, data_path, ckpt_path):        self.data_path = data_path        self.ckpt_path = ckpt_path  # the path of the ckpt file, e.g. ./ckpt/zh/pos.ckpt        self.session = tf.Session()        self.model = self._init_pos_model(self.session, self.ckpt_path)    def predict(self, words):        '''        :param words: input        :return: [word, tag]        '''        tagging = self._predict_pos_tags(self.session, self.model, words, self.data_path)        return tagging    def _init_pos_model(self, session, ckpt_path):        config = pos_model.get_config()        config.batch_size = 1        config.num_steps = 1        with tf.variable_scope("pos_var_scope"):            model = pos_model.POSTagger(is_training=False, config=config)        print(ckpt_path + '.data*')        if len(glob.glob(ckpt_path + '.data*')) > 0:            all_vars = tf.global_variables()            model_vars = [k for k in all_vars if k.name.startswith("pos_var_scope")]            tf.train.Saver(model_vars).restore(session, ckpt_path)        else:            print("Model not found")            session.run(tf.global_variables_initializer())        return model    def _predict_pos_tags(self, session, model, words, data_path):        word_data = pos_reader.sentence_to_word_ids(data_path, words)        tag_data = [0] * len(word_data)        state = session.run(model.initial_state)        predict_id = []        for step, (x, y) in enumerate(pos_reader.iterator(word_data, tag_data, model.batch_size, model.num_steps)):            fetches = [model.cost, model.final_state, model.logits]            feed_dict = {}            feed_dict[model.input_data] = x            feed_dict[model.targets] = y            for i, (c, h) in enumerate(model.initial_state):                feed_dict[c] = state[i].c                feed_dict[h] = state[i].h            _, _, logits = session.run(fetches, feed_dict)            predict_id.append(int(np.argmax(logits)))        predict_tag = pos_reader.word_ids_to_sentence(data_path, predict_id)        return zip(words, predict_tag)


代码太多了不贴了,这也算是我的第一个项目吧,mark一下。







原创粉丝点击