Spark SVDPlusPlus 源码分析

来源:互联网 发布:银行核心业务系统数据 编辑:程序博客网 时间:2024/06/05 23:39

  上一篇对SVD++算法的原理进行了总结,本文是对Spark SVDPlusPlus源码的分析总结。源码位置在Spark源码包的org.apache.spark.graphx.lib.SVDPlusPlus,需要引入spark-graphx相关包。

迭代公式推导

  相比于SVD和implicit ALS算法,SVD++的python单机版算法效率明显低很多,因为多了很重的一个子项|N(u)|12jN(u)yj,需要汇总用户接触物品集所表达的隐式反馈yj。模型的训练目标如下:

minq,q,b,y(u,i)K(ruiμbibuqTi(pu+|N(u)|12jN(u)yj))2+λ1(||qi||2+||pu||2+jN(u)||yj||2)+λ2(b2u+bi2)
下面是Spark SVDPlusPlus的配置类

class Conf(      var rank: Int,//因子向量维度      var maxIters: Int,//最大迭代次数      var minVal: Double,//评分下限      var maxVal: Double,//评分上限      var gamma1: Double,//b*梯度下降学习速率      var gamma2: Double,//q,p,y梯度下降学习速率      var gamma6: Double,//b*正则系数      var gamma7: Double)//q,p,y正则系数    extends Serializable

基于梯度下降(为了并行,注意这里是批量梯度下降,不是随机梯度下降)方法,对目标学习公式求偏导,可得到如下迭代公式(为了方便源码表达,这里系数名称和源码保持一致):

bubu+iN(u)γ1(euiγ6bu)bibi+uN(i)γ1(euiγ6bi)qiqi+uN(i)γ2(eui(pu+|N(u)|12jN(u)yj)γ7qi)pupu+iN(u)γ2(euiqiγ7pu)yiyi+uN(i)γ2(eui|N(u)|12qiγ7yi)

Spark Graphx

  因为Spark SVD++是基于Spark Graphx实现的,所以先对Graphx做简明总结。Graphx是用于图并行计算的spark组件,它基于RDD引入了图抽象:每个顶点和边都绑定属性的多重图(multigraph,两个顶点间有多条边)。Graphx提供了图操作和图算法集合。图的基本单元是点(Vertix)和边(Edge)组成的Triplets,下图中的蓝橙颜色块表示属性。

这里写图片描述
  对于用户的评分数据,A就表示用户,B表示物品,边上的橙色块表示评分,蓝色快储存了我们需要迭代了p、q、y、b*,会在之后详述。

Spark SVD++用到的图操作如下:

操作 解释 aggregateMessages Graphx的核心聚合操作,主要有两步操作
1.sendMsg:将edge中的属性发送到顶点
2.mergeMsg:在顶点处进行merge操作
这里我们可以联想到将物品表达的隐式反馈汇总到用户顶点 outerJoinVertices 是aggregateMessages的小伙伴
聚合操作会返回新的Vertices集合(这也是scala函数式编程的特点)
基于VertixId与原来的Triplets集合进行join,更新对应顶点里的属性

核心源码分析

下面分析关键代码片段,主要关注变量迭代,入口代码如下,Edge保存了Double类型的评分数据,srcVertex是用户,dstVertex是物品

def run(edges: RDD[Edge[Double]], conf: Conf)

顶点属性初始化和全局均值u计算

// 生成默认的顶点属性 包含四个属性(v1,v2,0,0),v1、v2表示向量def defaultF(rank: Int): (Array[Double], Array[Double], Double, Double) = {      // TODO: use a fixed random seed      val v1 = Array.fill(rank)(Random.nextDouble())      val v2 = Array.fill(rank)(Random.nextDouble())      (v1, v2, 0.0, 0.0)    }//计算全局均值u    val (rs, rc) = edges.map(e => (e.attr, 1L)).reduce((a, b) => (a._1 + b._1, a._2 + b._2))val u = rs / rc    //初始化图g    var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache()    

  这里为什么顶点属性是(v1, v2, 0.0, 0.0)=(Array,Array,Double,Double)这个模样,源代码没有注释,阅读了后面的代码细节才能推断出来,觉得这是源码做的不好的地方,虽然整体处理流程设计的很精巧,可读性却一般般,不过跟猜谜语一样,也挺有趣。
我们的目标是迭代puqiyibibu,其中pubi属于用户属性,qiyibi属于物品属性。看完整个代码,位置分布谜底如下,空缺的位置是存放一些中间值。

  • 用户属性(pu_bu_)
  • 物品属性(qi,yibi_)

下面以user_property表示用户顶点属性,item_property表示物品顶点属性,下标从0开始。随着源码不断更新空缺位置,此时

  • user_property=(____)
  • item_property=(____)

bias和norm初始化(bu,bi|N(u)|12)

计算每个用户和物品的rating_count rc、rating_sum rs (*号表示不区分用户和物品)

val t0 = g.aggregateMessages[(Long, Double)](  ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) },  (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2))

初始化b=rsrcunorm=|N(u)|12

val gJoinT0 = g.outerJoinVertices(t0) {  (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double),   msg: Option[(Long, Double)]) =>    (vd._1, vd._2, msg.get._2 / msg.get._1 - u, 1.0 / scala.math.sqrt(msg.get._1))}.cache()//触发spark action操作,便于缓存materialize(gJoinT0)g.unpersist()g = gJoinT0

此时,

  • user_property=(__bu|N(u)|12)
  • item_property=(__bi|N(i)|12)

这里,物品属性中的|N(i)|12并没有什么作用,作者应该是为了代码简洁,对用户属性和物品属性使用了同样的处理作用,出现了这个副产物。至此,准备工作全部ready。

迭代

阶段1 计算pu+|N(u)|12jN(u)yj

计算user_property中的pu+|N(u)|12jN(u)yj,这是为下一步计算预测评分pred,进而计算误差和因子更新迭代做准备。根据公式,需要把用户u看过的物品所表达的隐式反馈聚合到用户端(这一点太符合spark图计算了)。

1.1 聚合计算jN(u)yj
val t1 = g.aggregateMessages[Array[Double]](  //注意聚合用到了物品属性第二个位置的值,可以推断出该位置是隐式反馈yi  ctx => ctx.sendToSrc(ctx.dstAttr._2),  (g1, g2) => {    val out = g1.clone()    //向量相加操作 g1 + g2,使用了blas daxpy api,out=out+g2    blas.daxpy(out.length, 1.0, g2, 1, out, 1)    //和用户u关联的隐式反馈之和    out  })
1.2 更新user_property1=|N(u)|12jN(u)yj
val gJoinT1 = g.outerJoinVertices(t1) {  (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double),   msg: Option[Array[Double]]) =>    //物品顶点接收不到,因为聚合信息只发送给了用户顶点    if (msg.isDefined) {      val out = vd._1.clone()      blas.daxpy(out.length, vd._4, msg.get, 1, out, 1)      (vd._1, out, vd._3, vd._4)    } else {      vd    }}.cache()

此时

  • user_property=(pu|N(u)|12jN(u)yjbu|N(u)|12)
  • item_property(qiyibi|N(i)|12)

阶段2 更新puqiyi

2.1 梯度求解

ctx.sendToSrc (Δpu,ΔyiΔbu)
ctx.sendToDst (Δqi,ΔyiΔbi)

def sendMsgTrainF(conf: Conf, u: Double)    (ctx: EdgeContext[      (Array[Double], Array[Double], Double, Double),      Double,      (Array[Double], Array[Double], Double)]) {  val (usr, itm) = (ctx.srcAttr, ctx.dstAttr)  //这里说明了pu,qi的位置  val (p, q) = (usr._1, itm._1)  val rank = p.length  //计算误差  var pred = u + usr._3 + itm._3 + blas.ddot(rank, q, 1, usr._2, 1)  pred = math.max(pred, conf.minVal)  pred = math.min(pred, conf.maxVal)  val err = ctx.attr - pred  // updateP = (err * q - conf.gamma7 * p) * conf.gamma2  val updateP = q.clone()  blas.dscal(rank, err * conf.gamma2, updateP, 1)  blas.daxpy(rank, -conf.gamma7 * conf.gamma2, p, 1, updateP, 1)  // updateQ = (err * usr._2 - conf.gamma7 * q) * conf.gamma2  val updateQ = usr._2.clone()  blas.dscal(rank, err * conf.gamma2, updateQ, 1)  blas.daxpy(rank, -conf.gamma7 * conf.gamma2, q, 1, updateQ, 1)  // updateY = (err * usr._4 * q - conf.gamma7 * itm._2) * conf.gamma2  val updateY = q.clone()  blas.dscal(rank, err * usr._4 * conf.gamma2, updateY, 1)  blas.daxpy(rank, -conf.gamma7 * conf.gamma2, itm._2, 1, updateY, 1)  ctx.sendToSrc((updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1))  ctx.sendToDst((updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1))}
2.2 合并为(Δp,ΔyiΔb)
val t2 = g.aggregateMessages(  sendMsgTrainF(conf, u),  (g1: (Array[Double], Array[Double], Double), g2: (Array[Double], Array[Double], Double)) =>  {    val out1 = g1._1.clone()    blas.daxpy(out1.length, 1.0, g2._1, 1, out1, 1)    val out2 = g2._2.clone()    blas.daxpy(out2.length, 1.0, g2._2, 1, out2, 1)    (out1, out2, g1._3 + g2._3)  })
2.3 更新puqiyi
val gJoinT2 = g.outerJoinVertices(t2) {  (vid: VertexId,   vd: (Array[Double], Array[Double], Double, Double),   msg: Option[(Array[Double], Array[Double], Double)]) => {    val out1 = vd._1.clone()    blas.daxpy(out1.length, 1.0, msg.get._1, 1, out1, 1)    val out2 = vd._2.clone()    blas.daxpy(out2.length, 1.0, msg.get._2, 1, out2, 1)    (out1, out2, vd._3 + msg.get._3, vd._4)  }}.cache()

总结

  整个源码阅读下来,还是挺有收获的,理解了svd++的并行原理。spark svd++处理逻辑设计的很精巧,代码简洁高效,全篇代码也就200行左右,当然这也和spark、scala语言简洁有关。spark图模块的聚合操作非常契合svd++的迭代计算。唯一觉得有点不好的地方是代码可读性稍微不足,不过也是因为自己水平不足,读起来有点费劲。越精巧的代码就应该多点注释增强可读性,方便维护和迭代。
  如果文中有哪里理解不对的地方,希望大家帮忙指正。

参考

spark svd++源码
spark svd++源码分析
spark graphx

原创粉丝点击