EM算法学习

来源:互联网 发布:软件开发基础知识 编辑:程序博客网 时间:2024/06/02 02:24

EM算法学习记录

  • 开始学机器学习,٩(๑>◡<๑)۶

算法思路

参照 https://www.cnblogs.com/jerrylead/archive/2011/04/06/2006936.html

实现代码

import mathimport numpy# 函数原型:# Jensen 不等式# f_2(x) >= 0 (凹函数);E(F(X)) >= F(E(x))class gaussian_distribution:    def __init__(self,mu,sigma):        # 高斯模型        gaussian_distribution.mu = mu        # sigma代表标准差(math.sqrt(方差))        gaussian_distribution.sigma = sigma        # 可能性        gaussian_distribution.probability = 0.0    def get_value(self,x):        # 高斯模型        if type(x) == type(list()):                        value = numpy.exp(-((numpy.mat(x) - numpy.mat(self.mu))*(numpy.mat(x) - numpy.mat(self.mu)).transpose()).sum()/(2*self.sigma**2))            return value/(math.sqrt(2*numpy.pi)*self.sigma)        # 这里可以增加对于矩阵的判断         else :            # 二维            return numpy.exp(-(x - self.mu)**2/2*self.sigma**2)/(math.sqrt(2*numpy.pi)*self.sigma)    def change_mu(self,mu):        self.mu = mu    def change_sigma(self,sigma):        self.sigma = sigma    # receive a list of the value    # return a list include average value    def get_average_value(self,value_list):        sum = numpy.zeros((1,len(value_list[0])))        for list_reader in value_list:            sum += numpy.mat(list_reader)        sum /= len(value_list)        return sum.tolist()[0]    def get_sigma(self,value_list):        sum = 0.0        for list_reader in value_list:            sum += numpy.mat(list_reader)*numpy.mat(list_reader).transpose()        sum /= len(value_list)        return math.sqrt(sum)class EM_gauss:    def __init__(self,k,db):        EM_gauss.gauss_dis_list = list()        EM_gauss.data_list = db        # 初始化各个模型        # k代表初始有k个高斯模型        for reader in range(k):            gauss_dis = gaussian_distribution(0,1)            gauss_dis.mu = gauss_dis.get_average_value(db)            gauss_dis.sigma = gauss_dis.get_sigma(db)            gauss_dis.probability = 1/k            EM_gauss.gauss_dis_list.append(gauss_dis)    def get_data_probability(self,data_index,model_index):        value = self.gauss_dis_list[model_index].probability*self.gauss_dis_list[model_index].get_value(self.data_list[data_index])        sum = 0.0        for gauss_dis in self.gauss_dis_list:            sum += gauss_dis.probability*gauss_dis.get_value(self.data_list[data_index])        return value/sum    # 模型存在的可能性    def get_model_probability(self,model_index):        sum = 0.0        for i in range(len(self.data_list)):            sum += self.get_data_probability(i-1,model_index)        return sum    def round(self):        # E        model_probability_list = list()        # 每个模型获取它存在的可能性        for i in range(len(self.gauss_dis_list)):            model_probability_list.append(self.get_model_probability(i-1))        # M        # 维护μ        for i in range(len(self.gauss_dis_list)):            mu = numpy.zeros((1,len(self.gauss_dis_list[0].mu)))            for j in range(len(self.data_list)):                mu += self.get_data_probability(j-1,i-1)*numpy.mat(self.data_list[j-1])            self.gauss_dis_list[i-1].mu = (mu/model_probability_list[i-1]).tolist()[0]        # 维护σ        for i in range(len(self.gauss_dis_list)):            sigma = 0.0            for j in range(len(self.data_list)):                sigma += self.get_data_probability(j-1,i-1)*(numpy.mat(self.data_list[j-1])-self.gauss_dis_list[i-1].mu)*(numpy.mat(self.data_list[j-1])-self.gauss_dis_list[i-1].mu).transpose()            self.gauss_dis_list[i-1].sigma = (sigma/model_probability_list[i-1]).tolist()[0][0]        # 维护每个数据符合高斯模型的可能性        for i in range(len(self.gauss_dis_list)):            self.gauss_dis_list[i-1].probability = model_probability_list[i-1]/len(self.data_list)    def rounds(self,error = 0.0001):        # 多轮运算,直到获得误差小于1-error的模型        record_last = list()        null_list = list()        for list_reader in self.gauss_dis_list:            if list_reader.sigma == 0:                null_list.append(elf.gauss_dis_list.index(list_reader))                continue            record_last += list_reader.mu            record_last.append(list_reader.sigma)        while null_list != []:            index = null_list.pop()            del self.gauss_dis_list[index]        self.round()        while not self.check_gauss_equal(record_last):            # for i in self.gauss_dis_list:            #     print(i.mu)            #     print(i.sigma)            #     print(i.probability)            record_last = []            for list_reader in self.gauss_dis_list:                if list_reader.sigma == 0:                    null_list.append(self.gauss_dis_list.index(list_reader))                    continue                record_last += list_reader.mu                record_last.append(list_reader.sigma)            while null_list != []:                index = null_list.pop()                del self.gauss_dis_list[index]            self.round()        # for i in self.gauss_dis_list:        #     print(i.mu)        #     print(i.sigma)        #     print(i.probability)    # 需保证两个list长度相同    # 检查两个高斯模型是否等价    def check_gauss_equal(self,list_a,error = 0.0001):        list_b = list()        for list_reader in self.gauss_dis_list:            list_b += list_reader.mu            list_b.append(list_reader.sigma)        for i in range(len(list_a)):            if abs(list_a[i-1] - list_b[i-1]) > error:                return False        return True# 数据的标准格式 # num_x num_y num_z ...# ...# (均为数字)class point_data_reader:    # 获取点类数据    file_name = str()    def __init__(self,file_name):        point_data_reader.file_name = file_name    def get_data_list(self,num_lost):        file_ = open(self.file_name,'r+')        db = list()               for line in file_.readlines():            reader_list = list()            for reader_ in line.split():                try:                    eval(reader_)                except:                    reader_list.append(num_lost)                  else:                    reader_list.append(eval(reader_))                 db.append(reader_list)              file_.close()        return dbif __name__ == '__main__':    # num_lost为填充丢失数据    db = point_data_reader('text.dat').get_data_list(num_lost = 1.0)    # 2为模型个数    em = EM_gauss(2,db)    em.rounds() 
原创粉丝点击