利用spark的mllib构建GBDT模型
来源:互联网 发布:多益网络登录器 编辑:程序博客网 时间:2024/06/04 18:56
GBDT模型
GBDT模型的介绍,我主要是参考博客:http://blog.csdn.net/w28971023/article/details/8240756
在这里,我主要归纳以下几点要素:
1.GBDT中的树都是回归树;
2.回归树节点分割点衡量最好的标准是叶子个数的上限;
3.GBDT的核心在于,每个棵树学的是之前所有树结论和的残差,这个残差就是一个加预测值后能得到真实值的累加量;
4.GB为Gradient Boosting, Boosting的最大好处在于,每一步的残差计算其实变相地增大了分错instance的权重,而已经分对的instance则趋向于0;
5.GBDT采用一个Shrinkage策略,本质上,Shrinkage为每棵树设置了一个weight,累加时要乘以这个weight,但和Gradient并没有关系。
利用spark构建GBDT模型
训练GBDT模型
public void trainModel(){ //初始化spark SparkConf conf = new SparkConf().setAppName("GBDT").setMaster("local"); conf.set("spark.testing.memory","2147480000"); SparkContext sc = new SparkContext(conf); //加载训练文件, 使用MLUtils包 JavaRDD<LabeledPoint> lpdata = MLUtils.loadLibSVMFile(sc, this.trainsetFile).toJavaRDD(); //训练模型, 默认情况下使用均值方差作为阈值标准 int numIteration = 10; //boosting提升迭代的次数 int maxDepth = 3; //回归树的最大深度 BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression"); boostingStrategy.setNumIterations(numIteration); boostingStrategy.getTreeStrategy().setMaxDepth(maxDepth); //记录所有特征的连续结果 Map<Integer, Integer> categoricalFeaturesInfoMap = new HashMap<Integer, Integer>(); boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfoMap); //gdbt模型 final GradientBoostedTreesModel model = GradientBoostedTrees.train(lpdata, boostingStrategy); model.save(sc, modelpath); sc.stop(); }
预测数据
public void predict() { //初始化spark SparkConf conf = new SparkConf().setAppName("GBDT").setMaster("local"); conf.set("spark.testing.memory","2147480000"); SparkContext sc = new SparkContext(conf); //加载gbdt模型 final GradientBoostedTreesModel model = GradientBoostedTreesModel.load(sc, this.modelpath); //加载测试文件 JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, this.predictFile).toJavaRDD(); testData.cache(); //预测数据 JavaRDD<Tuple2<Double, Double>> predictionAndLabel = testData.map(new Prediction(model)) ; //计算所有数据的平均值方差 Double testMSE = predictionAndLabel.map(new CountSquareError()).reduce(new ReduceSquareError()) / testData.count(); System.out.println("testData's MSE is : " + testMSE); sc.stop(); } static class Prediction implements Function<LabeledPoint, Tuple2<Double , Double>> { GradientBoostedTreesModel model; public Prediction(GradientBoostedTreesModel model){ this.model = model; } public Tuple2<Double, Double> call(LabeledPoint p) throws Exception { Double score = model.predict(p.features()); return new Tuple2<Double , Double>(score, p.label()); } } static class CountSquareError implements Function<Tuple2<Double, Double>, Double> { public Double call (Tuple2<Double, Double> pl) { double diff = pl._1() - pl._2(); return diff * diff; } } static class ReduceSquareError implements Function2<Double, Double, Double> { public Double call(Double a , Double b){ return a + b ; } }
关于具体的代码放至我的github上:https://github.com/Quincy1994/MachineLearning
0 0
- 利用spark的mllib构建GBDT模型
- GBDT&Spark mllib
- Spark Mllib构建简单的电影推荐系统(转)
- Spark MLlib LDA主题模型
- 基于spark mllib的LDA模型训练Scala代码实现
- 基于spark mllib的LDA模型训练源码解析
- 分享Spark MLlib训练的广告点击率预测模型
- spark mllib源码分析之DecisionTree与GBDT
- spark下线性模型 spark.mllib
- <转>spark下线性模型 spark.mllib
- spark mllib 之 Pipeline工作流构建
- Spark MLlib之分类模型源码分析
- spark利用MLlib实现kmeans算法实例
- 利用Spark mllib识别点阵文本
- 如何利用Spark MLlib进行个性推荐?
- spark mllib的优缺点分析
- spark mllib 的数据预处理
- 转:利用GBDT模型构造新特征
- 设计模式入门——工厂模式
- lib 简单实现
- ros中编译moveit_tutorials出错
- 【Python学习笔记】-json模块
- Android Multimedia框架总结(八)Stagefright框架之AwesomePlayer及数据解析器
- 利用spark的mllib构建GBDT模型
- RabbitMQ实战基础
- HTML与CSS的链接方式
- [题解]April Cook-Off 2017
- 2017百度实习生春招java笔试题 输出第三便宜价格
- ERROR 1146 (42S02): Table 'eip_fileservice.t_document_file' doesn't exist
- Android Multimedia框架总结(九)Stagefright框架之数据处理及到OMXCodec过程
- 为何与0xff进行与运算
- 开发中遇到Js缓存问题。和页面会话级别的缓存。