lstm做NER
来源:互联网 发布:数据仿真 教材 编辑:程序博客网 时间:2024/06/06 06:47
1、我在网上下载了人民日报语料库199801.txt文件,nerTest将该文档里的全角符号改为半角符号
然后利用nerTest_1文件对该语料库进行预处理:将语料库中的英文、数字、不正规的写法去除
2、将语料库分为train、test、valid数据集(7:2:1)
3、将语料库中的word与tag分开,然后将所有的word与tag合并排序,并且建立tag_to_id、word_to_id
4、然后在nerTest_3中写了lstm模型,使用tf.contrib.rnn.BasicLSTMCell设置为默认LSTM单元。然后如果在is_training状态并且keep_drop小于1,则在前面的lstm_cell之后加一个dropout层(tf.contrib.rnn.DropoutWrapper( ) ),然后用tf.contrib.rnn.MultiRNNCell将这些单元堆叠起来,用cell.zero_state将初始状态置为0。然后将模型save
5、在nerTest_4中导入模型,然后写了预测函数
6、最后在nerTest_5中载入测试集,然后选取测试集的data[‘len’] == True的行(这里全是True),然后去掉里面的标点符号、数学、英文等,用jieba分词,再用之前训练好的模型predict每行的分词之后的数据,得到每个词语以及其词性标注
代码片段如下,就不全贴出来了,也是在github上看别人写的。
全角改半角:
def strQ2B(ustring): ''' :param ustring: fullWidth file :return: halfWidth file ''' rstring = "" for uchar in ustring: inside_code = ord(uchar) if inside_code == 12288: inside_code = 32 elif (inside_code >= 65281 and inside_code <= 65374): inside_code -= 65248 else: pass rstring += chr(inside_code) return rstring
注意tf.contrib.rnn.MultiRNNCell将这些单元堆叠起来时,要写:
def lstm_cell(): return tf.contrib.rnn.BasicLSTMCell( size, forget_bias=0.0, state_is_tuple=True, reuse=tf.get_variable_scope().reuse) attn_cell = lstm_cell # dropout if is_training and config.keep_prob < 1.0: def attn_cell(): return tf.contrib.rnn.DropoutWrapper(lstm_cell(), output_keep_prob=config.keep_prob) cell = tf.contrib.rnn.MultiRNNCell([attn_cell() for _ in range(num_layers)], state_is_tuple=True)
然后就是储存模型
# CheckPoint State ckpt = tf.train.get_checkpoint_state(FLAGS.pos_train_dir) if ckpt: print("Loading model parameters from %s" % ckpt.model_checkpoint_path) m.saver.restore(session, tf.train.latest_checkpoint(FLAGS.pos_train_dir)) else: print("Created model with fresh parameters.") session.run(tf.global_variables_initializer())
读取模型:
class ModelLoader(object): def __init__(self, data_path, ckpt_path): self.data_path = data_path self.ckpt_path = ckpt_path # the path of the ckpt file, e.g. ./ckpt/zh/pos.ckpt self.session = tf.Session() self.model = self._init_pos_model(self.session, self.ckpt_path) def predict(self, words): ''' :param words: input :return: [word, tag] ''' tagging = self._predict_pos_tags(self.session, self.model, words, self.data_path) return tagging def _init_pos_model(self, session, ckpt_path): config = pos_model.get_config() config.batch_size = 1 config.num_steps = 1 with tf.variable_scope("pos_var_scope"): model = pos_model.POSTagger(is_training=False, config=config) print(ckpt_path + '.data*') if len(glob.glob(ckpt_path + '.data*')) > 0: all_vars = tf.global_variables() model_vars = [k for k in all_vars if k.name.startswith("pos_var_scope")] tf.train.Saver(model_vars).restore(session, ckpt_path) else: print("Model not found") session.run(tf.global_variables_initializer()) return model def _predict_pos_tags(self, session, model, words, data_path): word_data = pos_reader.sentence_to_word_ids(data_path, words) tag_data = [0] * len(word_data) state = session.run(model.initial_state) predict_id = [] for step, (x, y) in enumerate(pos_reader.iterator(word_data, tag_data, model.batch_size, model.num_steps)): fetches = [model.cost, model.final_state, model.logits] feed_dict = {} feed_dict[model.input_data] = x feed_dict[model.targets] = y for i, (c, h) in enumerate(model.initial_state): feed_dict[c] = state[i].c feed_dict[h] = state[i].h _, _, logits = session.run(fetches, feed_dict) predict_id.append(int(np.argmax(logits))) predict_tag = pos_reader.word_ids_to_sentence(data_path, predict_id) return zip(words, predict_tag)
代码太多了不贴了,这也算是我的第一个项目吧,mark一下。
阅读全文
0 0
- lstm做NER
- NER
- 用pyltp做分词、词性标注、ner
- Tensorflow进行POS词性标注NER实体识别 - 构建LSTM网络进行序列化标注
- lstm 做 文本的情感分析
- sanford ner
- 神经网络做NER比用规则字串匹配提取实体的进步优势地方
- LSTM
- lstm
- lstm
- lstm
- LSTM
- LSTM
- LSTM
- LSTM
- lstm
- lstm
- LSTM
- Redis大数据应用场景
- leetcode 263[easy]---Ugly Number
- 卷类型和qos主要命令
- Linq distinct去重方法之一
- Java ExecutorService四种线程池的例子与说明
- lstm做NER
- fastjson将json字符串转化成map的五种方法
- thinkphp5中foreach遍历循环
- C# 混合模式程序集是针对“v2.0.50727”版的运行时生成的,在没有配置其他信息的情况下,无法在 4.0 运行时中加载该程序集
- window 下安装mysql
- [python]python的命令行参数
- 使用Rxjava实现Eventbus
- UI 自动化高亮元素与截图小工具
- JavaScript 对传入时间戳进行转换(1分钟内显示刚刚,1小时内显示xx分钟前,今日的显示时分,昨天显示昨天+时分,一年内显示月日时分,一年外显示年月日)