利用spark的mllib构建GBDT模型

来源:互联网 发布:多益网络登录器 编辑:程序博客网 时间:2024/06/04 18:56

GBDT模型

GBDT模型的介绍,我主要是参考博客:http://blog.csdn.net/w28971023/article/details/8240756
在这里,我主要归纳以下几点要素:
1.GBDT中的树都是回归树;
2.回归树节点分割点衡量最好的标准是叶子个数的上限;
3.GBDT的核心在于,每个棵树学的是之前所有树结论和的残差,这个残差就是一个加预测值后能得到真实值的累加量;
4.GB为Gradient Boosting, Boosting的最大好处在于,每一步的残差计算其实变相地增大了分错instance的权重,而已经分对的instance则趋向于0;
5.GBDT采用一个Shrinkage策略,本质上,Shrinkage为每棵树设置了一个weight,累加时要乘以这个weight,但和Gradient并没有关系。

利用spark构建GBDT模型

训练GBDT模型

public void trainModel(){        //初始化spark        SparkConf conf = new SparkConf().setAppName("GBDT").setMaster("local");        conf.set("spark.testing.memory","2147480000");        SparkContext sc = new SparkContext(conf);        //加载训练文件, 使用MLUtils包        JavaRDD<LabeledPoint> lpdata = MLUtils.loadLibSVMFile(sc, this.trainsetFile).toJavaRDD();        //训练模型, 默认情况下使用均值方差作为阈值标准        int numIteration = 10;  //boosting提升迭代的次数        int maxDepth = 3;       //回归树的最大深度        BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression");        boostingStrategy.setNumIterations(numIteration);        boostingStrategy.getTreeStrategy().setMaxDepth(maxDepth);        //记录所有特征的连续结果        Map<Integer, Integer> categoricalFeaturesInfoMap = new HashMap<Integer, Integer>();        boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfoMap);        //gdbt模型        final GradientBoostedTreesModel model = GradientBoostedTrees.train(lpdata, boostingStrategy);        model.save(sc, modelpath);        sc.stop();    }

预测数据

public void predict() {        //初始化spark        SparkConf conf = new SparkConf().setAppName("GBDT").setMaster("local");        conf.set("spark.testing.memory","2147480000");        SparkContext sc = new SparkContext(conf);        //加载gbdt模型        final GradientBoostedTreesModel model = GradientBoostedTreesModel.load(sc, this.modelpath);        //加载测试文件        JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, this.predictFile).toJavaRDD();        testData.cache();        //预测数据        JavaRDD<Tuple2<Double, Double>>  predictionAndLabel = testData.map(new Prediction(model)) ;        //计算所有数据的平均值方差         Double testMSE = predictionAndLabel.map(new CountSquareError()).reduce(new ReduceSquareError()) / testData.count();         System.out.println("testData's MSE is : " + testMSE);         sc.stop();    }    static class Prediction implements Function<LabeledPoint, Tuple2<Double , Double>> {        GradientBoostedTreesModel model;        public Prediction(GradientBoostedTreesModel model){            this.model = model;        }        public Tuple2<Double, Double> call(LabeledPoint p) throws Exception {            Double score = model.predict(p.features());            return new Tuple2<Double , Double>(score, p.label());        }    }    static class CountSquareError implements Function<Tuple2<Double, Double>, Double> {        public Double call (Tuple2<Double, Double> pl) {            double diff = pl._1() - pl._2();            return diff * diff;        }    }    static  class ReduceSquareError implements Function2<Double, Double, Double> {        public Double call(Double a , Double b){            return a + b ;        }    }

关于具体的代码放至我的github上:https://github.com/Quincy1994/MachineLearning

0 0
原创粉丝点击