机器学习入门算法及其java实现-EM(Expectation Maxium)算法

来源:互联网 发布:社区控烟网络会议记录 编辑:程序博客网 时间:2024/06/05 01:52

1、算法基本原理:

  • EM算法一般用于存在隐变量或潜在变量的概率模型,可以算是一种含有隐的概率模型参数的极大似然估计法;
  • 假设θ为模型的参数, 为模型的观测数据,γ模型中存在的隐藏变量,EM算法的是通过最大化观测数据logP(Y|θ)的方法来求出θ的极大似然估计,可以转化为表达式:θ^=argmaxθ(logP(Y|θ))
  • 经过转化,可以将问题转化为最大化E(γ)的问题,即θ^=argmaxγ(E(γ))

2、算法推导过程:

  • 根据极大似然法的原理,我们的目标是极大化观测数据Y关于参数θ的对数似然函数,即:
    L(θ)=logP(Y|θ)=logγP(Y,γ|θ)
    =log(λP(Y|γ,θ)P(Z|θ))
  • 因为EM算法是通过迭代的办法逐步接近极大L(θ)的,假设在第 i次迭代后θi,此我们希望能够使L(θ)L(θ(i))0
    L(θ)L(θi)=log(γP(Y|γ,θ)P(γ|θ))log(P(Y|θi)
    =log(γP(γ|Y,θi)P(Y|γ,θ)P(γ|θ)P(γ|Y,θi))logP(Y|θi)
    γP(γ|Y,θi)log(P(Y|γ,θ)P(γ|θ)P(γ|Y,θi))logP(Y|θi)
    =γP(γ|Y,θi)log(P(Y|γ,θ)P(γ|θ)P(γ|Y,θi)logP(Y|θi))
    B(θ,θi)=L(θi)+γP(γ|Y,θi)log(P(Y|γ,θ)P(γ|θ)P(γ|Y,θi)logP(Y|θi))
    L(θ)B(θ,θi)
    即函数B(θ,θi)L(θ)的一个下界,并且易知:L(θi)B(θi,θi),因此可以使B(θ,θi)增大的θ也可以使L(θ)增大,为了使L(θ)有尽可能大的增大,选择θi+1使B(θ,θi)打到极大,即:
    θ(i+1)=argmaxθB(θ,θi)
    上式可以改写为:
    θ(i+1)=argmaxθ(L(θi)+γP(γ|Y,θi)log(P(Y|γ,θ)P(γ|θ)P(γ|Y,θi)logP(Y|θi)))
    =argmaxθγP(γ|Y),θilog(P(Y|γ,θ)P(γ|θ))
    =argmaxθγP(γ|Y,θi)log(P(Y,γ|θ))
    =argmaxθQ(θ,θi)
    3、EM算法收敛性证明:根据对数函数函数性质:若P(Y|θi)单调递增且收敛到某一值则Q(θ,θi)收敛。单调性:
    P(Y|θ)=P(Y,γ|θ)P(γ|Y,θ)
    logP(Y|θ)=logP(Y,γ|θ)logP(γ|Y,θ)
    Q(θ,θi)=γlogP(Y,γ|θ)P(γ|Y,θi)
    H(θ,θi)=γlogP(γ|Y,θ)P(γ|Y,θi)
    于是对数似然函数可以写成:
    logP(Y|θ)=Q(θ,θi)H(θ,θi)
    上式中θ分别取为θiθi+1并相减,有:
    logP(Y|θi+1)logP(Y|θi)
    =[Q(θi+1,θi)Q(θi,θi)][H(θi+1,θi)H(θi,θi)]
    因为θi+1使Q(\theta,\theta^{i})达到极大,所以有:
    Q(θi+1,θi)Q(θi,θi)0
    其第2项,可以推导得出:
    H(θi+1,θi)H(θi,θi)
    =γ(logp(γ|Y,θi+1)P(γ|Y,θi))P(γ|Y,θi)
    log(γP(γ|Y,θi+1)P(γ|Y,θi)P(γ|Y,θi))
    =log(P(γ|Y,θi+1))=0
    又因为P(Y|θi)有界,所以L(θi)=log(P(Y|θi))收敛到某一值L

4、算法步骤:

  • 选择参数的初值θ0,开始迭代;
  • E步:记θi为第i次迭代参数θ的估计值,在第i次迭代的E步,计算:
    Q(θ,θi)=Eγ[logP(Y,γ|θ)|Y,θ]
    =γlog(P(Y,γ|θ)P(γ|Y,θi))
  • M步:求使Q(θ,θi)极大化的θ,确定第i+1次迭代的参数的估计值θi+1
    θi+1=argmaxθQ(θ,θi)

    -重复第E步和第M步,直到对于较小的正数ξ1ξ2,若满足 :
    ||θi+1θi||ξq
    ||Q(θi+1,θi)Q(θi,θi)||ξ2
    则停止迭代。
package binorandom;public class binomain {    public static void main(String[] args) {        int[] b=new int[1000];        for (int i=0;i<1000;i++){        b[i]=binorandom.getBinomial(1, 0.4);        }        int[] a=new int[1000];        for ( int i=0;i<999;i++){            if (b[i]==1){                a[i]=binorandom.getBinomial(1,0.5);            }            if(b[i]==0){                a[i]=binorandom.getBinomial(1,0.6);            }            System.out.print(a[i]+" ");        }        System.out.print(a[999]);    }}package binorandom;public class binorandom {    public static int getBinomial(int n, double p) {         int x = 0;         for(int i = 0; i < n; i++) {         if(Math.random() < p)          x++;         }         return x;        }}//生成数据集合package EMpackage;import java.util.Scanner;public class EMmain {    public static void main(String[] args){        System.out.println("请输入观测值个数");        Scanner input=new Scanner(System.in);        int datanumber=input.nextInt();        System.out.println("请输入观测值(0或者1):");        Scanner input1=new Scanner(System.in);        int[] obdata=new int[datanumber];        for(int i=0; i<datanumber;i++){         obdata[i]=input1.nextInt();        }        System.out.println("您输入的是:"+" ");        for (int b=0;b<datanumber-1;b++){            System.out.print(obdata[b]+" ");        }        System.out.println(obdata[datanumber-1]);        double[] original=new double[3];        original=ori.original();        double eq=ori.eq();        System.out.println("初始条件为:"+" "+original[0]+" "+original[1]+" "+original[2]);        System.out.println("停止条件为:"+" "+eq);        input1.close();        input.close();        double[] original1=new double[3];        original1=EM.original1(original, obdata, datanumber);           int x=0;        while (euclid(minus(original1,original))>eq){        original=original1;        original1=EM.original1(original,obdata,datanumber);        x=x+1;        }        System.out.println("pi="+original1[0]+"\n"+"p="+original1[1]+"\n"+"q="+original1[2]+"\n"+x);    }private static double euclid(double[] x) {    double sum=0;    for (int i=0;i<3;i++){        sum=sum+Math.pow(x[i], 2);    }    double euclid=Math.sqrt(sum);    return euclid;}private static double[] minus(double[] x,double[] y) {    double[] temp=new double[3];    for (int i=0;i<3;i++){        temp[i]=x[i]-y[i];    }    return temp;  }}package EMpackage;public class EM {    public static double[] original1(double[] original,int[] obdata,int datanumber){        double[] ybl=new double[datanumber];        double[] uybl=new double[datanumber];        double[] l=new double[datanumber];        double datanumber1=datanumber;        for (int i=0;i<datanumber;i++){            ybl[i]=(original[0]*Math.pow(original[1],obdata[i] )*Math.pow(1-original[1],1-obdata[i] ))/(original[0]*Math.pow(original[1],obdata[i])*Math.pow((1-original[1]),(1-obdata[i]))+(1-original[0])*Math.pow(original[2],obdata[i])*Math.pow((1-original[2]),(1-obdata[i])));            uybl[i]=obdata[i]*(original[0]*Math.pow(original[1],obdata[i] )*Math.pow(1-original[1],1-obdata[i] ))/(original[0]*Math.pow(original[1],obdata[i])*Math.pow((1-original[1]),(1-obdata[i]))+(1-original[0])*Math.pow(original[2],obdata[i])*Math.pow((1-original[2]),(1-obdata[i])));            l[i]=1;        }        double[] original1=new double[3];        original1[0]=(1/datanumber1)*sum(ybl,datanumber);        original1[1]=(sum(uybl,datanumber)/sum(ybl,datanumber));        original1[2]=(sum(ybl,datanumber)-sum(uybl,datanumber))/(sum(l,datanumber)-sum(ybl,datanumber));        return original1;       }    private static double sum(double[] ybl,int datanumber) {        double sum=0;        for (int i=0;i<datanumber;i++){            sum=sum+ybl[i];        }        return sum;    }}package EMpackage;import java.util.Scanner;public class ori{     public static double[] original(){        System.out.println("请输入初始条件条件:"+" ");    Scanner input=new Scanner(System.in);    double original[]=new double[3];    for(int d=0; d<3;d++){         original[d]=input.nextDouble();        }    return original;    }    public static double eq(){        System.out.println("请输入停止条件:"+" ");        Scanner input=new Scanner(System.in);        double eq=input.nextDouble();        return eq;    } }//EM算法主程序

实验结果及实例分析
这里写图片描述
多次运算结果对比:
原始系数pi,p,q(0.4、0.5、0.6):

初始迭代系数 (0.5、0.5、0.5) (0.4、0.4、0.4) (0.4、0.4、0.5) (0.5、0.4、0.6) (0.4、0.5、0.4) (0.5、0.4、0.5) (0.5、0.6、0.4) 运算结果 (0.5、0.73、0.32) (0.54、0.84、0.19) (0.55、0.84、0.19) (0.56、0.77、0.29) (0.56、0.85、0.19) (0.56、0.77、0.30) (0.56、0.76、0.30)

原始系数pi,p,q(0.5、0.5、0.5):

初始迭代系数 (0.4、0.4、0.4) (0.3、0.4、0.4) (0.4、0.4、0.5) (0.4、0.4、0.6) (0.4、0.5、0.4) (0.5、0.4、0.3) (0.5、0.6、0.4) 运算结果 (0.49、0.7、0.23) (0.49、0.85、0.14) (0.5、0.76、0.23) (0.49、0.75、0.24) (0.49、0.76、0.23) (0.49、1.02、-0.02) (0.5、0.53、0.45)

从以上两表不难看出EM算法受到初始迭代值的影响十分大,但是其优点在于需要的迭代次数少,收敛速度十分迅速。

原创粉丝点击