机器学习作业6
来源:互联网 发布:淘宝渔具代理 编辑:程序博客网 时间:2024/06/13 22:06
EM算法和朴素贝叶斯
上节课老师讲解了EM算法,然后要求我们使用EM算法完成一个低配版的朴素贝叶斯分类器。说实话网上的EM算法介绍的都比较抽象,对于数学并不是很好的我来说,看起来遇到了很大的障碍。对于EM算法的详细介绍可以参考 emma_zhang 的博文 机器学习之EM算法,下面我简单讲一下自己对于朴素贝叶斯分类器中EM算法的理解。
EM算法和朴素贝叶斯
在朴素贝叶斯中,数据的各个分量是相互独立的。如果有一维分量因为某些原因不可观测或数据缺失(举书上的例子:在西瓜数据集中,如果西瓜的根蒂掉落了,就无法观测到西瓜根蒂的形态,因此有些样本的“根蒂”维度的数据是缺失的),那么可以使用EM算法来对该分量的参数进行估计。这个缺失的分量就叫隐变量。
考虑较为简单的情况,即隐变量仅有一维,且数据服从高斯分布。在实验中,假设最后一维为隐变量,则将该维度数据从数据集中分离出来,设为
其中,前面的1到m个数据为已知数据,后面的
得到第
重复上述步骤直到收敛,即可得到较为精确的
Python代码实现
弄明白原理之后,用代码实现还是不难的。本次实验使用UCI的Iris数据集,数据维度为4,设前面3个维度数据正常,第4个维度存在数据缺失(50%),则首先对数据进行预处理,然后构造低配版的朴素贝叶斯分类器,在对最后一维数据进行处理时,仅使用其中一半的数据,然后使用EM算法估算其均值和方差。代码如下所示:
import numpy as npdef em_algorithm(data, valid_count, total_count, eps=1e-4): # data: 输入的一维数组,valid_count: 有效样本数 # total_count: 样本总数,eps: 收敛所需精度 # avg: 隐变量的均值,theta: 隐变量的方差 valid_data = data[0:valid_count] avg = np.sum(valid_data) / total_count theta = np.sum(np.square(valid_data)) / total_count - avg while True: s1 = np.sum(valid_data) + avg * (total_count - valid_count) s2 = np.sum(np.square(valid_data)) + (avg * avg + theta) * (total_count - valid_count) new_avg = s1 / total_count new_theta = s2 / total_count - new_avg * new_avg if new_avg - avg <= eps and new_theta - theta <= eps: break else: avg, theta = new_avg, new_theta return avg, thetadef elderly_man(dtype1, dtype2, latent_idx): # build NAIVE bayesian avg, var = [], [] for idx in range(latent_idx): # 对隐变量之前的数据,正常计算其均值和方差 # dim_type1 和 dim_type2 表示多维数据中的一维 dim_type1, dim_type2 = dtype1[:, idx], dtype2[:, idx] avg.append([np.average(dim_type1), np.average(dim_type2)]) var.append([np.var(dim_type1), np.var(dim_type2)]) # 假设维度 3 的数据为隐变量,只有一半的数据是可观测的 # 使用EM算法估计其均值和方差 em_avg_type1, em_var_type1 = em_algorithm(data_type1[:40, latent_idx], 20, 40) em_avg_type2, em_var_type2 = em_algorithm(data_type2[:40, latent_idx], 20, 40) # 将估计得到的均值和方差加入到数组中,并返回 avg.append([em_avg_type1, em_avg_type2]) var.append([em_var_type1, em_var_type2]) return avg, vardef calc_gaussian(x, avg, var): # 高斯分布函数 t = 1.0 / np.sqrt(2 * np.pi * var) return t * np.exp(-np.square(x - avg) / (2.0 * var))if __name__ == '__main__': data_str = open('Data/iris.data').readlines() data_type1 = np.ndarray([50, 4], np.float32) data_type2 = np.ndarray([50, 4], np.float32) for idx in range(50): data_type1[idx] = data_str[idx].strip('\n').split(',')[0:4] for idx in range(50, 100): data_type2[idx - 50] = data_str[idx].strip('\n').split(',')[0:4] a, v = elderly_man(data_type1[:40], data_type2[:40], 3) # 构造测试数据集,correct_times 表示测试结果准确的数据条数 data_test = np.concatenate((data_type1[40:], data_type2[40:])) correct_times = 0 for data_idx in range(len(data_test)): data = data_test[data_idx] # 数据集两类数据相同,因此先验概率均为0.5 val_type1, val_type2 = 0.5, 0.5 for idx in range(4): # 朴素贝叶斯计算 val_type1 *= calc_gaussian(data[idx], a[idx][0], v[idx][0]) val_type2 *= calc_gaussian(data[idx], a[idx][1], v[idx][1]) # 前10条数据为类型1,后10条数据为类型2 if val_type1 > val_type2 and data_idx < 10: correct_times += 1 elif val_type1 < val_type2 and data_idx >= 10: correct_times += 1 print("Number: %2d, Type1: %f, Type2: %f" % (data_idx + 1, val_type1, val_type2)) print("Accuracy: %.1f%%" % (correct_times * 5))
程序输出结果如下:
Number: 1, Type1: 3.068372, Type2: 0.000000
Number: 2, Type1: 0.006664, Type2: 0.000000
Number: 3, Type1: 0.614423, Type2: 0.000000
Number: 4, Type1: 0.001331, Type2: 0.000000
Number: 5, Type1: 0.026249, Type2: 0.000000
Number: 6, Type1: 1.754030, Type2: 0.000000
Number: 7, Type1: 2.557216, Type2: 0.000000
Number: 8, Type1: 2.103980, Type2: 0.000000
Number: 9, Type1: 3.413123, Type2: 0.000000
Number: 10, Type1: 5.123086, Type2: 0.000000
Number: 11, Type1: 0.000000, Type2: 0.363503
Number: 12, Type1: 0.000000, Type2: 0.513148
Number: 13, Type1: 0.000000, Type2: 0.429917
Number: 14, Type1: 0.000000, Type2: 0.000804
Number: 15, Type1: 0.000000, Type2: 0.582627
Number: 16, Type1: 0.000000, Type2: 0.450959
Number: 17, Type1: 0.000000, Type2: 0.642538
Number: 18, Type1: 0.000000, Type2: 0.743852
Number: 19, Type1: 0.000000, Type2: 0.000821
Number: 20, Type1: 0.000000, Type2: 0.630034
Accuracy: 100.0%
看得出来在简单数据集上,准确率还是很高的。那么这次作业就到这里了,源代码以及数据可以点击这里下载。完结撒花!
- 机器学习作业6
- 【机器学习】作业6-EM算法
- coursera 机器学习作业
- 机器学习作业1
- 机器学习作业2
- 机器学习作业笔记
- 机器学习作业3
- 机器学习作业4
- 机器学习作业5
- 机器学习作业7
- 机器学习作业8
- 【机器学习】作业8
- 机器学习作业9
- 机器学习技法第一次作业
- 机器学习技法作业7
- 机器学习技法第二次作业
- 机器学习基石第二次作业
- 机器学习技法第三次作业
- Google NMT 阅读笔记
- Trafodion 集成R实现数据可视化
- mybatis 连续日期统计
- 小随笔——PHP数据库数据处理、MD5
- 60个电子行业技术网站
- 机器学习作业6
- 【sail】第三篇MybatisPlus的配置以及FreeMarker的配置
- 数据库查出来的明明是时间返回却变成一串无规律的数字。解决方法 /** * 时间戳转时间格式 * @param jsondate 得到的number 型时间数 */ function
- 调试maxxaudio 新唐科技效果IC I2C通讯程序
- 海莲花团伙利用MSBuild机制免杀样本分析
- UUID 16bit和128bit切换
- 小马哥---山寨高仿苹果x 主板型号s306 机型图示展示
- JS实现倒计时操作
- 设计模式是什么(对设计模式的理解)