ALS 推荐系统
来源:互联网 发布:域名转出 编辑:程序博客网 时间:2024/04/25 23:02
1:ALS(alternating least squares ):交替最小二乘法
在机器学习中,特指使用最小二乘法的一种协同推荐算法。如下图所示,u表示用户,v表示商品,用户给商品打分,但是并不是每一个用户都会给每一种商品打分。? 表示用户没有打分的情况,所以这个矩阵A很多元素都是空的,我们称其为“缺失值(missing value)”。协同过滤提出了一种支持不完整评分矩阵的矩阵分解方法,不用对评分矩阵进行估值填充。
和协同过滤不一样的是,ALS认为用户的评分矩阵是有用户特征矩阵和物品特征矩阵相乘得到的。
ALS 的核心假设是:打分矩阵A是近似低秩的,即一个
Am×n=Um×k×Vk×n
我们把打分理解成相似度,那么“打分矩阵
- 给定隐含特征的数量,用随机数初始化用户-特征矩阵和商品-特征矩阵
- 用梯度下降法交替的优化这两个矩阵,用商品矩阵的各维度作为用户矩阵的梯度下降的方向,反之亦然
- 优化结束后,计算用户特征向量和商品特征向量的相似度(内积、余弦……),这就是用户对商品的偏好打分
2.Spark Mllib
Spark使用的是交叉最小二乘法(ALS)来最优化损失函数。算法的思想就是:我们先随机生成然后固定它求解,再固定求解,这样交替进行下去,直到取得最优解
3. MLlib的ALS实现
ALS伴生对象是建立ALS模型的入口,其主要定义训练线性回归模型的train方法,train方法通过设置训练参数进行模型训练,其参数主要包括:
- ratings-----评分RDD格式(userID,productID,rating)对;
- rank------特征数量
- iterations------迭代次数
- lambda------正则因子(推荐值为0.01)
- blocks-----数据分隔
- seed------随机种子
4. 优化步骤
- 一个用户特征和一个商品特征相乘,得到用户对商品的偏好(单元格)
- 已知偏好的单元格,乘的结果要和已知的值尽量接近(MSE,总的方差最小)
- 用梯度下降法交替的优化用户特征和商品特征(ALS)
5. 梯度下降
- n个隐含特征=在n维空间里优化用户、商品特征
- 找一个下降最快的方向(拉格朗日乘数法、随机……)
- 朝着这个方向走一小步
- 回到1,直到总的偏差不再下降
实现coding
package hhc.mllib.label.learn.recommend;import hhc.mllib.label.learn.config.AppConfig;import hhc.mllib.label.learn.ml.CreaterBase;import org.apache.spark.api.java.JavaPairRDD;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.api.java.function.Function;import org.apache.spark.mllib.evaluation.RegressionMetrics;import org.apache.spark.mllib.recommendation.ALS;import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;import org.apache.spark.mllib.recommendation.Rating;import org.apache.spark.rdd.RDD;import scala.Tuple2;import java.util.Arrays;import java.util.List;/** * Created by huhuichao on 2017/12/7. */public class ALSModelCreater extends CreaterBase{ private MatrixFactorizationModel model; private transient JavaSparkContext jsc; public ALSModelCreater(JavaSparkContext jsc) { this.jsc = jsc; } /** * 读取样本数据 * @param path * @return */ public static JavaRDD<Rating> getALSJavaRDD(String path, JavaSparkContext sc,String split) { JavaRDD<String> data=sc.textFile(path); JavaRDD<Rating> ratings = data.map( new Function<String, Rating>() { public Rating call(String s) { String[] sarray = s.split(split); return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), Double.parseDouble(sarray[2])); } } ); return ratings; } public MatrixFactorizationModel training (JavaRDD<Rating> ratings,int rank, int numIterations, double v){ return ALS.train(ratings.rdd(), rank, numIterations, v); } /** * * 计算方差 * @param ratings 样本数据 * @param model model * @return */ public static double evaluateMSE(JavaRDD<Rating> ratings,MatrixFactorizationModel model) { JavaRDD<Tuple2<Object, Object>> userProducts = ratings.map( new Function<Rating, Tuple2<Object, Object>>() { private static final long serialVersionUID = 1L; @Override public Tuple2<Object, Object> call(Rating r) { return new Tuple2<Object, Object>(r.user(), r.product()); } } ); JavaPairRDD<Tuple2<Integer, Integer>, Object> predictions = JavaPairRDD.fromJavaRDD( model.predict( JavaRDD.toRDD(userProducts)).toJavaRDD().map( new Function<Rating, Tuple2<Tuple2<Integer, Integer>, Object>>() { private static final long serialVersionUID = 1L; @Override public Tuple2<Tuple2<Integer, Integer>, Object> call(Rating r) { return new Tuple2<Tuple2<Integer, Integer>, Object>( new Tuple2<>(r.user(), r.product()), r.rating()); } } )); JavaRDD<Tuple2<Object, Object>> ratesAndPreds = JavaPairRDD.fromJavaRDD(ratings.map( new Function<Rating, Tuple2<Tuple2<Integer, Integer>, Object>>() { private static final long serialVersionUID = 1L; @Override public Tuple2<Tuple2<Integer, Integer>, Object> call(Rating r) { return new Tuple2<Tuple2<Integer, Integer>, Object>( new Tuple2<>(r.user(), r.product()), r.rating()); } } )).join(predictions).values(); // Create regression metrics object RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); return regressionMetrics.meanSquaredError(); } /** * 获取矩阵分解后的物品特征矩阵 * @param model * @return */ public static JavaPairRDD<Object, double[]> getProductPeatures(MatrixFactorizationModel model){ return JavaPairRDD.fromJavaRDD(model.productFeatures().toJavaRDD()); } /** * recommendProductsForUsers 对所有用户推荐物品,取前n个物品 * @param num * @param model * @return */ public static JavaPairRDD<Object, Rating[]> recommendProductsForUsers(int num,MatrixFactorizationModel model) { RDD<Tuple2<Object, Rating[]>> tuple2RDD = model.recommendProductsForUsers(num); JavaRDD<Tuple2<Object, Rating[]>> tuple2JavaRDD = tuple2RDD.toJavaRDD(); JavaPairRDD<Object, Rating[]> productFeatures=JavaPairRDD.fromJavaRDD(tuple2JavaRDD); return productFeatures; } public static void main(String[] args) {// ALSModelCreater alsModel=new ALSModelCreater(AppConfig.getInstance().sc);// //读取样本数据// JavaRDD<Rating> ratings= getALSJavaRDD("data/ml/recommend/als/test.data", alsModel.jsc,",");//// List<Rating> list=ratings.collect();// //建立模型// int rank=10;// int numIterations=5;// MatrixFactorizationModel model = ALS.train(ratings.rdd(), rank, numIterations, 0.01);//// System.out.println("Mean Squared Error = " + evaluateMSE(ratings,model));// model.save(alsModel.jsc.sc(),"data/ml/recommend/als/model"); MatrixFactorizationModel model=MatrixFactorizationModel.load(AppConfig.getInstance().sc.sc(),"data/ml/recommend/als/model"); System.out.println(Arrays.toString(model.recommendProducts(4,2))); JavaPairRDD<Object, double[]> productFeatures = getProductPeatures(model); List<Tuple2<Object, double[]>> list=productFeatures.collect(); System.out.println(list); JavaPairRDD<Object, Rating[]> features=recommendProductsForUsers(2,model); List<Tuple2<Object, Rating[]>> list1=features.collect(); System.out.println(list1); }}
阅读全文
1 0
- ALS推荐系统实战
- ALS 推荐系统
- 推荐系统ALS矩阵分解
- 推荐系统ALS矩阵分解
- 基于ALS的线推荐系统
- 基于Spark ALS在线推荐系统
- 基于Spark ALS在线推荐系统
- 基于Spark ALS在线推荐系统
- 基于ALS算法的简易在线推荐系统
- 基于ALS算法的简易在线推荐系统
- 基于Spark ALS的离线推荐系统实践
- ALS实现电影推荐
- ALS矩阵分解推荐模型
- Spark ALS推荐系统用户ID非整数的解决思路
- 推荐系统实践1---基于spark ALS做的电影推荐,参考网上的做的,能跑起来
- mahout0.9 分布式推荐算法ALS-MR
- spark mllib als推荐引擎学习
- SparkML之推荐算法(一)ALS
- H5移动端知识总结一
- ORMLite does not know how to store class java.util.ArrayList错误的解决
- Windows下PHP7如何连接Oracle 12c,并使用PDO
- python调用perl脚本
- sql语句中获取datetime的日期部分或时间部分
- ALS 推荐系统
- python通讯录管理系统
- 输入带小数点的键盘(小数点为2位为例)
- git退回后提交
- Apriori算法实例
- 也拆了机器,mac mini总体来说不如直接ssd啊
- Kubernetes1.9 在Unbunt16.4 安装
- RHEL更换为centros的yum
- vsftpd: refusing to run with writable root inside chroot()