fasttext初步使用

来源:互联网 发布:淘宝客网站要备案吗 编辑:程序博客网 时间:2024/05/17 22:27

转载自:

http://blog.csdn.net/lxg0807/article/details/52960072#comments


训练数据和测试数据来自网盘:

https://pan.baidu.com/s/1jH7wyOY

https://pan.baidu.com/s/1slGlPgx



训练以上数据

# _*_coding:utf-8 _*_import logginglogging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)import fasttext#训练模型classifier = fasttext.supervised("news_fasttext_train.txt","news_fasttext.model",label_prefix="__label__")



进行测试:

# -*- coding:utf-8 -*-import logginglogging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)import fasttext#load训练好的模型classifier = fasttext.load_model('news_fasttext.model.bin', label_prefix='__label__')result = classifier.test("news_fasttext_test.txt")print result.precisionprint result.recall

注意每次训练的模型都有不同,所以测试的结果大概是0.87~0.92左右


进行最终评价:

# -*- coding:utf-8 -*-import logginglogging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)import fasttext#load训练好的模型classifier = fasttext.load_model('news_fasttext.model.bin', label_prefix='__label__')result = classifier.test("news_fasttext_test.txt")print result.precisionprint result.recalllabels_right = []texts = []with open("news_fasttext_test.txt") as fr:    lines = fr.readlines()for line in lines:    labels_right.append(line.split("\t")[1].rstrip().replace("__label__",""))    texts.append(line.split("\t")[0].decode("utf-8"))#     print labels#     print texts#     breaklabels_predict = [e[0] for e in classifier.predict(texts)] #预测输出结果为二维形式# print labels_predicttext_labels = list(set(labels_right))text_predict_labels = list(set(labels_predict))print text_predict_labelsprint text_labelsA = dict.fromkeys(text_labels,0)  #预测正确的各个类的数目B = dict.fromkeys(text_labels,0)   #测试数据集中各个类的数目C = dict.fromkeys(text_predict_labels,0) #预测结果中各个类的数目for i in range(0,len(labels_right)):    B[labels_right[i]] += 1    C[labels_predict[i]] += 1    if labels_right[i] == labels_predict[i]:        A[labels_right[i]] += 1print Aprint Bprint C#计算准确率,召回率,F值for key in B:    p = float(A[key]) / float(B[key])    r = float(A[key]) / float(C[key])    f = p * r * 2 / (p + r)    print "%s:\tp:%f\t%fr:\t%f" % (key,p,r,f)

之所以搞这么一出,是因为fasttext提供的p值(准确率)和r值(召回率)只是针对所有结果的,而不是针对各个类别分别进行计算p值(准确率)和r值(召回率)的,所以该作者自己写了计算方法。




输出结果:

[u'affairs', u'fashion', u'lottery', u'house', u'sports', u'game', u'economic', u'ent', u'edu', u'home', u'stock', u'constellation', u'science']['affairs', 'fashion', 'house', 'sports', 'game', 'economic', 'ent', 'edu', 'home', 'stock', 'science']{'science': 8921, 'affairs': 8544, 'fashion': 2148, 'house': 9572, 'sports': 9814, 'game': 9389, 'economic': 9492, 'ent': 9660, 'edu': 9671, 'home': 8027, 'stock': 8525}{'science': 10000, 'affairs': 10000, 'fashion': 3369, 'house': 10000, 'sports': 10000, 'game': 10000, 'economic': 10000, 'ent': 10000, 'edu': 10000, 'home': 10000, 'stock': 10000}{u'science': 10311, u'affairs': 8953, u'fashion': 2176, u'lottery': 28, u'house': 10502, u'sports': 10288, u'game': 10182, u'economic': 11087, u'ent': 10940, u'edu': 10991, u'home': 8171, u'constellation': 466, u'stock': 9274}science:p:0.8921000.865193r:0.878440affairs:p:0.8544000.954317r:0.901599fashion:p:0.6375780.987132r:0.774752house:p:0.9572000.911445r:0.933763sports:p:0.9814000.953927r:0.967468game:p:0.9389000.922117r:0.930433economic:p:0.9492000.856138r:0.900270ent:p:0.9660000.882998r:0.922636edu:p:0.9671000.879902r:0.921443home:p:0.8027000.982377r:0.883496stock:p:0.8525000.919237r:0.884611



原创粉丝点击