EM算法及其应用(代码)

来源:互联网 发布:手机调音量软件 编辑:程序博客网 时间:2024/05/17 03:27

最近上模式识别的课需要做EM算法的作业,看了机器学习公开课及网上的一些例子,总结如下:(中间部分公式比较多,不能直接粘贴上去,为了方便用了截图,请见谅)

概要

适用问题

EM算法是一种迭代算法,主要用于计算后验分布的众数或极大似然估计,广泛地应用于缺损数据、截尾数据、成群数据、带有讨厌参数的数据等所谓不完全数据的统计推断问题。

优缺点

优点:EM算法简单且稳定,迭代能保证观察数据对数后验似然是单调不减的。

缺点:对于大规模数据和多维高斯分布,其总的迭代过程,计算量大,迭代速度易受影响;EM算法的收敛速度,非常依赖初始值的设置,设置不当,计算时的代价是相当大的;EM算法中的M-Step依然是采用求导函数的方法,所以它找到的是极值点,即局部最优解,而不一定是全局最优解。

原理

Jensen不等式










与ML估计的关系

EM算法的E-step是建立的下界,M-step是极大化下界,不断重复这两步直到收敛。最大似然估计是求参数的一种相当普遍的方法,但是它的使用场合是除了参数未知外,其它一切都已知,即不存在隐变量的情况。当存在隐变量时,我们可以通过EM算法,消掉隐变量,当期望取极大值时,最大似然函数也会在相同点取极大值。

具体模型应用

问题映射

       用两个高斯函数生成数据,这两个高斯密度函数分别为N(1,2)和N(20,3.5),分别生成160和240个数据,也就是说其在混合模型中的权重分别为0.4 和 0.6。将这些数据放入同一数组中,经过EM算法可得出混合模型的密度函数及其权重。

  1. clc;  
  2. clear;  
  3. %generate data fromnormal distribution.  
  4. mu1= 1;  
  5. sigma1= 2;  
  6. R1 = normrnd(mu1,sigma1,160,1);  
  7. mu2 = 20;  
  8. sigma2 = 3.5;  
  9. R2 = normrnd(mu2,sigma2,240,1);  
  10. %merge  
  11. R = [R1;R2];  
  12. %shuffel  
  13. r = randperm(size(R,1));  
  14. R=R(r,:);  
  15. figure,plot(R,'ro');  
  16. %用两个高斯函数生成数据,这两个高斯密度函数分别为N(0,2)和N(20,3.5),分别生成160和240个数据,也就是说其在混合模型中的权重分别为0.4 和 0.6.  
  17. [mu,sigma,phi] = mixGuassAnalysis(R,2,15);  

[plain] view plaincopy在CODE上查看代码片派生到我的代码片
  1. function   [mu,sigma,phi]  =mixGuassAnalysis(sampleMatrix,k,maxIteration,epsilon) 
  2. %UNTITLED Analysis the kguass distribution by the input matrix m 
  3. %sampleMatrix   the matrix of sample, in which each rowrepresents a sample. 
  4. %k  the number of guass distriubtion 
  5. %maxIteration   the max times of iteration,default 100. 
  6. %epsilon    the epsilon of loglikelihood,default 0.00001. 
  7. %check parameters 
  8. if nargin < 4 
  9.     epsilon = 0.00001; 
  10.     if nargin < 3 
  11.         maxIteration = 100; 
  12.     end 
  13. end 
  14. if k==1 
  15.     mu = mean(sampleMatrix); 
  16.     sigma = var(sampleMatrix); 
  17.     phi = 1; 
  18.     return; 
  19. end 
  20. [sampleNum,dimensionality] =size(sampleMatrix); 
  21. %init k guassdistribution 
  22. mu = zeros(k,dimensionality); 
  23. for i=1:1:dimensionality 
  24.     colVector = sampleMatrix(:,i); 
  25.     maxV = max(colVector); 
  26.     minV= min(colVector); 
  27.     mu(1,i) = minV; 
  28.     mu(k,i) = maxV; 
  29.     for j=2:1:k-1 
  30.         mu(j:i) =  minv+(j-1)*(maxV-minV)/(k-1); 
  31.     end 
  32. end 
  33. sigma =zeros(k,dimensionality,dimensionality); 
  34. for i=1:1:k 
  35.     d = rand(); 
  36.     sigma(i,:) = 10*d*eye(dimensionality); 
  37. end 
  38. phi = zeros(1,k); 
  39. for i=1:1:k 
  40.     phi(1,i) = 1.0/k; 
  41. end 
  42. %the weight of sample iis generated by guass distribution j 
  43. weight = zeros(sampleNum,k); 
  44. oldlikelihood = -inf; 
  45. for iter=1:maxIteration 
  46.     loglikelihood = 0; 
  47.     %E-step 
  48.     for i=1:1:sampleNum 
  49.         for j = 1:1:k 
  50.            weight(i,j)=mvnpdf(sampleMatrix(i,:),mu(j,:),reshape(sigma(j,:),dimensionality,dimensionality))*phi(j); 
  51.         end 
  52.         
  53.         sum = 0; 
  54.         for j = 1:1:k 
  55.             sum = sum+weight(i,j); 
  56.         end 
  57.         
  58.         loglikelihood = loglikelihood +log(sum); 
  59.      
  60.         for j = 1:1:k 
  61.             weight(i,j)=weight(i,j)/sum; 
  62.         end 
  63.     end 
  64.      
  65.     if abs(loglikelihood-oldlikelihood)<epsilon 
  66.         break; 
  67.     else 
  68.         oldlikelihood = loglikelihood; 
  69.     end 
  70.     %M-step 
  71.     
  72.     %updatephi 
  73.     for i=1:1:k 
  74.         sum = 0; 
  75.         for j=1:1:sampleNum 
  76.             sum = sum+weight(j,i); 
  77.         end 
  78.         phi(i) = sum/sampleNum; 
  79.     end 
  80.     
  81.     %updatemu 
  82.     for i=1:1:k 
  83.         sum = zeros(1,dimensionality); 
  84.         for j=1:1:sampleNum 
  85.             sum =  sum+weight(j,i)*sampleMatrix(j,:); 
  86.         end 
  87.         
  88.         mu(i,:) =  sum/(phi(i)*sampleNum); 
  89.     end 
  90.     mu1(iter) = mu(1); 
  91.     mu2(iter) = mu(2); 
  92.     %updatesigma 
  93.     for i=1:1:k 
  94.         sum =zeros(dimensionality,dimensionality); 
  95.         for j=1:1:sampleNum 
  96.             sum = sum+ weight(j,i)*(sampleMatrix(j,:)-mu(i,:))'*(sampleMatrix(j,:)-mu(i,:)); 
  97.         end 
  98.         sigma(i,:) = sum/(phi(i)*sampleNum); 
  99.     end 
  100.     
  101. end 
  102. sigma = sqrt(sigma); 
  103. end 
[plain] view plain copy
  1. function   [mu,sigma,phi]  =mixGuassAnalysis(sampleMatrix,k,maxIteration,epsilon)  
  2. %UNTITLED Analysis the kguass distribution by the input matrix m  
  3. %sampleMatrix   the matrix of sample, in which each rowrepresents a sample.  
  4. %k  the number of guass distriubtion  
  5. %maxIteration   the max times of iteration,default 100.  
  6. %epsilon    the epsilon of loglikelihood,default 0.00001.  
  7. %check parameters  
  8. if nargin < 4  
  9.     epsilon = 0.00001;  
  10.     if nargin < 3  
  11.         maxIteration = 100;  
  12.     end  
  13. end  
  14. if k==1  
  15.     mu = mean(sampleMatrix);  
  16.     sigma = var(sampleMatrix);  
  17.     phi = 1;  
  18.     return;  
  19. end  
  20. [sampleNum,dimensionality] =size(sampleMatrix);  
  21. %init k guassdistribution  
  22. mu = zeros(k,dimensionality);  
  23. for i=1:1:dimensionality  
  24.     colVector = sampleMatrix(:,i);  
  25.     maxV = max(colVector);  
  26.     minV= min(colVector);  
  27.     mu(1,i) = minV;  
  28.     mu(k,i) = maxV;  
  29.     for j=2:1:k-1  
  30.         mu(j:i) =  minv+(j-1)*(maxV-minV)/(k-1);  
  31.     end  
  32. end  
  33. sigma =zeros(k,dimensionality,dimensionality);  
  34. for i=1:1:k  
  35.     d = rand();  
  36.     sigma(i,:) = 10*d*eye(dimensionality);  
  37. end  
  38. phi = zeros(1,k);  
  39. for i=1:1:k  
  40.     phi(1,i) = 1.0/k;  
  41. end  
  42. %the weight of sample iis generated by guass distribution j  
  43. weight = zeros(sampleNum,k);  
  44. oldlikelihood = -inf;  
  45. for iter=1:maxIteration  
  46.     loglikelihood = 0;  
  47.     %E-step  
  48.     for i=1:1:sampleNum  
  49.         for j = 1:1:k  
  50.            weight(i,j)=mvnpdf(sampleMatrix(i,:),mu(j,:),reshape(sigma(j,:),dimensionality,dimensionality))*phi(j);  
  51.         end  
  52.          
  53.         sum = 0;  
  54.         for j = 1:1:k  
  55.             sum = sum+weight(i,j);  
  56.         end  
  57.          
  58.         loglikelihood = loglikelihood +log(sum);  
  59.       
  60.         for j = 1:1:k  
  61.             weight(i,j)=weight(i,j)/sum;  
  62.         end  
  63.     end  
  64.       
  65.     if abs(loglikelihood-oldlikelihood)<epsilon  
  66.         break;  
  67.     else  
  68.         oldlikelihood = loglikelihood;  
  69.     end  
  70.     %M-step  
  71.      
  72.     %updatephi  
  73.     for i=1:1:k  
  74.         sum = 0;  
  75.         for j=1:1:sampleNum  
  76.             sum = sum+weight(j,i);  
  77.         end  
  78.         phi(i) = sum/sampleNum;  
  79.     end  
  80.      
  81.     %updatemu  
  82.     for i=1:1:k  
  83.         sum = zeros(1,dimensionality);  
  84.         for j=1:1:sampleNum  
  85.             sum =  sum+weight(j,i)*sampleMatrix(j,:);  
  86.         end  
  87.          
  88.         mu(i,:) =  sum/(phi(i)*sampleNum);  
  89.     end  
  90.     mu1(iter) = mu(1);  
  91.     mu2(iter) = mu(2);  
  92.     %updatesigma  
  93.     for i=1:1:k  
  94.         sum =zeros(dimensionality,dimensionality);  
  95.         for j=1:1:sampleNum  
  96.             sum = sum+ weight(j,i)*(sampleMatrix(j,:)-mu(i,:))'*(sampleMatrix(j,:)-mu(i,:));  
  97.         end  
  98.         sigma(i,:) = sum/(phi(i)*sampleNum);  
  99.     end  
  100.      
  101. end  
  102. sigma = sqrt(sigma);  
  103. end  

迭代轨迹


问题

       本例用到的是伪随机组成的数据,本身具有一定的模型,最近结果可看出迭代次数只需要6次就达到下界。实际中将会有更多噪声及隐藏变量。

原创粉丝点击