Spark-MLlib实例——决策树

来源:互联网 发布:淘宝卖家如何发布微淘 编辑:程序博客网 时间:2024/05/29 15:07

Spark-MLlib实例——决策树

通俗来说,决策树分类的思想类似于找对象。现想象一个女孩的母亲要给这个女孩介绍男朋友,于是有了下面的对话:

女儿:多大年纪了?母亲:26。女儿:长的帅不帅?母亲:挺帅的。女儿:收入高不?母亲:不算很高,中等情况。女儿:是公务员不?母亲:是,在税务局上班呢。女儿:那好,我去见见。




以上是决策的经典例子,用spark-mllib怎么实现训练与预测呢


1、首先准备测试数据集

训练数据集 Tree1

字段说明:

是否见面, 年龄  是否帅  收入(1 高 2 中等 0 少)  是否公务员

0,32 1 1 00,25 1 2 01,29 1 2 11,24 1 1 00,31 1 1 01,35 1 2 10,30 0 1 00,31 1 1 01,30 1 2 11,21 1 1 00,21 1 2 01,21 1 2 10,29 0 2 10,29 1 0 10,29 0 2 11,30 1 1 0


测试数据集 Tree2

0,32 1 2 01,27 1 1 11,29 1 1 01,25 1 2 10,23 0 2 1

2、Spark-MLlib决策树应用代码

import org.apache.log4j.{Level, Logger}import org.apache.spark.mllib.feature.HashingTFimport org.apache.spark.mllib.linalg.Vectorsimport org.apache.spark.mllib.regression.LabeledPointimport org.apache.spark.mllib.tree.DecisionTreeimport org.apache.spark.mllib.util.MLUtilsimport org.apache.spark.{SparkConf, SparkContext}/**  * 决策树分类  */object TreeDemo {  def main(args: Array[String]) {    val conf = new SparkConf().setAppName("DecisionTree").setMaster("local")    val sc = new SparkContext(conf)    Logger.getRootLogger.setLevel(Level.WARN)    //训练数据    val data1 = sc.textFile("data/Tree1.txt")    //测试数据    val data2 = sc.textFile("data/Tree2.txt")    //转换成向量    val tree1 = data1.map { line =>      val parts = line.split(',')      LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))    }    val tree2 = data2.map { line =>      val parts = line.split(',')      LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))    }    //赋值    val (trainingData, testData) = (tree1, tree2)    //分类    val numClasses = 2    val categoricalFeaturesInfo = Map[Int, Int]()    val impurity = "gini"    //最大深度    val maxDepth = 5    //最大分支    val maxBins = 32    //模型训练    val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,      impurity, maxDepth, maxBins)    //模型预测    val labelAndPreds = testData.map { point =>      val prediction = model.predict(point.features)      (point.label, prediction)    }    //测试值与真实值对比    val print_predict = labelAndPreds.take(15)    println("label" + "\t" + "prediction")    for (i <- 0 to print_predict.length - 1) {      println(print_predict(i)._1 + "\t" + print_predict(i)._2)    }    //树的错误率    val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()    println("Test Error = " + testErr)    //打印树的判断值    println("Learned classification tree model:\n" + model.toDebugString)  }}


3、测试结果:

labelprediction0.00.01.01.01.01.01.01.00.00.0Test Error = 0.0Learned classification tree model:
可见真实值与预测值一致,Error为0


打印决策树的分支值,这里最大深度为 5 ,对应的树结构:

Learned classification tree model:DecisionTreeModel classifier of depth 4 with 11 nodes  If (feature 1 <= 0.0)   Predict: 0.0  Else (feature 1 > 0.0)   If (feature 3 <= 0.0)    If (feature 0 <= 30.0)     If (feature 2 <= 1.0)      Predict: 1.0     Else (feature 2 > 1.0)      Predict: 0.0    Else (feature 0 > 30.0)     Predict: 0.0   Else (feature 3 > 0.0)    If (feature 2 <= 0.0)     Predict: 0.0    Else (feature 2 > 0.0)     Predict: 1.0
可见预测出的分界值与真实一致,准确率与决策树算法,参数设置及训练样本的选择覆盖有关!


1 0