Spark MLLib 梯度提升树

来源:互联网 发布:维氏刀具在淘宝上买 编辑:程序博客网 时间:2024/05/23 22:07

梯度提升树是决策树的群集。GBTs为了使损失函数的值最小化,迭代式的训练决策树。和决策树一样,GBTs可以处理离散特征,也可以扩展设置为多级分类,而不需要进行特征值缩放,所以有能力处理非线性以及特征交互。

spark.mllib支持二分类GBTs以及回归GBTs,可以使用连续或者离散型特征。Spark.mllib在已经存在的决策树的实现的基础上实现的GBTs。如果想要对决策树了解更多的话,请参考决策树的文档。

注意:(Spark1.6.0的)GBTs目前尚不支持多级分类。对于多级分类问题,请使用决策树或者随机森林。

基础算法

梯度提升数迭代式的训练一系列决策树。在每轮迭代中,算法使用当前群集对于每个训练实例的预测结果与真实的结果进行比对。数据集被重新标注,使得加强对于错误预测了的实例进行训练。这样,在下一轮迭代中,决策树将会帮助修复之前的错误。

重新标注实例时使用的损失函数(下面有介绍)。在每一轮迭代中,GBTs将会降低训练数据的损失函数的值。

损失

下面的表格列举了spark.mllib当前支持的损失函数。注意每一个损失函数只支持分类或者回归中的一种,不是全部。

注意:N=实例个数,yi=实例的标签,xi=i实例的特征。F(xi)=用于预测实例i的模型。


使用说明

下面是一些使用GBTs的不同参数的说明。我们省略了一些决策树参数因为它们在决策树帮助文档里包含了。

损失:上文说明了一些损失函数以及它们的适用场景(分类、回归)。不同的损失函数会给出截然不同的结果,取决于数据集。

迭代次数:这个设置群集中树的数目。每一轮迭代生成一棵树。提升这个参数的值可以使得模型的准确度更高。然而,测试时间也许会受到影响,如果它太大的话。

学习率:这个参数不需要被调整。如果算法的效果看起来不稳定的话,降低这个值也许会提高稳定性。

算法:任务算法(分类、回归)可以在树参数里被设置

训练时验证

梯度提升也许会过拟合,如果训练了过多的树的话。为了防止过拟合,最好的办法就是训练时验证。runWithValidation方法可以提供帮助。它将一对RDD作为参数,第一个是测试数据,而第二个是验证数据。

训练停止当验证错误的改善没有高于指定阈值(BoostingStrategy的validationTol参数)。在实践中,验证错误先降低后升高。在实践中,验证错误也许是单调的,那么用户被推荐设置更大的负公差,同时使用evaluateEachIteration方法来检测验证曲线来调整迭代次数。


使用示例:

1.分类

import java.util.HashMap;import java.util.Map;import scala.Tuple2;import org.apache.spark.SparkConf;import org.apache.spark.api.java.JavaPairRDD;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.api.java.function.Function;import org.apache.spark.api.java.function.PairFunction;import org.apache.spark.mllib.regression.LabeledPoint;import org.apache.spark.mllib.tree.GradientBoostedTrees;import org.apache.spark.mllib.tree.configuration.BoostingStrategy;import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;import org.apache.spark.mllib.util.MLUtils;SparkConf sparkConf = new SparkConf()  .setAppName("JavaGradientBoostedTreesClassificationExample");JavaSparkContext jsc = new JavaSparkContext(sparkConf);// Load and parse the data file.String datapath = "data/mllib/sample_libsvm_data.txt";JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();// Split the data into training and test sets (30% held out for testing)JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});JavaRDD<LabeledPoint> trainingData = splits[0];JavaRDD<LabeledPoint> testData = splits[1];// Train a GradientBoostedTrees model.// The defaultParams for Classification use LogLoss by default.BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification");boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice.boostingStrategy.getTreeStrategy().setNumClasses(2);boostingStrategy.getTreeStrategy().setMaxDepth(5);// Empty categoricalFeaturesInfo indicates all features are continuous.Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo);final GradientBoostedTreesModel model =  GradientBoostedTrees.train(trainingData, boostingStrategy);// Evaluate model on test instances and compute test errorJavaPairRDD<Double, Double> predictionAndLabel =  testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {    @Override    public Tuple2<Double, Double> call(LabeledPoint p) {      return new Tuple2<Double, Double>(model.predict(p.features()), p.label());    }  });Double testErr =  1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {    @Override    public Boolean call(Tuple2<Double, Double> pl) {      return !pl._1().equals(pl._2());    }  }).count() / testData.count();System.out.println("Test Error: " + testErr);System.out.println("Learned classification GBT model:\n" + model.toDebugString());// Save and load modelmodel.save(jsc.sc(), "target/tmp/myGradientBoostingClassificationModel");GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(),  "target/tmp/myGradientBoostingClassificationModel");

2.回归

import java.util.HashMap;import java.util.Map;import scala.Tuple2;import org.apache.spark.SparkConf;import org.apache.spark.api.java.function.Function2;import org.apache.spark.api.java.JavaPairRDD;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.api.java.function.Function;import org.apache.spark.api.java.function.PairFunction;import org.apache.spark.mllib.regression.LabeledPoint;import org.apache.spark.mllib.tree.GradientBoostedTrees;import org.apache.spark.mllib.tree.configuration.BoostingStrategy;import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;import org.apache.spark.mllib.util.MLUtils;SparkConf sparkConf = new SparkConf()  .setAppName("JavaGradientBoostedTreesRegressionExample");JavaSparkContext jsc = new JavaSparkContext(sparkConf);// Load and parse the data file.String datapath = "data/mllib/sample_libsvm_data.txt";JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();// Split the data into training and test sets (30% held out for testing)JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});JavaRDD<LabeledPoint> trainingData = splits[0];JavaRDD<LabeledPoint> testData = splits[1];// Train a GradientBoostedTrees model.// The defaultParams for Regression use SquaredError by default.BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression");boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice.boostingStrategy.getTreeStrategy().setMaxDepth(5);// Empty categoricalFeaturesInfo indicates all features are continuous.Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo);final GradientBoostedTreesModel model =  GradientBoostedTrees.train(trainingData, boostingStrategy);// Evaluate model on test instances and compute test errorJavaPairRDD<Double, Double> predictionAndLabel =  testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {    @Override    public Tuple2<Double, Double> call(LabeledPoint p) {      return new Tuple2<Double, Double>(model.predict(p.features()), p.label());    }  });Double testMSE =  predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {    @Override    public Double call(Tuple2<Double, Double> pl) {      Double diff = pl._1() - pl._2();      return diff * diff;    }  }).reduce(new Function2<Double, Double, Double>() {    @Override    public Double call(Double a, Double b) {      return a + b;    }  }) / data.count();System.out.println("Test Mean Squared Error: " + testMSE);System.out.println("Learned regression GBT model:\n" + model.toDebugString());// Save and load modelmodel.save(jsc.sc(), "target/tmp/myGradientBoostingRegressionModel");GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(),  "target/tmp/myGradientBoostingRegressionModel");


0 0
原创粉丝点击