SparkMLlib Java 决策树分类算法(DecisionTree)

来源:互联网 发布:bat的程序员什么水平 编辑:程序博客网 时间:2024/05/16 17:56

决策树基本理解:

决策树利用树形结构,根据特征一层一层做出判断,会在某一层得到结果。我在其他博客中看到了一副非常好的诠释图:

SparkMLlib Java程序所用数据:

   训练数据:C:\hello\trainData.txt


该数据,逗号前为目标向量,逗号后为特征向量(空格隔开)。
   测试数据:C:\hello\testData.txt

该数据为特征向量,空格隔开。

SparkMLlib DecisionTreeJava程序:

package MLlibTest;import java.util.HashMap;import java.util.Map;import org.apache.spark.SparkConf;import org.apache.spark.api.java.JavaPairRDD;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.api.java.function.Function;import org.apache.spark.api.java.function.PairFunction;import org.apache.spark.api.java.function.VoidFunction;import org.apache.spark.mllib.linalg.Vector;import org.apache.spark.mllib.linalg.Vectors;import org.apache.spark.mllib.regression.LabeledPoint;import org.apache.spark.mllib.tree.DecisionTree;import org.apache.spark.mllib.tree.model.DecisionTreeModel;import scala.Tuple2;public class DecisionTreeTest{public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("DecisionTreeTest").setMaster("local[*]");      JavaSparkContext jsc = new JavaSparkContext(conf);      JavaRDD<String> lines = jsc.textFile("C://hello//trainData.txt");      JavaRDD<LabeledPoint> transdata = lines.map(new Function<String,LabeledPoint>(){      private static final long serialVersionUID = 1L; @Override      public LabeledPoint call(String str) throws Exception{ String[] t1 = str.split(",");String[] t2 = t1[1].split(" ");LabeledPoint lab = new LabeledPoint(Double.parseDouble(t1[0]),Vectors.dense(Double.parseDouble(t2[0]), Double.parseDouble(t2[1]), Double.parseDouble(t2[2])));return lab;}      });      //设置决策树参数,训练模型     Integer numClasses = 3;        Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();        String impurity = "gini";        Integer maxDepth = 5;        Integer maxBins = 32;        final DecisionTreeModel tree_model = DecisionTree.trainClassifier(transdata, numClasses,categoricalFeaturesInfo, impurity, maxDepth, maxBins);        System.out.println("决策树模型:");          System.out.println(tree_model.toDebugString());        //保存模型        tree_model.save(jsc.sc(), "C://hello//DecisionTreeModel");                        //未处理数据,带入模型处理        JavaRDD<String> testLines = jsc.textFile("C://hello//testData.txt");        JavaPairRDD<String,String> res = testLines.mapToPair(new PairFunction<String, String, String>() {        private static final long serialVersionUID = 1L;        @Override        public Tuple2<String,String> call(String line) throws Exception{        String[] t2 = line.split(" ");Vector v = Vectors.dense(Double.parseDouble(t2[0]), Double.parseDouble(t2[1]),Double.parseDouble(t2[2]));double res = tree_model.predict(v);        return new Tuple2<String,String>(line,Double.toString(res));        }}).cache();        //打印结果     res.foreach(new VoidFunction<Tuple2<String,String>>() {     private static final long serialVersionUID = 1L; @Override     public void call(Tuple2<String,String> a) throws Exception{ System.out.println(a._1+" : "+a._2);}});     //将结果保存在本地     res.saveAsTextFile("C://hello/res");}}

结语:

     做的时间匆忙,错误之处,请大家指出批评,相互学习。



原创粉丝点击