LinearRegressionWithSGD 问题

来源:互联网 发布:淘宝网开店 编辑:程序博客网 时间:2024/06/05 07:45

file/data/mllib/input/ridge-data/defDemo1 文件对应的数据源:

42,0.10

43.5,0.11

45,0.12

45.5,0.13

45,0.14

47.5,0.15

49,0.16

53,0.17

50,0.18

55,0.20

55,0.21

60,0.23


代码:   

  def defDemo1() {

    val conf = new SparkConf().setMaster("local").setAppName(this.getClass().getSimpleName().filter(!_.equals('$')))

    val sc = new SparkContext(conf)


    valdata = sc.textFile("file/data/mllib/input/ridge-data/defDemo1")//获取数据集路径

    val parsedData =data.map { line => //开始对数据集处理

      val parts = line.split(',') //根据逗号进行分区

      LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).trim().split(' ').map(_.toDouble)))

    } //转化数据格式

    parsedData.foreach(line=>{

      println(line.label +" , "+line.features)

    })

    val model = LinearRegressionWithSGD.train(parsedData,1000, 0.001) //建立模型

    val result = model.predict(Vectors.dense(0.19)) //通过模型预测模型 

   

    println("model weights:")

    println(model.weights)

    println("model intercept:")

    println(model.intercept)

    println("result:")

    println(result) //打印预测结果

    sc.stop

  }


执行结果:

model weights:

[0.11670307429843765]

model intercept:

0.0

result:

0.022173584116703154


实际线性函数应该接近:y=130.835x + 28.493

LinearRegressionWithSGD 执行的结果跟实际结果函数对不上.....

  


0 0