xgboost之spark上运行-scala接口
来源:互联网 发布:国动网络董事长卢杰 编辑:程序博客网 时间:2024/06/07 11:15
概述
xgboost可以在spark上运行,我用的xgboost的版本是0.7的版本,目前只支持spark2.0以上版本上运行,
编译好jar包,加载到maven仓库里面去:
mvn install:install-file -Dfile=xgboost4j-spark-0.7-jar-with-dependencies.jar -DgroupId=ml.dmlc -DartifactId=xgboost4j-spark -Dversion=0.7 -Dpackaging=jar
添加依赖:
<dependency><groupId>ml.dmlc</groupId><artifactId>xgboost4j-spark</artifactId><version>0.7</version></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-core_2.10</artifactId><version>2.0.0</version></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-mllib_2.10</artifactId><version>2.0.0</version></dependency></dependencies>
RDD接口:
package com.meituan.spark_xgboostimport org.apache.log4j.{ Level, Logger }import org.apache.spark.{ SparkConf, SparkContext }import ml.dmlc.xgboost4j.scala.spark.XGBoostimport org.apache.spark.sql.{ SparkSession, Row }import org.apache.spark.mllib.util.MLUtilsimport org.apache.spark.ml.feature.LabeledPointimport org.apache.spark.ml.linalg.Vectorsobject XgboostR { def main(args: Array[String]): Unit = { Logger.getLogger("org.apache.spark").setLevel(Level.ERROR) Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF) val spark = SparkSession.builder.master("local").appName("example"). config("spark.sql.warehouse.dir", s"file:///Users/shuubiasahi/Documents/spark-warehouse"). config("spark.sql.shuffle.partitions", "20").getOrCreate() spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") val path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/" val trainString = "agaricus.txt.train" val testString = "agaricus.txt.test" val train = MLUtils.loadLibSVMFile(spark.sparkContext, path + trainString) val test = MLUtils.loadLibSVMFile(spark.sparkContext, path + testString) val traindata = train.map { x => val f = x.features.toArray val v = x.label LabeledPoint(v, Vectors.dense(f)) } val testdata = test.map { x => val f = x.features.toArray val v = x.label Vectors.dense(f) } val numRound = 15 //"objective" -> "reg:linear", //定义学习任务及相应的学习目标 //"eval_metric" -> "rmse", //校验数据所需要的评价指标 用于做回归 val paramMap = List( "eta" -> 1f, "max_depth" ->5, //数的最大深度。缺省值为6 ,取值范围为:[1,∞] "silent" -> 1, //取0时表示打印出运行时信息,取1时表示以缄默方式运行,不打印运行时信息。缺省值为0 "objective" -> "binary:logistic", //定义学习任务及相应的学习目标 "lambda"->2.5, "nthread" -> 1 //XGBoost运行时的线程数。缺省值是当前系统可以获得的最大线程数 ).toMap println(paramMap) val model = XGBoost.trainWithRDD(traindata, paramMap, numRound, 55, null, null, useExternalMemory = false, Float.NaN) print("sucess") val result=model.predict(testdata) result.take(10).foreach(println) spark.stop(); }}
DataFrame接口:
package com.meituan.spark_xgboostimport org.apache.log4j.{ Level, Logger }import org.apache.spark.{ SparkConf, SparkContext }import ml.dmlc.xgboost4j.scala.spark.XGBoostimport org.apache.spark.mllib.evaluation.BinaryClassificationMetricsimport org.apache.spark.sql.{ SparkSession, Row }object XgboostD { def main(args: Array[String]): Unit = { Logger.getLogger("org.apache.spark").setLevel(Level.ERROR) Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF) val spark = SparkSession.builder.master("local").appName("example"). config("spark.sql.warehouse.dir", s"file:///Users/shuubiasahi/Documents/spark-warehouse"). config("spark.sql.shuffle.partitions", "20").getOrCreate() spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") val path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/" val trainString = "agaricus.txt.train" val testString = "agaricus.txt.test" val train = spark.read.format("libsvm").load(path + trainString).toDF("label", "feature") val test = spark.read.format("libsvm").load(path + testString).toDF("label", "feature") val numRound = 15 //"objective" -> "reg:linear", //定义学习任务及相应的学习目标 //"eval_metric" -> "rmse", //校验数据所需要的评价指标 用于做回归 val paramMap = List( "eta" -> 1f, "max_depth" -> 5, //数的最大深度。缺省值为6 ,取值范围为:[1,∞] "silent" -> 1, //取0时表示打印出运行时信息,取1时表示以缄默方式运行,不打印运行时信息。缺省值为0 "objective" -> "binary:logistic", //定义学习任务及相应的学习目标 "lambda" -> 2.5, "nthread" -> 1 //XGBoost运行时的线程数。缺省值是当前系统可以获得的最大线程数 ).toMap val model = XGBoost.trainWithDataFrame(train, paramMap, numRound, 45, obj = null, eval = null, useExternalMemory = false, Float.NaN, "feature", "label") val predict = model.transform(test) val scoreAndLabels = predict.select(model.getPredictionCol, model.getLabelCol) .rdd .map { case Row(score: Double, label: Double) => (score, label) } //get the auc val metric = new BinaryClassificationMetrics(scoreAndLabels) val auc = metric.areaUnderROC() println("auc:" + auc) }}
1 0
- xgboost之spark上运行-scala接口
- scala-spark版本xgboost包使用
- 在spark上运行scala代码,出现ClassNotFoundException
- 在spark上运行scala代码,出现ClassNotFoundException
- Scala-IDE Eclipse(Windows)中开发Spark应用程序,在Ubuntu Spark集群上运行
- Spark 之Scala
- Spark学习之在集群上运行Spark(6)
- 在linux上,用scalac编译在Spark平台上运行的scala程序
- scala编写的Spark程序远程提交到服务器集群上运行
- idea+maven+scala创建wordcount,打包jar并在spark on yarn上运行
- idea+maven+scala创建wordcount,打包jar并在spark on yarn上运行
- IDEA上Spark——Java、Scala的本地测试版与集群运行版
- spark官方文档之——Running Spark on YARN YARN上运行SPARK
- Scala + Spark +Maven之Helloworld
- 9-spark之Scala语言
- eclipse集成Scala,运行Spark项目
- spark上的scala学习笔记
- spark部署:在YARN上运行Spark
- 复写纸
- 一个对话框父子关系的帖子
- 【LeetCode】triangle & pascals-triangle i&ii
- 实验六项目三
- 安卓开发 简单实现自定义横向滚动选择View : HorizontalselectedView
- xgboost之spark上运行-scala接口
- 专访精致女人导师——兆衡老师,精致,就是活的够优雅!
- 限制QLineEdit的数值输入范围
- Spring+MyBatis 企业应用实战读书笔记之三 Spring MVC的常用注解(一)
- 正则表达式匹配以某字符串开始和结尾的字符串
- Python之isinstance
- fgets与stream_get_line获取文件行数效率比较
- 576. Out of Boundary Paths(Hard)
- java中Map,List与Set的区别 (下) -----better