Spark ML Lib中的Tf-Idf生成的向量不能直接用于其他算法的问题

来源:互联网 发布:usb001端口 编辑:程序博客网 时间:2024/05/16 15:03

Spark ML Lib中提供了文档转为Tf-Idf加权的向量的功能,但是Tf是用的Hash方式将token进行映射,并且向量直接存储出来的格式并不能直接用于SVM、Naive Bayes等算法,因此需要做一些其它工作:

1.调整向量格式

生成TF部分代码不做改变

JavaRDD<String> text = sc.textFile(inputPath);JavaPairRDD<String,List<String>> document= text.mapToPair(new PairFunction<String,String,List<String>>(){public Tuple2<String,List<String>> call(String s){String str[] = s.split("\t");return  new Tuple2<String,List<String>>(str[0],Arrays.asList(s.split(" ")) );}});HashingTF tf = new HashingTF();JavaRDD<List<String>> features = document.values();    termFreqs = tf.transform(features);

下面是IDF部分,注意匿名函数里面对向量的形式做了一些改变


IDF idf = new IDF();   JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);  JavaRDD<String> tfidfVector = tfIdfs.map(new Function<Vector,String>(){    public String call(Vector v){    final StringBuilder builder = new StringBuilder();    AbstractFunction2<Object, Object, BoxedUnit> f = new AbstractFunction2<Object, Object, BoxedUnit>() {            public BoxedUnit apply(Object t1, Object t2) {            int dim =  (Integer)t1;            if( dim >1){            builder.append( dim+":"+(Double)t2+" " );            }                return BoxedUnit.UNIT;            }        };        v.foreachActive(f);        builder.deleteCharAt( builder.length()-1 );        return builder.toString();    }    });
改变之后的向量就按照稀疏向量的格式保存下来,之后添加上分类标记就可以直接用来跑ML Mlib里面的算法了(如SVM、NaiveBayes)

22909:1.0986122886681098 119158:1.0986122886681098 639018:1.0986122886681098 735243:1.098612288668109820154:1.0986122886681098 24456:0.4054651081081644 37117:0.6931471805599453 201116:0.6931471805599453 875579:1.0986122886681098113009:1.0986122886681098 127612:1.0986122886681098 686294:1.0986122886681098 736858:1.0986122886681098 832444:1.098612288668109820250:0.6931471805599453 21644:1.0986122886681098 24456:0.4054651081081644 25105:1.0986122886681098 37117:0.6931471805599453 119301:1.0986122886681098 201116:0.6931471805599453 730991:1.098612288668109820250:0.6931471805599453 24456:0.4054651081081644 26469:1.0986122886681098 30340:2.1972245773362196 35828:1.0986122886681098 38271:1.0986122886681098 689163:1.0986122886681098 704478:1.0986122886681098 750005:1.0986122886681098 779641:1.0986122886681098 796407:1.0986122886681098 798459:1.0986122886681098
注意代码中有这么三行代码:

if( dim >1){builder.append( dim+":"+(Double)t2+" " );}
这是为了将维度号为0和1的维度给过滤掉,如果不过滤,运行SVM或NB会出现数组越界异常。



0 0
原创粉丝点击