KNN近邻算法总结

来源:互联网 发布:淘宝运费险赔付多少钱 编辑:程序博客网 时间:2024/05/22 14:04

K-近邻算法

1.什么是K近邻算法

K近邻(k-Nearest NeighborKNN)分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。该方法的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。

 

2.分类结果的决定因素

(1)K为近邻的对象个数,结果影响取决于K的值。

(2)测试样本的准确度与分布情况。

如下图,绿色圆要被决定赋予哪个类,是红色三角形还是蓝色四方形?如果K=3,由于红色三角形所占比例为2/3,绿色圆将被赋予红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。由此也说明了KNN算法的结果很大程度取决于K的选择。


3.模型构建基础

近邻分类模型的算法基于“距离”二字,取得K个与其最为临近的样本坐标进行分类匹配,距离公式有欧式距离,曼哈顿距离,切比雪夫等等距离公式将近10多余种类。

 

(1)欧式距离

欧氏距离是最易于理解的一种距离计算方法,源自欧氏空间中两点间的距离公式。

(1)二维平面上两点a(x1,y1)b(x2,y2)间的欧氏距离:

(2)三维空间两点a(x1,y1,z1)b(x2,y2,z2)间的欧氏距离:

(3)两个n维向量a(x11,x12,…,x1n) b(x21,x22,…,x2n)间的欧氏距离:

 

(2)曼哈顿距离

从名字就可以猜出这种距离的计算方法了。想象你在曼哈顿要从一个十字路口开车到另外一个十字路口,驾驶距离是两点间的直线距离吗?显然不是,除非你能穿越大楼。实际驾驶距离就是这个曼哈顿距离。而这也是曼哈顿距离名称的来源,曼哈顿距离也称为城市街区距离(CityBlock distance)

(1)二维平面两点a(x1,y1)b(x2,y2)间的曼哈顿距离

(2)两个n维向量a(x11,x12,…,x1n) b(x21,x22,…,x2n)间的曼哈顿距离

绿色为:欧式距离 即两点最短距离

其他颜色为:曼哈顿距离

 

(3)切比雪夫距距离

在平面几何中,若二点p及q的直角坐标系坐标为及,则切比雪夫距离为:

玩过国际象棋的朋友或许知道,国王走一步能够移动到相邻的8个方格中的任意一个。那么国王从格子(x1,y1)走到格子(x2,y2)最少需要多少步?。你会发现最少步数总是max( | x2-x1 | , | y2-y1 | ) 步 。有一种类似的一种距离度量方法叫切比雪夫距离。

(1)二维平面两点a(x1,y1)与b(x2,y2)间的切比雪夫距离 

(2)两个n维向量a(x11,x12,…,x1n)与 b(x21,x22,…,x2n)间的切比雪夫距离   

 

4.模型构建方式

1)计算测试数据与各个训练数据之间的距离;

2)按照距离的递增关系进行排序;

3)选取距离最小的K个点;

4)确定前K个点所在类别的出现频率;

5)返回前K个点中出现频率最高的类别作为测试数据的预测分类。

利用欧式距离制作分类器

需求:

判断如下属性的人是不是小丽喜欢的人

依据三个属性来判断:

1.每年万游戏时间的百分比

2.每年喝饮料的公升数

3.每年出行旅游的旅程数

类别为:

1.喜欢 2.非常喜欢 3.不喜欢

小丽给出了自己喜欢的一些标准数据

 


python3实现的代码 

#-*-coding:utf-8 -*-from numpy import *import operatorimport matplotlibimport matplotlib.pyplot as pltfrom mpl_toolkits.mplot3d import Axes3Dimport operator#读取文件数据def file2matrix(filename):    fr=open(filename)#打开文件    arrayOLines=fr.readlines()#将文件读入一个字符串列表,在列表中每个字符串就是一行    numberOFlines=len(arrayOLines)#读入字符串列表的数量,即文件的行数    returnMat=zeros((numberOFlines,3))#创建numberOFlines行3列的numpy矩阵    classLabelVector=[]#创建标签数组    index=0    for line in arrayOLines:        line=line.strip()#删除每行两侧的空格        listFormLine=line.split('\t')#将每行的字符串列表以‘\t’为间隔分为序列        returnMat[index,:]=listFormLine[0:3]#将每一行数据存入returnMat数组中        classLabelVector.append(int(listFormLine[-1]))#将每一行的最后一列即标签存入classLabelVector中        index+=1    return returnMat,classLabelVector#返回样本特征矩阵与标签向量#归一化数据def autoNorm(dataset):    minVals=dataset.min(0)#列中最小值    maxVals=dataset.max(0)#列中的最大值    ranges=maxVals-minVals    normDataSet=zeros(shape(dataset))#创建与样本特征矩阵同大小的数值全是0的矩阵    m=dataset.shape[0]#m是dataset的列数,即样本特征的维数    normDataSet=dataset-tile(minVals,(m,1))#tile()是将minVals复制成m行3列,即与dataset同大小的矩阵    normDataSet=normDataSet/tile(ranges,(m,1))    return normDataSet,ranges,minVals#返回归一化的样本特征矩阵,范围,每列最小值#K近邻分类def classify(inX,dataSet,labels,k):    dataSetSize=dataSet.shape[0]#读取样本的特征矩阵的维数    diffMat=tile(inX,(dataSetSize,1))-dataSet#计算测试数据与每一个样本特征矩阵的欧氏距离    sqDiffMat=diffMat**2    sqDistances=sqDiffMat.sum(axis=1)#每一行的相加    distances=sqDistances**0.5    sortedDistIndicies=distances.argsort()#测试数据与每一个样本特征矩阵的欧氏距离从小到大排列后,将原样本的索引值赋值给sortedDistIndicies    classCount={}#创建字典    for i in range(k):        voteIlabel=labels[sortedDistIndicies[i]]#将sortedDistIndicies相对应的标签赋值给voteIlabel        classCount[voteIlabel]=classCount.get(voteIlabel,0)+1#get是取字典里的元素,                              #如果之前这个voteIlabel是有的,那么就返回字典里这个voteIlabel里的值,                              #如果没有就返回0(后面写的),这行代码的意思就是算离目标点距离最近的k个点的类别,                        #这个点是哪个类别哪个类别就加1        sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)#key=operator.itemgetter(1)的意思是按照字典里的第一个排序,        #{A:1,B:2},要按照第1个(AB是第0个),即‘1’‘2’排序。reverse=True是降序排序        return sortedClassCount[0][0]#返回发生频率最高的元素标签def classifyPerson():   resultList=['not at all','in small doses','in large doses']   percentTats=float(input("percentage of time spent playing vidio games?"))   ffMines=float(input("frequent flier miles earned per year?"))   iceCream=float(input("liters of ice cream consumed per year?"))   datingDataMat,datingLabels=file2matrix('datingTestSet.txt')   normMat,ranges,minVals=autoNorm(datingDataMat)   inArr=array([ffMines,percentTats,iceCream])   classifierResult=classify((inArr-minVals)/ranges,normMat,datingLabels,3)   print("你对这个人的喜欢程度:",resultList[classifierResult - 1])   dataArr = array(datingDataMat)   n = shape(dataArr)[0]   xcord1 = []; ycord1 = [];zcord1=[]   xcord2 = []; ycord2 = [];zcord2=[]   xcord3 = []; ycord3 = [];zcord3=[]   for i in range(n):      if int(datingLabels[i])== 1:         xcord1.append(dataArr[i,0]); ycord1.append(dataArr[i,1]);zcord1.append(dataArr[i,2])      elif int(datingLabels[i])== 2:         xcord2.append(dataArr[i,0]); ycord2.append(dataArr[i,1]);zcord2.append(dataArr[i,2])      elif int(datingLabels[i])== 3:         xcord3.append(dataArr[i,0]); ycord3.append(dataArr[i,1]);zcord3.append(dataArr[i,2])   fig = plt.figure()   ax = fig.add_subplot(111, projection='3d')   ax.set_title('KNN')   type1=ax.scatter(xcord1, ycord1,zcord1, s=30, c='red', marker='s')   type2=ax.scatter(xcord2, ycord2,zcord2, s=30, c='green',marker='o')   type3=ax.scatter(xcord3, ycord3,zcord3, s=30, c='b',marker='+')   ax.scatter(inArr[0], inArr[1],inArr[2], s=100, c='k', marker='8')   plt.figtext(0.02,0.92,'class1:Did Not Like',color='red')   plt.figtext(0.02,0.90,'class2:Liked in Small Doses',color='green')   plt.figtext(0.02,0.88,'class3:Liked in Large Doses',color='b')   ax.set_zlabel('frequent flier miles earned per year')   ax.set_ylabel('percentage of time spent playing vidio games')   ax.set_xlabel('liters of ice cream consumed per year')   plt.show()

测试代码

#coding:utf-8from numpy import *import operatorfrom collections import Counterimport matplotlibimport matplotlib.pyplot as pltimport d3d3.classifyPerson()

测试样本

409208.3269760.9539521144887.1534691.6739042260521.4418710.80512437513613.1473940.4289643383441.6697880.13429637299310.1417401.0329553359486.8307921.21319214266613.2763690.5438801674978.6315770.74927833548312.2731691.5080531502423.7234980.8319173632758.3858791.669485355694.8754350.7286582510524.6800980.62522437737215.2995700.3313513436731.8894610.1912833613647.5167541.26916436967314.2391950.2613333156690.0000001.25018522848810.5285551.304844164873.5402650.8224832377082.9915510.8339203226205.2978650.6383062287826.5938030.1871081197392.8167601.68620923678812.4582580.649617157410.0000001.6564182285679.9686480.731232168081.3648380.6401032


5.优缺点总结
优点: 
1.简单好用,容易理解,精度高,理论成熟,既可以用来做分类也可以用来做回归; 
缺点: 
1.一般数值很大的时候不用这个,计算量太大。但是单个样本又不能太少 否则容易发生误分。 
2.最大的缺点是无法给出数据的内在含义,只是单纯意义上的寻找临近分类。

 

 

 

 

原创粉丝点击