ALS实现电影推荐

来源:互联网 发布:java获取unix时间戳 编辑:程序博客网 时间:2024/04/26 10:15
package com.ys.scala


import org.apache.log4j.Logger
import org.apache.log4j.Level
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.mllib.recommendation.Rating
import scala.util.Random
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel
import org.apache.spark.mllib.recommendation.ALS
import org.apache.spark.rdd.RDD




object ScalaMovieLensALS  {
  def main(args: Array[String]): Unit = {
    
    //屏蔽不必要的打印信息
    Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
    Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)
    
    val conf = new SparkConf().setAppName("ScalaMovieLensALS").setMaster("local")
    val sc = new SparkContext(conf)
    
    // load ratings and movie titles
    val ratings = sc.textFile("ratings.dat").map { line => 
      val fields = line.split("::")
      // format: (timestamp % 10, Rating(userId, movieId, rating))  
      (fields(3).toLong % 10, Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble))
    }
    
    val movies = sc.textFile("movies.dat").map { line => 
      val fields = line.split("::")
      // format: (movieId, movieName) key value格式
      (fields(0).toInt, fields(1))
    }.collect().toMap
    
    val numRatings = ratings.count();
    val numUsers = ratings.map(_._2.user).distinct().count()
    val numMovies = ratings.map(_._2.product).distinct().count()
    
    println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.")
    
    //get ratings of user on top 50 popular movies
    val mostRatedMovieIds = ratings.map(_._2.product) //extract movieId
      .countByValue() //count ratings per movie
      .toSeq //convert map to seq
      .sortBy(-_._2) //sort by rating count in decreasing order
      .take(50) //take 50 most rated
      .map(_._1) //get movie ids
      
    val random = new Random(0)
    val selectedMovies = mostRatedMovieIds.filter { x => random.nextDouble() < 0.2 }
      .map { x => (x, movies(x)) }
      .toSeq
    val myRatings = getRatings(selectedMovies)
    //convert received ratings to RDD[Rating], now this can be worked in parallel
    val myRatingsRDD = sc.parallelize(myRatings)
    
    // split ratings into train (60%), validation (20%), and test (20%) based on the
    // last digit of the timestamp, add myRatings to train, and cache them
    
    val numPartitions = 4
    //Rating(3329,953,5.0)
    val training = ratings.filter(x => x._1 < 6).values.union(myRatingsRDD).repartition(numPartitions).cache()
    val validation = ratings.filter(x => x._1 >= 6 && x._1 < 8).values.repartition(numPartitions).cache()
    val test = ratings.filter(x => x._1 >= 8).values.cache()
    
    val numTraining = training.count()
    val numValidation = validation.count()
    val numTest = test.count()
    
    println(s"Training: $numTraining, validation: $numValidation, test: $numTest")
    
    // train models and evaluate them on the validation set
    val ranks = List(8, 10, 12) //模型中的隐藏因子数目
    val lambdas = List(0.1, 1.0, 10.0) //ALS正则化参数
    val numIterations = List(10, 20) //算法迭代次数
    var bestModel: Option[MatrixFactorizationModel] = None //矩阵分解
    var bestValidationRmse = Double.MaxValue
    var bestRank = 0
    var bestLambda = -1.0
    var bestNumIter = -1
    
    for(rank <- ranks; lambda <- lambdas; numIter <- numIterations) {
      //learn model for these parameter
      val model = ALS.train(training, rank, numIter, lambda)
      val validationRmse = computeRmse(model, validation)
      println(s"RMSE (validation) = $validationRmse for the model trained with rank = $rank , lambda = $lambda ,and numIter = $numIter .")
          
      if(validationRmse < bestValidationRmse) {
        bestModel = Some(model)
        bestValidationRmse = validationRmse
        bestRank = rank
        bestLambda = lambda
        bestNumIter = numIter
      }
    }
      
    // evaluate the best model on the test set
    val testRmse = computeRmse(bestModel.get, test)
    println(s"The best model was trained with rank = $bestRank and lambda = $bestLambda , and numIter = $bestNumIter , and its RMSE on the test set is $testRmse .")
    
    //find best movies for the user
    val myRatedMovieIds = myRatings.map(_.product).toSet
    //generate candidates after taking out already rated movies
    val candidates = sc.parallelize(movies.keys.filter(!myRatedMovieIds.contains(_)).toSeq)
    val recommendations = bestModel.get.predict(candidates.map((0, _))).collect.sortBy(-_.rating).take(50)
    var i = 1
    println("Movies recommendation for you: ")
    recommendations.foreach { r => println("%2d".format(i) + ": " + movies(r.product))
      i += 1
    }
    
    // create a naive baseline and compare it with the best model
    val meanRating = training.union(validation).map(_.rating).mean
    val baselineRmse = math.sqrt(test.map(x => (meanRating - x.rating) * (meanRating - x.rating)).mean)
    val improvement = (baselineRmse - testRmse) / baselineRmse * 100
    println("The best model improves the baseline by " + "%1.2f".format(improvement) + "%.")


    // clean up
    sc.stop()
  }
  
  /** Get ratings from commandline **/
  def getRatings(movies: Seq[(Int, String)]) = {
    val prompt = "Please rate following movie (1-5(best), or 0 if not seen):"
    println(prompt)
    
    val ratings = movies.flatMap { x => 
      
      var rating: Option[Rating] = None
      var valid = false
      
      while (!valid) {
        print(x._2 + ":")
        try {
          val r = Console.readInt()
          if(r < 0 || r > 5) {
            println(prompt)
          } else {
            valid = true
            if (r > 0) {
              rating = Some(Rating(0,x._1,r))
            }
          }
        } catch {
          case e: Exception => println(prompt)
        }
      }
      
      rating match {
        case Some(r) => Iterator(r)
        case None => Iterator.empty
      }
    }//end flatMap
    
    if (ratings.isEmpty) {
      error("No rating provided")
    } else {
      ratings
    }
  }
  
  // Compute RMSE (Root Mean Squared Error).  计算测试集的评分和实际评分之间的均方根误差(RMSE)
  def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating]) = {
    val usersProducts = data.map { case Rating(user, product, rate) => (user, product) }
    
    val predictions = model.predict(usersProducts).map { case Rating(user, product, rate) => ((user, product), rate) }
    
    val ratesAndPreds = data.map { case Rating(user, product, rate) =>
      ((user, product), rate) }.join(predictions).sortByKey()
      
    math.sqrt(ratesAndPreds.map { case ((user, product), (r1, r2)) =>
      val err = (r1 - r2)
      err * err
    }.mean())
  }

}


代码中用到的数据movies.dat  ratings.dat可以在http://download.csdn.net/detail/u013147600/8908241下载

movies.dat的数据格式如下:序号,电影名,类型

1::Toy Story (1995)::Animation|Children's|Comedy
2::Jumanji (1995)::Adventure|Children's|Fantasy
3::Grumpier Old Men (1995)::Comedy|Romance
4::Waiting to Exhale (1995)::Comedy|Drama
5::Father of the Bride Part II (1995)::Comedy
6::Heat (1995)::Action|Crime|Thriller
7::Sabrina (1995)::Comedy|Romance
8::Tom and Huck (1995)::Adventure|Children's
9::Sudden Death (1995)::Action
10::GoldenEye (1995)::Action|Adventure|Thriller


ratings.dat的数据格式如下:用户id,电影id,评分,时间戳

1::1193::5::978300760
1::661::3::978302109
1::914::3::978301968
1::3408::4::978300275
1::2355::5::978824291
1::1197::3::978302268
1::1287::5::978302039
1::2804::5::978300719
1::594::4::978302268
1::919::4::978301368

0 0
原创粉丝点击