SparkMLlib---SGD随机梯度下降算法

来源:互联网 发布:centos 7.0安装教程 编辑:程序博客网 时间:2024/05/18 00:35

代码:

package mllibimport org.apache.log4j.{Level, Logger}import org.apache.spark.{SparkContext, SparkConf}import scala.collection.mutable.HashMap/**  * 随机梯度下降算法  * Created by 汪本成 on 2016/8/5.  */object SGD {  //屏蔽不必要的日志显示在终端上  Logger.getLogger("org.apache.spark").setLevel(Level.WARN)  Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)  //程序入口  val conf = new SparkConf()    .setMaster("local[1]")    .setAppName(this.getClass().getSimpleName()    .filter(!_.equals('$')))    println(this.getClass().getSimpleName().filter(!_.equals('$')))  val sc = new SparkContext(conf)  //创建存储数据集HashMap集合  val data = new HashMap[Int, Int]()  //生成数据集内容  def getData(): HashMap[Int, Int] = {    for(i <- 1 to 50) {      data += (i -> (2 * i))  //写入公式y=2x    }    data  }  //假设a=0  var a: Double = 0  //设置步进系数  var b: Double = 0.1  //设置迭代公式  def sgd(x: Double, y: Double) = {    a = a - b * ((a * x) - y)  }  def main(args: Array[String]) {    //获取数据集    val dataSource = getData()    println("data: ")    dataSource.foreach(each => println(each + " "))    println("\nresult: ")    var num = 1    //开始迭代    dataSource.foreach(myMap => {      println(num + ":" + a + "("+myMap._1+","+myMap._2+")")      sgd(myMap._1, myMap._2)      num = num + 1    })    //显示结果    println("最终结果a " + a)  }}

运行结果:

"C:\Program Files\Java\jdk1.8.0_77\bin\java" -Didea.launcher.port=7533 "-Didea.launcher.bin.path=D:\Program Files (x86)\JetBrains\IntelliJ IDEA 15.0.5\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_77\jre\lib\charsets.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\deploy.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\access-bridge-64.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\cldrdata.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\dnsns.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\jaccess.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\jfxrt.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\localedata.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\nashorn.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\sunec.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\sunjce_provider.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\sunmscapi.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\sunpkcs11.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\zipfs.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\javaws.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\jce.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\jfr.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\jfxswt.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\jsse.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\management-agent.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\plugin.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\resources.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\rt.jar;G:\location\spark-mllib\out\production\spark-mllib;C:\Program Files (x86)\scala\lib\scala-actors-migration.jar;C:\Program Files (x86)\scala\lib\scala-actors.jar;C:\Program Files (x86)\scala\lib\scala-library.jar;C:\Program Files (x86)\scala\lib\scala-reflect.jar;C:\Program Files (x86)\scala\lib\scala-swing.jar;G:\home\download\spark-1.6.1-bin-hadoop2.6\lib\datanucleus-api-jdo-3.2.6.jar;G:\home\download\spark-1.6.1-bin-hadoop2.6\lib\datanucleus-core-3.2.10.jar;G:\home\download\spark-1.6.1-bin-hadoop2.6\lib\datanucleus-rdbms-3.2.9.jar;G:\home\download\spark-1.6.1-bin-hadoop2.6\lib\spark-1.6.1-yarn-shuffle.jar;G:\home\download\spark-1.6.1-bin-hadoop2.6\lib\spark-assembly-1.6.1-hadoop2.6.0.jar;G:\home\download\spark-1.6.1-bin-hadoop2.6\lib\spark-examples-1.6.1-hadoop2.6.0.jar;D:\Program Files (x86)\JetBrains\IntelliJ IDEA 15.0.5\lib\idea_rt.jar" com.intellij.rt.execution.application.AppMain mllib.SGDSGDUsing Spark's default log4j profile: org/apache/spark/log4j-defaults.propertiesSLF4J: Class path contains multiple SLF4J bindings.SLF4J: Found binding in [jar:file:/G:/home/download/spark-1.6.1-bin-hadoop2.6/lib/spark-assembly-1.6.1-hadoop2.6.0.jar!/org/slf4j/impl/StaticLoggerBinder.class]SLF4J: Found binding in [jar:file:/G:/home/download/spark-1.6.1-bin-hadoop2.6/lib/spark-examples-1.6.1-hadoop2.6.0.jar!/org/slf4j/impl/StaticLoggerBinder.class]SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.SLF4J: Actual binding is of type [org.slf4j.impl.Log4jLoggerFactory]16/08/05 00:48:28 INFO Slf4jLogger: Slf4jLogger started16/08/05 00:48:28 INFO Remoting: Starting remoting16/08/05 00:48:28 INFO Remoting: Remoting started; listening on addresses :[akka.tcp://sparkDriverActorSystem@192.168.43.1:24009]data: (23,46) (50,100) (32,64) (41,82) (17,34) (8,16) (35,70) (44,88) (26,52) (11,22) (29,58) (38,76) (47,94) (20,40) (2,4) (5,10) (14,28) (46,92) (40,80) (49,98) (4,8) (13,26) (22,44) (31,62) (16,32) (7,14) (43,86) (25,50) (34,68) (10,20) (37,74) (1,2) (19,38) (28,56) (45,90) (27,54) (36,72) (18,36) (9,18) (21,42) (48,96) (3,6) (12,24) (30,60) (39,78) (15,30) (42,84) (24,48) (6,12) (33,66) result: 1:0.0(23,46)2:4.6000000000000005(50,100)3:-8.400000000000002(32,64)4:24.880000000000006(41,82)5:-68.92800000000003(17,34)6:51.649600000000035(8,16)7:11.929920000000003(35,70)8:-22.82480000000001(44,88)9:86.40432000000006(26,52)10:-133.04691200000013(11,22)11:15.504691199999996(29,58)12:-23.65891328(38,76)13:73.84495718400001(47,94)14:-263.82634158080003(20,40)15:267.82634158080003(2,4)16:214.66107326464004(5,10)17:108.33053663232002(14,28)18:-40.53221465292802(46,92)19:155.1159727505409(40,80)20:-457.3479182516227(49,98)21:1793.4568811813288(4,8)22:1076.8741287087973(13,26)23:-320.46223861263934(22,44)24:388.95468633516725(31,62)25:-810.6048413038511(16,32)26:489.56290478231085(7,14)27:148.2688714346932(43,86)28:-480.6872757344877(25,50)29:726.0309136017315(34,68)30:-1735.6741926441557(10,20)31:2.0000000000002274(37,74)32:1.999999999999386(1,2)33:1.9999999999994476(19,38)34:2.000000000000497(28,56)35:1.9999999999991056(45,90)36:2.00000000000313(27,54)37:1.9999999999946787(36,72)38:2.000000000013835(18,36)39:1.999999999988932(9,18)40:1.999999999998893(21,42)41:2.0000000000012172(48,96)42:1.9999999999953737(3,6)43:1.9999999999967615(12,24)44:2.000000000000648(30,60)45:1.999999999998704(39,78)46:2.0000000000037588(15,30)47:1.9999999999981206(42,84)48:2.0000000000060134(24,48)49:1.999999999991581(6,12)50:1.9999999999966325(33,66)最终结果a为 2.000000000007745416/08/05 00:48:28 INFO RemoteActorRefProvider$RemotingTerminator: Shutting down remote daemon.Process finished with exit code 0
分析:
当α为0.1的时候,一般30次计算就计算出来了;如果是0.5,一般15次计算就有正确结果 。如果是1,则50次都没有结果

0 0