机器学习(四)高斯混合模型

来源:互联网 发布:巴黎贝甜 知乎 编辑:程序博客网 时间:2024/05/16 08:38

高斯混合模型聚类算法

原文地址http://blog.csdn.NET/hjimce/article/details/45244603 

作者:hjimce

高斯混合算法是EM算法的一个典型的应用,EM算法的推导过程这里不打算详解,直接讲GMM算法的实现。之前做图像分割grab cut 算法的时候,只知道把OpenCV中的高斯混合模型代码复制下来,然后封装成类使用,学的比较浅。结果没过几天发现高斯混合算法又忘了差不多了,于是用matlab去亲自写过一遍,终于发现了高斯混合模型的奥义。我的理解是高斯混合模型其实是进化版的k均值算法,因此学习高斯混合模型,最好还是把k均值算法写过一遍。高斯混合与k均值的本质区别在于权值问题,k均值采用的是均匀权值,而高斯混合的权值需要根据高斯模型的概率进行确定。

    开始学习高斯混合模型,需要先简单复习一下单高斯模型的参数估计方法,描述一个高斯模型其实就是要计算它的均值、协方差矩阵(一维空间为方差,二维以上称之为协方差矩阵):

假设有数据集X={x1,x2,x3……,xn},那么用这些数据来估计单高斯模型参数的计算公式为

OK,开始写代码前,先用matlab生成数据集,然后在进行聚类:

利用matlab的生成高斯模型数据集X:

[c++] view plain copy
  1. mu = [2 3];  
  2. SIGMA = [1 0; 0 2];  
  3. r1 = mvnrnd(mu,SIGMA,1000);  
  4. plot(r1(:,1),r1(:,2),'r+');  


然后利用上面的估计方法计算均值,和协方差是否满足均值为[2 3],协方差为[1 0; 0 2];测试代码如下,r2、covmat即为计算结果

[c++] view plain copy
  1. [m n]=size(r1);  
  2. center=sum(r1)./m;  
  3. r2(:,1)=r1(:,1)-center(1);  
  4. r2(:,2)=r1(:,2)-center(2);  
  5. covmat=1/m*r2'*r2;  


先把单高斯模型的函数写好,因为高斯混合模型是它的进化版,计算高斯混合模型过程中需要调用单高斯模型参数估计,写好代码后面才不会乱掉。开始高斯混合建模之前,我先用matlab生成一个测试数据集data,如下图,然后再进行算法测试。

生成数据集代码如下:

[c++] view plain copy
  1. %生成测试数据  
  2. mu = [2 3];%测试数据1  
  3. SIGMA = [1 0; 0 2];  
  4. r1 = mvnrnd(mu,SIGMA,100);  
  5. plot(r1(:,1),r1(:,2),'.');  
  6. hold on;  
  7. mu = [10 10];%测试数据2  
  8. SIGMA = [ 1 0; 0 2];  
  9. r2 = mvnrnd(mu,SIGMA,100);  
  10. plot(r2(:,1),r2(:,2),'.');  
  11. mu = [5 8];%测试数据3  
  12. SIGMA = [ 1 0; 0 2];  
  13. r3= mvnrnd(mu,SIGMA,100);  
  14. plot(r3(:,1),r3(:,2),'.');  
  15. data=[r1;r2;r3];  




ok,数据生成完毕,接着我们正式开始高斯混合算法解析,先看一下高斯混合模型的建模求参步骤:

高斯混合模型的求解,说得简单一点就是要求解高斯模型中的均值与协方差,现在我们要把上述的数据分成3类,那么我们就是要求解3个均值及其对应的3个协方差矩阵。先讲一下总体步骤,高斯混合模型包含3个步骤:

a.初始化各个高斯模型的参数,及每个高斯模型的权重;

b.根据各个高斯模型的参数及其权重,计算每个点属于各个高斯模型的权重,计算公式为:


其中:,Wj是每个高斯模型在这个模型所占用得权重。这个公式说的简单一点就是每个高斯模型的权重与其概率的乘积,这样计算出来就相当于每个高斯模型在每个数据点中的所占用的比例。

c.更新各个高斯模型的均值与方差,计算公式如下:



d.更新各个高斯模型的总权重,计算公式如下:

   

其实第c、d两个步骤,无所谓顺序,你完全可以总权重更新放在各个模型参数更新之前。迭代过程就是b、c、d三个步骤进行更新就可以了。OK,接着结合上面的公式写一写代码。

(1)初始化高斯模型参数。

这一步初始化,在实际应用中一般是先通过k均值算法进行初始聚类,然后根据聚类结果进行计算初始化参数。不过这里我为了测试,我们选择随机初始化,这样才能看出GMM算法到底能不能实现聚类。

我这里各个高斯模型初始均值(中心)的初始化方法选择跟k均值的初始化方法一样,也就是随机选择k个点位置作为k个高斯模型的初始均值。然后协方差矩阵的初始化,我选择单位矩阵,具体代码如下:


[c++] view plain copy
  1. [m n]=size(data);  
  2. kn=3;  
  3. countflag=zeros(1,kn);  
  4. tdata=cell(1,kn);%建立3个空矩阵  
  5. mu=cell(1,kn);%建立3个空矩阵  
  6. sigma=cell(1,kn);%建立3个空矩阵  
  7. %方案2 随机初始化参数  
  8. for i=1:kn  
  9.     mu{1,i}=data(i*10,:);  
  10.     sigma{1,i}=eye(2,2);  
  11.     weightp(i)=1/kn;  
  12. end  


(2)计算各个模型在各个点的权重值

这一步是计算每个数据点属于各个高斯混合的概率,说白了就是计算权值:


[c++] view plain copy
  1. pro_ij=zeros(m,kn);%存储每个点属于每个类的概率  
  2. for i=1:m  
  3.     sumpk=0;  
  4.     for j=1:kn  
  5.         pk(j)=weightp(j)*GSMPro(mu{1,j},sigma{1,j},data(i,:));  
  6.         sumpk=sumpk+pk(j);  
  7.     end  
  8.     for j=1:kn  
  9.         pro_ij(i,j)=pk(j)/sumpk;  
  10.     end  
  11. end  

(3)步骤c 更新参数

[c++] view plain copy
  1. for j=1:kn  
  2.      [mu{1,j},sigma{1,j}]=WeightGSM(data,pro_ij(:,j));   
  3.  end  

(4)步骤d 更新各个模型的总权重

 

[c++] view plain copy
  1. for j=1:kn  
  2.       weightp(j)=sum(pro_ij(:,j))/m;  
  3.   end  

然后把步骤2、3、4的代码放在循环语句中进行迭代就ok了。最后贴一下整份代码:

1、脚本文件:

[c++] view plain copy
  1. close all;  
  2. clear;  
  3. clc;  
  4. %生成测试数据  
  5. mu = [2 3];%测试数据1  
  6. SIGMA = [1 0; 0 2];  
  7. r1 = mvnrnd(mu,SIGMA,100);  
  8. plot(r1(:,1),r1(:,2),'.');  
  9. hold on;  
  10. mu = [10 10];%测试数据2  
  11. SIGMA = [ 1 0; 0 2];  
  12. r2 = mvnrnd(mu,SIGMA,100);  
  13. plot(r2(:,1),r2(:,2),'.');  
  14.   
  15. mu = [5 8];%测试数据3  
  16. SIGMA = [ 1 0; 0 2];  
  17. r3= mvnrnd(mu,SIGMA,100);  
  18. plot(r3(:,1),r3(:,2),'.');  
  19.   
  20.   
  21. data=[r1;r2;r3];  
  22.   
  23. [m n]=size(data);  
  24. kn=3;  
  25. countflag=zeros(1,kn);  
  26. tdata=cell(1,kn);%建立10个空矩阵  
  27. mu=cell(1,kn);%建立10个空矩阵  
  28. sigma=cell(1,kn);%建立10个空矩阵  
  29. % 方案1 初始化采用kmeans,做参数的初步估计  
  30. % Idx=kmeans(data,kn);  
  31. % figure(2);%绘制初始化结果  
  32. % hold on;  
  33. for i=1:m  
  34. %     if Idx(i)==1  
  35. %         plot(data(i,1),data(i,2),'.y');  
  36. %     elseif Idx(i)==2  
  37. %          plot(data(i,1),data(i,2),'.b');  
  38. %     end  
  39. % end  
  40. for i=1:m  
  41. %    tdata{1,Idx(i)}=[tdata{1,Idx(i)};data(i,:)];  
  42. % end  
  43. for i=1:kn  
  44. %     [mu{1,i},sigma{1,i}]=GSMData(tdata{1,i});  
  45. % end  
  46. for i=1:kn  
  47. %     [trow,tcol]=size(tdata{1,i});  
  48. %     weightp(i)=trow/m;  
  49. % end  
  50. %方案2 随机初始化  
  51. for i=1:kn  
  52.     mu{1,i}=data(i*10,:);  
  53.     sigma{1,i}=eye(2,2);  
  54.     weightp(i)=1/kn;  
  55. end  
  56.   
  57.   
  58.   
  59. it=1;  
  60.   
  61. while it<1000  
  62.     %E步 计算每个点处于每个类的概率  
  63.     pro_ij=zeros(m,kn);%存储每个点属于每个类的概率  
  64.     for i=1:m  
  65.         sumpk=0;  
  66.         for j=1:kn  
  67.             pk(j)=weightp(j)*GSMPro(mu{1,j},sigma{1,j},data(i,:));  
  68.             sumpk=sumpk+pk(j);  
  69.         end  
  70.         for j=1:kn  
  71.             pro_ij(i,j)=pk(j)/sumpk;  
  72.         end  
  73.     end   
  74.     %M步   
  75.     for j=1:kn  
  76.         [mu{1,j},sigma{1,j}]=WeightGSM(data,pro_ij(:,j));   
  77.     end  
  78.     %更新权值  
  79.     for j=1:kn  
  80.         weightp(j)=sum(pro_ij(:,j))/m;  
  81.     end  
  82.     sumw=sum(weightp);  
  83.     it=it+1;  
  84. end  
  85. for i=1:m  
  86.     [value index]=max(pro_ij(i,:));  
  87.     Idx(i)=index;  
  88. end  
  89. figure(2);  
  90. hold on;  
  91. for i=1:m  
  92.     if Idx(i)==1  
  93.         plot(data(i,1),data(i,2),'.y');  
  94.     elseif Idx(i)==2  
  95.          plot(data(i,1),data(i,2),'.b');  
  96.     elseif Idx(i)==3  
  97.          plot(data(i,1),data(i,2),'.r');  
  98.     end  
  99. end  
  100.   
  101.   
  102. % figure(3);  
  103. % %px=gmmstd(data,3);  
  104. for i=1:m  
  105. %     [value index]=max(px(i,:));  
  106. %     Idx(i)=index;  
  107. % end  
  108. % hold on;  
  109. for i=1:m  
  110. %     if Idx(i)==1  
  111. %         plot(data(i,1),data(i,2),'.y');  
  112. %     elseif Idx(i)==2  
  113. %          plot(data(i,1),data(i,2),'.b');  
  114. %     elseif Idx(i)==3  
  115. %          plot(data(i,1),data(i,2),'.r');  
  116. %     end  
  117. % end  
  118. %单高斯模型参数估计  
  119. % [m n]=size(r1);  
  120. % center=sum(r1)./m;  
  121. % r2(:,1)=r1(:,1)-center(1);  
  122. % r2(:,2)=r1(:,2)-center(2);  
  123. % covmat=1/m*r2'*r2;  

2、相关函数

[c++] view plain copy
  1. function [ mu ,sigma ] = WeightGSM(data,weight)  
  2.     %计算加权均值  
  3.     [m n]=size(data);  
  4.     sumweight=sum(weight);  
  5.     weightdata=[];  
  6.     for i=1:m  
  7.         weightdata(i,:)=weight(i)*data(i,:);  
  8.     end  
  9.     center=sum(weightdata)/sumweight;  
  10.     %计算加权协方差  
  11.     for i=1:n  
  12.        r2(:,i)=data(:,i)-center(i);  
  13.     end  
  14.     for i=1:m  
  15.         r1(i,:)=weight(i)*r2(i,:);  
  16.     end  
  17.      
  18.     sigma=1/sumweight*r1'*r2;  
  19.     mu=center;  
  20. end  
  21. function [pro] = GSMPro(mu ,sigma,x)  
  22.   pro=exp(-0.5*(x-mu)*inv(sigma)*(x-mu)');  
  23.   pro=1/sqrt(2*pi*det(sigma))*pro;  
  24. end  

看以下最后的测试结果:


*******************作者:hjimce     联系qq:1393852684 更多资源请关注我的博客:http://blog.csdn.net/hjimce                  原创文章,转载请注明出处 *******************


0 0
原创粉丝点击