K-D树
来源:互联网 发布:淘宝详情模板下载 编辑:程序博客网 时间:2024/05/23 05:09
http://blog.csdn.net/acdreamers/article/details/44664645
K-D树,即K-Dimensional Tree,是一种高维索引树型数据结构。常用于大规模高维数据空间的最邻近或者K邻
近查找,例如图像检索中高维图像特征向量的K邻近匹配,对KNN算法的优化等。
Contents
1. K-D树的基本原理
2. K-D树的改进(BBF算法)
3. K-D树的C++实现
4. K-D树的开源框架介绍
1. K-D树的基本原理
K-D树实际上是一棵高维二叉搜索树,与普通二叉搜索树不同的是,树中存储的是一些K维数据。先回忆一下二
叉搜索树(BST),它是一棵具有如下性质的树
(1)若它的左子树不为空,那么左子树上所有节点的值均小于它的根节点的值。
(2)若它的右子树不为空。那么右子树上所有节点的值均大于它的根节点的值。
(3)它的左右子树也分别是一棵二叉搜索树。
二叉搜索树在建树时,按照上述规则分别插入即可。而在搜索时,从根节点开始往下查找。可以看出二叉搜索
树的建树平均时间复杂度为,最坏时间复杂度为,查找的平均时间复杂度为,
最坏时间复杂度为,由于二叉搜索树不是平衡的,可能退化为一条链,这种情况就是最坏情况了。
普通的二叉搜索树是一维的,当推广到K维后,就是我们的K-D树了。在K-D树中跟二叉搜索树差不多,也是将
一个K维的数据与根节点进行比较,然后划分的,这里的比较不是整体的比较,而是选择其中一个维度来进行比
较。那么在K-D树中我们需要解决两个重要的问题
(1)每一次划分时,应该选择哪个维度?
(2)在某个维度上划分时,如何保证左右子树节点个数尽量相等?
首先来看问题(1)每次划分时,应该选择哪个维度 ?
最简单的做法就是一个维度一个维度轮流着来,但是仔细想想,这种方法不能很好地解决问题。假设有这样一
种情况:我们需要切一个豆腐条,长度要远远大于宽度,要想把它切成尽量相同的小块,显然是先按照长度来
切,这样更合理,如果宽度比较窄,那么这种效果更明显。所以在K-D树中,每次选取属性跨度最大的那个来
进行划分,而衡量这个跨度的标准是什么? 无论是从数学上还是人的直观感受方面来说,如果某个属性的跨度
越大,也就是说越分散,那么这组数据的方差就越大,所以在K-D树进行划分时,可以每次选择方差最大的属性
来划分数据到左右子树。
问题(1)已解决,现在再来看问题(2),在某个维度上划分时,如何保证左右子树节点个数尽量相等?
当我们选择好划分的属性时,还要根据某个值来进行左右子树划分,而这个值就是一个划分轴,回忆一下,在快
速排序算法中,也有一个划分轴pivot。在K-D树的划分中,这个轴的选取很关键,要保证划分后的左右子树尽
量平衡,那么很显然选取这个属性的值对应数组的中位数作为pivot,就能保证这一点了。
这样就解决了K-D树中最重要的两个问题。接下来看K-D树是如何进行查找的。
假设现在已经构造好了一棵K-D树,最邻近查找的算法描述如下
(1)将查询数据Q从根节点开始,按照Q与各个节点的比较结果向下遍历,直到到达叶子节点为止。到达叶子节
点时,计算Q与叶子节点上保存的所有数据之间的距离,记录最小距离对应的数据点,假设当前最邻近点为
p_cur,最小距离记为d_cur。
(2)进行回溯操作,该操作的目的是找离Q更近的数据点,即在未访问过的分支里,是否还有离Q更近的点,它
们的距离小于d_cur。
以上就是K-D树的基本原理。
2. K-D树的改进(BBF算法)
上述中的K-D树存在缺点,当维数比较大的时候,建树后的分支自然会增多,进而回溯的次数增加,算法效率会
随之降低。在图像检索中,特征往往是高维的,很有必要对K-D树算法进行改进,这就是即将要介绍的BBF算法。
BBF算法我就不详细说了,具体可以参考如下两篇文章
(1)Kd-Tree算法原理和开源实现代码
(2)从K近邻算法、距离度量谈到KD树、SIFT+BBF算法
3. K-D树的C++实现
以HDU4347为例,给出K-D树的C++的简易代码。 题目:The Closest M Points
代码:
- #include <iostream>
- #include <string.h>
- #include <algorithm>
- #include <stdio.h>
- #include <math.h>
- #include <queue>
- using namespace std;
- #define N 50005
- #define lson rt << 1
- #define rson rt << 1 | 1
- #define Pair pair<double, Node>
- #define Sqrt2(x) (x) * (x)
- int n, k, idx;
- struct Node
- {
- int feature[5]; //定义属性数组
- bool operator < (const Node &u) const
- {
- return feature[idx] < u.feature[idx];
- }
- }_data[N]; //_data[]数组代表输入的数据
- priority_queue<Pair> Q; //队列Q用于存放离p最近的m个数据
- class KDTree{
- public:
- void Build(int, int, int, int); //建树
- void Query(Node, int, int, int); //查询
- private:
- Node data[4 * N]; //data[]数组代表K-D树的所有节点数据
- int flag[4 * N]; //用于标记某个节点是否存在,1表示存在,-1表示不存在
- }kd;
- //建树步骤,参数dept代表树的深度
- void KDTree::Build(int l, int r, int rt, int dept)
- {
- if(l > r) return;
- flag[rt] = 1; //表示编号为rt的节点存在
- flag[lson] = flag[rson] = -1; //当前节点的孩子暂时标记不存在
- idx = dept % k; //按照编号为idx的属性进行划分
- int mid = (l + r) >> 1;
- nth_element(_data + l, _data + mid, _data + r + 1); //nth_element()为STL中的函数
- data[rt] = _data[mid];
- Build(l, mid - 1, lson, dept + 1); //递归左子树
- Build(mid + 1, r, rson, dept + 1); //递归右子树
- }
- //查询函数,寻找离p最近的m个特征属性
- void KDTree::Query(Node p, int m, int rt, int dept)
- {
- if(flag[rt] == -1) return; //不存在的节点不遍历
- Pair cur(0, data[rt]); //获取当前节点的数据和到p的距离
- for(int i = 0; i < k; i++)
- cur.first += Sqrt2(cur.second.feature[i] - p.feature[i]);
- int dim = dept % k; //跟建树一样,这样能保证相同节点的dim值不变
- bool fg = 0; //用于标记是否需要遍历右子树
- int x = lson;
- int y = rson;
- if(p.feature[dim] >= data[rt].feature[dim]) //数据p的第dim个特征值大于等于当前的数据,则需要进入右子树
- swap(x, y);
- if(~flag[x]) Query(p, m, x, dept + 1); //如果节点x存在,则进入子树继续遍历
- //以下是回溯过程,维护一个优先队列
- if(Q.size() < m) //如果队列没有满,则继续放入
- {
- Q.push(cur);
- fg = 1;
- }
- else
- {
- if(cur.first < Q.top().first) //如果找到更小的距离,则用于替换队列Q中最大的距离的数据
- {
- Q.pop();
- Q.push(cur);
- }
- if(Sqrt2(p.feature[dim] - data[rt].feature[dim]) < Q.top().first)
- {
- fg = 1;
- }
- }
- if(~flag[y] && fg)
- Query(p, m, y, dept + 1);
- }
- //输出结果
- void Print(Node data)
- {
- for(int i = 0; i < k; i++)
- printf("%d%c", data.feature[i], i == k - 1 ? '\n' : ' ');
- }
- int main()
- {
- while(scanf("%d%d", &n, &k)!=EOF)
- {
- for(int i = 0; i < n; i++)
- for(int j = 0; j < k; j++)
- scanf("%d", &_data[i].feature[j]);
- kd.Build(0, n - 1, 1, 0);
- int t, m;
- scanf("%d", &t);
- while(t--)
- {
- Node p;
- for(int i = 0; i < k; i++)
- scanf("%d", &p.feature[i]);
- scanf("%d", &m);
- while(!Q.empty()) Q.pop(); //事先需要清空优先队列
- kd.Query(p, m, 1, 0);
- printf("the closest %d points are:\n", m);
- Node tmp[25];
- for(int i = 0; !Q.empty(); i++)
- {
- tmp[i] = Q.top().second;
- Q.pop();
- }
- for(int i = m - 1; i >= 0; i--)
- Print(tmp[i]);
- }
- }
- return 0;
- }
题目:http://acm.hdu.edu.cn/showproblem.php?pid=2966
题意:给定n个二维点,求每个点距离其它点的最近的距离。其中n <= 100000。
代码:
- import java.util.Arrays;
- import java.util.Scanner;
- public class Main {
- final static int SIZE = 100005;
- final static double EPS = 1e-10;
- private boolean[] d = null;
- private Node[] p = null;
- private long res;
- private int index;
- private int size;
- public class Node{
- private long[] x = null;
- Node(){
- x = new long[2];
- }
- }
- Main(int size){
- d = new boolean[size];
- p = new Node[size];
- for(int i = 0; i < size; i++)
- p[i] = new Node();
- }
- public void setSize(int size){
- this.size = size;
- Arrays.fill(d, false);
- }
- public void clear(){
- res = Long.MAX_VALUE;
- index = 0;
- }
- public void Insert(int id, Node t){
- p[id] = t;
- }
- public Node get(int id){
- return p[id];
- }
- public void InsertSort(Node a[], int id, int l, int r){
- for(int i = l + 1; i <= r; i++){
- if(a[i - 1].x[id] > a[i].x[id]){
- Node t = new Node();
- t = a[i];
- int j = i;
- while(j > l && a[j - 1].x[id] > t.x[id])
- {
- a[j] = a[j - 1];
- j--;
- }
- a[j] = t;
- }
- }
- }
- public Node FindMid(Node a[], int id, int l, int r)
- {
- if(l == r) return a[l];
- int i = 0;
- int n = 0;
- for(i = l; i < r - 5; i += 5)
- {
- InsertSort(a, id, i, i + 4);
- n = i - l;
- Node t = new Node();
- t = a[l + n / 5];
- a[l + n / 5] = a[i + 2];
- a[i + 2] = t;
- }
- int num = r - i + 1;
- if(num > 0)
- {
- InsertSort(a, id, i, i + num - 1);
- n = i - l;
- Node t = new Node();
- t = a[l + n / 5];
- a[l + n / 5] = a[i + num / 2];
- a[i + num / 2] = t;
- }
- n /= 5;
- if(n == l) return a[l];
- return FindMid(a, id, l, l + n);
- }
- public boolean Equals(Node a, Node b){
- if(Math.abs(a.x[0] - b.x[0]) > EPS)
- return false;
- if(Math.abs(a.x[1] - b.x[1]) > EPS)
- return false;
- return true;
- }
- public int FindId(Node a[], int l, int r, Node num)
- {
- for(int i = l; i <= r; i++)
- if(Equals(a[i], num))
- return i;
- return -1;
- }
- public int Partion(Node a[], int id, int l, int r, int p)
- {
- Node t = new Node();
- t = a[p];
- a[p] = a[l];
- a[l] = t;
- int i = l;
- int j = r;
- Node pivot = a[l];
- while(i < j)
- {
- while(a[j].x[id] >= pivot.x[id] && i < j)
- j--;
- a[i] = a[j];
- while(a[i].x[id] <= pivot.x[id] && i < j)
- i++;
- a[j] = a[i];
- }
- a[i] = pivot;
- return i;
- }
- public Node BFPTR(Node a[], int id, int l, int r, int k)
- {
- if(l > r) return null;
- Node num = FindMid(a, id, l, r);
- int p = FindId(a, l, r, num);
- int i = Partion(a, id, l, r, p);
- int m = i - l + 1;
- if(m == k) return a[i];
- if(m > k) return BFPTR(a, id, l, i - 1, k);
- return BFPTR(a, id, i + 1, r, k - m);
- }
- public Node getInterval(Node p[], int id, int l, int r){
- Node t = new Node();
- long max = Long.MIN_VALUE;
- long min = Long.MAX_VALUE;
- for(int i = l; i <= r; i++){
- if(max < p[i].x[id]) max = p[i].x[id];
- if(min > p[i].x[id]) min = p[i].x[id];
- }
- t.x[0] = min;
- t.x[1] = max;
- return t;
- }
- public long getDist(Node a, Node b){
- return (a.x[0] - b.x[0]) * (a.x[0] - b.x[0]) + (a.x[1] - b.x[1]) * (a.x[1] - b.x[1]);
- }
- public void Build(Node p[], int l, int r){
- if(l > r) return;
- Node t1 = getInterval(p, 0, l, r);
- long minx = t1.x[0];
- long maxx = t1.x[1];
- Node t2 = getInterval(p, 1, l, r);
- long miny = t2.x[0];
- long maxy = t2.x[1];
- int mid = (l + r) >> 1;
- d[mid] = (maxx - minx > maxy - miny);
- BFPTR(p, d[mid] ? 0 : 1, l, r, mid - l + 1);
- Build(p, l, mid - 1);
- Build(p, mid + 1, r);
- }
- public void Find(Node p[], Node t, int l, int r){
- if(l > r) return;
- int mid = (l + r) >> 1;
- long dist = getDist(p[mid], t);
- long df = d[mid] ? (t.x[0] - p[mid].x[0]) : (t.x[1] - p[mid].x[1]);
- if(dist > 0 && dist < res){
- res = dist;
- index = mid;
- }
- int l1 = l;
- int r1 = mid - 1;
- int l2 = mid + 1;
- int r2 = r;
- if (df > 0){
- l1 ^= l2;
- l2 ^= l1;
- l1 ^= l2;
- r1 ^= r2;
- r2 ^= r1;
- r1 ^= r2;
- }
- Find(p, t, l1, r1);
- if (df * df < res) Find(p, t, l2, r2);
- }
- public void Build(){
- Build(p, 0, size - 1);
- }
- public int Search(Node t){
- clear();
- Find(p, t, 0, size - 1);
- return index;
- }
- public static void main(String[] args){
- Scanner cin = new Scanner(System.in);
- int t = cin.nextInt();
- Main kd = new Main(SIZE);
- Node[] node = new Node[SIZE];
- for(int i = 0; i < SIZE; i++){
- node[i] = kd.new Node();
- }
- while(t-- > 0){
- int n = cin.nextInt();
- kd.setSize(n);
- for(int i = 0; i < n; i++){
- node[i].x[0] = cin.nextLong();
- node[i].x[1] = cin.nextLong();
- kd.Insert(i, node[i]);
- }
- kd.Build();
- for(int i = 0; i < n; i++){
- int id = kd.Search(node[i]);
- System.out.println(kd.getDist(kd.get(id), node[i]));
- }
- }
- }
- }
4. K-D树的开源框架介绍
K-D树的一个比较好的C++框架可以戳这里。下载后,可以参考里面的examples文件夹中的代码学习使用。
- K-d树
- k-d树学习
- K-d树详解
- k-D树强烈推荐
- k-d树
- K-D树
- K-D树
- K-D树小结
- K-D树 Hdu4347
- k近邻与k-d树
- k近邻与k-d树
- K-D树 C++实现
- k-d树(hdu2966)
- k近邻算法——k-d 树的实现
- k-d树头文件C语言
- k-d树实现文件C语言
- k-d树算法的研究
- BZOJ 2626 JZPFAR K-D树
- PAT1021
- 杭电ACM----2017字符串统计
- 数据结构基础之推导遍历结果
- 刷题、OJ 各种A+b Problem、、
- 杭电ACM----2018母牛的故事
- K-D树
- 杭电ACM----2019 数列有序!
- 杭电ACM----2020 绝对值排序
- nyoj733圣诞派对
- 二叉树
- Spark实战-Spark SQL(一)
- java小的误区
- 快速求幂取模
- 图(拓扑排序)2