EM算法求解混合伯努利模型

来源:互联网 发布:手游运营数据分析 编辑:程序博客网 时间:2024/05/28 09:31

本文记录了EM算法求解混合伯努利分布的推导,并提供了matlab实验代码。

混合伯努利分布

单个D维伯努利分布的分布率为:

p(x|μ)=d=1Dμxdd(1μd)1xd,

这里x=(x1,,xD)TD维0-1向量,μ=(μ1,,μD)T为对应维度上事件发生的概率。
混合伯努利分布是指由K个单个D维伯努利分布混合构成的分布,分布率如下:
p(x|μ,π)=k=1Kπkp(x|μk),

这里μ={μ1,,μK},π={π1,,πK},
p(x|μk)=d=1Dμxdkd(1μkd)1xd.

EM算法混合伯努利分布

现在我们有一个来自于混合伯努利分布的数据集X={x1,,xN},我们要最大化其似然函数:

maxμ,πlnp(X|μ,π)=n=1Nln{k=1Kπkp(xn|μk)}.

由于ln函数中的求和项,无法得到其闭式解,我们用EM算法求其一个局部数值解。关于EM算法的内容见上一篇博客。对于样本xn,我们构造一个隐变量zn用来表示xn来自于哪个模型。也即,zn=k当且仅当xn来自于第k个伯努利模型。令Z={z1,,zn},则X,Z联合分布为:
p(X,Z|μ,π)=n=1Nπznp(xn|μzn).

M步我们要求解
maxμ,πZp(Z|X,μ¯,π¯)lnp(X,Z|μ,π)=n=1Nk=1Kp(k|xn,μ¯,π¯)lnp(xn|μk,πk)(1)

观察到p(k|xn,μ¯,π¯)可看做关于n,k的常数,因此令
γnk=p(k|xn,μ¯,π¯)=πkp(xn|μk)i=1Kπip(xn|μi).

则式(1)中的优化问题可以进一步写作
maxμ,π=f(π,μ)=s.t.n=1Nk=1Kγnk[lnπk+xTnlnμk+(1xn)Tln(1μk)]k=1Kπk=1,0<μk<1(k=1,,K).

这里lnx表示对向量x逐元素操作。对μk求偏导并令其等于0,
μk=n=1Nγnk(xnμk1xn1μk)=n=1Nγnk(xnμk)μk(1μk)=0.

这里xyxy分别表示两个向量的表示逐元素除法和逐元素乘法。解得
μk=1Nkn=1Nγnkxn,

其中Nk=n=1Nγnk。 为了求解π,构造拉格朗日函数
L(π,λ)=f(π,μ)+λ(k=1Kπk1),

利用KKT条件得
πkL(π,λ)=1πkn=1Nγnk+λ=0k=1Kπk=1.k=1,,K

利用n=1Nk=1Kγnk=N解得
πk=NkN.

实验部分

这里我们利用MINIST手写数字数据集进行实验,这里有处理好的matlab数据(60000*785矩阵,每一行为一个样本,其中28*28=784为灰度图片拉成的行向量,最后一列为对应数字标签)。将灰度向量二值化就得到了0-1向量,我们把这个0-1向量当做由混合伯努利分布生成的一个样本,并且希望最后学习到的混合分布中的不同模型对应不同的数字向量概率分布。由于计算后验概率分布时涉及到概率的连乘操作,要注意精度问题。

这里写图片描述

function[]=main()    RR = 1;            CC = 3;             R = 10;    C = 10;    r = 28;    c = 28;    Eps = 1e-5;    scale = 0.8;        %每个digit图片的初始大小为28*28,这里为了计算效率和精度长宽都缩放为0.8倍    digit = [1 3 4];    %选用的digits    K = RR*CC;          %混合的模型个数    N = R*C*K;          %总共样本个数,即每个digit有R*C个样本    D = r*c;    load('Data.mat');   %导入MNIST数据集,Data为60000*785的矩阵,每一行对应一个样本,样本最后一个维度    label = Data(:,D+1);    index = zeros(N,1);    for k = 1:K;        t = find(label==digit(k));        index((k-1)*R*C+1:k*R*C) = (t:t+R*C-1)';    end    X = Data(index, 1:D);    tX = zeros(N, ceil(r*scale)*ceil(c*scale));    for n = 1:N %缩放并二值化        t = imresize(reshape(X(n,:), r, c), [ceil(r*scale) ceil(c*scale)]);        tX(n, :) = reshape(t, 1, ceil(r*scale)*ceil(c*scale));        tX(n,:) = im2bw(tX(n, :), graythresh(tX(n,:)));    end    X = tX;    r = ceil(r*scale);    c = ceil(c*scale);    D = r*c;    Pi = rand(K, 1);    Pi = Pi/(ones(1, K)*Pi);    Mu = rand(K, D);    logLikeBar = calLogLikeBound(X, K, Pi, Mu);    fprintf('Initial LogLikelihoodBound: %.6f\n', logLikeBar);    cnt = 0;    Gamma = zeros(N, K);    while true        %E step:         for n=1:N            t = 0;            %tic;            for k = 1:K                digits(100);                Gamma(n, k) = vpa(exp(vpa(X(n, :))*log(Mu(k, :)')+vpa(ones(1,D)-X(n,:))*log(ones(D,1)-Mu(k, :)'))); %这里用了ln变换和高精度运算                t = t+ Gamma(n, k);            end            Gamma(n,:) = Gamma(n,:)/t;            %toc;         end         Nk = (ones(1, N)*Gamma)';         %M step:         Mu = (X'*Gamma*(diag(Nk))^(-1))';         Pi = Nk/N;         Mu(Mu <= 0) = 1e-5;         Mu(Mu >= 1) = 1-1e-5;         cnt = cnt+1;         logLike = calLogLikeBound(X, K, Pi, Mu);         relGap = abs(logLike-logLikeBar)/abs(logLikeBar);         disp(['Iteration: ', num2str(cnt), ' RelativeGap: ', sprintf('%.6f LogLikelihoodBound: %.6f', relGap, logLike)]);         logLikeBar = logLike;         label = zeros(N, 1);         for k = 1:K            label(Gamma(:,k)'==max(Gamma'))=k;         end         showRes(RR, CC, R, C, r, c, X, cnt, label, lambda);        if relGap < Eps            break;        end    endendfunction [logLike] = calLogLikeBound(X, K, Pi, Mu) %计算对数似然函数值    logLike = 0;    N = size(X,1);    D = size(X,2);    for n =1:N        t = exp(X(n,:)*log(Mu')+(ones(1, D)-X(n, :))*log(ones(D, K)-Mu'))';        logLike = logLike+log(Pi'*t);    endendfunction [] = showRes(RR, CC, R, C, r, c, X, cnt, label, lambda) %绘制图片    lambda = 0.35;  %控制图片颜色程度    colorNum = 20;    mycolor = colorcube(colorNum);    for II = 1:RR        for JJ = 1:CC           for I = 1:R               for J = 1:C                   t1 = zeros(r, c, 3);                   t2 = (II-1)*CC*(R*C)+(JJ-1)*(R*C)+(I-1)*C+J;                   for i = 1:r                        t1(i,:,1)  = X(t2, (i-1)*c+1:i*c);                        t1(i,:,2) = t1(i,:,1);                        t1(i,:,3) = t1(i,:,1);                   end                   t1(:,:,1) = t1(:,:,1)*lambda+mycolor(mod(label(t2)-1, colorNum)+1, 1)*(1-lambda);                   t1(:,:,2) = t1(:,:,2)*lambda+mycolor(mod(label(t2)-1, colorNum)+1, 2)*(1-lambda);                   t1(:,:,3) = t1(:,:,3)*lambda+mycolor(mod(label(t2)-1, colorNum)+1, 3)*(1-lambda);                   tr = (II-1)*R*r+(I-1)*r;                   tc = (JJ-1)*C*c+(J-1)*c;                   RGBMat(tr+1:tr+r, tc+1:tc+c,:)=t1;               end           end        end    end    figure(cnt);    imshow(RGBMat);    saveas(gcf,sprintf('%d.jpg', cnt));end

参考文献

Christopher M.. Bishop. Pattern recognition and machine learning. pp. 423-450

0 0
原创粉丝点击