Spark MLLib 梯度提升树
来源:互联网 发布:维氏刀具在淘宝上买 编辑:程序博客网 时间:2024/05/23 22:07
梯度提升树是决策树的群集。GBTs为了使损失函数的值最小化,迭代式的训练决策树。和决策树一样,GBTs可以处理离散特征,也可以扩展设置为多级分类,而不需要进行特征值缩放,所以有能力处理非线性以及特征交互。
spark.mllib支持二分类GBTs以及回归GBTs,可以使用连续或者离散型特征。Spark.mllib在已经存在的决策树的实现的基础上实现的GBTs。如果想要对决策树了解更多的话,请参考决策树的文档。
注意:(Spark1.6.0的)GBTs目前尚不支持多级分类。对于多级分类问题,请使用决策树或者随机森林。
基础算法
梯度提升数迭代式的训练一系列决策树。在每轮迭代中,算法使用当前群集对于每个训练实例的预测结果与真实的结果进行比对。数据集被重新标注,使得加强对于错误预测了的实例进行训练。这样,在下一轮迭代中,决策树将会帮助修复之前的错误。
重新标注实例时使用的损失函数(下面有介绍)。在每一轮迭代中,GBTs将会降低训练数据的损失函数的值。
损失
下面的表格列举了spark.mllib当前支持的损失函数。注意每一个损失函数只支持分类或者回归中的一种,不是全部。
注意:N=实例个数,yi=实例的标签,xi=i实例的特征。F(xi)=用于预测实例i的模型。
使用说明
下面是一些使用GBTs的不同参数的说明。我们省略了一些决策树参数因为它们在决策树帮助文档里包含了。
损失:上文说明了一些损失函数以及它们的适用场景(分类、回归)。不同的损失函数会给出截然不同的结果,取决于数据集。
迭代次数:这个设置群集中树的数目。每一轮迭代生成一棵树。提升这个参数的值可以使得模型的准确度更高。然而,测试时间也许会受到影响,如果它太大的话。
学习率:这个参数不需要被调整。如果算法的效果看起来不稳定的话,降低这个值也许会提高稳定性。
算法:任务算法(分类、回归)可以在树参数里被设置
训练时验证
梯度提升也许会过拟合,如果训练了过多的树的话。为了防止过拟合,最好的办法就是训练时验证。runWithValidation方法可以提供帮助。它将一对RDD作为参数,第一个是测试数据,而第二个是验证数据。
训练停止当验证错误的改善没有高于指定阈值(BoostingStrategy的validationTol参数)。在实践中,验证错误先降低后升高。在实践中,验证错误也许是单调的,那么用户被推荐设置更大的负公差,同时使用evaluateEachIteration方法来检测验证曲线来调整迭代次数。
使用示例:
1.分类
import java.util.HashMap;import java.util.Map;import scala.Tuple2;import org.apache.spark.SparkConf;import org.apache.spark.api.java.JavaPairRDD;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.api.java.function.Function;import org.apache.spark.api.java.function.PairFunction;import org.apache.spark.mllib.regression.LabeledPoint;import org.apache.spark.mllib.tree.GradientBoostedTrees;import org.apache.spark.mllib.tree.configuration.BoostingStrategy;import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;import org.apache.spark.mllib.util.MLUtils;SparkConf sparkConf = new SparkConf() .setAppName("JavaGradientBoostedTreesClassificationExample");JavaSparkContext jsc = new JavaSparkContext(sparkConf);// Load and parse the data file.String datapath = "data/mllib/sample_libsvm_data.txt";JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();// Split the data into training and test sets (30% held out for testing)JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});JavaRDD<LabeledPoint> trainingData = splits[0];JavaRDD<LabeledPoint> testData = splits[1];// Train a GradientBoostedTrees model.// The defaultParams for Classification use LogLoss by default.BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification");boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice.boostingStrategy.getTreeStrategy().setNumClasses(2);boostingStrategy.getTreeStrategy().setMaxDepth(5);// Empty categoricalFeaturesInfo indicates all features are continuous.Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo);final GradientBoostedTreesModel model = GradientBoostedTrees.train(trainingData, boostingStrategy);// Evaluate model on test instances and compute test errorJavaPairRDD<Double, Double> predictionAndLabel = testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { @Override public Tuple2<Double, Double> call(LabeledPoint p) { return new Tuple2<Double, Double>(model.predict(p.features()), p.label()); } });Double testErr = 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() { @Override public Boolean call(Tuple2<Double, Double> pl) { return !pl._1().equals(pl._2()); } }).count() / testData.count();System.out.println("Test Error: " + testErr);System.out.println("Learned classification GBT model:\n" + model.toDebugString());// Save and load modelmodel.save(jsc.sc(), "target/tmp/myGradientBoostingClassificationModel");GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(), "target/tmp/myGradientBoostingClassificationModel");
2.回归
import java.util.HashMap;import java.util.Map;import scala.Tuple2;import org.apache.spark.SparkConf;import org.apache.spark.api.java.function.Function2;import org.apache.spark.api.java.JavaPairRDD;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.api.java.function.Function;import org.apache.spark.api.java.function.PairFunction;import org.apache.spark.mllib.regression.LabeledPoint;import org.apache.spark.mllib.tree.GradientBoostedTrees;import org.apache.spark.mllib.tree.configuration.BoostingStrategy;import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;import org.apache.spark.mllib.util.MLUtils;SparkConf sparkConf = new SparkConf() .setAppName("JavaGradientBoostedTreesRegressionExample");JavaSparkContext jsc = new JavaSparkContext(sparkConf);// Load and parse the data file.String datapath = "data/mllib/sample_libsvm_data.txt";JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();// Split the data into training and test sets (30% held out for testing)JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});JavaRDD<LabeledPoint> trainingData = splits[0];JavaRDD<LabeledPoint> testData = splits[1];// Train a GradientBoostedTrees model.// The defaultParams for Regression use SquaredError by default.BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression");boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice.boostingStrategy.getTreeStrategy().setMaxDepth(5);// Empty categoricalFeaturesInfo indicates all features are continuous.Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo);final GradientBoostedTreesModel model = GradientBoostedTrees.train(trainingData, boostingStrategy);// Evaluate model on test instances and compute test errorJavaPairRDD<Double, Double> predictionAndLabel = testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { @Override public Tuple2<Double, Double> call(LabeledPoint p) { return new Tuple2<Double, Double>(model.predict(p.features()), p.label()); } });Double testMSE = predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() { @Override public Double call(Tuple2<Double, Double> pl) { Double diff = pl._1() - pl._2(); return diff * diff; } }).reduce(new Function2<Double, Double, Double>() { @Override public Double call(Double a, Double b) { return a + b; } }) / data.count();System.out.println("Test Mean Squared Error: " + testMSE);System.out.println("Learned regression GBT model:\n" + model.toDebugString());// Save and load modelmodel.save(jsc.sc(), "target/tmp/myGradientBoostingRegressionModel");GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(), "target/tmp/myGradientBoostingRegressionModel");
- Spark MLLib 梯度提升树
- Spark中组件Mllib的学习39之梯度提升树(GBT)用于分类*
- Spark中组件Mllib的学习40之梯度提升树(GBT)用于回归*
- mllib之随机森林与梯度提升树
- Spark MLlib
- spark MLlib
- Spark MLLib
- Spark MLlib
- 梯度提升树GBDT原理
- 梯度提升树GBDT原理
- 梯度提升树(GBDT)原理
- 梯度提升树GBDT原理
- 梯度提升树GBDT原理
- 梯度树提升算法GBRT
- 梯度提升树GBDT原理
- 梯度提升树(GBDT)原理
- Spark中组件Mllib的学习23之随机梯度下降(SGD)
- 梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)
- Git安装、创建版本库、版本回退
- 【IDE-Visual Studio】关于exe的版本中“文件版本”和其他版本信息中的“文件版本”、以及“产品版本”
- maven "Generating project in Batch mode"问题
- idea maven mvn archetype:generate 速度缓慢问题
- 利用Gulp优化部署Web项目
- Spark MLLib 梯度提升树
- 邮件发送出现错误:535 Authentication failed
- java list 取随机
- 【BZOJ 1010】【HNOI2008】玩具装箱toy 【斜率优化】
- Java网络爬虫crawler4j学习笔记<18> Configurable类
- CSS 基础篇、绝对有你想要
- Realm-Android
- Python三元表达式
- Java基础入门——写在前面