Mahout 中文分类 (2)

来源:互联网 发布:mac远程控制 编辑:程序博客网 时间:2024/06/05 05:37

原文链接:http://zaumreit.me/blog/2013/12/15/mahout-chinese-classification/

===================

mahout 中文分类 里只写了如何部署 Mahout 以及如何训练模型以及测试,并没有写如何对新的数据进行分类。这篇文章讲如何对新的数据进行分类。

Mahout 好像没有提供命令行工具来对新数据进行向量化。好在 上文 中,从训练数据生成向量(vector)的过程中生成了 dictionarydf-countlabelindex 等文件(这些文件在 hdfs 上),Mahout API 也提供了读取这些文件的相关方法,所以可以自己写代码对新文档进行分类。

构造 TF-IDF 向量

上文 中,训练模型的训练数据我用了 tfidf-vector ,测试的时候也是 tfidf-vector,所以为了应用这个模型,需要把新文档表示成 tfidf-vector

Class TFIDF 提供了一个方法 calculate(int tf,int df,int length,int numDocs),用来计算一个词的 tf-idf 值,4 个参数分别代表:
tf: 单词在新文档中出现的次数 
df: 训练集中包含单词的文档个数 
length: 新文档包含的所有单词个数 
numDocs: 训练集所有文档个数

Mahout 的源码里计算 tf-idf 值的函数,length 参数没有被用到:

12345678
public class TFIDF implements Weight {  private final DefaultSimilarity sim = new DefaultSimilarity();  @Override  public double calculate(int tf, int df, int length, int numDocs) {    // ignore length        return sim.tf(tf) * sim.idf(df, numDocs);  }}

tf 和 length 从新文档里可以统计出来,df 和 numDocs 需要从 df-count 文件里取到。df-count 包含所有词的 key/value,其中 key 是 dictionary 文件里对应的 valuedf-count文件里 key = 1 代表训练集文档个数。

dictionary 文件,key 是单词,Value 是对应的 ID(我没有去停用词和标点符号):

1234567
Key: !: Value: 0Key: ": Value: 1Key: #: Value: 2Key: $: Value: 3Key: %: Value: 4Key: ': Value: 5....

df-count 文件,其实是 df-count 目录下的 part-r-00000 的文件,key 是 ID,value 是包含单词的文档个数:

1234567
Key: -1: Value: 8361Key: 0: Value: 154Key: 1: Value: 94Key: 2: Value: 6Key: 3: Value: 3Key: 4: Value: 10....

有了这些信息就可以计算 tf-idf,并且构造 vector 了。这里用到的 vector 是 Class Vector,每一个元素是 id : tf-idf

12
Vector vector = new RandomAccessSparseVector(10000);vector.setQuick(wordId, tfIdfValue); //设定一个词的tf-idf值

分类

利用 Class StandardNaiveBayesClassifier 的 classifyFull() 函数进行分类。

123456
//从模型文件读取模型NaiveBayesModel model = NaiveBayesModel.materialize(new Path(modelPath), configuration);//用模型初始化分类器StandardNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(model);//返回 vector 在所有类别下的得分,得分最高的就是最后的分类Vector resultVector = classifier.classifyFull(vector);

也可以把计算好的 tf-idf vector 输出到 sequence 类型的文件里,然后用命令行工具 mahout testnb 来看朴素贝叶斯分类器对新文档的分类效果。输出到 sequence 文件的代码,sequence文件也是由 key/value 对组成。

1234567
Writer writer = new SequenceFile.Writer(fs, configuration, new Path(outputFileName), Text.class, VectorWritable.class);Text key = new Text();VectorWritable value = new VectorWritable();key.set("/" + label + "/" + inputFileName); // label 是预期的类别标签,inputFileName 作为向量的标识value.set(vector); // value 是 输入文档的 `tf-idf` 向量writer.append(key, value);writer.close();

这样就把向量输出到一个 sequence 文件里了。

完整代码

这是我用来分类的完整代码,我是先把 modellabelindexdf-countdictionary 文件从 hdfs 上弄下来之后放到工程目录下使用的,也可以直接连接 hdfs 来读取这些文件:

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
import java.io.StringReader;import java.util.HashMap;import java.util.Map;import org.apache.hadoop.conf.Configuration;import org.apache.hadoop.fs.Path;import org.apache.hadoop.io.IntWritable;import org.apache.hadoop.io.LongWritable;import org.apache.hadoop.io.Text;import org.apache.lucene.analysis.Analyzer;import org.apache.lucene.analysis.TokenStream;import org.apache.lucene.analysis.core.WhitespaceAnalyzer;import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;import org.apache.lucene.util.Version;import org.apache.mahout.classifier.naivebayes.BayesUtils;import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;import org.apache.mahout.common.Pair;import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;import org.apache.mahout.math.RandomAccessSparseVector;import org.apache.mahout.math.Vector;import org.apache.mahout.math.Vector.Element;import org.apache.mahout.vectorizer.TFIDF;import com.google.common.collect.ConcurrentHashMultiset;import com.google.common.collect.Multiset;public class Classifier {    public static Map<String, Integer> readDictionnary(Configuration conf, Path dictionnaryPath) {        Map<String, Integer> dictionnary = new HashMap<String, Integer>();        for (Pair<Text, IntWritable> pair : new SequenceFileIterable<Text, IntWritable>(dictionnaryPath, true, conf)) {            dictionnary.put(pair.getFirst().toString(), pair.getSecond().get());        }        return dictionnary;    }    public static Map<Integer, Long> readDocumentFrequency(Configuration conf, Path documentFrequencyPath) {        Map<Integer, Long> documentFrequency = new HashMap<Integer, Long>();        for (Pair<IntWritable, LongWritable> pair : new SequenceFileIterable<IntWritable, LongWritable>(documentFrequencyPath, true, conf)) {            documentFrequency.put(pair.getFirst().get(), pair.getSecond().get());        }        return documentFrequency;    }    public static void main(String[] args) throws Exception {        //上述几个文件路径        String modelPath = "./mahout/model";        String labelIndexPath = "./mahout/labelindex";        String dictionaryPath = "./mahout/vectors/dictionary.file-0";        String documentFrequencyPath = "./mahout/vectors/df-count/part-r-00000";        Configuration configuration = new Configuration();        //hdfs 配置        //configuration.set("fs.default.name", "hdfs://172.21.1.129:9000");        //configuration.set("mapred.job.tracker", "172.21.1.129:9001");        //读取模型文件        NaiveBayesModel model = NaiveBayesModel.materialize(new Path(modelPath), configuration);        //初始化训练器        StandardNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(model);        //读取 labelindex、dictionary、df-count        Map<Integer, String> labels = BayesUtils.readLabelIndex(configuration, new Path(labelIndexPath));        Map<String, Integer> dictionary = readDictionnary(configuration, new Path(dictionaryPath));        Map<Integer, Long> documentFrequency = readDocumentFrequency(configuration, new Path(documentFrequencyPath));        //文本分析的 analyzer,我之前是用 fudannlp 对文件进行了分词        //输入是以空格分割的文件        //所以用 WhitespaceAnalyzer,也可以换成其他 analyzer        //lucene 版本是 4.3.0        Analyzer analyzer = new WhitespaceAnalyzer(Version.LUCENE_43);        //读取训练集包含的文档个数        int documentCount = documentFrequency.get(-1).intValue();        //待分类文本        String content = "";        Multiset<String> words = ConcurrentHashMultiset.create();        TokenStream ts = analyzer.tokenStream("text", new StringReader(content));        CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class);        ts.reset();        int wordCount = 0;        //统计在 dictionary 里出现的待分类的新文档的词        while (ts.incrementToken()) {            if (termAtt.length() > 0) {                String word = ts.getAttribute(CharTermAttribute.class).toString();                Integer wordId = dictionary.get(word);                if (wordId != null) {                    words.add(word);                    wordCount++;                }            }        }        //计算 TF-IDF,并构造 Vector         Vector vector = new RandomAccessSparseVector(10000);        TFIDF tfidf = new TFIDF();        for (Multiset.Entry<String> entry : words.entrySet()) {            String word = entry.getElement();            int count = entry.getCount();            Integer wordId = dictionary.get(word);            Long freq = documentFrequency.get(wordId);            double tfIdfValue = tfidf.calculate(count, freq.intValue(), wordCount, documentCount);            vector.setQuick(wordId, tfIdfValue);        }        //分类        Vector resultVector = classifier.classifyFull(vector);        double bestScore = -Double.MAX_VALUE;        int bestCategoryId = -1;        for(Element element : resultVector.all()) {            int categoryId = element.index();            double score = element.get();            if (score > bestScore) {                bestScore = score;                bestCategoryId = categoryId;            }            System.out.print("  " + labels.get(categoryId) + ": " + score);        }        System.out.println(" => " + labels.get(bestCategoryId));        analyzer.close();    }}

利用 Mahout 朴素贝叶斯分类大概就这样了。

0 0
原创粉丝点击