K-D树

来源:互联网 发布:淘宝详情模板下载 编辑:程序博客网 时间:2024/05/23 05:09

http://blog.csdn.net/acdreamers/article/details/44664645

K-D树,即K-Dimensional Tree,是一种高维索引树型数据结构。常用于大规模高维数据空间的最邻近或者K邻

近查找,例如图像检索中高维图像特征向量的K邻近匹配,对KNN算法的优化等。

 

Contents

 

   1. K-D树的基本原理

   2. K-D树的改进(BBF算法)

   3. K-D树的C++实现

   4. K-D树的开源框架介绍

 

 

1. K-D树的基本原理

 

   K-D树实际上是一棵高维二叉搜索树,与普通二叉搜索树不同的是,树中存储的是一些K维数据。先回忆一下二

   叉搜索树(BST),它是一棵具有如下性质的树

 

   (1)若它的左子树不为空,那么左子树上所有节点的值均小于它的根节点的值。

   (2)若它的右子树不为空。那么右子树上所有节点的值均大于它的根节点的值。

   (3)它的左右子树也分别是一棵二叉搜索树。

 

   二叉搜索树在建树时,按照上述规则分别插入即可。而在搜索时,从根节点开始往下查找。可以看出二叉搜索

   树的建树平均时间复杂度为,最坏时间复杂度为,查找的平均时间复杂度为

   最坏时间复杂度为,由于二叉搜索树不是平衡的,可能退化为一条链,这种情况就是最坏情况了。

 

   普通的二叉搜索树是一维的,当推广到K维后,就是我们的K-D树了。在K-D树中跟二叉搜索树差不多,也是将

   一个K维的数据与根节点进行比较,然后划分的,这里的比较不是整体的比较,而是选择其中一个维度来进行比

   较。那么在K-D树中我们需要解决两个重要的问题

 

   (1)每一次划分时,应该选择哪个维度?

   (2)在某个维度上划分时,如何保证左右子树节点个数尽量相等?

 

   首先来看问题(1)每次划分时,应该选择哪个维度 ?

 

   最简单的做法就是一个维度一个维度轮流着来,但是仔细想想,这种方法不能很好地解决问题。假设有这样一

   种情况:我们需要切一个豆腐条,长度要远远大于宽度,要想把它切成尽量相同的小块,显然是先按照长度来

   切,这样更合理,如果宽度比较窄,那么这种效果更明显。所以在K-D树中,每次选取属性跨度最大的那个来

   进行划分,而衡量这个跨度的标准是什么? 无论是从数学上还是人的直观感受方面来说,如果某个属性的跨度

   越大,也就是说越分散,那么这组数据的方差就越大,所以在K-D树进行划分时,可以每次选择方差最大的属性

   来划分数据到左右子树。

 

   问题(1)已解决,现在再来看问题(2),在某个维度上划分时,如何保证左右子树节点个数尽量相等?

 

   当我们选择好划分的属性时,还要根据某个值来进行左右子树划分,而这个值就是一个划分轴,回忆一下,在快

   速排序算法中,也有一个划分轴pivot。在K-D树的划分中,这个轴的选取很关键,要保证划分后的左右子树尽

   量平衡,那么很显然选取这个属性的值对应数组的中位数作为pivot,就能保证这一点了。

 

   这样就解决了K-D树中最重要的两个问题。接下来看K-D树是如何进行查找的。

 

   假设现在已经构造好了一棵K-D树,最邻近查找的算法描述如下

 

   (1)将查询数据Q从根节点开始,按照Q与各个节点的比较结果向下遍历,直到到达叶子节点为止。到达叶子节

       点时,计算Q与叶子节点上保存的所有数据之间的距离,记录最小距离对应的数据点,假设当前最邻近点为

       p_cur,最小距离记为d_cur。

   (2)进行回溯操作,该操作的目的是找离Q更近的数据点,即在未访问过的分支里,是否还有离Q更近的点,它

       们的距离小于d_cur。

 

   以上就是K-D树的基本原理。

 

 

2. K-D树的改进(BBF算法)

 

   上述中的K-D树存在缺点,当维数比较大的时候,建树后的分支自然会增多,进而回溯的次数增加,算法效率会

   随之降低。在图像检索中,特征往往是高维的,很有必要对K-D树算法进行改进,这就是即将要介绍的BBF算法。

 

   BBF算法我就不详细说了,具体可以参考如下两篇文章

 

  (1)Kd-Tree算法原理和开源实现代码

   (2)从K近邻算法、距离度量谈到KD树、SIFT+BBF算法

 

 

3. K-D树的C++实现

 

   以HDU4347为例,给出K-D树的C++的简易代码。 题目:The Closest M Points

 

   代码:

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. #include <iostream>  
  2. #include <string.h>  
  3. #include <algorithm>  
  4. #include <stdio.h>  
  5. #include <math.h>  
  6. #include <queue>  
  7.    
  8. using namespace std;  
  9.    
  10. #define N 50005  
  11.    
  12. #define lson rt << 1  
  13. #define rson rt << 1 | 1  
  14. #define Pair pair<double, Node>  
  15. #define Sqrt2(x) (x) * (x)  
  16.    
  17. int n, k, idx;  
  18.    
  19. struct Node  
  20. {  
  21.     int feature[5];     //定义属性数组  
  22.     bool operator < (const Node &u) const  
  23.     {  
  24.         return feature[idx] < u.feature[idx];  
  25.     }  
  26. }_data[N];   //_data[]数组代表输入的数据  
  27.    
  28. priority_queue<Pair> Q;     //队列Q用于存放离p最近的m个数据  
  29.    
  30. class KDTree{  
  31.    
  32.     public:  
  33.         void Build(intintintint);     //建树  
  34.         void Query(Node, intintint);    //查询  
  35.    
  36.     private:  
  37.         Node data[4 * N];    //data[]数组代表K-D树的所有节点数据  
  38.         int flag[4 * N];      //用于标记某个节点是否存在,1表示存在,-1表示不存在  
  39. }kd;  
  40.    
  41. //建树步骤,参数dept代表树的深度  
  42. void KDTree::Build(int l, int r, int rt, int dept)  
  43. {  
  44.     if(l > r) return;  
  45.     flag[rt] = 1;                   //表示编号为rt的节点存在  
  46.     flag[lson] = flag[rson] = -1;   //当前节点的孩子暂时标记不存在  
  47.     idx = dept % k;                 //按照编号为idx的属性进行划分  
  48.     int mid = (l + r) >> 1;  
  49.     nth_element(_data + l, _data + mid, _data + r + 1);   //nth_element()为STL中的函数  
  50.     data[rt] = _data[mid];  
  51.     Build(l, mid - 1, lson, dept + 1);  //递归左子树  
  52.     Build(mid + 1, r, rson, dept + 1);  //递归右子树  
  53. }  
  54.    
  55. //查询函数,寻找离p最近的m个特征属性  
  56. void KDTree::Query(Node p, int m, int rt, int dept)  
  57. {  
  58.     if(flag[rt] == -1) return;   //不存在的节点不遍历  
  59.     Pair cur(0, data[rt]);       //获取当前节点的数据和到p的距离  
  60.     for(int i = 0; i < k; i++)  
  61.         cur.first += Sqrt2(cur.second.feature[i] - p.feature[i]);  
  62.     int dim = dept % k;          //跟建树一样,这样能保证相同节点的dim值不变  
  63.     bool fg = 0;                 //用于标记是否需要遍历右子树  
  64.     int x = lson;  
  65.     int y = rson;  
  66.     if(p.feature[dim] >= data[rt].feature[dim]) //数据p的第dim个特征值大于等于当前的数据,则需要进入右子树  
  67.         swap(x, y);  
  68.     if(~flag[x]) Query(p, m, x, dept + 1);      //如果节点x存在,则进入子树继续遍历  
  69.    
  70.     //以下是回溯过程,维护一个优先队列  
  71.     if(Q.size() < m)   //如果队列没有满,则继续放入  
  72.     {  
  73.         Q.push(cur);  
  74.         fg = 1;  
  75.     }  
  76.     else  
  77.     {  
  78.         if(cur.first < Q.top().first)  //如果找到更小的距离,则用于替换队列Q中最大的距离的数据  
  79.         {  
  80.             Q.pop();  
  81.             Q.push(cur);  
  82.         }  
  83.         if(Sqrt2(p.feature[dim] - data[rt].feature[dim]) < Q.top().first)  
  84.         {  
  85.             fg = 1;  
  86.         }  
  87.     }  
  88.     if(~flag[y] && fg)   
  89.         Query(p, m, y, dept + 1);  
  90. }  
  91.    
  92. //输出结果  
  93. void Print(Node data)  
  94. {  
  95.     for(int i = 0; i < k; i++)  
  96.         printf("%d%c", data.feature[i], i == k - 1 ? '\n' : ' ');  
  97. }  
  98.    
  99. int main()  
  100. {  
  101.     while(scanf("%d%d", &n, &k)!=EOF)  
  102.     {  
  103.         for(int i = 0; i < n; i++)  
  104.             for(int j = 0; j < k; j++)  
  105.                 scanf("%d", &_data[i].feature[j]);  
  106.         kd.Build(0, n - 1, 1, 0);  
  107.         int t, m;  
  108.         scanf("%d", &t);  
  109.         while(t--)  
  110.         {  
  111.             Node p;  
  112.             for(int i = 0; i < k; i++)  
  113.                 scanf("%d", &p.feature[i]);  
  114.             scanf("%d", &m);  
  115.             while(!Q.empty()) Q.pop();   //事先需要清空优先队列  
  116.             kd.Query(p, m, 1, 0);  
  117.             printf("the closest %d points are:\n", m);  
  118.             Node tmp[25];  
  119.             for(int i = 0; !Q.empty(); i++)  
  120.             {  
  121.                 tmp[i] = Q.top().second;  
  122.                 Q.pop();  
  123.             }  
  124.             for(int i = m - 1; i >= 0; i--)  
  125.                 Print(tmp[i]);  
  126.         }  
  127.     }  
  128.     return 0;  
  129. }  

 

题目:http://acm.hdu.edu.cn/showproblem.php?pid=2966 

 

题意:给定n个二维点,求每个点距离其它点的最近的距离。其中n <= 100000。

 

代码:

[java] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. import java.util.Arrays;  
  2. import java.util.Scanner;  
  3.   
  4. public class Main {  
  5.       
  6.     final static int SIZE = 100005;  
  7.     final static double EPS = 1e-10;  
  8.       
  9.     private boolean[] d = null;  
  10.     private Node[] p = null;  
  11.     private long res;  
  12.     private int index;  
  13.     private int size;  
  14.       
  15.     public class Node{  
  16.         private long[] x = null;  
  17.         Node(){  
  18.             x = new long[2];  
  19.         }  
  20.     }  
  21.       
  22.     Main(int size){  
  23.         d = new boolean[size];  
  24.         p = new Node[size];  
  25.         for(int i = 0; i < size; i++)  
  26.             p[i] = new Node();  
  27.     }  
  28.   
  29.     public void setSize(int size){  
  30.         this.size = size;  
  31.         Arrays.fill(d, false);  
  32.     }  
  33.       
  34.     public void clear(){  
  35.         res = Long.MAX_VALUE;  
  36.         index = 0;  
  37.     }  
  38.       
  39.     public void Insert(int id, Node t){  
  40.         p[id] = t;  
  41.     }  
  42.       
  43.     public Node get(int id){  
  44.         return p[id];  
  45.     }  
  46.   
  47.     public void InsertSort(Node a[], int id, int l, int r){  
  48.         for(int i = l + 1; i <= r; i++){  
  49.             if(a[i - 1].x[id] > a[i].x[id]){  
  50.                 Node t = new Node();  
  51.                 t = a[i];  
  52.                 int j = i;  
  53.                 while(j > l && a[j - 1].x[id] > t.x[id])  
  54.                 {  
  55.                     a[j] = a[j - 1];  
  56.                     j--;  
  57.                 }  
  58.                 a[j] = t;  
  59.             }  
  60.         }  
  61.     }  
  62.       
  63.     public Node FindMid(Node a[], int id, int l, int r)  
  64.     {  
  65.         if(l == r) return a[l];  
  66.         int i = 0;  
  67.         int n = 0;  
  68.         for(i = l; i < r - 5; i += 5)  
  69.         {  
  70.             InsertSort(a, id, i, i + 4);  
  71.             n = i - l;  
  72.   
  73.             Node t = new Node();  
  74.             t = a[l + n / 5];  
  75.             a[l + n / 5] = a[i + 2];  
  76.             a[i + 2] = t;  
  77.         }  
  78.   
  79.         int num = r - i + 1;  
  80.         if(num > 0)  
  81.         {  
  82.             InsertSort(a, id, i, i + num - 1);  
  83.             n = i - l;  
  84.   
  85.             Node t = new Node();  
  86.             t = a[l + n / 5];  
  87.             a[l + n / 5] = a[i + num / 2];  
  88.             a[i + num / 2] = t;  
  89.         }  
  90.         n /= 5;  
  91.         if(n == l) return a[l];  
  92.         return FindMid(a, id, l, l + n);  
  93.     }  
  94.       
  95.     public boolean Equals(Node a, Node b){  
  96.         if(Math.abs(a.x[0] - b.x[0]) > EPS)   
  97.             return false;  
  98.         if(Math.abs(a.x[1] - b.x[1]) > EPS)   
  99.             return false;  
  100.         return true;  
  101.     }  
  102.       
  103.     public int FindId(Node a[], int l, int r, Node num)  
  104.     {  
  105.         for(int i = l; i <= r; i++)  
  106.             if(Equals(a[i], num))  
  107.                 return i;  
  108.         return -1;  
  109.     }  
  110.       
  111.     public int Partion(Node a[], int id, int l, int r, int p)  
  112.     {  
  113.         Node t = new Node();  
  114.         t = a[p];  
  115.         a[p] = a[l];  
  116.         a[l] = t;  
  117.   
  118.         int i = l;  
  119.         int j = r;  
  120.         Node pivot = a[l];  
  121.         while(i < j)  
  122.         {  
  123.             while(a[j].x[id] >= pivot.x[id] && i < j)  
  124.                 j--;  
  125.             a[i] = a[j];  
  126.             while(a[i].x[id] <= pivot.x[id] && i < j)  
  127.                 i++;  
  128.             a[j] = a[i];  
  129.         }  
  130.         a[i] = pivot;  
  131.         return i;  
  132.     }  
  133.       
  134.     public Node BFPTR(Node a[], int id, int l, int r, int k)  
  135.     {  
  136.         if(l > r) return null;  
  137.         Node num = FindMid(a, id, l, r);    
  138.         int p =  FindId(a, l, r, num);   
  139.         int i = Partion(a, id, l, r, p);  
  140.   
  141.         int m = i - l + 1;  
  142.         if(m == k) return a[i];  
  143.         if(m > k)  return BFPTR(a, id, l, i - 1, k);  
  144.         return BFPTR(a, id, i + 1, r, k - m);  
  145.     }  
  146.   
  147.     public Node getInterval(Node p[], int id, int l, int r){  
  148.         Node t = new Node();  
  149.         long max = Long.MIN_VALUE;  
  150.         long min = Long.MAX_VALUE;  
  151.         for(int i = l; i <= r; i++){  
  152.             if(max < p[i].x[id]) max = p[i].x[id];  
  153.             if(min > p[i].x[id]) min = p[i].x[id];  
  154.         }  
  155.         t.x[0] = min;  
  156.         t.x[1] = max;  
  157.         return t;  
  158.     }  
  159.       
  160.     public long getDist(Node a, Node b){  
  161.         return (a.x[0] - b.x[0]) * (a.x[0] - b.x[0]) + (a.x[1] - b.x[1]) * (a.x[1] - b.x[1]);  
  162.     }  
  163.   
  164.     public void Build(Node p[], int l, int r){  
  165.         if(l > r) return;  
  166.         Node t1 = getInterval(p, 0, l, r);  
  167.         long minx = t1.x[0];  
  168.         long maxx = t1.x[1];  
  169.           
  170.         Node t2 = getInterval(p, 1, l, r);  
  171.         long miny = t2.x[0];  
  172.         long maxy = t2.x[1];  
  173.           
  174.         int mid = (l + r) >> 1;  
  175.         d[mid] = (maxx - minx > maxy - miny);  
  176.           
  177.         BFPTR(p, d[mid] ? 0 : 1, l, r, mid - l + 1);  
  178.           
  179.         Build(p, l, mid - 1);  
  180.         Build(p, mid + 1, r);  
  181.     }  
  182.       
  183.     public void Find(Node p[], Node t, int l, int r){  
  184.         if(l > r) return;  
  185.         int mid = (l + r) >> 1;  
  186.           
  187.         long dist = getDist(p[mid], t);  
  188.         long df = d[mid] ? (t.x[0] - p[mid].x[0]) : (t.x[1] - p[mid].x[1]);  
  189.           
  190.         if(dist > 0 && dist < res){  
  191.             res = dist;  
  192.             index = mid;  
  193.         }  
  194.           
  195.         int l1 = l;  
  196.         int r1 = mid - 1;  
  197.         int l2 = mid + 1;  
  198.         int r2 = r;  
  199.         if (df > 0){  
  200.   
  201.             l1 ^= l2;  
  202.             l2 ^= l1;  
  203.             l1 ^= l2;  
  204.               
  205.             r1 ^= r2;  
  206.             r2 ^= r1;  
  207.             r1 ^= r2;  
  208.         }  
  209.         Find(p, t, l1, r1);  
  210.         if (df * df < res) Find(p, t, l2, r2);  
  211.     }  
  212.       
  213.     public void Build(){  
  214.         Build(p, 0, size - 1);  
  215.     }  
  216.       
  217.     public int Search(Node t){  
  218.         clear();  
  219.         Find(p, t, 0, size - 1);  
  220.         return index;  
  221.     }  
  222.       
  223.     public static void main(String[] args){  
  224.   
  225.         Scanner cin = new Scanner(System.in);  
  226.         int t = cin.nextInt();  
  227.         Main kd = new Main(SIZE);  
  228.   
  229.         Node[] node = new Node[SIZE];  
  230.         for(int i = 0; i < SIZE; i++){  
  231.              node[i] = kd.new Node();  
  232.         }  
  233.   
  234.         while(t-- > 0){  
  235.             int n = cin.nextInt();  
  236.             kd.setSize(n);  
  237.             for(int i = 0; i < n; i++){  
  238.                 node[i].x[0] = cin.nextLong();  
  239.                 node[i].x[1] = cin.nextLong();  
  240.                 kd.Insert(i, node[i]);  
  241.             }  
  242.   
  243.             kd.Build();  
  244.             for(int i = 0; i < n; i++){  
  245.                 int id = kd.Search(node[i]);  
  246.                 System.out.println(kd.getDist(kd.get(id), node[i]));  
  247.             }  
  248.         }  
  249.           
  250.     }  
  251. }  

4. K-D树的开源框架介绍

 

   K-D树的一个比较好的C++框架可以戳这里。下载后,可以参考里面的examples文件夹中的代码学习使用。




0 0
原创粉丝点击