Spark MLlib 入门学习笔记

来源:互联网 发布:微总汇软件怎么样 编辑:程序博客网 时间:2024/04/28 04:09

在官方API文档可以查到用法。

def trainClassifier(input: RDD[LabeledPoint], numClasses: Int, categoricalFeaturesInfo: Map[Int, Int], impurity: String, maxDepth: Int, maxBins: Int): DecisionTreeModelMethod to train a decision tree model for binary or multiclass classification.input Training dataset: RDD of org.apache.spark.mllib.regression.LabeledPoint. Labels should take values {0, 1, ..., numClasses-1}.numClasses number of classes for classification.categoricalFeaturesInfo Map storing arity of categorical features. E.g., an entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, ..., k-1}.impurity Criterion used for information gain calculation. Supported values: "gini" (recommended) or "entropy".maxDepth Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (suggested value: 5)maxBins maximum number of bins used for splitting features (suggested value: 32)returns DecisionTreeModel that can be used for prediction
kyphosis 数据集

kyphosis数据集的各列含义:
数据集是从儿童接受外科脊柱矫正手术中来的,数据集有4列、81行(81个病例)。
1、kyphosis:采取手术后依然出现脊柱后凸(驼背)的因子
2、Age:单位是“月”
3、Number:代表进行手术的脊柱椎骨的数目
4、Start:在脊柱上从上往下数、参与手术的第一节椎骨所在的序号

absent 158 3 14present 128 4 5absent 2 5 1absent 1 4 15absent 1 2 16absent 61 2 17absent 37 3 16absent 113 2 16present 59 6 12present 82 5 14absent 148 3 16absent 18 5 2absent 1 4 12absent 168 3 18absent 1 3 16absent 78 6 15absent 175 5 13absent 80 5 16absent 27 4 9absent 22 2 16present 105 6 5present 96 3 12absent 131 2 3present 15 7 2absent 9 5 13absent 8 3 6absent 100 3 14absent 4 3 16absent 151 2 16absent 31 3 16absent 125 2 11absent 130 5 13absent 112 3 16absent 140 5 11absent 93 3 16absent 1 3 9present 52 5 6absent 20 6 9present 91 5 12present 73 5 1absent 35 3 13absent 143 9 3absent 61 4 1absent 97 3 16present 139 3 10absent 136 4 15absent 131 5 13present 121 3 3absent 177 2 14absent 68 5 10absent 9 2 17present 139 10 6absent 2 2 17absent 140 4 15absent 72 5 15absent 2 3 13present 120 5 8absent 51 7 9absent 102 3 13present 130 4 1present 114 7 8absent 81 4 1absent 118 3 16absent 118 4 16absent 17 4 10absent 195 2 17absent 159 4 13absent 18 4 11absent 15 5 16absent 158 5 14absent 127 4 12absent 87 4 16absent 206 4 10absent 11 3 15absent 178 4 15present 157 3 13absent 26 7 13absent 120 2 13present 42 7 6absent 36 4 13
这个数据集,缺省是用多个空格隔开的,所以写了一个python脚本将其处理成一个空格隔开。

import ref = open("e:/MyProject/SparkDiscover/data/kyphosis.data","r")line = f.readline()while line:    line = f.readline()    line = line.strip("\n")    out = re.sub(r"\s{2,}", " ", line)    print(out)

测试代码

package classifyimport org.apache.spark.mllib.tree.DecisionTreeimport org.apache.spark.mllib.evaluation.MulticlassMetricsimport org.apache.spark.{SparkConf, SparkContext}import org.apache.spark.mllib.linalg.{Vector, Vectors}import org.apache.spark.mllib.regression.LabeledPointobject DsTree {  def parseLine(line: String): LabeledPoint = {    val parts = line.split(" ")    val vd: Vector = Vectors.dense(parts(1).toInt, parts(2).toInt, parts(3).toInt)    var target = 0    parts(0) match {      case "absent" => target = 0;      case "present" => target = 1;    }    return LabeledPoint(target, vd)  }  def main(args: Array[String]) {    val conf = new SparkConf().setMaster(args(0)).setAppName("Iris")    val sc = new SparkContext(conf)    val data = sc.textFile(args(1)).map(parseLine(_))    val splits = data.randomSplit(Array(0.7, 0.3), seed = 11L)    val trainData = splits(0)    val testData = splits(1)    val numClasses = 2 //分类数量    val categoricalFeaturesInfo = Map[Int, Int]() //输入格式    val impurity = "entropy" //信息增益计算方式 gini    val maxDepth = 5 //树的高度    val maxBins = 3 //分裂数据集    val model = DecisionTree.trainClassifier(trainData, numClasses, categoricalFeaturesInfo,      impurity, maxDepth, maxBins)    val predictionAndLabel = testData.map(p => (model.predict(p.features), p.label))    predictionAndLabel.foreach(println)    val metrics = new MulticlassMetrics(predictionAndLabel)    val precision = metrics.precision    println("Precision = " + precision)  }}




原创粉丝点击