使用Spark MLlib随机森林RandomForest+pipeline进行预测

来源:互联网 发布:淘宝好看的跑步鞋店铺 编辑:程序博客网 时间:2024/04/29 02:25

这个程序中,我们使用pipeline来完成整个预测流程,加入了10-fold cross validation。

import org.apache.spark.{SparkConf, SparkContext}import org.apache.spark.mllib.linalg.Vectorsimport org.apache.spark.mllib.regression.LabeledPointimport org.apache.spark.sql.SQLContextimport org.apache.spark.ml.Pipelineimport org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}import org.apache.spark.ml.classification.RandomForestClassifierimport org.apache.spark.ml.evaluation.BinaryClassificationEvaluatorimport org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorAssembler}/**  * Created by simon on 2017/5/8.  */object genderClassificationWithRandomForest {  def main(args: Array[String]): Unit = {    val conf = new SparkConf()    conf.setAppName("genderClassification").setMaster("local[2]")    val sc = new SparkContext(conf)    val sqlContext = new SQLContext(sc)    import sqlContext.implicits._    val trainData = sc.textFile("file:\\E:\\test.csv")// 第一步,预处理数据,构建为DataFrame格式    val data = trainData.map { line =>      val parts= line.split("\\|")      val label = toInt(parts(1)) //第二列是标签      val features = Vectors.dense(parts.slice(6,parts.length-1).map(_.toDouble)) //第7到最后一列是属性,需要转换为Doube类型      LabeledPoint(label, features) //构建LabelPoint格式    }.toDF()// 第二步,将数据随机分为训练集和测试集    val Array(training, testing) = data.randomSplit(Array(0.7, 0.3),131L)// 第三步,准备一些基本参数和标签列indexer// 设置K折交叉验证的K的数量,以及随机森林树的数量,树的数量增加会大幅度增加训练时间    val nFolds: Int = 10    val NumTrees: Int = 500 //800,2000    val indexer = new StringIndexer()      .setInputCol("label")      .setOutputCol("label_idx")// 第四步,创建随机森林分类器    val rf = new RandomForestClassifier()      .setNumTrees(NumTrees)      .setFeaturesCol("features")      .setLabelCol("label_idx")      .setFeatureSubsetStrategy("auto")      .setImpurity("gini")      .setMaxDepth(10) //2,5,7      .setMaxBins(100)// 第五步,创建pipeline    val pipeline = new Pipeline().setStages(Array(indexer,rf))// 第六步,创建参数    val paramGrid = new ParamGridBuilder().build()// 第七步,设置预测效果测量器    val evaluator = new BinaryClassificationEvaluator()      .setLabelCol("label")      .setRawPredictionCol("rawPrediction")      .setMetricName("areaUnderROC")// 第八步,创建交叉验证对象,设置好pipeline、测量器、参数、K的数量    val cv = new CrossValidator()      .setEstimator(pipeline)      .setEvaluator(evaluator)      .setEstimatorParamMaps(paramGrid)      .setNumFolds(nFolds)// 第九步,使用训练集训练模型    val model = cv.fit(training)// 第十步,拿训练好的模型预测测试集    val predictions = model.transform(testing)    predictions.show()// 第十一步,测量预测效果    val metrics  = evaluator.evaluate(predictions)    println(metrics)  }  // 将标签转换为01  def toInt(s: String): Int = {    if (s == "m") 1 else  0  }}
阅读全文
0 0
原创粉丝点击