EM算法训练GMM的Matlab实现过程(总结)

来源:互联网 发布:数组push方法 编辑:程序博客网 时间:2024/05/02 06:11

         最近看到论文中很多地方提到EM算法,之前对EM算法只是大概知道是一个参数优化算法,而不知道具体的过程,通过阅读相关的资料,大概了解了其推导过程以及实现过程。

   GMM模型就是由若干个高斯分量相互组成的,通过混合的高斯模型来逼近样本的真实分布。

        GMM模型估计包括三个参数:混合权重,每个高斯函数的均值以及方差,他们的递推公式如下:

                 权重的递推公式如下:

          

            均值和方差的递推公式如下:

          

 

 

 

其中M为混合高斯数,n为训练的样本数

假设现在有训练样本data集合,每一列为一个样本,行数代表样本的特征维数,采用Matlab实现EM算法的训练过程如下:

 

%演示EM训练算法的实现过程clc;clear all;load data;[dim,Num]=size(data);max_iter=10;%最大迭代次数min_improve=1e-4;% 提升的精度Ngauss=3;%混合高斯函数个数Pw=zeros(1,Ngauss);%保存权重mu= zeros(dim,Ngauss);%保存每个高斯分类的均值,每一列为一个高斯分量sigma= zeros(dim,dim,Ngauss);%保存高斯分类的协方差矩阵fprintf('采用K均值算法对各个高斯分量进行初始化\n');[cost,cm,cv,cc,cs,map] = vq_flat(data, Ngauss);%聚类过程  map:样本所对应的聚类中心mu=cm;%均值初始化for j=1:Ngauss   gauss_labels=find(map==j);%找出每个类对应的标签   Pw(j)= length(gauss_labels)/length(map);%类别为1的样本个数占总样本的个数    sigma(:,:,j)  = diag(std(data(:,gauss_labels),0,2)); %求行向量的方差,只取对角线,其他特征独立,并将其赋值给对角线endlast_loglik = -Inf;%上次的概率% 采用EM算法估计GMM的各个参数if Ngauss==1,%一个高斯函数不需要用EM进行估计    sigma(:,:,1)  = sqrtm(cov(data',1));    mu(:,1)       = mean(data,2);else     sigma_i  = squeeze(sigma(:,:,:));          iter= 0;     for iter = 1:max_iter          %E 步骤          %求每一样样本对应于GMM函数的输出以及每个高斯分量的输出,          sigma_old=sigma_i;          %E步骤。。。。。          for i=1:Ngauss          P(:,i)= Pw(i) * p_single(data, squeeze(mu(:,i)), squeeze(sigma_i(:,:,i)));%每一个样本对应每一个高斯分量的输出          end          s=sum(P,2);%        for j=1:Num            P(j,:)=P(j,:)/s(j);        end       %%%Max步骤        Pw(1:Ngauss) = 1/Num*sum(P);%权重的估计        %均值的估计        for i=1:Ngauss            sum1=0;            for j=1:Num             sum1=sum1+P(j,i).*data(:,j);            end          mu(:,i)=sum1./sum(P(:,i));        end               %方差估计按照公式类似         %sigma_i         if((sum(sum(sum(abs(sigma_i- sigma_old))))<min_improve))             break;        end                     end         end


子函数:

function p = p_single(x, mu, sigma)%返回高斯函数的值 [dim,N]=size(x); p=zeros(1,N); for i=1:N     p(i)= 1/(2*pi*abs(det(sigma)))^(length(mu)/2)*exp(-0.5*(x(:,i)-mu)'*inv(sigma)*(x(:,i)-mu)); end


 

 

注明:鉴于大家都要求vq_flat代码,这里就不一一发送到邮箱了,提供下载地http://download.csdn.net/detail/xiaoding133/5501211