Machine Learning In Action -- kNN (k Nearest Neighbors)

来源:互联网 发布:防蹭网软件怎么用 编辑:程序博客网 时间:2024/05/07 20:45

k最近邻分类算法:k Nearest Neighbors

k最近邻分类算法是最简单的机器学习算法之一,主要应用在对未知事物的识别。

主要思想:

如果一个样本在特性空间的k个最相似样本的大多数都以属于同一个类别,那么这个样本也属于该类别。

算法优点

  • 算法准确度较高
  • 对数据不作假设
  • 适用于交叉或重叠较多的待分样本集

算法缺点

  • 计算量大
  • 内存消耗大
  • 样本数量不平衡时易受影响

示例图

这里写图片描述
绿色圆点表示未分类的样本,令其为A。如果我们把k设成3, 那么离A最近的3个样本就是黑色圆中所包含的样本。由于红色三角形有2个,而蓝色正方形只有一个。所以最终的分类结果为红色三角形。

代码 Python

这里主要参考了Machine Learning In Action这本书中的代码。其kNN的具体python实现代码如下。
Note: 运行代码之前,请安装好matplotlib。

import numpy as npimport operatorfrom os import listdirdef knn_classify(vec_in, data_set, labels, k):    rows = data_set.shape[0]    diffs = np.tile(vec_in, (rows, 1)) - data_set    sq_diffs = diffs ** 2    sq_distances = sq_diffs.sum(axis = 1)    distances = sq_distances ** 0.5    sorted_dist_indices = distances.argsort()    class_cnt = {}    for i in range(k):        vote_label = labels[sorted_dist_indices[i]]        class_cnt[vote_label] = class_cnt.get(vote_label, 0) + 1    sorted_class_cnt = sorted(class_cnt.iteritems(), key=operator.itemgetter(1), reverse=True)    return sorted_class_cnt[0][0]

实例:手写数字0-9的识别

在Machine Learning In Action中,有一个将kNN算法用于手写数字识别的例子。该例子中的training set共包含2000个样本,也就是说每个数字大约有200个样本。每个数字通过处理,以32x32大小的0/1数字组成。
应用kNN算法,先将每个手写数字变成1x1024的向量。保存在training set数组中。当未判别的数字出现时,用该数字的向量于training set中的每一个向量计算距离,选其中的top k个样本进行投票,最后哪个类别的数量最多,就将该数字判定成那个类别。
这里写图片描述
具体代码:

def img2vector(filename):    ret_vec = np.zeros((1, 1024))    fp = open(filename)    for i in range(32):        line_str = fp.readline()        for j in range(32):            ret_vec[0, 32*i + j] = int(line_str[j])    fp.close()    return ret_vecdef hand_writing_test():    training_path = './digits/trainingDigits'    test_path = './digits/testDigits'    training_files = listdir(training_path)    m = len(training_files)    training_mat = np.zeros((m, 1024))    labels = []    for i in range(m):        filename = training_files[i]        class_num_str = filename.split('_')[0]        labels.append(class_num_str)        training_mat[i, :] = img2vector(training_path + '/%s' % filename)    # print training_mat    test_files = listdir(test_path)    m = len(test_files)    err_cnt = 0.0    for i in range(m):        filename = test_files[i]        class_num_str = filename.split('_')[0]        vec_in = img2vector(test_path + '/%s' % filename)        ret = knn_classify(vec_in, training_mat, labels, 3)        if str(ret) != class_num_str:            print "file %s, classifier result: %s, real ans: %s." % (filename, ret, class_num_str)            err_cnt += 1.0    print "The total number of error is %d." % err_cnt    print "The failure rate is %f." % (err_cnt / float(m))if __name__ == '__main__':    hand_writing_test()

这里写图片描述

0 0
原创粉丝点击