kdtree c++版本

来源:互联网 发布:如何评价魔兽世界 知乎 编辑:程序博客网 时间:2024/06/05 08:53


http://blog.csdn.net/zhl30041839/article/details/9277807


等我实现一个python版本的再贴上来

ok,实现了python版本的kdtree,并增加了本文没有实现的查询k-nn的函数,按照http://web.stanford.edu/class/cs106l/handouts/assignment-3-kdtree.pdf给的思路实现的。

# -*-mport randomimport numpy as npclass Treenode(object):    def __init__(self, current_node = None, split = None, left = None, right = None):        self.current_node = None        self.split = split        self.left = left        self.right = rightdef findSplitPoint(datapoints, split):    local_split = split % (datapoints.shape[0])    datapoints = datapoints[datapoints[:,local_split].argsort()]    return datapointsdef buildKdtree(datapoints, split):    if datapoints.size == 0:        return    datapoints = findSplitPoint(datapoints, split)    numpoints = datapoints.shape[0]    middle = numpoints/2    left_datapoints = datapoints[:middle,:]    right_datapoints = datapoints[middle+1:,:]    current_node = Treenode()    current_node.split = split    current_node.current_node = datapoints[middle,:]    current_node.left = buildKdtree(left_datapoints, split+1)    current_node.right = buildKdtree(right_datapoints, split+1)    return current_nodedef printKdtree(treenode):    print treenode.current_node, treenode.split    if treenode.left:        printKdtree(treenode.left)    if treenode.right:        printKdtree(treenode.right)def distance(node1, node2):    return np.linalg.norm(node1-node2)def findNearestNeighbor(root, x):    p = root    dim = p.current_node.shape[0]    search_path = list()    dist = np.finfo(np.float64()).max    nearest_neighbor = None    while p.current_node.size <> 0:        if (not p.left) and (not p.right):            current_dist = distance(p.current_node, x)            if current_dist < dist:                dist = current_dist                nearest_neighbor = p.current_node            break        search_path.append(p)        local_split = p.split % dim        if x[local_split] < p.current_node[local_split]:            p = p.left        else:            p = p.right    search_path = np.array(search_path)    while search_path.size > 0:        #for item in search_path:        #    print 'yes', item.current_node,        #distance between the point x to the separate plane        current_node = search_path[-1]        search_path = search_path[:-1]        local_split = current_node.split % len(x)        dist_point_plane = x[local_split] - current_node.current_node[local_split]        if dist_point_plane < dist:            current_distance =  distance(current_node.current_node, x)            if current_distance < dist:                dist = current_distance                nearest_neighbor = current_node.current_node            if (not current_node.left) and (not current_node.right):                continue        #    print 'abc', x[local_split], current_node.current_node        #    print x[local_split] <= current_node.current_node[local_split], local_split            if x[local_split] <= current_node.current_node[local_split]:                np.append(search_path, [current_node.right])            else:                search_path = np.append(search_path, [current_node.left])    return dist, nearest_neighbordef findKNearestNeighbor(root, k, x):    res = [0]*k    elementnum = 0    p = root    dim = p.current_node.shape[0]    search_path = list()    dist = np.finfo(np.float64()).max    nearest_neighbor = None    while p.current_node.size <> 0:        if (not p.left) and (not p.right):            current_dist = distance(p.current_node, x)            if current_dist < dist:                dist = current_dist                nearest_neighbor = p.current_node                print dist, p.current_node                res[elementnum] = (dist, (p.current_node))                elementnum += 1            break        search_path.append(p)        local_split = p.split % dim        if x[local_split] < p.current_node[local_split]:            p = p.left        else:            p = p.right    search_path = np.array(search_path)    while search_path.size > 0:        #for item in search_path:        #    print 'yes', item.current_node,        #distance between the point x to the separate plane        current_node = search_path[-1]        search_path = search_path[:-1]        local_split = current_node.split % len(x)        dist_point_plane = x[local_split] - current_node.current_node[local_split]        dist = res[elementnum-1][0]        if dist_point_plane < dist or elementnum < k:            current_distance =  distance(current_node.current_node, x)            # if the res is not full, insert current node            if elementnum < k or current_distance < dist:                local_index = elementnum-1                while local_index >= 0 and res[local_index][0] > current_distance:                    if local_index == k-1:                        local_index -= 1                        elementnum -= 1                        continue                    res[local_index+1] = res[local_index]                    local_index -= 1                res[local_index+1] = (current_distance, current_node.current_node)                elementnum += 1            if current_distance < dist:                dist = current_distance                nearest_neighbor = current_node.current_node            if (not current_node.left) and (not current_node.right):                continue        #    print 'abc', x[local_split], current_node.current_node        #    print x[local_split] <= current_node.current_node[local_split], local_split            if dist_point_plane < dist or elementnum < k:                if x[local_split] <= current_node.current_node[local_split]:                    np.append(search_path, [current_node.right])                else:                    search_path = np.append(search_path, [current_node.left])    return resif __name__ == "__main__":    datapoints = list()    datapoints = [(2,3), (5,4), (9,6), (4,7),(8,1),(7,2)]    ndim = 2    #for i in range(10):    #    data = list()    #    for j in range(ndim):    #        data.append(random.randint(1, 10))    #    datapoints.append(data)    print datapoints    datapoints = np.array(datapoints)    for data in datapoints:        print data    root = buildKdtree(datapoints, 0)    printKdtree(root)    #find 1 nearest neighbor example    #res = findNearestNeighbor(root, (2, 4.5))    #find k nearest neighbor examples    res = findKNearestNeighbor(root, 2, (2, 4.5))    print res

一、如何高效率地实现k近邻法?

  在SIFT图像特征匹配等应用中,需要在高维特征空间中快速找到距离目标图像特征最近邻的那个特征点,往往需要进行比较的特征向量的数量很大,如果进行朴素最近邻搜索,也就是依次计算目标点和每一个待匹配特征的距离,然后再算出最短距离这样的策略,那么特征匹配算法的时间复杂度将会高得令人难以接受。因此,我们需要借助一种存储和表示k维数据的数据结构,既能够方便地存储k维数据,又能够进行高效率的搜索。


二、k-d树的基本思想

  k-d树由斯坦福大学本科生Jon Louis Bentley于1975年首次提出。k-d树是每个节点都为k维点的二叉树。其中k表示存储的数据的维度,d就是dimension的意思。所有非叶子节点可以视作用一个超平面把空间分割成两部分。在超平面左边的点代表节点的左子树,在超平面右边的点代表节点的右子树。超平面的方向可以用下述方法来选择:每个节点都与k维中垂直于超平面的那一维有关。因此,如果选择按照x轴划分,所有x值小于指定值的节点都会出现在左子树,所有x值大于指定值的节点都会出现在右子树。这样,超平面可以用该x值来确定,其法矢为x轴的单位向量。一个三维空间内的3-d树如下所示:



  当特征空间维度大于20时,k-d tree算法的性能会剧烈下降,对于高维数据,David Lowe在1997的一篇文章中提出一种近似算法best-bins-first,可以有效改善这种情况。


kd树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形结构。kd树从本质上来说是二叉树,表示对k维空间的一个划分。构造kd树相当于不断地用垂直于坐标轴的超平面切分k维空间,构成一系列的k维超矩形区域,kd树的每一个结点都对应于一个超矩形区域,非叶结点的左右子树分别表示划分得到的两个区域。在2维情形,当划分超平面平行于x轴时,在划分超平面以下的数据点将存储在此划分结点的左子树,在超平面以上的点存储在此划分结点右子树;若划分超平面平行于y轴,在划分超平面左侧的数据点将存储在此划分结点的左子树,在超平面右侧的点存储在此划分结点右子树。


构造kd树的方法:首先构造根节点,根节点对应于整个k维空间,包含所有的实例点,(至于如何选取划分点,有不同的策略。最常用的是一种方法是:对于所有的样本点,统计它们在每个维上的方差,挑选出方差中的最大值,对应的维就是要进行数据切分的维度。数据方差最大表明沿该维度数据点分散得比较开,这个方向上进行数据分割可以获得最好的分辨率;然后再将所有样本点按切分维度的值进行排序,位于正中间的那个数据点选为分裂结点。)。然后利用递归的方法,分别构造k-d树根节点的左右子树。在超矩形区域上选择一个坐标轴(切分维度)和一个分裂结点,以通过此分裂结点且垂直于切分方向坐标轴的直线作为分隔线,将当前超矩形区域分隔成左右或者上下两个子超矩形区域,对应于分裂结点的左右子树的根节点。实例也就被分到两个不相交的区域中。重复此过程直到子区域内没有实例点时终止。终止时的结点为叶结点。

通常依次选择坐标轴对空间切分,选择训练实例点在选定坐标轴上的中位数为切分点,这样得到的kd树是平衡的,但并不一定能保证检索的效率最优。

三、k-d tree的实现

  k-d tree是英文K-dimension tree的缩写,是对数据点在k维空间中划分的一种数据结构。k-d tree实际上是一种二叉树。每个结点的内容如下:

域名类型描述dom_eltkd维的向量kd维空间中的一个样本点split整数分裂维的序号,也是垂直于分割超面的方向轴序号leftkd-tree由位于该结点分割超面左子空间内所有数据点构成的kd-treerightkd-tree由位于该结点分割超面右子空间内所有数据点构成的kd-tree  k-d树算法可以分为两大部分,一部分是有关k-d树本身这种数据结构建立的算法,另一部分是在建立的k-d树上如何进行最邻近查找的算法。

  先以一个简单直观的实例来介绍k-d树算法。假设有6个二维数据点{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)},数据点位于二维空间内(如图1中黑点所示)。k-d树算法就是要确定图1中这些分割空间的分割线(多维空间即为分割平面,一般为超平面)。下面就要通过一步步展示k-d树是如何确定这些分割线的。


由于此例简单,数据维度只有2维,所以可以简单地给x,y两个方向轴编号为0,1,也即split={0,1}。

  (1)确定split域的首先该取的值。分别计算x,y方向上数据的方差得知x方向上的方差最大,所以split域值首先取0,也就是x轴方向;

  (2)确定Node-data的域值。根据x轴方向的值2,5,9,4,8,7排序选出中值为7,所以Node-data = (7,2)。这样,该节点的分割超平面就是通过(7,2)并垂直于split = 0(x轴)的直线x = 7;

  (3)确定左子空间和右子空间。分割超平面x = 7将整个空间分为两部分,如图2所示。x < =  7的部分为左子空间,包含3个节点{(2,3),(5,4),(4,7)};另一部分为右子空间,包含2个节点{(9,6),(8,1)}。


如算法所述,k-d树的构建是一个递归的过程。然后对左子空间和右子空间内的数据重复根节点的过程就可以得到下一级子节点(5,4)和(9,6)(也就是左右子空间的'根'节点),同时将空间和数据集进一步细分。如此反复直到空间中只包含一个数据点,如图1所示。最后生成的k-d树如图3所示。




从上面的表也可以看出k-d tree本质上是一种二叉树,因此k-d tree的构建是一个逐级展开的递归过程。

[cpp] view plain copy
  1. 算法:createKDTree 构建一棵k-d tree   
  2.    
  3. 输入:exm_set 样本集   
  4.    
  5. 输出 : Kd, 类型为kd-tree   
  6.    
  7. 1. 如果exm_set是空的,则返回空的kd-tree   
  8.    
  9. 2.调用分裂结点选择程序(输入是exm_set),返回两个值   
  10.    
  11.        dom_elt:= exm_set中的一个样本点   
  12.    
  13.        split := 分裂维的序号   
  14.    
  15. 3.exm_set_left = {exm∈exm_set – dom_elt && exm[split] <= dom_elt[split]}   
  16.    
  17.    exm_set_right = {exm∈exm_set – dom_elt && exm[split] > dom_elt[split]}   
  18.    
  19. 4.left = createKDTree(exm_set_left)   
  20.    
  21. right = createKDTree(exm_set_right)   

k-d tree最近邻搜索算法

  如前所述,在k-d tree树中进行数据的k近邻搜索是特征匹配的重要环节,其目的是检索在k-d tree中与待查询点距离最近的k个数据点。

  最近邻搜索是k近邻的特例,也就是1近邻。将1近邻改扩展到k近邻非常容易。下面介绍最简单的k-d tree最近邻搜索算法。

  基本的思路很简单:首先通过二叉树搜索(比较待查询节点和分裂节点的分裂维的值,小于等于就进入左子树分支,等于就进入右子树分支直到叶子结点),顺着“搜索路径”很快能找到最近邻的近似点,也就是与待查询点处于同一个子空间的叶子结点;然后再回溯搜索路径,并判断搜索路径上的结点的其他子结点空间中是否可能有距离查询点更近的数据点,如果有可能,则需要跳到其他子结点空间中去搜索(将其他子结点加入到搜索路径)。重复这个过程直到搜索路径为空。下面给出k-d tree最近邻搜索的伪代码:

[cpp] view plain copy
  1. 算法:kdtreeFindNearest /* k-d tree的最近邻搜索 */   
  2.    
  3. 输入:Kd /* k-d tree类型*/   
  4.    
  5. target /* 待查询数据点 */   
  6.    
  7. 输出 : nearest /* 最近邻数据结点 */   
  8.    
  9. dist /* 最近邻和查询点的距离 */   
  10.    
  11. 1. 如果Kd是空的,则设dist为无穷大返回   
  12.    
  13. 2. 向下搜索直到叶子结点   
  14.    
  15. pSearch = &Kd   
  16.    
  17. while(pSearch != NULL)    
  18. {    
  19. pSearch加入到search_path中;    
  20. if(target[pSearch->split] <= pSearch->dom_elt[pSearch->split]) /* 如果小于就进入左子树 */    
  21. {    
  22. pSearch = pSearch->left;    
  23. }    
  24. else    
  25. {    
  26. pSearch = pSearch->right;    
  27. }    
  28. }    
  29. 取出search_path最后一个赋给nearest   
  30.    
  31. dist = Distance(nearest, target);    
  32. 3. 回溯搜索路径   
  33.    
  34. while(search_path不为空)    
  35. {    
  36. 取出search_path最后一个结点赋给pBack   
  37.    
  38. if(pBack->left为空 && pBack->right为空) /* 如果pBack为叶子结点 */   
  39.    
  40. {   
  41.    
  42. if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )    
  43. {    
  44. nearest = pBack->dom_elt;    
  45. dist = Distance(pBack->dom_elt, target);    
  46. }   
  47.    
  48. }   
  49.    
  50. else   
  51.    
  52. {   
  53.    
  54. s = pBack->split;    
  55. if( abs(pBack->dom_elt[s] - target[s]) < dist) /* 如果以target为中心的圆(球或超球),半径为dist的圆与分割超平面相交, 那么就要跳到另一边的子空间去搜索 */    
  56. {    
  57. if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )    
  58. {    
  59. nearest = pBack->dom_elt;    
  60. dist = Distance(pBack->dom_elt, target);    
  61. }    
  62. if(target[s] <= pBack->dom_elt[s]) /* 如果target位于pBack的左子空间,那么就要跳到右子空间去搜索 */    
  63. pSearch = pBack->right;    
  64. else    
  65. pSearch = pBack->left; /* 如果target位于pBack的右子空间,那么就要跳到左子空间去搜索 */    
  66. if(pSearch != NULL)    
  67. pSearch加入到search_path中    
  68. }   
  69.    
  70. }    
  71. }   

假设我们的k-d tree就是上面通过样本集{(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)}创建的。

我们来查找点(2.1,3.1),在(7,2)点测试到达(5,4),在(5,4)点测试到达(2,3),然后search_path中的结点为<(7,2), (5,4), (2,3)>,从search_path中取出(2,3)作为当前最佳结点nearest, dist为0.141;

然后回溯至(5,4),以(2.1,3.1)为圆心,以dist=0.141为半径画一个圆,并不和超平面y=4相交,如下图,所以不必跳到结点(5,4)的右子空间去搜索,因为右子空间中不可能有更近样本点了。

于是在回溯至(7,2),同理,以(2.1,3.1)为圆心,以dist=0.141为半径画一个圆并不和超平面x=7相交,所以也不用跳到结点(7,2)的右子空间去搜索

至此,search_path为空,结束整个搜索,返回nearest(2,3)作为(2.1,3.1)的最近邻点,最近距离为0.141。



再举一个稍微复杂的例子,我们来查找点(2,4.5),在(7,2)处测试到达(5,4),在(5,4)处测试到达(4,7),然后search_path中的结点为<(7,2), (5,4), (4,7)>,从search_path中取出(4,7)作为当前最佳结点nearest, dist为3.202;

然后回溯至(5,4),以(2,4.5)为圆心,以dist=3.202为半径画一个圆与超平面y=4相交,如下图,所以需要跳到(5,4)的左子空间去搜索。所以要将(2,3)加入到search_path中,现在search_path中的结点为<(7,2), (2, 3)>;另外,(5,4)与(2,4.5)的距离为3.04 < dist = 3.202,所以将(5,4)赋给nearest,并且dist=3.04。

回溯至(2,3),(2,3)是叶子节点,直接平判断(2,3)是否离(2,4.5)更近,计算得到距离为1.5,所以nearest更新为(2,3),dist更新为(1.5)

回溯至(7,2),同理,以(2,4.5)为圆心,以dist=1.5为半径画一个圆并不和超平面x=7相交, 所以不用跳到结点(7,2)的右子空间去搜索

至此,search_path为空,结束整个搜索,返回nearest(2,3)作为(2,4.5)的最近邻点,最近距离为1.5。



  以下是k-d树的c++代码实现,包括建树过程和搜索过程。算法main函数输入k-d树训练实例点,算法会完成建树操作,随后可以输入待查询的目标点,程序将会搜索K-d树找出与输入目标点最近邻的训练实例点。本程序只实现了1近邻搜索,如果要实现k近邻搜索,只需对程序稍作修改。比如可以对每个结点添加一个标记,如果已经输出该结点为最近邻结点,那么就继续查找次近邻的结点,直到输出k个结点后算法结束。
[cpp] view plain copy
  1. #include <iostream>    
  2. #include <algorithm>    
  3. #include <stack>    
  4. #include <math.h>    
  5. using namespace std;    
  6. /*function of this program: build a 2d tree using the input training data  
  7.  the input is exm_set which contains a list of tuples (x,y)  
  8.  the output is a 2d tree pointer*/    
  9.     
  10.     
  11. struct data    
  12. {    
  13.     double x = 0;    
  14.     double y = 0;    
  15. };    
  16.     
  17. struct Tnode    
  18. {    
  19.     struct data dom_elt;    
  20.     int split;    
  21.     struct Tnode * left;    
  22.     struct Tnode * right;    
  23. };    
  24.     
  25. bool cmp1(data a, data b){    
  26.     return a.x < b.x;    
  27. }    
  28.     
  29. bool cmp2(data a, data b){    
  30.     return a.y < b.y;    
  31. }    
  32.     
  33. bool equal(data a, data b){    
  34.     if (a.x == b.x && a.y == b.y)    
  35.     {    
  36.         return true;    
  37.     }    
  38.     else{    
  39.         return false;    
  40.     }    
  41. }    
  42.     
  43. void ChooseSplit(data exm_set[], int size, int &split, data &SplitChoice){    
  44.     /*compute the variance on every dimension. Set split as the dismension that have the biggest  
  45.      variance. Then choose the instance which is the median on this split dimension.*/    
  46.     /*compute variance on the x,y dimension. DX=EX^2-(EX)^2*/    
  47.     double tmp1,tmp2;    
  48.     tmp1 = tmp2 = 0;    
  49.     for (int i = 0; i < size; ++i)    
  50.     {    
  51.         tmp1 += 1.0 / (double)size * exm_set[i].x * exm_set[i].x;    
  52.         tmp2 += 1.0 / (double)size * exm_set[i].x;    
  53.     }    
  54.     double v1 = tmp1 - tmp2 * tmp2;  //compute variance on the x dimension    
  55.         
  56.     tmp1 = tmp2 = 0;    
  57.     for (int i = 0; i < size; ++i)    
  58.     {    
  59.         tmp1 += 1.0 / (double)size * exm_set[i].y * exm_set[i].y;    
  60.         tmp2 += 1.0 / (double)size * exm_set[i].y;    
  61.     }    
  62.     double v2 = tmp1 - tmp2 * tmp2;  //compute variance on the y dimension    
  63.         
  64.     split = v1 > v2 ? 0:1; //set the split dimension    
  65.         
  66.     if (split == 0)    
  67.     {    
  68.         sort(exm_set,exm_set + size, cmp1);    
  69.     }    
  70.     else{    
  71.         sort(exm_set,exm_set + size, cmp2);    
  72.     }    
  73.         
  74.     //set the split point value    
  75.     SplitChoice.x = exm_set[size / 2].x;    
  76.     SplitChoice.y = exm_set[size / 2].y;    
  77.         
  78. }    
  79.     
  80. Tnode* build_kdtree(data exm_set[], int size, Tnode* T){    
  81.     //call function ChooseSplit to choose the split dimension and split point    
  82.     if (size == 0){    
  83.         return NULL;    
  84.     }    
  85.     else{    
  86.         int split;    
  87.         data dom_elt;    
  88.         ChooseSplit(exm_set, size, split, dom_elt);    
  89.         data exm_set_right [100];    
  90.         data exm_set_left [100];    
  91.         int sizeleft ,sizeright;    
  92.         sizeleft = sizeright = 0;    
  93.             
  94.         if (split == 0)    
  95.         {    
  96.             for (int i = 0; i < size; ++i)    
  97.             {    
  98.                     
  99.                 if (!equal(exm_set[i],dom_elt) && exm_set[i].x <= dom_elt.x)    
  100.                 {    
  101.                     exm_set_left[sizeleft].x = exm_set[i].x;    
  102.                     exm_set_left[sizeleft].y = exm_set[i].y;    
  103.                     sizeleft++;    
  104.                 }    
  105.                 else if (!equal(exm_set[i],dom_elt) && exm_set[i].x > dom_elt.x)    
  106.                 {    
  107.                     exm_set_right[sizeright].x = exm_set[i].x;    
  108.                     exm_set_right[sizeright].y = exm_set[i].y;    
  109.                     sizeright++;    
  110.                 }    
  111.             }    
  112.         }    
  113.         else{    
  114.             for (int i = 0; i < size; ++i)    
  115.             {    
  116.                     
  117.                 if (!equal(exm_set[i],dom_elt) && exm_set[i].y <= dom_elt.y)    
  118.                 {    
  119.                     exm_set_left[sizeleft].x = exm_set[i].x;    
  120.                     exm_set_left[sizeleft].y = exm_set[i].y;    
  121.                     sizeleft++;    
  122.                 }    
  123.                 else if (!equal(exm_set[i],dom_elt) && exm_set[i].y > dom_elt.y)    
  124.                 {    
  125.                     exm_set_right[sizeright].x = exm_set[i].x;    
  126.                     exm_set_right[sizeright].y = exm_set[i].y;    
  127.                     sizeright++;    
  128.                 }    
  129.             }    
  130.         }    
  131.         T = new Tnode;    
  132.         T->dom_elt.x = dom_elt.x;    
  133.         T->dom_elt.y = dom_elt.y;    
  134.         T->split = split;    
  135.         T->left = build_kdtree(exm_set_left, sizeleft, T->left);    
  136.         T->right = build_kdtree(exm_set_right, sizeright, T->right);    
  137.         return T;    
  138.             
  139.     }    
  140. }    
  141.     
  142.     
  143. double Distance(data a, data b){    
  144.     double tmp = (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y);    
  145.     return sqrt(tmp);    
  146. }    
  147.     
  148.     
  149. void searchNearest(Tnode * Kd, data target, data &nearestpoint, double & distance){    
  150.         
  151.     //1. 如果Kd是空的,则设dist为无穷大返回    
  152.         
  153.     //2. 向下搜索直到叶子结点    
  154.         
  155.     stack<Tnode*> search_path;    
  156.     Tnode* pSearch = Kd;    
  157.     data nearest;    
  158.     double dist;    
  159.         
  160.     while(pSearch != NULL)    
  161.     {    
  162.         //pSearch加入到search_path中;    
  163.         search_path.push(pSearch);    
  164.             
  165.         if (pSearch->split == 0)    
  166.         {    
  167.             if(target.x <= pSearch->dom_elt.x) /* 如果小于就进入左子树 */    
  168.             {    
  169.                 pSearch = pSearch->left;    
  170.             }    
  171.             else    
  172.             {    
  173.                 pSearch = pSearch->right;    
  174.             }    
  175.         }    
  176.         else{    
  177.             if(target.y <= pSearch->dom_elt.y) /* 如果小于就进入左子树 */    
  178.             {    
  179.                 pSearch = pSearch->left;    
  180.             }    
  181.             else    
  182.             {    
  183.                 pSearch = pSearch->right;    
  184.             }    
  185.         }    
  186.     }    
  187.     //取出search_path最后一个赋给nearest    
  188.     nearest.x = search_path.top()->dom_elt.x;    
  189.     nearest.y = search_path.top()->dom_elt.y;    
  190.     search_path.pop();    
  191.         
  192.         
  193.     dist = Distance(nearest, target);    
  194.     //3. 回溯搜索路径    
  195.         
  196.     Tnode* pBack;    
  197.         
  198.     while(search_path.size() != 0)    
  199.     {    
  200.         //取出search_path最后一个结点赋给pBack    
  201.         pBack = search_path.top();    
  202.         search_path.pop();    
  203.             
  204.         if(pBack->left == NULL && pBack->right == NULL) /* 如果pBack为叶子结点 */    
  205.                 
  206.         {    
  207.                 
  208.             if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )    
  209.             {    
  210.                 nearest = pBack->dom_elt;    
  211.                 dist = Distance(pBack->dom_elt, target);    
  212.             }    
  213.                 
  214.         }    
  215.             
  216.         else    
  217.                 
  218.         {    
  219.                 
  220.             int s = pBack->split;    
  221.             if (s == 0)    
  222.             {    
  223.                 if( fabs(pBack->dom_elt.x - target.x) < dist) /* 如果以target为中心的圆(球或超球),半径为dist的圆与分割超平面相交, 那么就要跳到另一边的子空间去搜索 */    
  224.                 {    
  225.                     if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )    
  226.                     {    
  227.                         nearest = pBack->dom_elt;    
  228.                         dist = Distance(pBack->dom_elt, target);    
  229.                     }    
  230.                     if(target.x <= pBack->dom_elt.x) /* 如果target位于pBack的左子空间,那么就要跳到右子空间去搜索 */    
  231.                         pSearch = pBack->right;    
  232.                     else    
  233.                         pSearch = pBack->left; /* 如果target位于pBack的右子空间,那么就要跳到左子空间去搜索 */    
  234.                     if(pSearch != NULL)    
  235.                         //pSearch加入到search_path中    
  236.                         search_path.push(pSearch);    
  237.                 }    
  238.             }    
  239.             else {    
  240.                 if( fabs(pBack->dom_elt.y - target.y) < dist) /* 如果以target为中心的圆(球或超球),半径为dist的圆与分割超平面相交, 那么就要跳到另一边的子空间去搜索 */    
  241.                 {    
  242.                     if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )    
  243.                     {    
  244.                         nearest = pBack->dom_elt;    
  245.                         dist = Distance(pBack->dom_elt, target);    
  246.                     }    
  247.                     if(target.y <= pBack->dom_elt.y) /* 如果target位于pBack的左子空间,那么就要跳到右子空间去搜索 */    
  248.                         pSearch = pBack->right;    
  249.                     else    
  250.                         pSearch = pBack->left; /* 如果target位于pBack的右子空间,那么就要跳到左子空间去搜索 */    
  251.                     if(pSearch != NULL)    
  252.                        // pSearch加入到search_path中    
  253.                         search_path.push(pSearch);    
  254.                 }    
  255.             }    
  256.                 
  257.         }    
  258.     }    
  259.         
  260.     nearestpoint.x = nearest.x;    
  261.     nearestpoint.y = nearest.y;    
  262.     distance = dist;    
  263.         
  264. }    
  265.     
  266. int main(){    
  267.     data exm_set[100]; //assume the max training set size is 100    
  268.     double x,y;    
  269.     int id = 0;    
  270.     cout<<"Please input the training data in the form x y. One instance per line. Enter -1 -1 to stop."<<endl;    
  271.     while (cin>>x>>y){    
  272.         if (x == -1)    
  273.         {    
  274.             break;    
  275.         }    
  276.         else{    
  277.             exm_set[id].x = x;    
  278.             exm_set[id].y = y;    
  279.             id++;    
  280.         }    
  281.     }    
  282.     struct Tnode * root = NULL;    
  283.     root = build_kdtree(exm_set, id, root);    
  284.         
  285.     data nearestpoint;    
  286.     double distance;    
  287.     data target;    
  288.     cout <<"Enter search point"<<endl;    
  289.     while (cin>>target.x>>target.y)    
  290.     {    
  291.         searchNearest(root, target, nearestpoint, distance);    
  292.         cout<<"The nearest distance is "<<distance<<",and the nearest point is "<<nearestpoint.x<<","<<nearestpoint.y<<endl;    
  293.         cout <<"Enter search point"<<endl;    
  294.     
  295.     }    
  296. }  

0 0
原创粉丝点击