SparkML实战之五:SVM

来源:互联网 发布:社交网络上的跟风行为 编辑:程序博客网 时间:2024/05/16 16:03
package MLlibimport org.apache.spark.{SparkContext, SparkConf}import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}import org.apache.spark.mllib.evaluation.BinaryClassificationMetricsimport org.apache.spark.mllib.util.MLUtils/** * Created by root on 16-1-12. */object SVM {  def main(args: Array[String]) {    val conf = new SparkConf().setAppName("SVM").setMaster("local[4]")    val sc = new SparkContext(conf)    // Load training data in LIBSVM format.//    1 1 0 2.52078447201548 0 0 0 2.004684436494304 2.000347299268466 0 2.228387042742021 2.228387042742023 0 0 0 0 0 0 2//    0 2.857738033247042 0 0 2.619965104088255 0 2.004684436494304 2.000347299268466 0 2.228387042742021 2.228387042742023 0     0 0 0 0 0    val data = MLUtils.loadLibSVMFile(sc, "/usr/local/spark/spark-1.6.0-bin-hadoop2.4" +      "/data/mllib/sample_libsvm_data.txt")    // Split data into training (60%) and test (40%).    val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)    val training = splits(0).cache()    val test = splits(1)    // Run training algorithm to build the model    val numIterations = 100    val model = SVMWithSGD.train(training, numIterations)    // Clear the default threshold.    model.clearThreshold()    // Compute raw scores on the test set.    val scoreAndLabels = test.map { point =>      val score = model.predict(point.features)      (score, point.label)    }    // Get evaluation metrics.    val metrics = new BinaryClassificationMetrics(scoreAndLabels)    val auROC = metrics.areaUnderROC()    println("Area under ROC = " + auROC)    // Save and load model    //保存模型后会之myModelPath下创建data和metadata目录保存模型对象//    model.save(sc, "myModelPath")//    val sameModel = SVMModel.load(sc, "myModelPath")    //--------------------------------------------------------------------------    //SVMWithSGD.train()方法默认使用正则参数为1.0的L2正则化    //如果我们想配置算法,我们可以new 一个对象并调用setter方法,其实所有spark.mllib算法都支持这种    //方式,例如,下面使用正则参数为0.1的L1正则来训练算法递归200次    import org.apache.spark.mllib.optimization.L1Updater    System.out.println("使用正则参数为0.1的L1正则来训练算法递归200次----------------" +      "-------------------------------------------------")    val svmAlg = new SVMWithSGD()    svmAlg.optimizer.      setNumIterations(200).      setRegParam(0.1).      setUpdater(new L1Updater)    val modelL1 = svmAlg.run(training)    // Compute raw scores on the test set.    val L1scoreAndLabels = test.map { point =>      val score = model.predict(point.features)      (score, point.label)    }    // Get evaluation metrics.    val L1metrics = new BinaryClassificationMetrics(L1scoreAndLabels)    val L1auROC = L1metrics.areaUnderROC()    println("L1正则的Area under ROC = " + L1auROC)  }}
0 0