K近邻法之kd树及其Python实现

来源:互联网 发布:2017年中国m2数据 编辑:程序博客网 时间:2024/06/11 02:59

作为机器学习中一种基本的分类方法,K近邻(KNN)法是一种相对简单的方法。其中一个理由是K近邻法不需要对训练集进行学习。然而,不需要对训练集进行学习,反过来也会造成对测试集进行判定时,计算与空间复杂度的增加。

K近邻法最简单的实现方法是对需要分类的目标点,计算出训练集中每一个点到其的距离(比较常用的有欧氏距离),然后选取K个距离目标最近的点,根据这些点的分类以多数表决的形式来决定目标点的分类。理论上该方法的时间复杂度为O(n)。当n的数量巨大时,时间开销是巨大的。

kd树是一种为了提高KNN方法效率的特殊数据结构,它的本质是二叉树,每一个节点代表着对k维输入空间上的某一位进行划分的超平面。构造kd树相当于不断使用垂直于坐标轴的超平面将K维空间划分。

构建树的过程中,依次按顺序选择k维空间的某一个特征,在该特征的坐标轴上通过某个实例的特征值确定超平面,该超平面垂直于选择的坐标轴,将当前的超矩形区域划分为左右两个子区域,此时输入实例以选择的特征上选择的点为界,被分在左右两个超矩形区域中。重复以上过程直到子区域没有实例为止。

下面以一个例子来说明kd树的构建过程(例子来自《统计学习方法》-李航):
给定一个二维空间的数据集 T = {(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)}

首先选取k维空间的某一特征(该例为第一个,理论上先选择哪一个都是可行的,顺序应该也是任意的,为了简化描述过程,先选择第一个特征)。

所有数据集的第一个特征为{2, 5, 9,4, 8, 7},然后选取某个特征值来确定超平面。一般情况下选取中点(事实证明选取中点并不是在任何情况下都是最佳的,通过中点构建出的kd树在搜素时未必效率最高,例如在训练集中两个类别之间的相对距离较远,且某一类的数量远大于另一类的数量)(偶数个的情况下可以选靠后的点)。该例中中点值为7.以平面x(1) = 7为超平面,将空间分为左右两部分(左空间小于中点值,右空间大于中点值),左空间点(2,3)(4,7)(5,4),再以第二个特征值画超平面,此时第二个特征值的中点为4,以x(2)= 4为超平面再次划分左半平面,以此类推。最终得到的划分如下。

kd树如下

kd树构建完毕后,利用kd树进行k近邻搜索。在kd树上进行近邻搜索时,很多时候可以不进入某个父节点的另一个子节点(省去了另一个子节点数据点的查找)。kd树查找的具体算法如下:

算法输入:构造完毕的kd树,需要分类的目标点x

算法输出:目标点x的k近邻点

算法过程:

1通过深度优先方法,在kd树中搜索到目标点的所在的叶节点。(注:该搜索并不能直接找到最近邻点)搜索方法如下,在搜索每一层的过程中,根据该层的分割特征的序数,来对目标点的该序数的特征进行分类(决定是进入左子节点还是右子节点)。如目标点为(6,1),在根节点,分割特征为x(1),目标点的x(1)为 6,6小于根节点的x(1)= 7,进入左子节点。第二层的分割特征为x(2),目标点的x(2)为1, 1小于子节点的x(2)= 4,进入左子节点,最终得到叶节点(2,3)。

2 以该叶节点作为当前的NN(最近邻)点,计算该叶节点与目标点的距离,并设为当前的最小距离。

3 计算该叶节点父节点与目标点的距离,若小于当前的最小距离,则更新当前的最小距离以及当前的NN点(被覆盖的点先记录下来)

4 判断是否要进入父节点的另一个子节点:

判断方法为:计算父节点在其分割特征上的值距离目标点在该特征上的值的距离,若该距离小于当前的最小距离,则进入另一个子节点,否则不进入。

既:检查另一子节点对应的区域是否与以目标点为球心,以目标点与当前的NN点的距离为半径的球体相交。若相交则进入,反正不进入。

a)若不进入另一个子节点,则以此父节点视为叶节点,重复步骤3。

b)若进入另一个子节点,则对右边节点以下的子树执行步骤1,找到新的叶节点。判断是否要更新NN点与当前最小距离。随后以该叶节点开始,重复步骤3。

以此类推,搜索过程中将不断向跟节点回退。在向根节点回退的过程中,不必再次进入从中退出的子节点,保证过程不会进入死循环。

5 当回退到根节点时(且以根节点与目标点的距离来更新最小距离与NN点后),最后的NN点即为x的最近邻点。且记录下来的所有NN点,对应的距离,最小的K个即为K近邻点。


下面以Python来实现kd树的生成及搜索。

##Generate KD treedef createTree(dataSet, layer = 0, feature = 2):    length = len(dataSet)    dataSetCopy = dataSet[:]    featureNum = layer % feature    dataSetCopy.sort(key = lambda x:x[featureNum])    layer += 1    if length == 0:        return None    elif length == 1:        return {'Value':dataSet[0], 'Layer':layer, 'feature':featureNum,'Left':None, 'Right':None}    elif length != 1:        midNum = length // 2        dataSetLeft = dataSetCopy[:midNum]        dataSetRight = dataSetCopy[midNum+1:]        return {'Value':dataSetCopy[midNum], 'Layer':layer, 'feature':featureNum, 'Left':createTree(dataSetLeft, layer)                , 'Right':createTree(dataSetRight, layer)}
该部分为生成kd树的函数。生成的树以字典嵌套字典的形式表示。(每个节点有‘Left’和‘Right’两个键,里面的值为该分支下的子树。有“Layer”键表示节点在树的第几层,“feature”表示该节点分割使用的是实例的第几个特征(由于python的特性,程序中的0表示x(1)))。

具体思路:当该函数的输入仅有一个实例时,返回一个左右子节点均为None的节点(既表示这是一个叶节点),若不仅有一个实例,则根据层数和相应的特征,将输入的实例分割成两个部分,该节点的左子节点为,递归本函数的返回值(输入为实例分割的左半部分),右子节点为输入为实例分割右半部的本函数的递归返回值。

以上面的例子生成的子树如下:

{'Value': (7, 2), 'Layer': 1, 'feature': 0, 'Left': {'Value': (5, 4), 'Layer': 2, 'feature': 1, 'Left': {'Value': (2, 3), 'Layer': 3, 'feature': 0, 'Left': None, 'Right': None}, 'Right': {'Value': (4, 7), 'Layer': 3, 'feature': 0, 'Left': None, 'Right': None}}, 'Right': {'Value': (9, 6), 'Layer': 2, 'feature': 1, 'Left': {'Value': (8, 1), 'Layer': 3, 'feature': 0, 'Left': None, 'Right': None}, 'Right': None}}


#calculate distancedef calDistance(sourcePoint, targetPoint):    length = len(targetPoint)    sum = 0.0    for i in range(length):        sum += (sourcePoint[i] - targetPoint[i])**2    sum = sqrt(sum)    return sum#A function use to find a point in KDtreedef findPoint(Tree, value):    if  Tree !=None and Tree['Value'] == value:        return Tree    else:        if Tree['Left'] != None:            return findPoint(Tree['Left'], value) or findPoint(Tree['Right'], value)
以上两个函数分别为距离计算函数(欧氏距离),以及根据节点的值,在kd树中找到对应的节点的函数(包括值,层,分割特征及左右子节点,准确的说已经不止是该节点,而是一颗子树)。

#DFS algorithmdef dfs(kdTree, target,tracklist = []):    tracklistCopy = tracklist[:]    if not kdTree:        return None, tracklistCopy    elif not kdTree['Left']:        tracklistCopy.append(kdTree['Value'])        return kdTree['Value'], tracklistCopy    elif kdTree['Left']:        pointValue = kdTree['Value']        feature = kdTree['feature']        tracklistCopy.append(pointValue)        #return kdTree['Value'], tracklistCopy        if target[feature] <= pointValue[feature]:            return dfs(kdTree['Left'], target, tracklistCopy)        elif target[feature] > pointValue[feature]:            return dfs(kdTree['Right'], target, tracklistCopy)
该函数为kd搜索的一部分,既深度优先算法部分。函数的返回值为最终找到的叶节点的值,以及从根节点到叶节点的跟踪路径(列表结构)。由于在本程序中,通过子节点去搜索父节点的复杂度极高,所以通过跟踪路径的方式来找寻父节点。在跟踪路径中,每一个值的上一个值,均为该值对应的节点的父节点。


#KDtree search algorithmdef kdTreeSearch(tracklist, target , usedPoint = [] , minDistance = float('inf'), minDistancePoint = None):    tracklistCopy = tracklist[:]    usedPointCopy = usedPoint[:]    if not minDistancePoint:        minDistancePoint = tracklistCopy[-1]    if len(tracklistCopy) == 1:        return minDistancePoint    else:        point = findPoint(kdTree, tracklist[-1])        if calDistance(point['Value'], target) < minDistance:            minDistance = calDistance(point['Value'], target)            minDistancePoint = point['Value']        fatherPoint = findPoint(kdTree, tracklistCopy[-2])        fatherPointval = fatherPoint['Value']        fatherPointfea = fatherPoint['feature']        if calDistance(fatherPoint['Value'], target) < minDistance:            minDistance = calDistance(fatherPoint['Value'], target)            minDistancePoint = fatherPoint['Value']        if point == fatherPoint['Left']:            anotherPoint = fatherPoint['Right']        elif point == fatherPoint['Right']:            anotherPoint = fatherPoint['Left']        if (anotherPoint == None or anotherPoint['Value'] in usedPointCopy or            abs(fatherPointval[fatherPointfea] - target[fatherPointfea]) > minDistance):            usedPoint = tracklistCopy.pop()            usedPointCopy.append(usedPoint)            return kdTreeSearch(tracklistCopy, target, usedPointCopy, minDistance, minDistancePoint)        else:            usedPoint = tracklistCopy.pop()            usedPointCopy.append(usedPoint)            subvalue, subtrackList = dfs(anotherPoint, target)            tracklistCopy.extend(subtrackList)            return kdTreeSearch(tracklistCopy, target, usedPointCopy, minDistance, minDistancePoint)
该部分为搜索函数的剩余部分,主要思想还是递归。函数的返回值为最近邻点。若需要K近邻点可以创建一个列表保存每一个曾出现过的NN点,选出距离最小的K个即可。

完整程序如下:

from math import sqrtfrom random import randint##Generate KD treedef createTree(dataSet, layer = 0, feature = 2):    length = len(dataSet)    dataSetCopy = dataSet[:]    featureNum = layer % feature    dataSetCopy.sort(key = lambda x:x[featureNum])    layer += 1    if length == 0:        return None    elif length == 1:        return {'Value':dataSet[0], 'Layer':layer, 'feature':featureNum,'Left':None, 'Right':None}    elif length != 1:        midNum = length // 2        dataSetLeft = dataSetCopy[:midNum]        dataSetRight = dataSetCopy[midNum+1:]        return {'Value':dataSetCopy[midNum], 'Layer':layer, 'feature':featureNum, 'Left':createTree(dataSetLeft, layer)                , 'Right':createTree(dataSetRight, layer)}#calculate distancedef calDistance(sourcePoint, targetPoint):    length = len(targetPoint)    sum = 0.0    for i in range(length):        sum += (sourcePoint[i] - targetPoint[i])**2    sum = sqrt(sum)    return sum#DFS algorithmdef dfs(kdTree, target,tracklist = []):    tracklistCopy = tracklist[:]    if not kdTree:        return None, tracklistCopy    elif not kdTree['Left']:        tracklistCopy.append(kdTree['Value'])        return kdTree['Value'], tracklistCopy    elif kdTree['Left']:        pointValue = kdTree['Value']        feature = kdTree['feature']        tracklistCopy.append(pointValue)        #return kdTree['Value'], tracklistCopy        if target[feature] <= pointValue[feature]:            return dfs(kdTree['Left'], target, tracklistCopy)        elif target[feature] > pointValue[feature]:            return dfs(kdTree['Right'], target, tracklistCopy)#A function use to find a point in KDtreedef findPoint(Tree, value):    if  Tree !=None and Tree['Value'] == value:        return Tree    else:        if Tree['Left'] != None:            return findPoint(Tree['Left'], value) or findPoint(Tree['Right'], value)#KDtree search algorithmdef kdTreeSearch(tracklist, target , usedPoint = [] , minDistance = float('inf'), minDistancePoint = None):    tracklistCopy = tracklist[:]    usedPointCopy = usedPoint[:]    if not minDistancePoint:        minDistancePoint = tracklistCopy[-1]    if len(tracklistCopy) == 1:        return minDistancePoint    else:        point = findPoint(kdTree, tracklist[-1])        if calDistance(point['Value'], target) < minDistance:            minDistance = calDistance(point['Value'], target)            minDistancePoint = point['Value']        fatherPoint = findPoint(kdTree, tracklistCopy[-2])        fatherPointval = fatherPoint['Value']        fatherPointfea = fatherPoint['feature']        if calDistance(fatherPoint['Value'], target) < minDistance:            minDistance = calDistance(fatherPoint['Value'], target)            minDistancePoint = fatherPoint['Value']        if point == fatherPoint['Left']:            anotherPoint = fatherPoint['Right']        elif point == fatherPoint['Right']:            anotherPoint = fatherPoint['Left']        if (anotherPoint == None or anotherPoint['Value'] in usedPointCopy or            abs(fatherPointval[fatherPointfea] - target[fatherPointfea]) > minDistance):            usedPoint = tracklistCopy.pop()            usedPointCopy.append(usedPoint)            return kdTreeSearch(tracklistCopy, target, usedPointCopy, minDistance, minDistancePoint)        else:            usedPoint = tracklistCopy.pop()            usedPointCopy.append(usedPoint)            subvalue, subtrackList = dfs(anotherPoint, target)            tracklistCopy.extend(subtrackList)            return kdTreeSearch(tracklistCopy, target, usedPointCopy, minDistance, minDistancePoint)trainingSet = [(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)]kdTree = createTree(trainingSet)target = eval(input('Input target point:'))value, trackList = dfs(kdTree, target)nnPoint = kdTreeSearch(trackList, target)print(nnPoint)

输入目标点,即可返回最近邻点的值。


该程序仍有大量不足之处,其中一点即为kd书的储存结构并不是特别的优秀。每次在使用的时候也都需要使用查找函数去找到value对应的节点,才能获得节点的各种属性。这大大增加了时间复杂度。其次,当数据量极大的时候,递归的层数过深,空间的复杂度也是很高的。










原创粉丝点击