基于NaiveBayes的文本分类之Spark实现

来源:互联网 发布:redflag linux 8 编辑:程序博客网 时间:2024/05/16 10:36

  在尝试了python下面用sklearn进行文本分类(http://blog.csdn.net/a_step_further/article/details/50189727)后,我们再来看下用spark如何实现文本分类的工作,采用的算法同样是朴素贝叶斯。

    此前,我们已经实现了hadoop集群环境下使用mapreduce进行中文分词(http://blog.csdn.net/a_step_further/article/details/50333961),那么文本分类的过程也使用集群环境操作,相对于python的单机版本实现,无疑更方便一些。

上代码:

import org.apache.spark.mllib.classification.NaiveBayesimport org.apache.spark.mllib.feature.{IDFModel, HashingTF, IDF}import org.apache.spark.mllib.linalg.Vectorimport org.apache.spark.mllib.regression.LabeledPointimport org.apache.spark.rdd.RDDimport org.apache.spark.{SparkContext, SparkConf}object textClassify {  def main(args: Array[String]): Unit = {    val conf = new SparkConf().setAppName("text_classify").set("spark.akka.frameSize","20")    val sc = new SparkContext(conf)        if(args.length != 2){           println("Usage: textClassify <inputLoc> <idfSaveLoc> <modelSaveLoc> ")           System.exit(-1)        }    val inputLoc = args(0)    val inputData = sc.textFile(inputLoc).map(line => line.split("\t")).filter(_.length == 2).cache()    val features = inputData.map(x => x(1).split(" ").toSeq).cache()    val hashingTF = new HashingTF()    val tf = hashingTF.transform(features)    val idf: IDFModel = new IDF(minDocFreq = 2).fit(tf)    val tfIdf = idf.transform(tf)    val zippedData = inputData.map(x => x(0)).zip(tfIdf).map{case (label, tfIdf) =>       LabeledPoint(label.toDouble, tfIdf)    }.cache()    val randomSplitData = zippedData.randomSplit(Array(0.6, 0.4), seed=10L)    zippedData.unpersist()    val trainData = randomSplitData(0).cache()    val testData = randomSplitData(1)    val model = NaiveBayes.train(trainData, lambda = 0.1)    trainData.unpersist()    //预测    val predictTestData = testData.map{case x => (model.predict(x.features), x.label)}    val totalTrueNum = predictTestData.filter(x => x._2 == 1.0).count()    val predictTrueNum = predictTestData.filter(x => x._1 == 1.0).count()    val predictRealTrue = predictTestData.filter(x => x._1 == x._2 && x._2 == 1.0).count()    println("results------------------------------------------------")    println("准确率:", 1.0*predictRealTrue/predictTrueNum)    println("召回率:",1.0*predictRealTrue/totalTrueNum)    println("------------------------------------------------")    val modelSaveLoc = args(1)    model.save(sc,modelSaveLoc)    sc.stop()  }}



0 0