SparkMLlib---LinearRegression(线性回归)、LogisticRegression(逻辑回归)

来源:互联网 发布:网络安全员日常工作 编辑:程序博客网 时间:2024/06/06 01:56

1、随机梯度下降

首先介绍一下随机梯度下降算法:

1.1、代码一:

package mllibimport org.apache.log4j.{Level, Logger}import org.apache.spark.{SparkContext, SparkConf}import scala.collection.mutable.HashMap/**  * 随机梯度下降算法  * Created by 汪本成 on 2016/8/7.  */object SGD {  //屏蔽不必要的日志显示在终端上  Logger.getLogger("org.apache.spark").setLevel(Level.WARN)  Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)  //程序入口  val conf = new SparkConf()    .setMaster("local[1]")    .setAppName(this.getClass().getSimpleName()    .filter(!_.equals('$')))  println(this.getClass().getSimpleName().filter(!_.equals('$')))  val sc = new SparkContext(conf)  //创建存储数据集HashMap集合  val data = new HashMap[Int, Int]()  //生成数据集内容  def getData(): HashMap[Int, Int] = {    for(i <- 1 to 50) {      data += (i -> (2 * i))  //写入公式y=2x    }    data  }  //假设a=0  var a: Double = 0  //设置步进系数  var b: Double = 0.1  //设置迭代公式  def sgd(x: Double, y: Double) = {    a = a - b * ((a * x) - y)  }  def main(args: Array[String]) {    //获取数据集    val dataSource = getData()    println("data: ")    dataSource.foreach(each => println(each + " "))    println("\nresult: ")    var num = 1    //开始迭代    dataSource.foreach(myMap => {      println(num + ":" + a + "("+myMap._1+","+myMap._2+")")      sgd(myMap._1, myMap._2)      num = num + 1    })    //显示结果    println("最终结果a " + a)  }}
结果请大家自己验证。

2、线性回归

2.1、数据

首先是做下小数据集的实验,测试的公式在代码中有说明,实验数据如下:

5,1 17,2 110,2 29,3 211,4 119,5 318,6 2

2.2、代码二:

package mllibimport org.apache.log4j.{Level, Logger}import org.apache.spark.mllib.linalg.Vectorsimport org.apache.spark.mllib.regression.{LinearRegressionWithSGD, LabeledPoint}import org.apache.spark.{SparkContext, SparkConf}/**  * 线性回归1-小数据集  * 公式:f(x) = ax1 + bx2  * Created by 汪本成 on 2016/8/6.  */object LinearRegression1 {  //屏蔽不必要的日志显示在终端上  Logger.getLogger("org.apache.spark").setLevel(Level.WARN)  Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)  //程序入口  val conf = new SparkConf()    .setMaster("local[1]")    .setAppName(this.getClass().getSimpleName().filter(!_.equals('$')))  println(this.getClass().getSimpleName().filter(!_.equals('$')))  val sc = new SparkContext(conf)  def main(args: Array[String]) {    //获取数据集路径    val data = sc.textFile("G:\\MLlibData\\lpsa2.txt")    //处理数据集    val parsedData = data.map { line =>      val parts = line.split(',')      //转化数据格式      LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))    }.cache()    //建立模型    val numiteartor = 100    val stepSize = 0.1    val model = LinearRegressionWithSGD.train(parsedData, numiteartor, stepSize)    //通过模型预测模型    val result = model.predict(Vectors.dense(2, 1))    println("model weights:")    //计算两个系数,并以向量形式保存    println(model.weights)    println(result)    sc.stop()  }}

3、回归曲线

回归曲线这块我们不仅预测结果和真实结果,还要计算回归曲线的MSE。

3.1、数据

-0.4307829,-1.63735562648104 -2.00621178480549 -1.86242597251066 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306-0.1625189,-1.98898046126935 -0.722008756122123 -0.787896192088153 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306-0.1625189,-1.57881887548545 -2.1887840293994 1.36116336875686 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.155348103855541-0.1625189,-2.16691708463163 -0.807993896938655 -0.787896192088153 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373060.3715636,-0.507874475300631 -0.458834049396776 -0.250631301876899 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373060.7654678,-2.03612849966376 -0.933954647105133 -1.86242597251066 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373060.8544153,-0.557312518810673 -0.208756571683607 -0.787896192088153 0.990146852537193 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373061.2669476,-0.929360463147704 -0.0578991819441687 0.152317365781542 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373061.2669476,-2.28833047634983 -0.0706369432557794 -0.116315079324086 0.80409888772376 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373061.2669476,0.223498042876113 -1.41471935455355 -0.116315079324086 -1.02470580167082 -0.522940888712441 -0.29928234305568 0.342627053981254 0.1992110978853411.3480731,0.107785900236813 -1.47221551299731 0.420949810887169 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.6871869064668651.446919,0.162180092313795 -1.32557369901905 0.286633588334355 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373061.4701758,-1.49795329918548 -0.263601072284232 0.823898478545609 0.788388310173035 -0.522940888712441 -0.29928234305568 0.342627053981254 0.1992110978853411.4929041,0.796247055396743 0.0476559407005752 0.286633588334355 -1.02470580167082 -0.522940888712441 0.394013435896129 -1.04215728919298 -0.8644665073373061.5581446,-1.62233848461465 -0.843294091975396 -3.07127197548598 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373061.5993876,-0.990720665490831 0.458513517212311 0.823898478545609 1.07379746308195 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373061.6389967,-0.171901281967138 -0.489197399065355 -0.65357996953534 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373061.6956156,-1.60758252338831 -0.590700340358265 -0.65357996953534 -0.619561070667254 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373061.7137979,0.366273918511144 -0.414014962912583 -0.116315079324086 0.232904453212813 -0.522940888712441 0.971228997418125 0.342627053981254 1.262888703107991.8000583,-0.710307384579833 0.211731938156277 0.152317365781542 -1.02470580167082 -0.522940888712441 -0.442797990776478 0.342627053981254 1.617447904848871.8484548,-0.262791728113881 -1.16708345615721 0.420949810887169 0.0846342590816532 -0.522940888712441 0.163172393491611 0.342627053981254 1.972007106589751.8946169,0.899043117369237 -0.590700340358265 0.152317365781542 -1.02470580167082 -0.522940888712441 1.28643254437683 -1.04215728919298 -0.8644665073373061.9242487,-0.903451690500615 1.07659722048274 0.152317365781542 1.28380453408541 -0.522940888712441 -0.442797990776478 -1.04215728919298 -0.8644665073373062.008214,-0.0633337899773081 -1.38088970920094 0.958214701098423 0.80409888772376 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373062.0476928,-1.15393789990757 -0.961853075398404 -0.116315079324086 -1.02470580167082 -0.522940888712441 -0.442797990776478 -1.04215728919298 -0.8644665073373062.1575593,0.0620203721138446 0.0657973885499142 1.22684714620405 -0.468824786336838 -0.522940888712441 1.31421001659859 1.72741139715549 -0.3326277047259832.1916535,-0.75731027755674 -2.92717970468456 0.018001143228728 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.3326277047259832.2137539,1.11226993252773 1.06484916245061 0.555266033439982 0.877691038550889 1.89254797819741 1.43890404648442 0.342627053981254 0.3764906987557832.2772673,-0.468768642850639 -1.43754788774533 -1.05652863719378 0.576050411655607 -0.522940888712441 0.0120483832567209 0.342627053981254 -0.6871869064668652.2975726,-0.618884859896728 -1.1366360750781 -0.519263746982526 -1.02470580167082 -0.522940888712441 -0.863171185425945 3.11219574032972 1.972007106589752.3272777,-0.651431999123483 0.55329161145762 -0.250631301876899 1.11210019001038 -0.522940888712441 -0.179808625688859 -1.04215728919298 -0.8644665073373062.5217206,0.115499102435224 -0.512233676577595 0.286633588334355 1.13650173283446 -0.522940888712441 -0.179808625688859 0.342627053981254 -0.1553481038555412.5533438,0.266341329949937 -0.551137885443386 -0.384947524429713 0.354857790686005 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.3326277047259832.5687881,1.16902610257751 0.855491905752846 2.03274448152093 1.22628985326088 1.89254797819741 2.02833774827712 3.11219574032972 2.681125510071522.6567569,-0.218972367124187 0.851192298581141 0.555266033439982 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 0.9083295013671062.677591,0.263121415733908 1.4142681068416 0.018001143228728 1.35980653053822 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373062.7180005,-0.0704736333296423 1.52000996595417 0.286633588334355 1.39364261119802 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.3326277047259832.7942279,-0.751957286017338 0.316843561689933 -1.99674219506348 0.911736065044475 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373062.8063861,-0.685277652430997 1.28214038482516 0.823898478545609 0.232904453212813 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.1553481038555412.8124102,-0.244991501432929 0.51882005949686 -0.384947524429713 0.823246560137838 -0.522940888712441 -0.863171185425945 0.342627053981254 0.5537702996262242.8419982,-0.75731027755674 2.09041984898851 1.22684714620405 1.53428167116843 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373062.8535925,1.20962937075363 -0.242882661178889 1.09253092365124 -1.02470580167082 -0.522940888712441 1.24263233939889 3.11219574032972 2.503845909201082.9204698,0.570886990493502 0.58243883987948 0.555266033439982 1.16006887775962 -0.522940888712441 1.07357183940747 0.342627053981254 1.617447904848872.9626924,0.719758684343624 0.984970304132004 1.09253092365124 1.52137230773457 -0.522940888712441 -0.179808625688859 0.342627053981254 -0.5099073055964242.9626924,-1.52406140158064 1.81975700990333 0.689582255992796 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373062.9729753,-0.132431544081234 2.68769877553723 1.09253092365124 1.53428167116843 -0.522940888712441 -0.442797990776478 0.342627053981254 -0.6871869064668653.0130809,0.436161292804989 -0.0834447307428255 -0.519263746982526 -1.02470580167082 1.89254797819741 1.07357183940747 0.342627053981254 1.262888703107993.0373539,-0.161195191984091 -0.671900359186746 1.7641120364153 1.13650173283446 -0.522940888712441 -0.863171185425945 0.342627053981254 0.02193149701493.2752562,1.39927182372944 0.513852869452676 0.689582255992796 -1.02470580167082 1.89254797819741 1.49394503405693 0.342627053981254 -0.1553481038555413.3375474,1.51967002306341 -0.852203755696565 0.555266033439982 -0.104527297798983 1.89254797819741 1.85927724828569 0.342627053981254 0.9083295013671063.3928291,0.560725834706224 1.87867703391426 1.09253092365124 1.39364261119802 -0.522940888712441 0.486423065822545 0.342627053981254 1.262888703107993.4355988,1.00765532502814 1.69426310090641 1.89842825896812 1.53428167116843 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.5099073055964243.4578927,1.10152996153577 -0.10927271844907 0.689582255992796 -1.02470580167082 1.89254797819741 1.97630171771485 0.342627053981254 1.617447904848873.5160131,0.100001934217311 -1.30380956369388 0.286633588334355 0.316555063757567 -0.522940888712441 0.28786643052924 0.342627053981254 0.5537702996262243.5307626,0.987291634724086 -0.36279314978779 -0.922212414640967 0.232904453212813 -0.522940888712441 1.79270085261407 0.342627053981254 1.262888703107993.5652984,1.07158528137575 0.606453149641961 1.7641120364153 -0.432854616994416 1.89254797819741 0.528504607720369 0.342627053981254 0.1992110978853413.5876769,0.180156323255198 0.188987436375017 -0.519263746982526 1.09956763075594 -0.522940888712441 0.708239632330506 0.342627053981254 0.1992110978853413.6309855,1.65687973755377 -0.256675483533719 0.018001143228728 -1.02470580167082 1.89254797819741 1.79270085261407 0.342627053981254 1.262888703107993.6800909,0.5720085322365 0.239854450210939 -0.787896192088153 1.0605418233138 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373063.7123518,0.323806133438225 -0.606717660886078 -0.250631301876899 -1.02470580167082 1.89254797819741 0.342907418101747 0.342627053981254 0.1992110978853413.9843437,1.23668206715898 2.54220539083611 0.152317365781542 -1.02470580167082 1.89254797819741 1.89037692416194 0.342627053981254 1.262888703107993.993603,0.180156323255198 0.154448192444669 1.62979581386249 0.576050411655607 1.89254797819741 0.708239632330506 0.342627053981254 1.794727505719314.029806,1.60906277046565 1.10378605019827 0.555266033439982 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.8644665073373064.1295508,1.0036214996026 0.113496885050331 -0.384947524429713 0.860016436332751 1.89254797819741 -0.863171185425945 0.342627053981254 -0.3326277047259834.3851468,1.25591974271076 0.577607033774471 0.555266033439982 -1.02470580167082 1.89254797819741 1.07357183940747 0.342627053981254 1.262888703107994.6844434,2.09650591351268 0.625488598331018 -2.66832330782754 -1.02470580167082 1.89254797819741 1.67954222367555 0.342627053981254 0.5537702996262245.477509,1.30028987435881 0.338383613253713 0.555266033439982 1.00481276295349 1.89254797819741 1.24263233939889 0.342627053981254 1.97200710658975

3.2、代码三:

package mllibimport java.text.SimpleDateFormatimport java.util.Dateimport org.apache.log4j.{Level, Logger}import org.apache.spark.mllib.linalg.Vectorsimport org.apache.spark.mllib.regression.{LinearRegressionWithSGD, LabeledPoint}import org.apache.spark.{SparkContext, SparkConf}/**  * 计算回归曲线的MSE  * 对多组数据进行modeltraining,然后再利用modelpredict具体的值  * 过程中有输出model的权重  * 公式:f(x)=a1X1+a2X2+a3X3+……  * Created by 汪本成 on 2016/8/7.  */object LinearRegression2 {  //屏蔽不必要的日志显示在终端上  Logger.getLogger("org.apache.spark").setLevel(Level.WARN)  Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)  //程序入口  val conf = new SparkConf()    .setMaster("local[1]")    .setAppName(this.getClass().getSimpleName().filter(!_.equals('$')))  println(this.getClass().getSimpleName().filter(!_.equals('$')))  val sc = new SparkContext(conf)  def main(args: Array[String]) {    //获取数据集路径    val data = sc.textFile("G:\\MLlibData\\lpsa.data", 1)    //处理数据集    val parsedData = data.map{ line =>      val parts = line.split(",")      LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))    }    //建立模型    //建立model的数据和predict的数据没有分开    val numIterations = 100    val model = LinearRegressionWithSGD.train(parsedData, numIterations, 0.1)    //for (i <- parsedData) println(i.label + ":" + i.features)    //获取真实值与预测值    val valuesAndPreds = parsedData.map { point =>      //对系数进行预测      val  prediction = model.predict(point.features)      //按格式进行储存      (point.label, prediction)    }    //打印权重    var weights = model.weights    println("model.weights" + weights)    //save as file    val isString = new SimpleDateFormat("yyyyMMddHHmmssSSS").format(new Date())    val path = "G:\\MLlibData\\saveFile\\" + isString + "\\results"    valuesAndPreds.saveAsTextFile(path)    val MSE = valuesAndPreds.map {case(v, p) => math.pow((v - p), 2)}      .reduce(_ + _ ) / valuesAndPreds.count    println("训练的数据集的均方误差是: " + MSE)    sc.stop()  }}

注意:MLlib中的线性回归比较适合做一元线性回归而非多元线性回归,当回归系数比较多时,算法产生的过拟合现象较为严重。

4、逻辑回归

4.1、数据

这里包括了我写的意愿逻辑回归和多元逻辑回归,数据用的是spark工程下的sample_libsvm_data.txt文件和我自己弄的logisticRegression1.data,内容如下:

1|21|31|41|51|60|70|80|90|100|11

4.2、代码四

package mllibimport org.apache.log4j.{Level, Logger}import org.apache.spark.mllib.classification.{LogisticRegressionModel, LogisticRegressionWithSGD}import org.apache.spark.mllib.evaluation.MulticlassMetricsimport org.apache.spark.mllib.linalg.{Vector, Vectors}import org.apache.spark.mllib.regression.LabeledPointimport org.apache.spark.mllib.util.MLUtilsimport org.apache.spark.rdd.RDDimport org.apache.spark.{SparkContext, SparkConf}/**  * 逻辑回归  * Created by 汪本成 on 2016/8/7.  */object LogisticRegression {  //屏蔽不必要的日志显示在终端上  Logger.getLogger("org.apache.spark").setLevel(Level.WARN)  Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)  val conf = new SparkConf()    .setMaster("local[4]")    .setAppName(this.getClass().getSimpleName().filter(!_.equals('$')))  val sc = new SparkContext(conf)  var logisticRegression = new LogisticRegression  //一元逻辑回归数据集  val LR1_PATH = "file\\data\\mllib\\input\\regression\\logisticRegression1.data"  //多元逻辑回归数据集  val LR2_PATH = "file\\data\\mllib\\input\\regression\\sample_libsvm_data.txt"  val data = sc.textFile(LR1_PATH)  val svmData = MLUtils.loadLibSVMFile(sc, LR2_PATH)  //分割数据集  val splits = svmData.randomSplit(Array(0.6, 0.4), seed = 11L)  val parsedData_SVM = splits(0)  val parsedTest_SVM = splits(1)  //转化数据格式  val parsedData = data.map { line =>    val parts = line.split('|')    LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))  }.cache()  //建立模型  val model = LogisticRegressionWithSGD.train(parsedData, 50)  val svmModel = LogisticRegressionWithSGD.train(parsedData_SVM, 50)  //创建测试值  val target = Vectors.dense(-1)  //根据模型计算测试值结果  val predict = model.predict(target)  //计算多元逻辑回归的测试值,并存储测试和预测值  val predict_svm = logisticRegression.predictAndLabels(parsedTest_SVM, svmModel)  //创建验证类  val metrics = new MulticlassMetrics(predict_svm)  //计算验证值  val precision = metrics.precision  def main(args: Array[String]) {    println("一元逻辑回归:")    parsedData.foreach(println)    //打印权重    println("权重: " + model.weights)    println(predict)    println(model.predict(Vectors.dense(10)))    println("*************************************************************")    println("多元逻辑回归:")    println("svmData记录数:" + svmData.count())    println("parsedData_SVM" + parsedData_SVM.count())    println("parsedTest_SVM" + parsedTest_SVM.count())    println("Precision = " + precision) //打印验证值    predict_svm.take(10).foreach(println)    println("权重: " + svmModel.weights)    println("weights 个数是: " + svmModel.weights.size)    //打印weight不为0个数    println("weights不为0的个数是: " + model.weights.toArray.filter(_ != 0).size)    sc.stop()  }}class LogisticRegression {  /**    *    * @param data  svmData    * @param model LogisticRegressionModel    * @return    */  def predictAndLabels(  data: RDD[LabeledPoint],  model: LogisticRegressionModel):RDD[(Double, Double)]= {    val parsedData = data.map {      point =>        val prediction = model.predict(point.features)        (point.label, prediction)    }    parsedData  }}
运行结果请读者自己实验,我就不截图了大笑

0 0
原创粉丝点击