Spark MLlib 入门学习笔记

来源:互联网 发布:php判断时间的大小 编辑:程序博客网 时间:2024/04/28 07:52

在官方的API文档可以查到用法。

def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModelinput RDD of (label, array of features) pairs. Every vector should be a frequency vector or a count vector.lambda The smoothing parametermodelType The type of NB model to fit from the enumeration NaiveBayesModels, can be multinomial or bernoulli,缺省为multinomial
调用方法很简单,用Iris数据集进行测试,代码如下。

package classifyimport org.apache.spark.{SparkConf, SparkContext}import org.apache.spark.mllib.linalg.{Vector, Vectors}import org.apache.spark.mllib.regression.LabeledPointimport org.apache.spark.mllib.classification.NaiveBayesimport org.apache.spark.mllib.evaluation.MulticlassMetricsobject bayes {  def isValid(line: String): Boolean = {    val parts = line.split(",")    return parts.length == 5  }  def parseLine(line: String): LabeledPoint = {    val parts = line.split(",")    val vd: Vector = Vectors.dense(parts(0).toDouble, parts(1).toDouble, parts(2).toDouble, parts(3).toDouble)    var target = 0    parts(4) match {      case "Iris-setosa"   => target=1;      case "Iris-versicolor"   => target = 2;      case "Iris-virginica"    => target = 3;    }    return LabeledPoint(target, vd )  }  def main(args: Array[String]) {    val conf = new SparkConf().setMaster(args(0)).setAppName("")    val sc = new SparkContext(conf)    val data = sc.textFile(args(1)).filter(isValid(_)).map(parseLine(_))    val splits = data.randomSplit(Array(0.7, 0.3), seed=11L)    val trainData = splits(0)    val testData = splits(1)    val model = NaiveBayes.train(trainData, lambda = 1.0)    val predictionAndLabel = testData.map(p => (model.predict(p.features), p.label))    predictionAndLabel.foreach(println)    val metrics = new MulticlassMetrics(predictionAndLabel)    val precision = metrics.precision    println("Precision = " + precision)  }}


原创粉丝点击