机器学习实战代码详细注释之kNN算法

来源:互联网 发布:油菜花粉 知乎 编辑:程序博客网 时间:2024/05/21 23:31

近来在学习机器学习的八大算法,看到了《机器学习实战》这本好书,然而我感觉它的代码架构太过突然,好像难以让初学者迅速看懂到底每一步发生了什么,我决定将代码全部注释一遍
这个实现的难点在于这一段的取巧,使用了矩阵来整体处理数据,建议看不懂的话直接去debug模式看看矩阵长什么样子

请配合原书观看

from numpy import *import operatordef createDataSet():    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])    labels = ['A','A','B','B']    return group,labelsgroup,labels = createDataSet()##inX是被分类的向量,dataSet是数据集,labels是数据集标签,k是参数def classify(inX,dataSet,labels,k):    dataSetSize = dataSet.shape[0]    ##将矩阵作为分块子阵,创建dataSetSize行,1列的矩阵,并且整个矩阵减去数据集,整个矩阵变成了坐标差    diffMat = tile(inX,(dataSetSize,1))-dataSet    sqDiffMat = diffMat**2    ##将所有列向量求和    sqDistances = sqDiffMat.sum(axis=1)##axis就是沿着第二个轴的意思,也就是列    distances = sqDistances**0.5##现在得到了一个1行dataSet列的矩阵,代表的是每个数据和这个点的距离    ##利用argsort()方法得到排序后的下标    sortedDistIndicies = distances.argsort()    ##对k个邻居进行投票    classCount={}    for i in range(k):        voteIlabel = labels[sortedDistIndicies[i]]        ##这一步就可以直接利用map.get()中的默认值,创建不存在的key        classCount[voteIlabel] = classCount.get(voteIlabel,0)+1    sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse=True)    return sortedClassCount[0][0]print(classify([0,0],group,labels,3))
原创粉丝点击