Spark机器学习之协同过滤算法使用-Java篇
来源:互联网 发布:网络cv设备 编辑:程序博客网 时间:2024/06/15 21:57
协同过滤通常用于推荐系统,这些技术旨在填补用户和项目关联矩阵里面缺少的值。Spark目前实现基于模型的协同过滤,其中模型的用户和项目由一组小的潜在因素所描述,可用于预测缺少的值。Spark使用交替最小二乘法alternating least squares(ALS)算法来学习这些潜在因素。
1. ALS的参数
当使用ALSModel进行预测时,在训练模型期间,普遍会在测试数据集中遇到用户和/或项目不存在的情况。这一般出现在以下两种情型:
然而,在交叉验证期间这是不可取的,因为任何NaN预测值将导致评估指标的NaN结果(例如当使用RegressionEvaluator的时候)。这使得模型的选择变得不可能。
Spark允许用户将coldStartStrategy参数设置为”drop”,以便删除DataFrame中包含预测NaN值的任何行,然后会根据非NaN的数据计算评估指标。
注意:目前支持的冷启动策略是“nan”(默认)和“drop”,未来可能会支持其它的策略。
3. Java代码例子
本文使用Spark 2.2.0、Java 1.8版本,测试数据可以在以下链接下载:
http://files.grouplens.org/datasets/movielens/ml-100k.zip
打印均方根误差结果为:Root-mean-square error = 1.0645093959897054,这个值是越小越好,如果得出的值不符合预期,可以调整ALS的参数重新计算直到符合预期为止。然后可以分别对所有用户和项目进行建议:
* 参考Spark Collaborative Filtering官方链接:http://spark.apache.org/docs/latest/ml-collaborative-filtering.html
END O(∩_∩)O
1. ALS的参数
- numBlocks:用户和项目将会被分区的块数,以便并行化计算(默认值为10)
- rank:模型中潜在因素的数值(默认值为10)
- maxIter:要运行的最大迭代次数(默认值为10)
- regParam:指定的正则化参数(默认值为1.0)
- implicitPrefs:是否使用隐式反馈(默认为false,使用显式反馈)
- alpha:当使用隐式反馈时,用于控制偏好观察的基线置信度(默认值为1.0)
- nonnegative:是否对最小二乘法使用非负约束 (默认值为false)
当使用ALSModel进行预测时,在训练模型期间,普遍会在测试数据集中遇到用户和/或项目不存在的情况。这一般出现在以下两种情型:
- 在生产环境中,对于没有评级历史的新用户或项目,和未经过训练的模型(这是“冷启动问题”)
- 在交叉验证期间,数据被拆分成训练集和评估集。当使用Spark的CrossValidator或TrainValidationSplit中的简单随机拆分时,评估集里面的用户和/或项目不在训练集里面是非常常见的
然而,在交叉验证期间这是不可取的,因为任何NaN预测值将导致评估指标的NaN结果(例如当使用RegressionEvaluator的时候)。这使得模型的选择变得不可能。
Spark允许用户将coldStartStrategy参数设置为”drop”,以便删除DataFrame中包含预测NaN值的任何行,然后会根据非NaN的数据计算评估指标。
注意:目前支持的冷启动策略是“nan”(默认)和“drop”,未来可能会支持其它的策略。
3. Java代码例子
本文使用Spark 2.2.0、Java 1.8版本,测试数据可以在以下链接下载:
http://files.grouplens.org/datasets/movielens/ml-100k.zip
import java.io.Serializable;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.ml.evaluation.RegressionEvaluator;import org.apache.spark.ml.recommendation.ALS;import org.apache.spark.ml.recommendation.ALSModel;import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row;import org.apache.spark.sql.SparkSession;public class JavaALSExample {public static class Rating implements Serializable {private static final long serialVersionUID = 1L;private int userId;private int movieId;private float rating;private long timestamp;public Rating() {}public Rating(int userId, int movieId, float rating, long timestamp) {this.userId = userId;this.movieId = movieId;this.rating = rating;this.timestamp = timestamp;}public int getUserId() {return userId;}public int getMovieId() {return movieId;}public float getRating() {return rating;}public long getTimestamp() {return timestamp;}public static Rating parseRating(String str) {String[] fields = str.split("\\t");if (fields.length != 4) {throw new IllegalArgumentException("Each line must contain 4 fields");}int userId = Integer.parseInt(fields[0]);int movieId = Integer.parseInt(fields[1]);float rating = Float.parseFloat(fields[2]);long timestamp = Long.parseLong(fields[3]);return new Rating(userId, movieId, rating, timestamp);}}public static void main(String[] args) { // 测试数据文件路径String path = "ml-100k/u.data";// 使用本地所有可用线程local[*]SparkSession spark = SparkSession.builder().master("local[*]").appName("JavaALSExample").getOrCreate();JavaRDD<Rating> ratingsRDD = spark.read().textFile(path).javaRDD().map(Rating::parseRating);Dataset<Row> ratings = spark.createDataFrame(ratingsRDD, Rating.class);// 按比例随机拆分数据Dataset<Row>[] splits = ratings.randomSplit(new double[] { 0.8, 0.2 });Dataset<Row> training = splits[0];Dataset<Row> test = splits[1];// 对训练数据集使用ALS算法构建建议模型ALS als = new ALS().setMaxIter(5).setRegParam(0.01).setUserCol("userId").setItemCol("movieId").setRatingCol("rating");ALSModel model = als.fit(training);// Evaluate the model by computing the RMSE on the test data// 通过计算均方根误差RMSE(Root Mean Squared Error)对测试数据集评估模型// 注意下面使用冷启动策略drop,确保不会有NaN评估指标model.setColdStartStrategy("drop");Dataset<Row> predictions = model.transform(test);// 打印predictions的schema predictions.printSchema();// predictions的schema输出 // root// |-- movieId: integer (nullable = false)// |-- rating: float (nullable = false)// |-- timestamp: long (nullable = false)// |-- userId: integer (nullable = false)// |-- prediction: float (nullable = true)RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating").setPredictionCol("prediction");double rmse = evaluator.evaluate(predictions);// 打印均方根误差System.out.println("Root-mean-square error = " + rmse);}}
打印均方根误差结果为:Root-mean-square error = 1.0645093959897054,这个值是越小越好,如果得出的值不符合预期,可以调整ALS的参数重新计算直到符合预期为止。然后可以分别对所有用户和项目进行建议:
// Generate top 10 movie recommendations for each userDataset<Row> userRecs = model.recommendForAllUsers(10);// Generate top 10 user recommendations for each movieDataset<Row> movieRecs = model.recommendForAllItems(10);
* 参考Spark Collaborative Filtering官方链接:http://spark.apache.org/docs/latest/ml-collaborative-filtering.html
END O(∩_∩)O
阅读全文
0 0
- Spark机器学习之协同过滤算法使用-Java篇
- Spark机器学习之协同过滤算法
- Spark机器学习之协同过滤
- Spark机器学习库mllib之协同过滤
- 机器学习,协同过滤算法
- 机器学习之协同过滤
- 机器学习之协同过滤
- 机器学习----推荐系统之协同过滤算法
- [机器学习]推荐系统之协同过滤算法
- 机器学习常用算法三:协同过滤
- 机器学习-协同过滤
- 离线轻量级大数据平台Spark之MLib机器学习协同过滤ALS实例
- Spark 机器学习-实例演示-协同过滤《三》
- spark/MLlib 协同过滤算法
- Spark MLlib之协同过滤
- Spark MLlib之协同过滤
- spark之CF协同过滤
- 机器学习之基于协同过滤的推荐引擎
- object转化成json,json格式字符串转字典,数组或字典转为json串
- usb转串口异步读取数据
- 背包DP合辑
- 某外企C++面试题
- Android BLE 开发资料汇总
- Spark机器学习之协同过滤算法使用-Java篇
- vue-router(1)
- MySQL 第四天
- python,Windows环境安装及导入beautifulsoup
- RSA密钥,JAVA与.NET之间转换
- java发送邮件
- 【yoyo】计算2018年1月1日距当天事件还剩多少天,多少小时,多少分钟,多少秒;
- Android初级开发(十一)——(转载)一篇文章轻松掌握Material Design
- 【多校训练】hdu 6085 Rikka with Candies bitset