EM算法--python代码和注意事项

来源:互联网 发布:公务员上岸经验知乎 编辑:程序博客网 时间:2024/05/16 00:47
一,EM-PYTHON 代码:
#! /usr/bin/env python#coding=utf-8'''author:zhaojiongEM算法初稿2016-4-28初始化三个一维的高斯分布,'''from numpy  import *import numpy as npimport matplotlib.pyplot as pltimport copy def init_em(x_num=2000):    '''    定义数据            '''        global  mod_num,mod_prob_arr,x_prob_mat,theta_mat,theta_mat_temp,x_mat,mod_prob_arr_test    mod_num=3    x_mat =zeros((x_num,1))    mod_prob_arr=[0.3,0.4,0.3] #三个状态    mod_prob_arr_test=[0.3,0.3,0.4]            x_prob_mat=zeros((x_num,mod_num))    #theta_mat =zeros((mod_num,2))    theta_mat =array([ [30.0,4.0],                       [80.0,9.0],                       [180.0,3.0]                    ])    theta_mat_temp =array([ [20.0,3.0],                            [60.0,7.0],                            [80.0,2.0]                            ])    for i in range(x_num):        if np.random.random(1)<=mod_prob_arr[0]:            x_mat[i,0] = np.random.normal()*math.sqrt(theta_mat[0,1]) + theta_mat[0,0]        elif np.random.random(1)<= mod_prob_arr[0]+mod_prob_arr[1]:            x_mat[i,0] = np.random.normal()*math.sqrt(theta_mat[1,1]) + theta_mat[1,0]        else :               x_mat[i,0] = np.random.normal()*math.sqrt(theta_mat[2,1]) + theta_mat[2,0]        return x_matdef plot_data(x_mat):    plt.hist(x_mat[:,0],200)    plt.show()        def e_step(x_arr):    x_row ,x_colum =shape(x_arr)    global  mod_num,mod_prob_arr,x_prob_mat,theta_mat,theta_mat_temp,mod_prob_arr_test    for i in range(x_row):        Denom = 0.0        for j in range(mod_num):            exp_temp=math.exp((-1.0/(2*(float(theta_mat_temp[j,1]))))*(float(x_arr[i,0]-theta_mat_temp[j,0]))**2)                        Denom += mod_prob_arr_test[j]*(1.0/math.sqrt(theta_mat_temp[j,1]))*exp_temp                for j in range(mod_num):            Numer = mod_prob_arr_test[j]*(1.0/math.sqrt(theta_mat_temp[j,1]))*math.exp((-1.0/(2*(float(theta_mat_temp[j,1]))))*(float(x_arr[i,0]-theta_mat_temp[j,0]))**2)#            if(Numer<1e-6):#                Numer=0.0            if(Denom!=0):               x_prob_mat[i,j] = Numer/Denom            else:                x_prob_mat[i,j]=0.0    return x_prob_matdef m_step(x_arr):    x_row ,x_colum =shape(x_arr)    global  mod_num,mod_prob_arr,x_prob_mat,theta_mat,theta_mat_temp,mod_prob_arr_test    for j in range(mod_num):        MU_K = 0.0        Denom = 0.0        MD_K=0.0        for i in range(x_row):            MU_K += x_prob_mat[i,j]*x_arr[i,0]            Denom +=x_prob_mat[i,j]                    theta_mat_temp[j,0] = MU_K / Denom         for i in range(x_row):            MD_K +=x_prob_mat[i,j]*((x_arr[i,0]-theta_mat_temp[j,0])**2)                theta_mat_temp[j,1] = MD_K / Denom        mod_prob_arr_test[j]=Denom/x_row                return theta_mat_tempdef main_run(iter_num=500,Epsilon=0.0001,data_num=2000):    init_em(data_num)    global  mod_num,mod_prob_arr,x_prob_mat,theta_mat,theta_mat_temp,x_mat,mod_prob_arr_test    theta_row ,theta_colum =shape(theta_mat_temp)    for i in range(iter_num):        Old_theta_mat_temp=copy.deepcopy(theta_mat_temp)        x_prob_mat=e_step(x_mat)        theta_mat_temp= m_step(x_mat)        if sum(abs(theta_mat_temp-Old_theta_mat_temp)) < Epsilon:           print "第 %d 次迭代退出" %i           break               return theta_mat_tempdef test(data_num):    testdata=init_em(data_num)    #print testdata     #print '\n'    plot_data(testdata)

二,注意事项

  2.1  高斯分布定义和概率密度 (注意参数形式)

         

  2.2  迭代公式 (注意方差的迭代和模型占比的参数迭代)

          

其中 方差的迭代是利用  新产生的均值带入的 。

  附件:相关EM算法的博客参考:

            1:http://blog.csdn.net/abcjennifer/article/details/8170378

            2:http://www.cnblogs.com/jerrylead/archive/2011/04/06/2006924.html

三 ,HMM ,EM变形,应用场景:时间序列预测  待续。。。

0 0
原创粉丝点击