使用 ODPS-GRAPH 进行变分 EM 推断一例(转载)
来源:互联网 发布:最喜欢的一句话知乎 编辑:程序博客网 时间:2024/05/18 01:50
内容简介
ODPS Graph 是基于飞天平台实现的面向迭代的图处理框架,为用户提供了类似于 Pregel 的编程接口。用户需要将问题抽象成图的表述,然后通过一些超步进行以顶点为中心的迭代更新。
对于需要迭代学习模型参数的机器学习算法来说,图计算相比 MAP/REDUCE 具有天然的优势。
这篇文章将以用户的汽车品牌分布的推断为例,说明如何利用 ODPS-GRAPH 来做复杂的变分EM 推断。
问题陈述
在汽车类目下,通过用户购买的商品的属性信息,推断用户的汽车品牌的分布
假定有 K 个汽车品牌 b1, b2, ..., b_K,通过 “适合汽车” 这个属性id,可以简单统计出每个用户 u 在这 K 个汽车品牌的购买次数.
我们能观察到的就是这个购买次数,那么如何推断出用户的汽车品牌分布。
建模
采用 Bayesian Multinomial Mixture 模型来建模观察到的计数数据,模型图如下:
相应的生成过程为:
超参数 alpha 是一个 K 维向量。
隐变量 Z = [ z1, z2, ..., zN]
混合比例 pi
不可见数据 U = (Z, pi)
可见数据 V = [ v1, v2, ..., vN]
模型参数 theta = (alpha,phi)
全数据为 D = (V, Z, pi)
对 zn 采用 1-of-K 编码,即 zn 是一个 0/1 的 K 维向量,如果 vn 来自成分 k,则
znk = 1
znj = 0 for other j
全数据似然函数:
全数据对数似然:
下面开始 EM 算法的推导过程
观察数据的似然及其变分下界(VLB)分别为
标准的 EM 算法分为如下的 E 步 和 M 步
在我们的模型中,不可见数据 (Z, pi) 的联合后验分布 intractable,这种情况下一般可以进行近似推断或者基于采样的MCMC方法。 在本文中,将使用一种近似推断方法 —— 变分推断 (Variational Inference)。
变分分布假定有如下的独立性
根据 mean-field theory,分别求解出 pi 的变分分布:
以及 Z 的变分分布为:
通过公式 (10) 和 (15) 不难看出,pi 和 Z 的后验分布之间是有联系的,互相之间通过统计量的期望值进行联系,因而实际求解过程中,需要进行迭代多轮直到两者的分布保持稳定。收敛后的分布即使两者的变分分布的最优解。
当推断出这些分布之后,将通过变分下界最大化来学习出模型参数 theta = (alpha, phi),在这个模型中,参数有封闭的求解公式,如下:
综上分析,我们给出模型的推断学习的算法
Input: 观察数据 V
Output:
a). 模型参数 alpha, phi
b) Z 的分布参数 gamma
Procedue:
1. 初始化参数 alpha, phi, gamma, sigma
2. 由 (15) 式计算 gamma
3. 由 (10), (11), (17) 迭代计算出 alpha, sigma 直到收敛
4. 由 (18) 式计算出 phi
5. 判定收敛条件是否达到,如果达到,则算法结束;否则进入 step 2.
由算法返回的 gamma 是一个 N * K 的矩阵, 矩阵的每一行对应用户 n 的后验汽车品牌分布。
ODPS-GRAPH实现
在本模型学习和推断中,涉及到迭代,而这个过程能够在 ODPS-GRAPH 上非常优雅的支持。
首先,是建立 Graph。
每个观察数据 n 是一个 Vertex,该节点维护三个信息:用户的 nick 作为 VertexId,一个 K 维向量 vn 以及一个需要推断的 K 维概率向量 gamman;
顶点之间没有直接的信息计算依赖,因此图中不需要边的存在。
然后,设计 AggregatorValue 及 Aggregator
事实上 Aggregator 能收集顶点信息,并进行一些用户定义的聚合操作。在我们的实现中,需要用 Aggregator 算出 (10) 和 (18) 式中的求和值。因而在 AggregatorValue 中维护两个数据 s4gramma 和 s4phi 分别用于计算这两个和。
另外在 AggregatorValue 中需要维护模型参数 alpha, phi 及参数 sigma。
除此之外,就是要实现 Vertex 的 compute() 方法,需要做的工作就是按照 (15) 式更新该顶点的 gamma 值。
最后是要做算法结束的判定,实现 Aggregator 的 terminate 方法,比较新旧参数的差异的 L2 范数,如果小于预先指定的容许误差 epsilon 或者超步数达到最大超步,则算法终止。
将实现的class打包即可运行,运行过程是这样的
add jar /home/weidong.yin/odps/lib/zvbmm.jar -f;add jar /home/weidong.yin/packages/commons-math3-3.3/commons-math3-3.3.jar -f;set odps.graph.worker.num=2;jar -libjars zvbmm.jar,commons-math3-3.3.jar -classpath /home/weidong.yin/odps/lib/zvbmm.jar:/home/weidong.yin/packages/commons-math3-3.3/commons-math3-3.3.jar com.taobao.graph.test.VBMMAS zecheng_vbmm_in zecheng_vbmm_out $K;
这里的 commons-math3-3.3.jar 包中有 Gamma 函数 和 Digamma 函数可供调用。
一个玩具例子的运行情况如下:
输入数据: 13 个用户在 3 个汽车品牌上的购买计数
+------------+------------+| key | info |+------------+------------+| u1 | 1,1,0 || u2 | 1,0,1 || u3 | 1,1,0 || u4 | 1,0,1 || u5 | 0,0,1 || u6 | 0,6,0 || u7 | 0,1,0 || u8 | 0,0,1 || u9 | 1,0,0 || u10 | 2,0,0 || u11 | 1,0,0 || u12 | 1,0,0 || u13 | 1,2,1 |+------------+------------+
收敛过程记录如下:
superstep:1 -- superdelta:182605.64601456857superstep:3 -- superdelta:325.1760773137442superstep:5 -- superdelta:29.823993045142903superstep:7 -- superdelta:8.858108454843975superstep:9 -- superdelta:3.3334592359224535superstep:11 -- superdelta:1.2654907417770958superstep:13 -- superdelta:0.41016235969433495superstep:15 -- superdelta:0.1008304226893457superstep:17 -- superdelta:0.03374567141079521superstep:19 -- superdelta:0.007017635448616098superstep:21 -- superdelta:0.0013885342582021832
21步内达到容许误差,收敛是比较快的。
推断结果 gamma:
+------------+------------+| key | info |+------------+------------+| u13 | 0.060984323394069596,0.9240107583094663,0.01500491829646417 || u11 | 0.9226879885707154,0.034669555298412306,0.04264245613087234 || u2 | 0.5773565581849824,0.015464121746682774,0.4071793200683348 || u4 | 0.5773565581849824,0.015464121746682774,0.4071793200683348 || u8 | 0.12802749598577526,0.05813812103740773,0.813834382976817 || u6 | 8.671555846303735E-9,0.9999999911750542,1.533899796361806E-10 || u9 | 0.9226879885707154,0.034669555298412306,0.04264245613087234 || u10 | 0.9927099076999223,0.0022000835657728438,0.0050900087343050075 || u1 | 0.5205681717937852,0.4652215734656923,0.014210254740522454 || u12 | 0.9226879885707154,0.034669555298412306,0.04264245613087234 || u3 | 0.5205681717937852,0.4652215734656923,0.014210254740522454 || u5 | 0.12802749598577526,0.05813812103740773,0.813834382976817 || u7 | 0.060984323394069596,0.9240107583094662,0.01500491829646417 |+------------+------------+
推断结果 sigma
(-0.7209447347994979,-1.1718416342165527,-1.5968869532868073)
模型参数 alpha
(345.0677795410156,220.01290893554688,144.00588989257812)
模型参数 phi
phi[0]:(0.7224296368389298,0.12706511156093872,0.15050525160013148)phi[1]:(0.15566385389422016,0.760368639366158,0.08396750673962187)phi[2]:(0.28444353588382226,0.021037927314571835,0.694518536801606)
对于大的数据的训练和推断,仿此过程即可。
注记
- 似然函数的下界 VLB 是非凸的函数,因此在做 EM 推断或者变分 EM 推断时存在局部极小问题,选择好的参数初始值对于得到合理的结果非常重要,需要细致的选择初始值;
- Muitinomial Mixture 模型可以对特征为计数的观察数据进行建模;
- ODPS-GRAPH 上可以实现任何 EM 相关算法的推断,其过程可仿效这篇文章的实现。
- 使用 ODPS-GRAPH 进行变分 EM 推断一例(转载)
- 变分推断
- 变分推断
- 机器学习:LDA_数学基础_4:变分推断:EM基础
- LDA的变分推断
- 变分推断学习笔记
- ODPS Graph
- LDA模型解析(变分推断)
- 变分推断(variational inference)
- 机器学习(2) 变分推断
- 变分推断(Variational Inference)-mean field
- 变分推断(variational inference)
- 机器学习(2) 变分推断
- 机器学习:LDA_数学基础_5:变分推断:变分推断部分
- 变分推断学习笔记(1)——概念介绍
- 转--Approximate Inference(近似推断,变分推断,KL散度,平均场, Mean Field )
- Gaussian LDA(1): LDA回顾以及变分EM
- PRML读书会第十章 Approximate Inference(近似推断,变分推断,KL散度,平均场, Mean Field )
- ubuntu 14.04 一键安装 gitlab7
- cocos2dx 3.2 + vs2012 + cocostudio开发笔记
- 2014 Multi-University Training Contest 5 1010 Matrix multiplication 涨姿势系列
- java 线程的创建与启动
- ufldl学习笔记与编程作业:Debugging: Gradient Checking(梯度检测)
- 使用 ODPS-GRAPH 进行变分 EM 推断一例(转载)
- 3.1.4、ObjectARX程序的初始化
- 10g RAC 使用service实现taf
- js判断undefined类型
- string
- HDU2674 N!Again 【数学】
- poj 1308 并查集(判断一组点对是否能够组成树)
- HDU 1711 Number Sequence KMP题解
- 在eclipse下直接部署maven工程缺少jar包问题