Spark SVDPlusPlus 源码分析
来源:互联网 发布:银行核心业务系统数据 编辑:程序博客网 时间:2024/06/05 23:39
上一篇对SVD++算法的原理进行了总结,本文是对Spark SVDPlusPlus源码的分析总结。源码位置在Spark源码包的org.apache.spark.graphx.lib.SVDPlusPlus,需要引入spark-graphx相关包。
迭代公式推导
相比于SVD和implicit ALS算法,SVD++的python单机版算法效率明显低很多,因为多了很重的一个子项
class Conf( var rank: Int,//因子向量维度 var maxIters: Int,//最大迭代次数 var minVal: Double,//评分下限 var maxVal: Double,//评分上限 var gamma1: Double,//b*梯度下降学习速率 var gamma2: Double,//q,p,y梯度下降学习速率 var gamma6: Double,//b*正则系数 var gamma7: Double)//q,p,y正则系数 extends Serializable
基于梯度下降(为了并行,注意这里是批量梯度下降,不是随机梯度下降)方法,对目标学习公式求偏导,可得到如下迭代公式(为了方便源码表达,这里系数名称和源码保持一致):
Spark Graphx
因为Spark SVD++是基于Spark Graphx实现的,所以先对Graphx做简明总结。Graphx是用于图并行计算的spark组件,它基于RDD引入了图抽象:每个顶点和边都绑定属性的多重图(multigraph,两个顶点间有多条边)。Graphx提供了图操作和图算法集合。图的基本单元是点(Vertix)和边(Edge)组成的Triplets,下图中的蓝橙颜色块表示属性。
对于用户的评分数据,A就表示用户,B表示物品,边上的橙色块表示评分,蓝色快储存了我们需要迭代了p、q、y、b*,会在之后详述。
Spark SVD++用到的图操作如下:
1.sendMsg:将edge中的属性发送到顶点
2.mergeMsg:在顶点处进行merge操作
这里我们可以联想到将物品表达的隐式反馈汇总到用户顶点 outerJoinVertices 是aggregateMessages的小伙伴
聚合操作会返回新的Vertices集合(这也是scala函数式编程的特点)
基于VertixId与原来的Triplets集合进行join,更新对应顶点里的属性
核心源码分析
下面分析关键代码片段,主要关注变量迭代,入口代码如下,Edge保存了Double类型的评分数据,srcVertex是用户,dstVertex是物品
def run(edges: RDD[Edge[Double]], conf: Conf)
顶点属性初始化和全局均值u 计算
// 生成默认的顶点属性 包含四个属性(v1,v2,0,0),v1、v2表示向量def defaultF(rank: Int): (Array[Double], Array[Double], Double, Double) = { // TODO: use a fixed random seed val v1 = Array.fill(rank)(Random.nextDouble()) val v2 = Array.fill(rank)(Random.nextDouble()) (v1, v2, 0.0, 0.0) }//计算全局均值u val (rs, rc) = edges.map(e => (e.attr, 1L)).reduce((a, b) => (a._1 + b._1, a._2 + b._2))val u = rs / rc //初始化图g var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache()
这里为什么顶点属性是(v1, v2, 0.0, 0.0)=(Array,Array,Double,Double)这个模样,源代码没有注释,阅读了后面的代码细节才能推断出来,觉得这是源码做的不好的地方,虽然整体处理流程设计的很精巧,可读性却一般般,不过跟猜谜语一样,也挺有趣。
我们的目标是迭代
- 用户属性
(pu,_,bu,_) - 物品属性
(qi,yi,bi,_)
下面以user_property表示用户顶点属性,item_property表示物品顶点属性,下标从0开始。随着源码不断更新空缺位置,此时
- user_property=
(_,_,_,_) - item_property=
(_,_,_,_)
bias和norm初始化(bu,bi 和|N(u)|−12 )
计算每个用户和物品的rating_count
val t0 = g.aggregateMessages[(Long, Double)]( ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) }, (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2))
初始化
val gJoinT0 = g.outerJoinVertices(t0) { (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), msg: Option[(Long, Double)]) => (vd._1, vd._2, msg.get._2 / msg.get._1 - u, 1.0 / scala.math.sqrt(msg.get._1))}.cache()//触发spark action操作,便于缓存materialize(gJoinT0)g.unpersist()g = gJoinT0
此时,
- user_property=
(_,_,bu,|N(u)|−12) - item_property=
(_,_,bi,|N(i)|−12)
这里,物品属性中的
迭代
阶段1 计算pu+|N(u)|−12∑j∈N(u)yj
计算user_property中的
1.1 聚合计算∑j∈N(u)yj
val t1 = g.aggregateMessages[Array[Double]]( //注意聚合用到了物品属性第二个位置的值,可以推断出该位置是隐式反馈yi ctx => ctx.sendToSrc(ctx.dstAttr._2), (g1, g2) => { val out = g1.clone() //向量相加操作 g1 + g2,使用了blas daxpy api,out=out+g2 blas.daxpy(out.length, 1.0, g2, 1, out, 1) //和用户u关联的隐式反馈之和 out })
1.2 更新user_property1=|N(u)|−12∑j∈N(u)yj
val gJoinT1 = g.outerJoinVertices(t1) { (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), msg: Option[Array[Double]]) => //物品顶点接收不到,因为聚合信息只发送给了用户顶点 if (msg.isDefined) { val out = vd._1.clone() blas.daxpy(out.length, vd._4, msg.get, 1, out, 1) (vd._1, out, vd._3, vd._4) } else { vd }}.cache()
此时
- user_property=
(pu,|N(u)|−12∑j∈N(u)yj,bu,|N(u)|−12) - item_property
(qi,yi,bi,|N(i)|−12)
阶段2 更新pu ,qi ,yi
2.1 梯度求解
ctx.sendToSrc (
ctx.sendToDst (
def sendMsgTrainF(conf: Conf, u: Double) (ctx: EdgeContext[ (Array[Double], Array[Double], Double, Double), Double, (Array[Double], Array[Double], Double)]) { val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) //这里说明了pu,qi的位置 val (p, q) = (usr._1, itm._1) val rank = p.length //计算误差 var pred = u + usr._3 + itm._3 + blas.ddot(rank, q, 1, usr._2, 1) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) val err = ctx.attr - pred // updateP = (err * q - conf.gamma7 * p) * conf.gamma2 val updateP = q.clone() blas.dscal(rank, err * conf.gamma2, updateP, 1) blas.daxpy(rank, -conf.gamma7 * conf.gamma2, p, 1, updateP, 1) // updateQ = (err * usr._2 - conf.gamma7 * q) * conf.gamma2 val updateQ = usr._2.clone() blas.dscal(rank, err * conf.gamma2, updateQ, 1) blas.daxpy(rank, -conf.gamma7 * conf.gamma2, q, 1, updateQ, 1) // updateY = (err * usr._4 * q - conf.gamma7 * itm._2) * conf.gamma2 val updateY = q.clone() blas.dscal(rank, err * usr._4 * conf.gamma2, updateY, 1) blas.daxpy(rank, -conf.gamma7 * conf.gamma2, itm._2, 1, updateY, 1) ctx.sendToSrc((updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)) ctx.sendToDst((updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1))}
2.2 合并为(∑Δp∗,∑Δyi , ∑Δb∗ )
val t2 = g.aggregateMessages( sendMsgTrainF(conf, u), (g1: (Array[Double], Array[Double], Double), g2: (Array[Double], Array[Double], Double)) => { val out1 = g1._1.clone() blas.daxpy(out1.length, 1.0, g2._1, 1, out1, 1) val out2 = g2._2.clone() blas.daxpy(out2.length, 1.0, g2._2, 1, out2, 1) (out1, out2, g1._3 + g2._3) })
2.3 更新pu ,qi ,yi
val gJoinT2 = g.outerJoinVertices(t2) { (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), msg: Option[(Array[Double], Array[Double], Double)]) => { val out1 = vd._1.clone() blas.daxpy(out1.length, 1.0, msg.get._1, 1, out1, 1) val out2 = vd._2.clone() blas.daxpy(out2.length, 1.0, msg.get._2, 1, out2, 1) (out1, out2, vd._3 + msg.get._3, vd._4) }}.cache()
总结
整个源码阅读下来,还是挺有收获的,理解了svd++的并行原理。spark svd++处理逻辑设计的很精巧,代码简洁高效,全篇代码也就200行左右,当然这也和spark、scala语言简洁有关。spark图模块的聚合操作非常契合svd++的迭代计算。唯一觉得有点不好的地方是代码可读性稍微不足,不过也是因为自己水平不足,读起来有点费劲。越精巧的代码就应该多点注释增强可读性,方便维护和迭代。
如果文中有哪里理解不对的地方,希望大家帮忙指正。
参考
spark svd++源码
spark svd++源码分析
spark graphx
- Spark SVDPlusPlus 源码分析
- GraphX SVDPlusPlus Java源码
- Spark源码分析
- Spark Catalyst 源码分析
- Spark源码分析文章
- Spark Streaming源码分析
- Spark RDD 源码分析
- spark-streaming源码分析
- spark源码分析-storage
- Spark Broadcast源码分析
- spark-shuffle-源码分析
- Spark Broadcast源码分析
- spark 源码分析
- Spark Catalyst 源码分析
- Spark-ThriftServer源码分析
- Spark源码分析-schedule()
- Spark源码分析-worker
- 源码- Spark Broadcast源码分析
- 入门:学习《Head First HTML与CSS》
- 索引
- 从输入URL到页面加载显示完成的过程
- 8.活动的启动模式
- 1042. Shuffling Machine (20)
- Spark SVDPlusPlus 源码分析
- 解析XML文件——SAX基本操作
- js对象与json字符串的互转
- JavaScript-简单语法1
- 设计模式六大原则——里氏替换原则
- Android Studio下,gradle project sync failed 错误
- 重置MySQL Root密码
- 图像处理、显示中的行宽(linesize)、步长(stride)、间距(pitch)
- maven安装丶配置本地仓库