Gibbs Sampling for Gaussian Mixture Model

来源:互联网 发布:华盛顿大学知乎 编辑:程序博客网 时间:2024/05/21 17:58

MCMC是我不太容易理解的一个技术点,需要多做一些联系。


MLaPP第24.2.3节介绍了一个使用Gibbs Sampling确定Gaussian Mixture Model(GMM)的范例,但过于简单;同时代码库中提供了一个范例函数gaussMissingFitGibbs,但并未详细介绍如何使用。


我在此范例程序的基础上,修改完成一个针对GMM数据的聚类程序。


下列程序与范例相比gaussMissingFitGibbs相比,1. 删除了x数据有缺失的部分代码;2. 完成了完整的GMM聚类过程(因此需要引入Dirichlet抽样);3. 增加了自动生成聚类数的代码(但是,这部分不太稳定,还需要继续研究)。


在这个过程中,除了理解Gibbs Sampling算法之外,个人认为最重要的是找到必须的抽样函数,包括Dirichlet抽样和IW抽样,这两部分都是使用了MLaPP提供的范例函数。


输出结果如下:



代码(主程序)

clear all;close all;rng(2);%% ParametersN = 1000;                                   % 总数据量D = 2;                                      % 数据维度K = 3;                                      % 类别数目Pi = rand([K,1]);                           % 随机生成各类比例Pi = Pi/sum(Pi);% 数据初始化,与之前的EM聚类程序相同mu = [1 2; -6 2; 7 1];sigma=zeros(K,D,D);sigma(1,:,:)=[2 -1.5; -1.5 2];sigma(2,:,:)=[5 -2.; -2. 3];sigma(3,:,:)=[1 0.1; 0.1 2];%% Data Generation and displayx = zeros(N,D);PzCDF1 = 0;figure(1); subplot(2,3,1); hold on;figure(2); hold on;for ii = 1:K,    PzCDF2 = PzCDF1 + Pi(ii);    PzIdx1 = round(PzCDF1*N);    PzIdx2 = round(PzCDF2*N);    x(PzIdx1+1:PzIdx2,:) = mvnrnd(mu(ii,:), squeeze(sigma(ii,:,:)), PzIdx2-PzIdx1);    PzCDF1 = PzCDF2;        figure(1); subplot(2,3,1); hold on;    plot(x(PzIdx1+1:PzIdx2,1),x(PzIdx1+1:PzIdx2,2),'o');end;[~, tmpidx] = sort(rand(N,1));x = x(tmpidx,:);                        % shuffle datafigure(1); subplot(2,3,1);plot(mu(:,1),mu(:,2),'k*');axis([-10,10,-4,8]);title('1.Generated Data (original)', 'fontsize', 20);xlabel('x1');ylabel('x2');figure(2);plot(x(:,1),x(:,2),'o');figure(2);plot(mu(:,1),mu(:,2),'k*');axis([-10,10,-4,8]);title('Generated Data (original)');xlabel('x1');ylabel('x2');fprintf('\n$$ Data generation and display completed...\n');save('GMM_data.mat', 'x', 'K');%% clustering: Matlab k-meansclear all;load('GMM_data.mat');[N,D] = size(x);k_idx=kmeans(x,K);                  % 使用Matlab现有k-means算法figure(1); subplot(2,3,2); hold on;for ii=1:K,    idx=(k_idx==ii);    plot(x(idx,1),x(idx,2),'o');    center = mean(x(idx,:));    plot(center(1),center(2),'k*');end;axis([-10,10,-4,8]);title('2.Clustering: Matlab k-means', 'fontsize', 20);xlabel('x1');ylabel('x2');fprintf('\n$$ K-means clustering completed...\n');%% clustering: EM% Refer to pp.351, MLaPP% Pw: weight% mu: u of Gaussion distribution% sigma: Covariance matrix of Gaussion distribution% r(i,k): responsibility; rk: sum of r over i% px: p(x|mu,sigma)% 上面的聚类结果作为EM算法的初始值Pw=zeros(K,1);for ii=1:K,    idx=(k_idx==ii);    Pw(ii)=sum(idx)*1.0/N;    mu(ii,:)=mean(x(idx,:));    sigma(ii,:,:)=cov(x(idx,1),x(idx,2));end;px=zeros(N,K);for jj=1:100, % 简单起见,直接循环,不做结束判断    for ii=1:K,        px(:,ii)=GaussPDF(x,mu(ii,:),squeeze(sigma(ii,:,:)));        % 使用Matlab自带的mvnpdf,有时会出现sigma非正定的错误    end;        % E step    temp=px.*repmat(Pw',N,1);    r=temp./repmat(sum(temp,2),1,K);    % M step    rk=sum(r);    Pw=rk'/N;    mu=r'*x./repmat(rk',1,D);    for ii=1:K        sigma(ii,:,:)=x'*(repmat(r(:,ii),1,D).*x)/rk(ii)-mu(ii,:)'*mu(ii,:);    end;end;% display[~,clst_idx]=max(px,[],2);figure(1); subplot(2,3,3); hold on;for ii=1:K,    idx=(clst_idx==ii);    plot(x(idx,1),x(idx,2),'o');    center = mean(x(idx,:));    sigma(ii,:,:)=cov(x(idx,1),x(idx,2));    plot(center(1),center(2),'k*');end;axis([-10,10,-4,8]);title('3.Clustering: GMM/EM', 'fontsize', 20);xlabel('x1');ylabel('x2');fprintf('\n$$ Gaussian Mixture using EM completed...\n');%% Variational Bayes EM% Refer to ch.10.2, PRML% x: visible variable, N * D% z: latent variable, N * K% z: Pz, Ppi, alp0, alpk%    Pz = P(z|pi);                                          PRML(10.37)%    Ppi = Dir(pi|alp0)                                     PRML(10.39)% x: Px, Pz, Ppi, mu, lambda, m0, beta0, W0, nu0%    Px = P(x|z, mu, lambda);        高斯分布               PRML(10.38)%    P(mu, lambda) = P(mu|lambda)*P(lambda)                PRML(10.40)%        = N(mu|m0, (beta0*lambda)^-1) * Wi(lambda|W0, nu0)% rho: N*K,定义参见PRML(10.46)% r: N*K, responsibility; 归一化之后的rho,定义参见PRML(10.49)% N_k: sum of r over n                    定义参见PRML(10.51)% xbar_k:                                 定义参见PRML(10.52)% S_k                                     定义参见PRML(10.53)clear all;load('GMM_data.mat');[N,D] = size(x);K = 6;                  % 增加分类数,利用VBEM自动选择分类数k_idx=kmeans(x,K);      % 使用Matlab自带的k-means聚类,结果作为VBEM的初始值for ii=1:K,    idx=(k_idx==ii);    mu(ii,:) = mean(x(idx,:));    sigma(ii,:,:)=cov(x(idx,1),x(idx,2));    px(:,ii)=GaussPDF(x,mu(ii,:),squeeze(sigma(ii,:,:)));    % 使用Matlab自带的mvnpdf,有时会出现sigma非正定的错误,特使用自编函数GaussPDFend;% 初始化,具体定义参见PRML式(10.40)alp0 = 0.0001;          % alpha0,应<<1,以实现类别数自动筛选m0 = 0;beta0 = rand()+0.5;         % 拍脑袋初始化W0 = squeeze(mean(sigma));W0inv = pinv(W0);nu0 = D*2;                  % 拍脑袋初始化S_k = zeros(K,D,D);W_k = zeros(K,D,D);E_mu_lmbd = zeros(N,K);     % 即PRML中式(10.64)的等号左侧r = px./repmat(sum(px,2),1,K);                  % N*KN_k = ones(1,K)*(-100);for ii = 1:1000,    % M-step    N_k_new = sum(r);                           % 1*K,式(11.51)    N_k_new(N_k_new<N/1000.0)=1e-4;             % 避免出现特别小或为零的Nk    if sum(abs(N_k_new-N_k))<0.001,                      break;  % early stop,如果Nk基本没变化了,则停止迭代    else        N_k = N_k_new;    end;        xbar_k = r'*x./repmat(N_k', 1, D);          % K*D,PRML式(10.52)    for jj = 1:K,        dx = x-repmat(xbar_k(jj,:), N, 1);      % N*D        S_k(jj,:,:) = dx'*(dx.*repmat(r(:,jj),1,D))/N_k(jj); % D*D,PRML式(10.53)    end;        alp_k = alp0 + N_k;         % PRML式(10.58)    beta_k = beta0 + N_k;       % PRML式(10.60)    m_k = (beta0*m0 + repmat(N_k',1,D).*xbar_k)./...        repmat(beta_k',1,D);    % K*D,PRML式(10.61)    for jj = 1:K,        dxm = xbar_k(jj,:)-m0;        Wkinv = W0inv + N_k(jj)*squeeze(S_k(jj,:,:)) + ...            dxm'*dxm*beta0*N_k(jj)/(beta0+N_k(jj));        W_k(jj,:,:) = pinv(Wkinv);           % K*D*D,PRML式(10.62)    end;    nu_k = nu0 + N_k;                        % 1*K,PRML式(10.63)        % E-step: 迭代计算r    alp_tilde = sum(alp_k);    E_ln_pi = psi(alp_k) - psi(alp_tilde);      % PRML式(10.66)    E_ln_lambda = D*log(2)*ones(1,K);               for jj = 1:D,        E_ln_lambda = E_ln_lambda + psi((nu_k+1-jj)/2);     end;    for jj = 1:K,        E_ln_lambda(jj) = E_ln_lambda(jj) + ...            log(det(squeeze(W_k(jj,:,:))));     % PRML式(10.65)        dxm = x-repmat(m_k(jj,:),N,1);          % N*D        Dbeta = D/beta_k(jj);        for nn = 1:N,            E_mu_lmbd(nn,jj) = Dbeta+nu_k(jj)*(dxm(nn,:)*...                squeeze(W_k(jj,:,:))*dxm(nn,:)');   % PRML式(10.64)        end;    end;        rho = exp(repmat(E_ln_pi,N,1)+repmat(E_ln_lambda,N,1)/2-...        E_mu_lmbd/2);                           % PRML式(10.46)    r = rho./repmat(sum(rho,2),1,K);            % PRML式(10.49)    end;    [~,clst_idx]=max(r,[],2);figure(1); subplot(2,3,4); hold on;Nclst = 0;for ii=1:K,    idx=(clst_idx==ii);    if sum(idx)/N>0.01,        Nclst = Nclst+1;        plot(x(idx,1),x(idx,2),'o');        center = mean(x(idx,:));        plot(center(1),center(2),'k*');    end;end;fprintf('\n$$ GMM using VBEM completed, and totally %d clusters found.\n', Nclst);axis([-10,10,-4,8]);title('4.Clustering: Variational Bayes EM', 'fontsize', 20);xlabel('x1');ylabel('x2');%% Gibbs sampling for Gaussian Mixture Model% Latent Variables:% z: N*K, x所处的类别% mu:1*K, 第k类分布的均值% sig:K*D*D,第k类分布的方差% pz:N*K,z(i)属于K类的分布概率clear all; rng(1);load('GMM_data.mat');[N,D] = size(x);K = 6;                  % 增加分类数,自动选择分类数?Nth = N/K/20;            % 阈值threshold,当某一分类样本数少于此值时,抛弃此分类k0 = 0.0;dof = 0;Nsmpl = 60;            % 总抽样数Nbnin = 20;            % 前面需要扔掉的抽样数,只取后面的抽样(稳定后的抽样)z = zeros(N,K);         % z(i)中只有一个为1,其它为0pz = zeros(N,K);        % z(i)属于K类的概率,用于最终聚类pi = ones(1,K)/K;       % K类的总概率px = zeros(N,K);        % N(x(i)|mu(k),sigma(k))pxtmp = zeros(size(px));mu = zeros(K,D);sig = zeros(K,D,D);xbar = zeros(1,D);Nk = zeros(1,K);ClstMask = ones(1,K);  % Cluster MaskpiSamples = zeros(Nsmpl-Nbnin, K);muSamples = zeros(Nsmpl-Nbnin, K, D);sigSamples = zeros(Nsmpl-Nbnin, K, D, D);k_idx=kmeans(x,K);      % 使用Matlab自带的k-means聚类,结果作为GS的初始值figure(1); subplot(2,3,5); hold on;for ii=1:K,    idx=(k_idx==ii);    mu(ii,:) = mean(x(idx,:));    sig(ii,:,:)=cov(x(idx,1),x(idx,2));    px(:,ii)=GaussPDF(x,mu(ii,:),squeeze(sig(ii,:,:)));    % 使用Matlab自带的mvnpdf,有时会出现sigma非正定的错误,因此使用自编函数GaussPDF        plot(x(idx,1),x(idx,2),'o');    plot(mu(ii,1),mu(ii,2),'*');end;axis([-10,10,-4,8]);title('5.Clustering: Gibbs Sampling (initial)', 'fontsize', 20);xlabel('x1');ylabel('x2');for s = 1:Nsmpl,    % need to be refreshed: pi, px, mu, sig    pz_k = px.*repmat(pi,N,1);    [~,tmpidx] = max(pz_k,[],2);        z = zeros(N,K);    for ii = 1:K,        idx=(tmpidx==ii);        z(idx,ii) = 1;        Nk(ii) = sum(z(:,ii));        if Nk(ii)<Nth,      % 如果某一分类样本数少于阈值Nth,则抛弃            ClstMask(ii) = 0;            Nk(ii) = 0;            px(:,ii) = 0;            break;        end;                % 如下代码借鉴了MLaPP所附gaussMissingFitGibbs函数        xbar = mean(x(idx,:));        muPost = (Nk(ii)*xbar + k0*mu(ii,:)) / (Nk(ii) + k0);        sigPost = squeeze(sig(ii,:,:)) + Nk(ii)*cov(x(idx,:),1) + ...            Nk(ii)*k0/(Nk(ii)+k0) * (xbar - mu(ii,:))*(xbar - mu(ii,:))';        sig(ii,:,:) = invWishartSample(struct('Sigma', sigPost, 'dof', k0 + Nk(ii)));        mu(ii,:) = mvnrnd(muPost, squeeze(sig(ii,:,:))/(k0 + Nk(ii)));                px(:,ii)=GaussPDF(x,mu(ii,:),squeeze(sig(ii,:,:)));            end;    pi = dirichlet_sample(Nk).*ClstMask;    pi = pi/sum(pi);        if s > Nbnin,        muSamples(s - Nbnin,:,:) = mu;        sigSamples(s - Nbnin,:,:,:) = sig;        piSamples(s - Nbnin,:) = pi;    end;    end;muMean = squeeze(mean(muSamples));sigMean = squeeze(mean(sigSamples));piMean = squeeze(mean(piSamples)).*ClstMask;for ii = 1:K,    if ClstMask(ii)==1,        px(:,ii)=GaussPDF(x,muMean(ii,:),squeeze(sigMean(ii,:,:)));    else        px(:,ii)=0;    end;end;pz_k = px.*repmat(piMean,N,1);[~,tmpidx] = max(pz_k,[],2);figure(1); subplot(2,3,6); hold on;Nclst = 0;for ii = 1:K,    idx=(tmpidx==ii);    if sum(idx)>=Nth,        Nclst = Nclst + 1;        plot(x(idx,1),x(idx,2),'o');        plot(muMean(ii,1),muMean(ii,2),'*');    end;end;axis([-10,10,-4,8]);fprintf('\n$$ GMM using Gibbs sampling completed, and totally %d clusters found.\n\n', Nclst);title('6.Clustering: Gibbs Sampling (final)', 'fontsize', 20);xlabel('x1');ylabel('x2');

函数Dirichlet抽样:

function r = dirichlet_sample(a,n)% DIRICHLET_SAMPLE   Sample from Dirichlet distribution.%% DIRICHLET_SAMPLE(a) returns a probability vector sampled from a % Dirichlet distribution with parameter vector A.% DIRICHLET_SAMPLE(a,n) returns N samples, collected into a matrix, each % vector having the same orientation as A.%%   References:%      [1]  L. Devroye, "Non-Uniform Random Variate Generation", %      Springer-Verlag, 1986% This is essentially a generalization of the method for Beta rv's.% Theorem 4.1, p.594if nargin < 2  n = 1;endrow = (size(a, 1) == 1);a = a(:);y = gamrnd(repmat(a, 1, n),1);% randgamma is faster%y = randgamma(repmat(a, 1, n));%r = col_sum(y);r = sum(y,1);r(find(r == 0)) = 1;r = y./repmat(r, size(y, 1), 1);if row  r = r';endend

函数IW抽样:

function S = invWishartSample(model, n)% S(:, :, 1:n) ~ IW(model.Sigma, model.dof)% This file is from pmtk3.googlecode.comif nargin < 2, n = 1; endSigma = model.Sigma;dof   = model.dof;d     = size(Sigma, 1);C     = chol(Sigma)';S     = zeros(d, d, n);for i=1:n    if (dof <= 81+d) && (dof==round(dof))        Z = randn(dof, d);    else        Z = diag(sqrt(2.*randg((dof-(0:d-1))./2))); % randgamma改为randg        Z(utri(d)) = randn(d*(d-1)/2, 1);    end    [Q, R] = qr(Z, 0);    M = C / R;    S(:, :, i) = M*M';endend

函数(IW抽样函数需要用到的一个小函数,不知道用途)

function ndx = utri(d)% Return the indices of the upper triangluar part of a square d-by-d matrix% Does not include the main diagonal.% This file is from pmtk3.googlecode.comndx = ones(d*(d-1)/2,1);ndx(1+cumsum(0:d-2)) = d+1:-1:3;ndx = cumsum(ndx);end

函数GaussPDF(等效于Matlab自带的mvnpdf函数,之前用mvnpdf有时会出现非正定矩阵问题)

function p = GaussPDF(x, mu, sigma)[N, D] = size(x);x_u = x-repmat(mu, N, 1);p = zeros(N,1);for ii=1:N,    p(ii) = exp(-0.5*x_u(ii,:)*pinv(sigma)*x_u(ii,:)')/...        sqrt(det(sigma)*(2*pi)^D);end;end