spark mllib机器学习之六 ALS

来源:互联网 发布:三国杀张鲁淘宝价格 编辑:程序博客网 时间:2024/04/28 13:03

协同过滤采用音乐推荐的数据

http://www.iro.umontreal.ca/~lisa/datasets/profiledata_06-May-2005.tar.gz  


package com.agm.practice



import java.io.File
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.log4j.{ Level, Logger }


object adviceMusic {
  def main(args: Array[String]) {
    Logger.getLogger("org").setLevel(Level.ERROR)
    val conf = new SparkConf().setAppName("Simple Application") //给Application命名    
    conf.setMaster("local[2]")
    val sc = new SparkContext(conf)
    val rawArtistData = sc.textFile("D://Spark//文档//profiledata_06-May-2005//artist_data.txt")
    val artistByID = rawArtistData.flatMap { line =>
      val (id, name) = line.span(_ != '\t')
      if (name.isEmpty) {
        None
      } else {
        try {
          Some((id.toInt, name.trim))
        } catch {
          case e: NumberFormatException => None
        }
      }
    }


    val rawArtistAlias = sc.textFile("D://Spark//文档//profiledata_06-May-2005//artist_alias.txt")
    val artistAlias = rawArtistAlias.flatMap { line =>
      val tokens = line.split('\t')
      if (tokens(0).isEmpty) {
        None
      } else {
        Some((tokens(0).toInt, tokens(1).toInt))
      }
    }.collectAsMap()


    println(artistByID.lookup(6803336).head)
    println(artistByID.lookup(1000010).head)


    val rawUserArtistData = sc.textFile("D://Spark//文档//profiledata_06-May-2005//user_artist_data.txt")
    import org.apache.spark.mllib.recommendation._


    val bArtistAlias = sc.broadcast(artistAlias)
    val trainData = rawUserArtistData.map { line =>
      val Array(userID, artistID, count) = line.split(' ').map(_.toInt)
      val finalArtistID =
        bArtistAlias.value.getOrElse(artistID, artistID)
      Rating(userID, finalArtistID, count)
    }.cache()


    val model = ALS.trainImplicit(trainData, 10, 5, 0.01, 1.0)


    val rawArtistsForUser = rawUserArtistData.map(_.split(' ')).
      filter { case Array(user, _, _) => user.toInt == 2093760 }
    val existingProducts =
      rawArtistsForUser.map { case Array(_, artist, _) => artist.toInt }.
        collect().toSet
    artistByID.filter {
      case (id, name) =>
        existingProducts.contains(id)
    }.values.collect().foreach(println)


    val recommendations = model.recommendProducts(2093760, 5)
    recommendations.foreach(println)
  }
}
0 0
原创粉丝点击