机器学习实战--k近邻算法(续)

来源:互联网 发布:超旗网络 编辑:程序博客网 时间:2024/04/29 14:42

    接续上次的k近邻算法,上一篇博文地址,这里用一个新的实例进行算法的验证。

一个手写数字识别系统,为了处理方便,书中已经将样本训练好,并转化为txt格式方便后续处理。具体格式如下:

这是0的其中一种表示方式。

    我们的目标是输入一个类似的数字,系统能够识别出来即可。

————————————————————————————————————————

一、具体算法实现:

# coding: UTF-8import numpy as npimport osimport operator# 将txt格式的数字转化为1*1024的向量格式def img2vector(filename):    return_vector = np.arange(1024)    with open(filename) as f:        for i in range(32):            line = f.readline()            for j in range(32):                return_vector[32 * i + j] = int(line[j])    return return_vector# 分类算法实现def classify(input_vector, trained_mat, class_list, k=3):    # 欧式距离计算    rows = trained_mat.shape[0]    input_mat = np.tile(input_vector, (rows, 1))    diff_mat = input_mat - trained_mat    squ_mat = diff_mat ** 2    sum_mat = squ_mat.sum(axis=1)    d = sum_mat ** 0.5    # 根据距离排序,获得排序后的索引    sorted_d = d.argsort()    # 创建用来统计某一类标签的字典    class_count = {}    for i in xrange(k):        class_label = class_list[sorted_d[i]]        class_count[class_label] = class_count.get(class_label, 0) + 1    # 根据统计得到的类别数量,进行排序,返回一个包含元组的列表,[(),(),...()]    sorted_class = sorted(class_count.iteritems(), key=operator.itemgetter(1), reverse=True)    return sorted_class[0][0]# 用N个行向量构成训练矩阵# 通过txt的命名获得分类def vector2mat():    training_file_list = os.listdir('./trainingDigits')    rows = len(training_file_list)    trained_mat = np.zeros((rows, 1024))    class_list = []    for index, each_file in enumerate(training_file_list):        digits, _ = each_file.split('.')        class_list.append(digits.split('_')[0])        trained_mat[index, :] = img2vector('./trainingDigits/%s' % each_file)    print trained_mat    return trained_mat, class_list# 系统错误率测试# 通过另一组test数据作为测试样本输入def handwriting_test(trained_mat, class_list):    test_file_list = os.listdir('./testDigits')    test_num = len(test_file_list)    err_count = 0.0    for each_file in test_file_list:        # 去掉.txt的后缀        digits = each_file.split('.')[0]        # 得到已知分类,known_label类型应该与class_list中元素类别一致        known_label = digits.split('_')[0]        input_vector = img2vector('./testDigits/%s' % each_file)        classify_result = classify(input_vector, trained_mat, class_list)        print "Predict:%s\tReal answer:%s\n" % (classify_result, known_label)        if known_label != classify_result:            err_count += 1.0    print "total err num:%d" % err_count    print "err rate:%.2f" % (err_count / (float(test_num)))def main():    trained_mat, class_list = vector2mat()    handwriting_test(trained_mat, class_list)if __name__ == '__main__':    main()

————————————————————————————————————————

二、结果分析


错误率在1%左右,说明识别准确率还是挺高的;但是因为每次输入一个样本,都要计算与所有训练样本之间的欧式距离,运算量还是挺大的,速度上稍显不足。

0 0
原创粉丝点击