近邻搜索之制高点树(VP-Tree)

来源:互联网 发布:2016 淘宝最近怎么了 编辑:程序博客网 时间:2024/04/29 17:18

引子

近邻搜索是一种很基础的又相当重要的操作,除了信息检索以外,还被广泛用于计算机视觉、机器学习等领域,如何快速有效的做近邻查询一直是一项热门的研究。较早提出的方法多基于空间划分(Space Partition),最具有代表性的如kd-tree(kdt),球树等。本篇将介绍基于空间划分方法中的一种,制高点树(Vantage Point Tree,vpt),最初在1993年提出,比kdt稍晚,提供了一个不一样的建树思路。

VPT结构

和kdt一样,vpt也是一类二叉树,不同的是在每个节点的划分策略。略微回顾一下kdt,它在每个节点选择一个维度,根据数据点在该维度上的大小将数据均分为二。而在vpt中,首先从节点中选择一个数据点(可随机选)作为制高点(vp),然后算出其它点到vp的距离大小,最后根据该距离大小将数据点均分为二。建树算法如下:

  1. 选择某数据点v作为vp
  2. 计算其它点{Xi}到v的距离{Di}
  3. 求出{Di}中值M,小于M的数据点分给左子树,大于M的数据点分给右子树
  4. 递归地建立左子树和右子树
这里提供一个简单的例子如图,框中为平面上的点,其中红框为选中的vp,根据其它点到vp的距离进行了子树划分。

VPT查询算法


vpt查询是 准确近邻查询,较适合范围查询,可方便扩展为k近邻查询。

进行近邻查询时,假定查询点为q,当前的制高点为v,距离中值为M,则有如下策略搜索到q点距离小于r的点集:

(1)  若 dist(q,v)+r≥M,递归地搜索右子树(球外区域)

(2)  若 dist(q,v)-r≤M,递归地搜索左子树(球内区域)

为了方便写公式,用图片文字来进行证明,其实就是简单的三角形不等式的应用。

简易实现代码


最后上点干货,一个简易c++实现如下:
#ifndef _VPTREE_HEADER_#define _VPTREE_HEADER_#include <stdlib.h>#include <algorithm>#include <vector>#include <stdio.h>#include <queue>#include <limits>//#include "fnn.h"template<typename T, double (*distance)( const T&, const T& ), int (*getId)(const T&)>class VpTree{public:    VpTree() : _root(0) {}    ~VpTree() {        delete _root;    }    void create( const std::vector<T>& items ) {        delete _root;        _items = items;        _root = buildFromPoints(0, items.size());    }    void search( const T& target, int k, std::vector<T>* results,         std::vector<double>* distances)     {        std::priority_queue<HeapItem> heap;        _tau = std::numeric_limits<double>::max();        search( _root, target, k, heap );        results->clear(); distances->clear();        while( !heap.empty() ) {            results->push_back( _items[heap.top().index] );            distances->push_back( heap.top().dist );            heap.pop();        }        std::reverse( results->begin(), results->end() );        std::reverse( distances->begin(), distances->end() );printf("vp search dist = %f\n",distances->at(0));brute(target);    }void search(const T& target,std::vector<T>* results,std::vector<double>* distances){        int idx;double min = 1.0e+10;for(int i=0;i<_items.size();i++){double dist = distance( _items[i], target );if(dist<min){min=dist;idx = i;}}results->push_back(_items[idx]);distances->push_back(min);}int range_search(const T& target, double range, int *list, int &listnum){int hit = 0;for(int i=0;i<_items.size();i++){double dist = distance( _items[i], target );//debug here/*if(getId(_items[i])==4){printf("vp dist=%f range=%f\n",dist,range);}*///-debugif(dist<=range){  //inside, need to check//list[listnum++] = getId(_items[i]);int id = getId(_items[i]);list[id] = 1;listnum++;hit++;}}/*_tau = range;rsearch( _root, target, hit, list, listnum);*/return hit;}private:    std::vector<T> _items;    double _tau;    struct Node     {        int index;        double threshold;        Node* left;        Node* right;        Node() :            index(0), threshold(0.), left(0), right(0) {}        ~Node() {            delete left;            delete right;        }    }* _root;    struct HeapItem {        HeapItem( int index, double dist) :            index(index), dist(dist) {}        int index;        double dist;        bool operator<( const HeapItem& o ) const {            return dist < o.dist;           }    };    struct DistanceComparator    {        const T& item;        DistanceComparator( const T& item ) : item(item) {}        bool operator()(const T& a, const T& b) {            return distance( item, a ) < distance( item, b );        }    };    Node* buildFromPoints( int lower, int upper )    {        if ( upper == lower ) {            return NULL;        }        Node* node = new Node();        node->index = lower;        if ( upper - lower > 1 ) {            // choose an arbitrary point and move it to the start            int i = (int)((double)rand() / RAND_MAX * (upper - lower - 1) ) + lower;            std::swap( _items[lower], _items[i] );            int median = ( upper + lower ) / 2;            // partitian around the median distance            std::nth_element(                 _items.begin() + lower + 1,                 _items.begin() + median,                _items.begin() + upper,                DistanceComparator( _items[lower] ));            // what was the median?            node->threshold = distance( _items[lower], _items[median] );            node->index = lower;            node->left = buildFromPoints( lower + 1, median );            node->right = buildFromPoints( median, upper );        }        return node;    }double brute(const T& target){double min = 1.0e+10;for(int i=0;i<_items.size();i++){double dist = distance( _items[i], target );if(dist<min)min=dist;}return min;//printf("vp brute dist = %f\n",min);}void rsearch(Node* node, const T& target, int & counter, int *list, int &listnum){if ( node == NULL ) return;double dist = distance( _items[node->index], target );if ( dist < _tau ) {counter++;//list[ listnum++ ] = getId(_items[node->index]);list[getId(_items[node->index])] = 1;}if ( node->left == NULL && node->right == NULL ) {            return;        }        if ( dist < node->threshold ) {            if ( dist - _tau <= node->threshold ) {                rsearch( node->left, target, counter, list, listnum);            }            if ( dist + _tau >= node->threshold ) {                rsearch( node->right, target, counter, list, listnum );            }        } else {            if ( dist + _tau >= node->threshold ) {                rsearch( node->right, target, counter, list, listnum );            }            if ( dist - _tau <= node->threshold ) {                rsearch( node->left, target, counter, list, listnum);            }        }}    void search( Node* node, const T& target, int k,                 std::priority_queue<HeapItem>& heap )    {        if ( node == NULL ) return;        double dist = distance( _items[node->index], target );        //printf("dist=%g tau=%gn", dist, _tau );        if ( dist < _tau ) {            if ( heap.size() == k ) heap.pop();            heap.push( HeapItem(node->index, dist) );            if ( heap.size() == k ) _tau = heap.top().dist;        }        if ( node->left == NULL && node->right == NULL ) {            return;        }        if ( dist < node->threshold ) {            if ( dist - _tau <= node->threshold ) {                search( node->left, target, k, heap );            }            if ( dist + _tau >= node->threshold ) {                search( node->right, target, k, heap );            }        } else {            if ( dist + _tau >= node->threshold ) {                search( node->right, target, k, heap );            }            if ( dist - _tau <= node->threshold ) {                search( node->left, target, k, heap );            }        }    }};#endif

原创粉丝点击