机器学习实战--k近邻算法(续)
来源:互联网 发布:超旗网络 编辑:程序博客网 时间:2024/04/29 14:42
接续上次的k近邻算法,上一篇博文地址,这里用一个新的实例进行算法的验证。
一个手写数字识别系统,为了处理方便,书中已经将样本训练好,并转化为txt格式方便后续处理。具体格式如下:
这是0的其中一种表示方式。
我们的目标是输入一个类似的数字,系统能够识别出来即可。
————————————————————————————————————————
一、具体算法实现:
# coding: UTF-8import numpy as npimport osimport operator# 将txt格式的数字转化为1*1024的向量格式def img2vector(filename): return_vector = np.arange(1024) with open(filename) as f: for i in range(32): line = f.readline() for j in range(32): return_vector[32 * i + j] = int(line[j]) return return_vector# 分类算法实现def classify(input_vector, trained_mat, class_list, k=3): # 欧式距离计算 rows = trained_mat.shape[0] input_mat = np.tile(input_vector, (rows, 1)) diff_mat = input_mat - trained_mat squ_mat = diff_mat ** 2 sum_mat = squ_mat.sum(axis=1) d = sum_mat ** 0.5 # 根据距离排序,获得排序后的索引 sorted_d = d.argsort() # 创建用来统计某一类标签的字典 class_count = {} for i in xrange(k): class_label = class_list[sorted_d[i]] class_count[class_label] = class_count.get(class_label, 0) + 1 # 根据统计得到的类别数量,进行排序,返回一个包含元组的列表,[(),(),...()] sorted_class = sorted(class_count.iteritems(), key=operator.itemgetter(1), reverse=True) return sorted_class[0][0]# 用N个行向量构成训练矩阵# 通过txt的命名获得分类def vector2mat(): training_file_list = os.listdir('./trainingDigits') rows = len(training_file_list) trained_mat = np.zeros((rows, 1024)) class_list = [] for index, each_file in enumerate(training_file_list): digits, _ = each_file.split('.') class_list.append(digits.split('_')[0]) trained_mat[index, :] = img2vector('./trainingDigits/%s' % each_file) print trained_mat return trained_mat, class_list# 系统错误率测试# 通过另一组test数据作为测试样本输入def handwriting_test(trained_mat, class_list): test_file_list = os.listdir('./testDigits') test_num = len(test_file_list) err_count = 0.0 for each_file in test_file_list: # 去掉.txt的后缀 digits = each_file.split('.')[0] # 得到已知分类,known_label类型应该与class_list中元素类别一致 known_label = digits.split('_')[0] input_vector = img2vector('./testDigits/%s' % each_file) classify_result = classify(input_vector, trained_mat, class_list) print "Predict:%s\tReal answer:%s\n" % (classify_result, known_label) if known_label != classify_result: err_count += 1.0 print "total err num:%d" % err_count print "err rate:%.2f" % (err_count / (float(test_num)))def main(): trained_mat, class_list = vector2mat() handwriting_test(trained_mat, class_list)if __name__ == '__main__': main()
————————————————————————————————————————
二、结果分析
错误率在1%左右,说明识别准确率还是挺高的;但是因为每次输入一个样本,都要计算与所有训练样本之间的欧式距离,运算量还是挺大的,速度上稍显不足。
0 0
- K近邻算法(机器学习实战)
- 机器学习实战--k近邻算法(续)
- 机器学习实战之K-近邻算法
- 机器学习实战笔记 K近邻算法
- 《机器学习实战》之K-近邻算法
- 机器学习实战-k近邻算法
- 机器学习实战 k-近邻算法
- 【机器学习实战】-k近邻算法
- 《机器学习实战》—K-近邻算法
- 【机器学习实战一:K-近邻算法】
- 机器学习实战(k-近邻算法)
- 机器学习实战笔记:K近邻算法
- 机器学习实战笔记 k-近邻算法
- 机器学习实战之k-近邻算法
- 机器学习实战--k近邻算法
- 机器学习实战:K近邻算法(kNN)
- 【机器学习实战02】k-近邻算法
- 机器学习实战-K-近邻算法
- CIImage CGImage UIImage 区别
- 复选框的全选与取消全选
- c++成员函数的重载、覆盖、隐藏区别
- Spark 1.6 内存管理模型( Unified Memory Management)分析
- oracleconsoleorcl服务不能启动原因分析
- 机器学习实战--k近邻算法(续)
- Android 上拉固定某一布局到顶部
- FragmentPagerAdapter
- java中this关键字
- hdoj1232
- Java MD5加密算法
- Android属性动画propertAnimation
- Apache Spark Jobs 性能调优(二)
- css cursor属性