Spark机器学习之协同过滤算法使用-Java篇

来源:互联网 发布:网络cv设备 编辑:程序博客网 时间:2024/06/15 21:57
协同过滤通常用于推荐系统,这些技术旨在填补用户和项目关联矩阵里面缺少的值。Spark目前实现基于模型的协同过滤,其中模型的用户和项目由一组小的潜在因素所描述,可用于预测缺少的值。Spark使用交替最小二乘法alternating least squares(ALS)算法来学习这些潜在因素。

1. ALS的参数
  • numBlocks:用户和项目将会被分区的块数,以便并行化计算(默认值为10)
  • rank:模型中潜在因素的数值(默认值为10)
  • maxIter:要运行的最大迭代次数(默认值为10)
  • regParam:指定的正则化参数(默认值为1.0)
  • implicitPrefs:是否使用隐式反馈(默认为false,使用显式反馈)
  • alpha:当使用隐式反馈时,用于控制偏好观察的基线置信度(默认值为1.0)
  • nonnegative:是否对最小二乘法使用非负约束 (默认值为false)
2. 冷启动(Cold-start)策略

当使用ALSModel进行预测时,在训练模型期间,普遍会在测试数据集中遇到用户和/或项目不存在的情况。这一般出现在以下两种情型:
  • 在生产环境中,对于没有评级历史的新用户或项目,和未经过训练的模型(这是“冷启动问题”)
  • 在交叉验证期间,数据被拆分成训练集和评估集。当使用Spark的CrossValidator或TrainValidationSplit中的简单随机拆分时,评估集里面的用户和/或项目不在训练集里面是非常常见的
默认地,当模型中不存在的用户和/或项目因素时,Spark在调用ALSModel.transform方法时,预测的值会是NaN。这在生产系统中可以是有用的,因为NaN表示一个新的用户或项目,因此系统可以预测作出一些回退的决定。

然而,在交叉验证期间这是不可取的,因为任何NaN预测值将导致评估指标的NaN结果(例如当使用RegressionEvaluator的时候)。这使得模型的选择变得不可能。

Spark允许用户将coldStartStrategy参数设置为”drop”,以便删除DataFrame中包含预测NaN值的任何行,然后会根据非NaN的数据计算评估指标。

注意:目前支持的冷启动策略是“nan”(默认)和“drop”,未来可能会支持其它的策略。

3. Java代码例子

本文使用Spark 2.2.0、Java 1.8版本,测试数据可以在以下链接下载:

http://files.grouplens.org/datasets/movielens/ml-100k.zip

import java.io.Serializable;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.ml.evaluation.RegressionEvaluator;import org.apache.spark.ml.recommendation.ALS;import org.apache.spark.ml.recommendation.ALSModel;import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row;import org.apache.spark.sql.SparkSession;public class JavaALSExample {public static class Rating implements Serializable {private static final long serialVersionUID = 1L;private int userId;private int movieId;private float rating;private long timestamp;public Rating() {}public Rating(int userId, int movieId, float rating, long timestamp) {this.userId = userId;this.movieId = movieId;this.rating = rating;this.timestamp = timestamp;}public int getUserId() {return userId;}public int getMovieId() {return movieId;}public float getRating() {return rating;}public long getTimestamp() {return timestamp;}public static Rating parseRating(String str) {String[] fields = str.split("\\t");if (fields.length != 4) {throw new IllegalArgumentException("Each line must contain 4 fields");}int userId = Integer.parseInt(fields[0]);int movieId = Integer.parseInt(fields[1]);float rating = Float.parseFloat(fields[2]);long timestamp = Long.parseLong(fields[3]);return new Rating(userId, movieId, rating, timestamp);}}public static void main(String[] args) {    // 测试数据文件路径String path = "ml-100k/u.data";// 使用本地所有可用线程local[*]SparkSession spark = SparkSession.builder().master("local[*]").appName("JavaALSExample").getOrCreate();JavaRDD<Rating> ratingsRDD = spark.read().textFile(path).javaRDD().map(Rating::parseRating);Dataset<Row> ratings = spark.createDataFrame(ratingsRDD, Rating.class);// 按比例随机拆分数据Dataset<Row>[] splits = ratings.randomSplit(new double[] { 0.8, 0.2 });Dataset<Row> training = splits[0];Dataset<Row> test = splits[1];// 对训练数据集使用ALS算法构建建议模型ALS als = new ALS().setMaxIter(5).setRegParam(0.01).setUserCol("userId").setItemCol("movieId").setRatingCol("rating");ALSModel model = als.fit(training);// Evaluate the model by computing the RMSE on the test data// 通过计算均方根误差RMSE(Root Mean Squared Error)对测试数据集评估模型// 注意下面使用冷启动策略drop,确保不会有NaN评估指标model.setColdStartStrategy("drop");Dataset<Row> predictions = model.transform(test);// 打印predictions的schema        predictions.printSchema();// predictions的schema输出        // root// |-- movieId: integer (nullable = false)// |-- rating: float (nullable = false)// |-- timestamp: long (nullable = false)// |-- userId: integer (nullable = false)// |-- prediction: float (nullable = true)RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating").setPredictionCol("prediction");double rmse = evaluator.evaluate(predictions);// 打印均方根误差System.out.println("Root-mean-square error = " + rmse);}}

打印均方根误差结果为:Root-mean-square error = 1.0645093959897054,这个值是越小越好,如果得出的值不符合预期,可以调整ALS的参数重新计算直到符合预期为止。然后可以分别对所有用户和项目进行建议:

// Generate top 10 movie recommendations for each userDataset<Row> userRecs = model.recommendForAllUsers(10);// Generate top 10 user recommendations for each movieDataset<Row> movieRecs = model.recommendForAllItems(10);

* 参考Spark Collaborative Filtering官方链接:http://spark.apache.org/docs/latest/ml-collaborative-filtering.html

END O(∩_∩)O