Graphx源码解析之SVD++算法
来源:互联网 发布:常用的数据库有哪些 编辑:程序博客网 时间:2024/05/29 03:22
Spark Graphx中SVD++算法主要是参考论文:
http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf,
核心计算公式为:rui = u + bu + bi + qi*(pu + |N(u)|^^-0.5^^*sum(y))
输入
输入:user,item,score
1,1,5.01,2,1.01,3,5.01,4,1.02,1,5.02,2,1.02,3,5.02,4,1.03,1,1.03,2,5.03,3,1.03,4,5.04,1,1.04,2,5.04,3,1.04,4,5.0
根据需要,图主要分为有向图与无向图两种。由于起点与终点代表着不同的含义,选择有向图作为研究的主体。为了便于区分起点与终点,在数据录入之前起点ID乘2,终点ID乘2加1
Edge(uid.toString.toLong * 2, live_uid.toString.toLong * 2 + 1,score.toString.toDouble)
参数
class Conf( var rank: Int,//向量维数 var maxIters: Int,//迭代次数 var minVal: Double,//最小值 var maxVal: Double,//最大值 var gamma1: Double,//衰减系数 var gamma2: Double,//衰减系数 var gamma6: Double,//衰减系数 var gamma7: Double)//衰减系数 extends Serializable
算法输入 输出
run(edges: RDD[Edge[Double]], conf: Conf) : (Graph[(Array[Double], Array[Double], Double, Double), Double], Double)
计算平均评分
//rs 评分之和//rc 记录总数val (rs, rc) = edges.map(e => (e.attr, 1L)).reduce((a, b) => (a._1 + b._1, a._2 + b._2))//平均评分val u = rs / rc
组成图
Graph.fromEdges(edges, defaultF(conf.rank)).cache()
defaultF根据rank值随机生成feature向量,看下defaultF方法
def defaultF(rank: Int): (Array[Double], Array[Double], Double, Double) = { // TODO: use a fixed random seed val v1 = Array.fill(rank)(Random.nextDouble()) val v2 = Array.fill(rank)(Random.nextDouble()) (v1, v2, 0.0, 0.0) }
计算SUM和根号值
//顶点 顶点出现的次数 评分总和 val t0: VertexRDD[(VertexId, Double)] = g.aggregateMessages[(VertexId, Double)](ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) }, (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2)) // t0.foreach(println(_)) // 总评分除以总次数减去平均评分 1 / 总次数的开根号 val gJoinT0 = g.outerJoinVertices(t0) { (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), msg: Option[(Long, Double)]) => // println(msg.get._2 + " " + msg.get._1) (vd._1, vd._2, msg.get._2 / msg.get._1 - u, 1.0 / scala.math.sqrt(msg.get._1)) }.cache()
此时的输出结果为g:
id p q bu/bi |N(u)|^^-0.5(4,([D@7ed9bdff,[D@15188d22,0.0,0.5))(6,([D@7ed9bdff,[D@15188d22,0.0,0.5))(3,([D@7ed9bdff,[D@15188d22,0.0,0.5))(7,([D@7ed9bdff,[D@15188d22,0.0,0.5))(9,([D@7ed9bdff,[D@15188d22,0.0,0.5))(8,([D@7ed9bdff,[D@15188d22,0.0,0.5))(5,([D@7ed9bdff,[D@15188d22,0.0,0.5))(2,([D@7ed9bdff,[D@15188d22,0.0,0.5))
将g作为输入,进行迭代
第一步
所有起点的第二个数组根据起点求MR并合并到g中
//起始点 相加 次数 起始点个数 第二个数组 val t1 = g.aggregateMessages[Array[Double]]( ctx => { // println(ctx); ctx.sendToSrc(ctx.dstAttr._2) }, (g1, g2) => { // println(g1.toList) // println(g2.toList) val out = g1.clone() blas.daxpy(out.length, 1.0, g2, 1, out, 1) out }) // t1.foreach(x => println(x._1 + " " + x._2.toList)) val gJoinT1 = g.outerJoinVertices(t1) { (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), msg: Option[Array[Double]]) => if (msg.isDefined) { val out = vd._1.clone() blas.daxpy(out.length, vd._4, msg.get, 1, out, 1) (vd._1, out, vd._3, vd._4) } else { vd } }.cache()
注意:blas.daxpy 是矩阵相加,由第三方jar提供
第二步
// Phase 2, update p for user nodes and q, y for item nodes g.cache() val t2 = g.aggregateMessages( sendMsgTrainF(conf, u), (g1: (Array[Double], Array[Double], Double), g2: (Array[Double], Array[Double], Double)) => { val out1 = g1._1.clone() blas.daxpy(out1.length, 1.0, g2._1, 1, out1, 1) val out2 = g2._2.clone() blas.daxpy(out2.length, 1.0, g2._2, 1, out2, 1) (out1, out2, g1._3 + g2._3) }) // t2.foreach(x => println(x)) val gJoinT2 = g.outerJoinVertices(t2) { (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), msg: Option[(Array[Double], Array[Double], Double)]) => { val out1 = vd._1.clone() blas.daxpy(out1.length, 1.0, msg.get._1, 1, out1, 1) val out2 = vd._2.clone() blas.daxpy(out2.length, 1.0, msg.get._2, 1, out2, 1) (out1, out2, vd._3 + msg.get._3, vd._4) } }.cache()
重点介绍sendMsgTrainF
def sendMsgTrainF(conf: Conf, u: Double) (ctx: EdgeContext[ (Array[Double], Array[Double], Double, Double), Double, (Array[Double], Array[Double], Double)]) { val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) println(usr._3 + " " + usr._4) val (p, q) = (usr._1, itm._1) val rank = p.length var pred = u + usr._3 + itm._3 + blas.ddot(rank, q, 1, usr._2, 1) // println("srcId: " + ctx.srcId + " dstId: " + ctx.dstId + " attr: " + ctx.attr + " pred: " + pred + " error: " + (ctx.attr - pred)) // println("sendMsgTrainF pred " + pred) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) val err = ctx.attr - pred // println("sendMsgTrainF err " + err) // updateP = (err * q - conf.gamma7 * p) * conf.gamma2 val updateP = q.clone() blas.dscal(rank, err * conf.gamma2, updateP, 1) blas.daxpy(rank, -conf.gamma7 * conf.gamma2, p, 1, updateP, 1) // updateQ = (err * usr._2 - conf.gamma7 * q) * conf.gamma2 val updateQ = usr._2.clone() // println("begin srcId: " + ctx.srcId + " dstId: " + ctx.dstId + " " + updateQ.toList) blas.dscal(rank, err * conf.gamma2, updateQ, 1) // println("dscal: " + updateQ.toList + " " + err * conf.gamma2) blas.daxpy(rank, -conf.gamma7 * conf.gamma2, q, 1, updateQ, 1) // println("daxpy: " + updateQ.toList + " " + (-conf.gamma7 * conf.gamma2)) // updateY = (err * usr._4 * q - conf.gamma7 * itm._2) * conf.gamma2 val updateY = q.clone() blas.dscal(rank, err * usr._4 * conf.gamma2, updateY, 1) blas.daxpy(rank, -conf.gamma7 * conf.gamma2, itm._2, 1, updateY, 1) ctx.sendToSrc((updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)) ctx.sendToDst((updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1)) }
pred为迭代一次的评分,err为误差。
updateP = (err * q - conf.gamma7 * p) * conf.gamma2
updateQ = (err * usr._2 - conf.gamma7 * q) * conf.gamma2
updateY = (err * usr._4 * q - conf.gamma7 * itm._2) * conf.gamma2
起点修改为(updateP, updateY, score)
终点修改为(updateQ, updateY, score)
然后分布将解决更新到g中对应顶点的前三个位置。
可以很明显的发现这里才有的是每个顶点下降最后sum随机梯度下降的方式迭代。
M1.per = u + user.3 + item.3 + item1*user22.per 最大最小闭区间 [min, max] 范围约束3.误差 err = 真实评分 - per4.user(err * gamma2 * item1 - gamma7 * gamma2 * user1, err * user4 * gamma2 * item1 - gamma7 * gamma2 * item2, (err - gamma6 * user3) * gamma1) item (err * gamma2 * user2 - gamma7 * gamma2 * user1, err * user4 * gamma2 * item1 - gamma7 * gamma2 * item2, (err - gamma6 * item3) * gamma1)
循环上述迭代过程
评测
val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _) val gJoinT3 = g.outerJoinVertices(t3) { (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), msg: Option[Double]) => if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd }.cache()
第三步
获取err
val t3 = gJoinT2.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _) val gJoinT3 = gJoinT2.outerJoinVertices(t3) { (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), msg: Option[Double]) => if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd }.cache() val err = gJoinT3.vertices.map { case (vid, vd) => if (vid % 2 == 1) vd._4 else 0.0 }.reduce(_ + _) / gJoinT3.numEdges RedisUtil.setIntoRedis(i + "_ERR", err.toString)
如果发现每次迭代过程中err的值出现波动,则可以将gamma1,gamma2调小,再次进行迭代试验。err走向图如下:
user中的user1为隐性feature,item中的item1为隐性feature。
结果输出
val labels = g.triplets.map { ctx => val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) var pred = u + usr._3 + itm._3 + blas.ddot(q.length, q, 1, usr._2, 1) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) val err = (ctx.attr - pred) } (ctx.srcId / 2) + "|" + (ctx.dstId - 1) / 2 + "|" + pred }.saveAsTextFile("/spark/grxpah/svd")
后面可以进行类AUC之类的效果评测
- Graphx源码解析之SVD++算法
- Graphx 最短路径源码解析
- GraphX源码解析(Graph构建过程)
- SVD算法实战应用解析
- 相似度算法之SVD
- GraphX SVDPlusPlus Java源码
- 解析SVD
- Picasso源码解析之Lrucache算法源码解析
- hanlp源码解析之中文分词算法
- SVD++算法
- SVD算法
- 【Petuum 源码解析】之K-Means分布式算法源码
- spark-graphx之pagerank
- graphx之图迭代
- graphx之pregel模型
- GraphX之PartitionStrategy修改
- Graphx社区发现算法学习
- KNN算法源码解析
- vsftp-部署和优化2
- libsvm3.22安装和调试
- 状压DP入门-POJ3254corn fields
- Epoll在Java Nio中的实现
- 小程序_开发环境搭建
- Graphx源码解析之SVD++算法
- k-d树与特征匹配
- 在下刚来到csdn 看到大神很多的啊
- 优雅整合SSM框架
- JS得到div的value值(网上资料保存)
- jsp:使用include乱码
- hibernate预编译SQL语句中的setParameter和setParameterList
- 堆排序/heapSort
- java中的方法覆盖(overriding)和方法重载(overloading)