Spark 机器学习逻辑回归demo
来源:互联网 发布:手机二次曝光软件 编辑:程序博客网 时间:2024/06/05 10:55
这里整理记录一下Spark ML学习的小示例,本人运行实例都是在spark-shell下,详细教程请参考官网地址:
http://spark.apache.org/docs/latest/ml-pipeline.html
Estimator, Transformer, 和 Param使用代码实例:
import org.apache.spark.ml.classification.LogisticRegressionimport org.apache.spark.ml.linalg.{Vector, Vectors}import org.apache.spark.sql.SparkSessionimport org.apache.spark.ml.param.ParamMapimport org.apache.spark.sql.Rowimport spark.implicits._//创建spark对象val spark = SparkSession.builder().appName("Spark SQL basic example").config("spark.some.config.option", "some-value").getOrCreate()//准备训练集val training = spark.createDataFrame(Seq( (1.0, Vectors.dense(0.0, 1.1, 0.1)), (0.0, Vectors.dense(2.0, 1.0, -1.0)), (0.0, Vectors.dense(2.0, 1.3, 1.0)), (1.0, Vectors.dense(0.0, 1.2, -0.5)))).toDF("label", "features")//准备测试集val test = spark.createDataFrame(Seq( (1.0, Vectors.dense(-1.0, 1.5, 1.3)), (0.0, Vectors.dense(3.0, 2.0, -0.1)), (1.0, Vectors.dense(0.0, 2.2, -1.5)))).toDF("label", "features")//创建逻辑回归算法实例,并查看、设置相应参数val lr = new LogisticRegression()println("LogisticRegression parameters:\n" + lr.explainParams() + "\n")lr.setMaxIter(10).setRegParam(0.01)//训练学习得到model1,查看model1的参数val model1 = lr.fit(training)println("Model 1 was fit using parameters: " + model1.parent.extractParamMap)//用paraMap来设置参数集val paramMap = ParamMap(lr.maxIter -> 20).put(lr.maxIter, 30) .put(lr.regParam -> 0.1, lr.threshold -> 0.55)//可以将两个paraMap结合起来val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") val paramMapCombined = paramMap ++ paramMap2//使用结合的paraMap训练学习得到model2val model2 = lr.fit(training, paramMapCombined)println("Model 2 was fit using parameters: " + model2.parent.extractParamMap)//使用测试集测试model2model2.transform(test).select("features", "label", "myProbability", "prediction").collect().foreach { case Row(features: Vector, label: Double, prob:Vector, prediction: Double) =>println(s"($features, $label) -> prob=$prob,prediction=$prediction")}
Pipeline 代码实例:
import org.apache.spark.ml.classification.LogisticRegressionimport org.apache.spark.ml.feature.{HashingTF, Tokenizer}import org.apache.spark.ml.{Pipeline, PipelineModel}import org.apache.spark.ml.linalg.Vectorimport org.apache.spark.sql.SparkSessionimport org.apache.spark.sql.Rowimport spark.implicits._//创建spark对象val spark = SparkSession.builder().appName("Spark SQL basic example").config("spark.some.config.option", "some-value").getOrCreate()//准备训练集val training = spark.createDataFrame(Seq( (0L, "a b c d e spark", 1.0), (1L, "b d", 0.0), (2L, "spark f g h", 1.0), (3L, "hadoop mapreduce", 0.0))).toDF("id", "text", "label")//准备测试集val test = spark.createDataFrame(Seq( (4L, "spark i j k"), (5L, "l m n"), (6L, "spark hadoop spark"), (7L, "apache hadoop"))).toDF("id", "text")//配置ML pipeline,由,tokenzier(分词器)、hashingTF和lr(逻辑回归)三个stage组成val tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")val hashingTF = new HashingTF().setNumFeatures(1000).setInputCol(tokenizer.getOutputCol).setOutputCol("features")val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.001)val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, lr))//训练Pipline得到model,即一个transformer(转换器)val model = pipeline.fit(training)//保存模型model.write.overwrite().save("/tmp/spark-logistic-regression-model")//保存pipeline结构pipeline.write.overwrite().save("/tmp/unfit-lr-model")//需要使用的时候加载模型val sameModel = PipelineModel.load("/tmp/spark-logistic-regression-model")//使用测试集对模型进行测试model.transform(test).select("id", "text", "probability", "prediction").collect().foreach { case Row(id: Long, text: String, prob: Vector,prediction: Double) =>println(s"($id, $text) --> prob=$prob,prediction=$prediction")}
阅读全文
0 0
- Spark 机器学习逻辑回归demo
- spark学习逻辑回归
- 机器学习-逻辑回归
- 机器学习:逻辑回归
- 机器学习---逻辑回归
- 【机器学习】逻辑回归
- 机器学习----逻辑回归
- 机器学习 逻辑回归
- 机器学习:逻辑回归
- 机器学习--逻辑回归
- 机器学习-逻辑回归
- 机器学习-- 逻辑回归
- 机器学习-逻辑回归
- 机器学习-逻辑回归
- 【机器学习】--逻辑回归
- Spark ML机器学习算法svm,als,线性回归,逻辑回归简单试验
- 【机器学习笔记】逻辑回归
- 机器学习之逻辑回归
- number方法
- ActiveMQ笔记——技术点汇总
- 使用Webpack ES6转ES5 实现模块化(import export)
- Git日常工作流程及常用命令
- TreeMap源码分析
- Spark 机器学习逻辑回归demo
- SpringCloud篇之服务的注册与发现
- Qt中mouseMoveEvent无效
- python类的绑定方法与非绑定方法
- nodejs 1 nodejs 简介
- vue后台项目
- OpenJDK 64-Bit Server VM warning: INFO: os::commit_memory(0x0000000083e80000, 1366294528, 0) failed;
- 0-day(网站建设-部署与发布)
- tomcat7 性能优化