使用 ODPS-GRAPH 进行变分 EM 推断一例(转载)

来源:互联网 发布:最喜欢的一句话知乎 编辑:程序博客网 时间:2024/05/18 01:50

内容简介

ODPS Graph 是基于飞天平台实现的面向迭代的图处理框架,为用户提供了类似于 Pregel 的编程接口。用户需要将问题抽象成图的表述,然后通过一些超步进行以顶点为中心的迭代更新。
对于需要迭代学习模型参数的机器学习算法来说,图计算相比 MAP/REDUCE 具有天然的优势。
这篇文章将以用户的汽车品牌分布的推断为例,说明如何利用 ODPS-GRAPH 来做复杂的变分EM 推断。

问题陈述

在汽车类目下,通过用户购买的商品的属性信息,推断用户的汽车品牌的分布
假定有 K 个汽车品牌 b1, b2, ..., b_K,通过 “适合汽车” 这个属性id,可以简单统计出每个用户 u 在这 K 个汽车品牌的购买次数. 
我们能观察到的就是这个购买次数,那么如何推断出用户的汽车品牌分布。

建模

采用 Bayesian Multinomial Mixture 模型来建模观察到的计数数据,模型图如下:

pgm

相应的生成过程为:
equ

超参数 alpha 是一个 K 维向量。

隐变量 Z = [ z1, z2, ..., zN] 
混合比例 pi
不可见数据 U = (Z, pi)
可见数据 V = [ v1, v2, ..., vN]
模型参数 theta = (alpha,phi)

全数据为 D = (V, Z, pi)

对 zn 采用 1-of-K 编码,即 zn 是一个 0/1 的 K 维向量,如果 vn 来自成分 k,则 
znk = 1 
znj = 0 for other j

全数据似然函数:
eq1

全数据对数似然:
eq2

下面开始 EM 算法的推导过程
观察数据的似然及其变分下界(VLB)分别为
eq3_4

标准的 EM 算法分为如下的 E 步 和 M 步
eq5_6

在我们的模型中,不可见数据 (Z, pi) 的联合后验分布 intractable,这种情况下一般可以进行近似推断或者基于采样的MCMC方法。 在本文中,将使用一种近似推断方法 —— 变分推断 (Variational Inference)。

变分分布假定有如下的独立性
eq7

根据 mean-field theory,分别求解出 pi 的变分分布:
eq8_11

以及 Z 的变分分布为:
eq12_15

通过公式 (10) 和 (15) 不难看出,pi 和 Z 的后验分布之间是有联系的,互相之间通过统计量的期望值进行联系,因而实际求解过程中,需要进行迭代多轮直到两者的分布保持稳定。收敛后的分布即使两者的变分分布的最优解。

当推断出这些分布之后,将通过变分下界最大化来学习出模型参数 theta = (alpha, phi),在这个模型中,参数有封闭的求解公式,如下:
eq16_18

综上分析,我们给出模型的推断学习的算法
Input: 观察数据 V
Output: 
a). 模型参数 alpha, phi
b) Z 的分布参数 gamma
Procedue:
1. 初始化参数 alpha, phi, gamma, sigma
2. 由 (15) 式计算 gamma
3. 由 (10), (11), (17) 迭代计算出 alpha, sigma 直到收敛
4. 由 (18) 式计算出 phi
5. 判定收敛条件是否达到,如果达到,则算法结束;否则进入 step 2.

由算法返回的 gamma 是一个 N * K 的矩阵, 矩阵的每一行对应用户 n 的后验汽车品牌分布。

ODPS-GRAPH实现

在本模型学习和推断中,涉及到迭代,而这个过程能够在 ODPS-GRAPH 上非常优雅的支持。
首先,是建立 Graph。
每个观察数据 n 是一个 Vertex,该节点维护三个信息:用户的 nick 作为 VertexId,一个 K 维向量 vn 以及一个需要推断的 K 维概率向量 gamman;
顶点之间没有直接的信息计算依赖,因此图中不需要边的存在。
graph

然后,设计 AggregatorValue 及 Aggregator
事实上 Aggregator 能收集顶点信息,并进行一些用户定义的聚合操作。在我们的实现中,需要用 Aggregator 算出 (10) 和 (18) 式中的求和值。因而在 AggregatorValue 中维护两个数据 s4gramma 和 s4phi 分别用于计算这两个和。
另外在 AggregatorValue 中需要维护模型参数 alpha, phi 及参数 sigma。

除此之外,就是要实现 Vertex 的 compute() 方法,需要做的工作就是按照 (15) 式更新该顶点的 gamma 值。

最后是要做算法结束的判定,实现 Aggregator 的 terminate 方法,比较新旧参数的差异的 L2 范数,如果小于预先指定的容许误差 epsilon 或者超步数达到最大超步,则算法终止。

将实现的class打包即可运行,运行过程是这样的

add jar /home/weidong.yin/odps/lib/zvbmm.jar -f;add jar /home/weidong.yin/packages/commons-math3-3.3/commons-math3-3.3.jar -f;set odps.graph.worker.num=2;jar -libjars zvbmm.jar,commons-math3-3.3.jar -classpath /home/weidong.yin/odps/lib/zvbmm.jar:/home/weidong.yin/packages/commons-math3-3.3/commons-math3-3.3.jar com.taobao.graph.test.VBMMAS zecheng_vbmm_in zecheng_vbmm_out $K;

这里的 commons-math3-3.3.jar 包中有 Gamma 函数 和 Digamma 函数可供调用。

一个玩具例子的运行情况如下:
输入数据: 13 个用户在 3 个汽车品牌上的购买计数

+------------+------------+| key        | info       |+------------+------------+| u1         | 1,1,0      || u2         | 1,0,1      || u3         | 1,1,0      || u4         | 1,0,1      || u5         | 0,0,1      || u6         | 0,6,0      || u7         | 0,1,0      || u8         | 0,0,1      || u9         | 1,0,0      || u10        | 2,0,0      || u11        | 1,0,0      || u12        | 1,0,0      || u13        | 1,2,1      |+------------+------------+

收敛过程记录如下:

superstep:1 -- superdelta:182605.64601456857superstep:3 -- superdelta:325.1760773137442superstep:5 -- superdelta:29.823993045142903superstep:7 -- superdelta:8.858108454843975superstep:9 -- superdelta:3.3334592359224535superstep:11 -- superdelta:1.2654907417770958superstep:13 -- superdelta:0.41016235969433495superstep:15 -- superdelta:0.1008304226893457superstep:17 -- superdelta:0.03374567141079521superstep:19 -- superdelta:0.007017635448616098superstep:21 -- superdelta:0.0013885342582021832

21步内达到容许误差,收敛是比较快的。

推断结果 gamma:

+------------+------------+| key        | info       |+------------+------------+| u13        | 0.060984323394069596,0.9240107583094663,0.01500491829646417 || u11        | 0.9226879885707154,0.034669555298412306,0.04264245613087234 || u2         | 0.5773565581849824,0.015464121746682774,0.4071793200683348 || u4         | 0.5773565581849824,0.015464121746682774,0.4071793200683348 || u8         | 0.12802749598577526,0.05813812103740773,0.813834382976817 || u6         | 8.671555846303735E-9,0.9999999911750542,1.533899796361806E-10 || u9         | 0.9226879885707154,0.034669555298412306,0.04264245613087234 || u10        | 0.9927099076999223,0.0022000835657728438,0.0050900087343050075 || u1         | 0.5205681717937852,0.4652215734656923,0.014210254740522454 || u12        | 0.9226879885707154,0.034669555298412306,0.04264245613087234 || u3         | 0.5205681717937852,0.4652215734656923,0.014210254740522454 || u5         | 0.12802749598577526,0.05813812103740773,0.813834382976817 || u7         | 0.060984323394069596,0.9240107583094662,0.01500491829646417 |+------------+------------+

推断结果 sigma

(-0.7209447347994979,-1.1718416342165527,-1.5968869532868073)

模型参数 alpha

(345.0677795410156,220.01290893554688,144.00588989257812)

模型参数 phi

phi[0]:(0.7224296368389298,0.12706511156093872,0.15050525160013148)phi[1]:(0.15566385389422016,0.760368639366158,0.08396750673962187)phi[2]:(0.28444353588382226,0.021037927314571835,0.694518536801606)

对于大的数据的训练和推断,仿此过程即可。

注记

  1. 似然函数的下界 VLB 是非凸的函数,因此在做 EM 推断或者变分 EM 推断时存在局部极小问题,选择好的参数初始值对于得到合理的结果非常重要,需要细致的选择初始值;
  2. Muitinomial Mixture 模型可以对特征为计数的观察数据进行建模;
  3. ODPS-GRAPH 上可以实现任何 EM 相关算法的推断,其过程可仿效这篇文章的实现。
0 0
原创粉丝点击