泛型算法_k近邻_KD-Tree(kd树)

来源:互联网 发布:魏延为什么要谋反知乎 编辑:程序博客网 时间:2024/06/10 14:21

一、数据集和算法:


数据:

T={(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)}


创建KD树的算法比较容易看懂,参考这篇:点我

看了网上很多查询的算法,大多都是给的伪代码,很多都是互相抄的,也不一定正确。这里我自己写了一个可以运行的代码,测试了几次,没什么问题。


我贴一个比较靠谱的最近邻算法(后面的代码给的是一个k近邻查询,但是原理都差不多):

(1)将查询数据Q从根结点开始,按照Q与各个结点的比较结果向下访问Kd-Tree,直至达到叶子结点。
其中Q与结点的比较指的是将Q对应于结点中的k维度上的值与m进行比较,若Q(k) < m,则访问左子树,否则访问右子树。达到叶子结点时,计算Q与叶子结点上保存的数据之间的距离,记录下最小距离对应的数据点,记为当前“最近邻点”Pcur和最小距离Dcur。
(2)进行回溯(Backtracking)操作,该操作是为了找到离Q更近的“最近邻点”。即判断未被访问过的分支里是否还有离Q更近的点,它们之间的距离小于Dcur。
如果Q与其父结点下的未被访问过的分支之间的距离小于Dcur,则认为该分支中存在离P更近的数据,进入该结点,进行(1)步骤一样的查找过程,如果找到更近的数据点,则更新为当前的“最近邻点”Pcur,并更新Dcur。
如果Q与其父结点下的未被访问过的分支之间的距离大于Dcur,则说明该分支内不存在与Q更近的点。
回溯的判断过程是从下往上进行的,直到回溯到根结点时已经不存在与P更近的分支为止。


下面是运行结果:




给一个可以直接编译运行的VS2008代码:点我下载


一、构造一些类模板和函数模板以便后面计算方便


关于函数对象这一块的内容,请参考《C++标准程序库》、《STL源码解析》

这里大家只要了解bind3rd的用法,由于bind3rd标准库并未提供,需要自己动手写。不懂用法,可以百度搜索下bind2nd函数,以便参考。

1. 头文件 myfunctional.hpp


#include <functional>template<class _Arg1,class _Arg2,class _Arg3,class _Result>struct tenary_function{typedef _Arg1 first_argument_type;typedef _Arg2 second_argument_type;typedef _Arg3 third_argument_type;typedef _Result result_type;};template<class _Operation>class binder3rd : public std::binary_function<typename _Operation::first_argument_type, typename _Operation::second_argument_type, typename _Operation::result_type>{protected:_Operation op;typename _Operation::third_argument_type value;public:binder3rd(const _Operation &_Func, const typename _Operation::third_argument_type &_Third) : op(_Func), value(_Third){}typename _Operation::result_type operator()(const typename _Operation::first_argument_type &__x, const typename _Operation::second_argument_type &__y) const{return op(__x, __y, value);}};//将三元函数对象适配成二元函数对象template<class _Operation, class _Ty> inlinebinder3rd<_Operation> bind3rd(const _Operation& _Func, const _Ty& _Third){typename _Operation::third_argument_type _Val(_Third);return (binder3rd<_Operation>(_Func, _Val));}


2.heap.hpp

最大堆类模板,用于存放前k近邻,这部分代码是我用的@江南烟雨 ,部分代码我做了修改为了方便KdTree使用

关于最大堆的概念不清楚可以参考他的博客:点我进入

#pragma once//STL堆算法实现(大顶堆)//包含容器vector的头文件:Heap用vector来存储元素#include <vector>#include <functional>#define MAX_VALUE 999999 //某个很大的值,存放在vector的第一个位置(最大堆)const int StartIndex = 1;//容器中堆元素起始索引using namespace std;//堆类定义//默认比较规则lesstemplate <class ElemType,class Compare = less<ElemType> >class MyHeap{private:vector<ElemType> heapDataVec;//存放元素的容器int numCounts;//堆中元素个数Compare comp;//比较规则public:MyHeap();vector<ElemType> getVec();bool empty();int size();void initHeap(ElemType *data,const int n);//初始化操作void makeHeap();//建堆void pushHeap(ElemType elem);//向堆中插入元素void popHeap();//删除堆顶的元素void clear();vector<ElemType> sortHeap();ElemType getTop();//获取堆顶元素private:void adjustHeap(int childTree,ElemType adjustValue);//调整子树void percolateUp(int holeIndex,ElemType adjustValue);//上溯操作};template <class ElemType,class Compare>vector<ElemType> MyHeap<ElemType, Compare>::sortHeap(){std::vector<ElemType> result(numCounts);for (int i = numCounts - 1; i >=0 ; --i){ElemType topElem = getTop();popHeap();result[i] = topElem;}return result;}template <class ElemType,class Compare>void MyHeap<ElemType, Compare>::clear(){heapDataVec.clear();ElemType e;heapDataVec.push_back(e);numCounts = 0;}template <class ElemType,class Compare>int MyHeap<ElemType, Compare>::size(){return numCounts;}template <class ElemType,class Compare>bool MyHeap<ElemType, Compare>::empty(){return numCounts == 0 ? true : false;}template <class ElemType,class Compare>ElemType MyHeap<ElemType, Compare>::getTop(){return heapDataVec[1];}template <class ElemType,class Compare>MyHeap<ElemType,Compare>::MyHeap():numCounts(0){ElemType e;heapDataVec.push_back(e);}template <class ElemType,class Compare>vector<ElemType> MyHeap<ElemType,Compare>::getVec(){return heapDataVec;}template <class ElemType,class Compare>void MyHeap<ElemType,Compare>::initHeap(ElemType *data,const int n){//拷贝元素数据到vector中for (int i = 0;i < n;++i){heapDataVec.push_back(*(data + i));++numCounts;}}template <class ElemType,class Compare>void MyHeap<ElemType,Compare>::makeHeap(){//建堆的过程就是一个不断调整堆的过程,循环调用函数adjustHeap依次调整子树if (numCounts < 2)return;//第一个需要调整的子树的根节点多音int parent = numCounts / 2;while(1){adjustHeap(parent,heapDataVec[parent]);if (StartIndex == parent)//到达根节点return;--parent;}}template <class ElemType,class Compare>void MyHeap<ElemType,Compare>::pushHeap(ElemType elem){//将新元素添加到vector中heapDataVec.push_back(elem);++numCounts;//执行一次上溯操作,调整堆,以使其满足最大堆的性质percolateUp(numCounts,heapDataVec[numCounts]);}template <class ElemType,class Compare>void MyHeap<ElemType,Compare>::popHeap(){//将堆顶的元素放在容器的最尾部,然后将尾部的原元素作为调整值,重新生成堆ElemType adjustValue = heapDataVec[numCounts];//堆顶元素为容器的首元素heapDataVec[numCounts] = heapDataVec[StartIndex];//堆中元素数目减一--numCounts;adjustHeap(StartIndex,adjustValue);//直接删除heapDataVec.pop_back();}//调整以childTree为根的子树为堆template <class ElemType,class Compare>void MyHeap<ElemType,Compare>::adjustHeap(int childTree,ElemType adjustValue){//洞节点索引int holeIndex = childTree;int secondChid = 2 * holeIndex + 1;//洞节点的右子节点(注意:起始索引从1开始)while(secondChid <= numCounts){if (comp(heapDataVec[secondChid],heapDataVec[secondChid - 1])){--secondChid;//表示两个子节点中值较大的那个}//上溯heapDataVec[holeIndex] = heapDataVec[secondChid];//令较大值为洞值holeIndex = secondChid;//洞节点索引下移secondChid = 2 * secondChid + 1;//重新计算洞节点右子节点}//如果洞节点只有左子节点if (secondChid == numCounts + 1){//令左子节点值为洞值heapDataVec[holeIndex] = heapDataVec[secondChid - 1];holeIndex = secondChid - 1;}//将调整值赋予洞节点heapDataVec[holeIndex] = adjustValue;//此时可能尚未满足堆的特性,需要再执行一次上溯操作percolateUp(holeIndex,adjustValue);}//上溯操作template <class ElemType,class Compare>void MyHeap<ElemType,Compare>::percolateUp(int holeIndex,ElemType adjustValue){//将新节点与其父节点进行比较,如果键值比其父节点大,就父子交换位置。//如此,知道不需要对换或直到根节点为止int parentIndex = holeIndex / 2;while(holeIndex > StartIndex && comp(heapDataVec[parentIndex],adjustValue)){heapDataVec[holeIndex] = heapDataVec[parentIndex];holeIndex = parentIndex;parentIndex /= 2;}heapDataVec[holeIndex] = adjustValue;//将新值放置在正确的位置}


3. KdTree.hpp

这个模板类接受任意数据类型,客户端需要自己继承该类并重写虚方法

#pragma once#include <vector>#include <stack>#include <algorithm>#include <cmath>#include "myfunctional.hpp"#include "heap.hpp"#define INFINITE 0xFFFFFFFFtemplate<class DataType, unsigned N>class KdTree;template<class DataType, unsigned N>class KdNode{friend KdTree<DataType, N>;public:~KdNode(){if (_left != NULL){delete _left;_left = NULL;}if (_right != NULL){delete _right;_right = NULL;}}private:std::vector<DataType> _data;int _split;KdNode<DataType, N>* _left;KdNode<DataType, N>* _right;};//////////////////////////////////////////////////////////////////////////template<class DataType, unsigned N>class KdTree{public:KdTree();virtual ~KdTree();//数据必须能够度量距离virtual double getDist(const std::vector<DataType> &first, const std::vector<DataType> &second) = 0;virtual double getDist(const DataType &first, const DataType &second) = 0;//任一维度之间可比较大小virtual bool less(const DataType &first, const DataType &second) const = 0;void createKdTree(const std::vector<DataType> *dataset, int size);std::vector<std::pair<double, std::vector<DataType>>> query(const std::vector<DataType> &queryData, int k);//寻找split维度上的中位数std::vector<DataType> getMedium(std::vector<DataType> *first, std::vector<DataType> *last, int split);private:KdTree(const KdTree<DataType, N>&);KdTree<DataType, N>& operator=(KdTree<DataType, N>&);KdNode<DataType, N>* createKdTree(std::vector<DataType> *first, std::vector<DataType> *last, int split);private:KdNode<DataType, N> *_head;std::vector<DataType> *_copydata;std::stack<KdNode<DataType, N>*> _search_path;MyHeap<std::pair<double, std::vector<DataType>>> _heap;};//按维度排序准则template<class DataType, unsigned N>struct _less : public tenary_function<std::vector<DataType>, std::vector<DataType>, int, bool>{const KdTree<DataType, N> *_kdTree;bool operator()(const std::vector<DataType> &__x, const std::vector<DataType> &__y, const int __z) const {return _kdTree->less(__x[__z], __y[__z]);}_less(const KdTree<DataType, N> *kdTree) : _kdTree(kdTree){}};template<class DataType, unsigned N>std::vector<std::pair<double, std::vector<DataType>>> KdTree<DataType, N>::query(const std::vector<DataType> &queryData, int k){_heap.clear();KdNode<DataType, N> *p = _head;KdNode<DataType, N> *curNearest = NULL;double minDist = INFINITE;//查询至叶节点while(p != NULL){_search_path.push(p);if (queryData[p->_split] < p->_data[p->_split]){p = p->_left;}else{p = p->_right;}}if (!_search_path.empty()){curNearest = _search_path.top();_search_path.pop();minDist = getDist(curNearest->_data, queryData);_heap.pushHeap(std::make_pair(minDist, curNearest->_data));}KdNode<DataType, N>* backPoint = NULL;while(!_search_path.empty()){backPoint = _search_path.top(); _search_path.pop();double temp = getDist(backPoint->_data, queryData);//如果堆小于k, 直接添加到堆if (_heap.size() < k){_heap.pushHeap(std::make_pair(temp, backPoint->_data));}else{// 如果距离小于堆顶元素,则删除堆顶元素,添加此元素std::pair<double, std::vector<DataType>> topElement = _heap.getTop();minDist = topElement.first;if (temp < minDist){_heap.popHeap();_heap.pushHeap(std::make_pair(temp, backPoint->_data));}}std::pair<double, std::vector<DataType>> topElement = _heap.getTop();minDist = topElement.first;//更新最小超球if (temp < minDist){minDist = temp;curNearest = backPoint;}//查看backPoint所在维度的超平面是否和当前最小超球相交,若相交则进入另一半空间查找if (getDist(backPoint->_data[backPoint->_split], queryData[backPoint->_split]) <= minDist){//当前节点是否在左子空间,如果在则进入右子空间继续搜索直至叶结点,如果不在则进入左子空间搜索直至叶结点if (queryData[backPoint->_split] < backPoint->_data[backPoint->_split]){p = backPoint->_right;}else{p = backPoint->_left;}//搜索至叶节点while(p != NULL){_search_path.push(p);if (queryData[p->_split] < p->_data[p->_split]){p = p->_left;}else{p = p->_right;}}}}std::vector<std::pair<double, std::vector<DataType>>> result = _heap.sortHeap();return result;}template<class DataType, unsigned N>KdTree<DataType, N>::~KdTree(){if(_head != NULL){delete _head;}if (_copydata != NULL){delete[] _copydata;}}template<class DataType, unsigned N>KdTree<DataType, N>::KdTree() : _head(NULL), _copydata(NULL){}template<class DataType, unsigned N>std::vector<DataType> KdTree<DataType, N>::getMedium( std::vector<DataType> *first, std::vector<DataType> *last, int split ){std::size_t size = last - first;std::sort(first, last, bind3rd(_less<DataType, N>(this), split));return *(first+size/2);}template<class DataType, unsigned N>KdNode<DataType, N>* KdTree<DataType, N>::createKdTree(std::vector<DataType> *first, std::vector<DataType> *last, int split){if (first == last){return NULL;}std::size_t size = last  - first;KdNode<DataType, N>* newNode = new KdNode<DataType, N>;std::vector<DataType> data = getMedium(first, last, split);newNode->_split = split;newNode->_data = data;newNode->_left = createKdTree(first, first + size/2, (split+1)%N);newNode->_right = createKdTree(first + size/2 + 1, last, (split+1)%N);return newNode;}template<class DataType, unsigned N>void KdTree<DataType, N>::createKdTree(const std::vector<DataType> *dataset, int size){_copydata = new std::vector<DataType>[size];std::copy(dataset, dataset + size, _copydata);_head = createKdTree(_copydata, _copydata + size, 0);}

二、客户端实现

// kd_tree.cpp : 定义控制台应用程序的入口点。//#include "stdafx.h"#include <iostream>#include <string>#include "KdTree.hpp"class MyKdTree : public KdTree<double, 2>{public:virtual double getDist(const std::vector<double> &first, const std::vector<double> &second){double sum = 0;for (std::size_t i = 0; i < first.size(); ++ i){sum += std::pow(first[i]-second[i], 2);}return std::sqrt(sum);}virtual double getDist(const double &first, const double &second){return fabs(first - second);}virtual bool less(const double &first, const double &second) const{return first < second;}};void DoubleKdTree();int _tmain(int argc, _TCHAR* argv[]){DoubleKdTree();return 0;}void DoubleKdTree(){//创建数据double dataset[6][2] = {2.0, 3.0,5.0, 4.0,9.0, 6.0,4.0, 7.0,8.0, 1.0,7.0, 2.0,};MyKdTree myKdTree;std::vector<double> vDataSet[6];for (int i = 0; i < 6; ++i){double *p = (double*)(&dataset[i]);std::vector<double> temp(p, p + 2);vDataSet[i] = temp;}//构建KD树myKdTree.createKdTree(vDataSet, 6);double data[2] = {0};int k = 1;while (data[0] != -1 && data[1] != -1){std::cout << "输入一个二维数据:";std::cin >> data[0] >> data[1];std::cout << "\n输入第k近邻的k值:";std::cin >> k;std::vector<double> test;test.push_back(data[0]);test.push_back(data[1]);std::vector<std::pair<double, std::vector<double>>> result = myKdTree.query(test, k);for (std::size_t i = 0; i < result.size(); ++i){std::cout << "距离: " << result[i].first << "\t";std::cout << "[" << result[i].second[0] << ", " << result[i].second[1] << "]" << std::endl;}std::cout << std::endl;}}


0 0
原创粉丝点击