spark之MLlib机器学习-线性回归

来源:互联网 发布:php extension 编辑:程序博客网 时间:2024/05/17 04:13

此篇博文根据《Spark MLlib机器学习》实例程序编写,可作为熟悉scala和mllib编写机器学习算法的一种实践。
1、准备测试数据
可从作者博客自行下载。代码及数据下载地址
2、编写scala源码
为了进一步熟悉scala编程语言,建议自己把代码敲一次。

//import org.apache.log4j{ Level, Logger }import org.apache.spark.{SparkConf,SparkContext}import org.apache.spark.mllib.regression.LinearRegressionWithSGDimport org.apache.spark.mllib.util.MLUtilsimport org.apache.spark.mllib.regression.LabeledPointimport org.apache.spark.mllib.linalg.Vectorsimport org.apache.spark.mllib.regression.LinearRegressionModelobject LinearRegression{  def main(args:Array[String]){    val conf = new SparkConf().setAppName("LinearRegressionWithSGD")    val sc =new SparkContext(conf)  //  Logger.getRootLogger.setLevel(Level.WARN)    val data_path1="file:///usr/spark2.0/data/mllib/mydata/lpsa.data"    val data=sc.textFile(data_path1)    val examples=data.map{line=>     val parts=line.split(',')     LabeledPoint(parts(0).toDouble,Vectors.dense(parts(1).split(' ').map(_.toDouble)))    }.cache()    val numExamples=examples.count()    val numIterations=100    val stepSize=1    val miniBatchFraction=1.0    val model=LinearRegressionWithSGD.train(examples,numIterations,stepSize,miniBatchFraction)    val prediction=model.predict(examples.map(_.features))    val predictionAndLabel=prediction.zip(examples.map(_.label))    val print_predict=predictionAndLabel.take(50)    println("prediction"+"\t"+"label")    for (i <- 0 to print_predict.length-1 ){       println(print_predict(i)._1 + "\t" + print_predict(i)._2)    }    val loss =predictionAndLabel.map{        case(p,_)=>            val err = p - 1            err*err    }.reduce(_+_)    val rmse=math.sqrt(loss/numExamples)    println(s"Test RMSE = $rmse.")  }

3、使用sbt工具编译和打包

4、结果输出:
这里写图片描述

可以看出,线性拟合的结果并不理想,说明模型选择的不合理。本示例仅仅为了说明线性回归api的用法。

0 0
原创粉丝点击