Scala-Spark实现RF(随机森林)

来源:互联网 发布:java中的抽象类的作用 编辑:程序博客网 时间:2024/04/23 21:07
package Cii_Forecast

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.hive.HiveContext

object Cii_Forecast_main {
  def main(args: Array[String]): Unit ={
    //获得昨天的日期
    // 调用形式: /opt/app/spark-1.6.1/bin/spark-shell --master yarn-client
    import java.util.Calendar
    import java.text.SimpleDateFormat
    import java.util.Date

    def getYesterday():String={
      var  dateFormat:SimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd")
      var cal:Calendar=Calendar.getInstance()
      cal.add(Calendar.DATE,-1)
      var yesterday=dateFormat.format(cal.getTime())
      yesterday
    }
    val Yesterday=getYesterday()

    //获取今天的日期
    def getNowDate():String={
      var now:Date = new Date()
      var  dateFormat:SimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd")
      var today = dateFormat.format( now )
      today
    }
    val Today=getNowDate()


    // 训练集、测试集
    // 训练集:CiiFcst_hotel_Cii_Traintable
    //测试集: CiiFcst_hotel_Cii_Testtable

    // LabelPoint相关包的导入
    import org.apache.spark.mllib.linalg.Vectors
    import org.apache.spark.mllib.regression.LabeledPoint


    // 机器学习包
    import org.apache.spark.mllib.tree.RandomForest
    import org.apache.spark.mllib.tree.model.RandomForestModel
    import org.apache.spark.mllib.util.MLUtils


    //训练集构建(模型数据保持不变)

    //sqlContext初始化
    val conf=new SparkConf().setAppName("cii_prediction")
    val sc=new SparkContext(conf)
    val sqlContext=new HiveContext(sc)


    //新增,对toDF()
    import sqlContext.implicits._

    val df=sqlContext.sql("select * from databasename.CiiFcst_hotel_Cii_Traintable where d='2016-08-29'")
    val data=df.select(df("notcancelcii"),df("week_constant"),df("working_day_constant"),df("cii_ahead_sameoneweek_constant")
      ,df("cii_ahead_sametwoweeks_avg_constant"),df("cii_ahead_samethreeweeks_avg_constant") ,df("cii_ahead_samefourweeks_avg_constant")
      ,df("simple_estimate_constant") ,df("cii_ahead_1day_1") ,df("cii_ahead_3days_avg_1"),df("cii_ahead_7days_avg_1"),df("order_ahead_lt_1days_1")
      ,df("order_ahead_lt_2days_1"),df("order_ahead_lt_3days_1"),df("order_ahead_lt_7days_1"),df("order_ahead_lt_14days_1"),df("order_alldays_1"))

    //LablePoint构建
    val  trainData=data.map{line =>
      val label=line(0).toString.toDouble
      val value0=(1 to 16).map(i=>  line(i).toString.toDouble  )
      val featureVector=Vectors.dense( value0.toArray)
      LabeledPoint(label, featureVector)
    }

    //模型预测
    //              提取预测表数据


    //*******************************  模型修改部分1  ******************************//
    //  val df1=sqlContext.sql("select * from databasename.CiiFcst_hotel_Cii_Testtable where d='2016-08-29' and starttime='2016-08-30'")


    val sql_1="select * from databasename.CiiFcst_hotel_Cii_Testtable where d="+"'"+Yesterday+"'"+" and starttime="+"'"+ Today +"'"

    val df1=sqlContext.sql(sql_1)


    val data1=df1.select(df1("hotelid"),df1("week_constant"),df1("working_day_constant"),df1("cii_ahead_sameoneweek_constant")
       ,df1("cii_ahead_sametwoweeks_avg_constant"),df1("cii_ahead_samethreeweeks_avg_constant") ,df1("cii_ahead_samefourweeks_avg_constant")
       ,df1("simple_estimate_constant") ,df1("cii_ahead_1day_1") ,df1("cii_ahead_3days_avg_1"),df1("cii_ahead_7days_avg_1"),df1("order_ahead_lt_1days_1")
       ,df1("order_ahead_lt_2days_1"),df1("order_ahead_lt_3days_1"),df1("order_ahead_lt_7days_1"),df1("order_ahead_lt_14days_1"),df1("order_alldays_1"))


    //构建预测集
    //LablePoint构建
    val  testData=data1.map{line =>
      val label=line(0).toString.toDouble
      val value0=(1 to 16).map(i=>  line(i).toString.toDouble  )
      val featureVector=Vectors.dense( value0.toArray)
      LabeledPoint(label, featureVector)
    }


    // 开始位置
    //模型参数
    val categoricalFeaturesInfo = Map[Int, Int]()
    val numTrees = 100 // Use more in practice.
    val featureSubsetStrategy = "auto" // Let the algorithm choose.
    val impurity = "variance"    //mse没有
    val maxDepth = 5
    val maxBins = 200

    //模型调用
    val model = RandomForest.trainRegressor(trainData, categoricalFeaturesInfo,
      numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)


    //结果集预测,母酒店ID和预测的产量,还可以加上日期
    val result_test = testData.map { point =>
      val prediction = model.predict(point.features)
      (point.label.toInt, prediction)
    }

    val tmp_frame=result_test.toDF()
    val tmp_frame1=tmp_frame.withColumnRenamed(tmp_frame.columns(0),"masterhotel")
    val tmp_frame2=tmp_frame1.withColumnRenamed(tmp_frame.columns(1),"cii_num");


    val result_frame= tmp_frame2.toDF()
    result_frame.registerTempTable("tmp_tableresult")


    //*******************************  模型修改部分2  ******************************//
    //sqlContext.sql("insert overwrite table  databasename.Tmp_CiiFcst_hotel_forecast_result partition(d='2016-09-05') select * from tmp_tableresult")


    val sql_2="insert overwrite table  databasename.Tmp_CiiFcst_hotel_forecast_result partition(d='"+Yesterday+"') select * from tmp_tableresult"
    sqlContext.sql(sql_2)


  }

}
0 1
原创粉丝点击