KD tree算法(1)-简介&构建KD tree

来源:互联网 发布:weka java api文档 编辑:程序博客网 时间:2024/06/07 14:57

KD tree算法是KNN(K-nearest neighbor)实现的重要算法之一,下面我们先简单介绍一些KNN的知识,然后开始我们KD tree的讲解。

KNN分类算法

KNN是一种简单的分类方法:

分类时,对新的实例,根据k个最近邻的训练实例的类别,通过多数表决等方式进行预测。

以上是《统计学习方法》中对KNN的解释,简单明了。那么问题就在于如何快速有效找到K个近邻就是该算法的关键了。关于KNN算法的详细知识,请看另一篇文章《统计学习方法之K近邻法》,下面我们来简单学习一下如何快速找到K个近邻的样本。

KD Tree

为了对训练数据进行快速k近邻搜索,我们使用特殊的数据结构存储训练数据-kd Tree方法。

kd Tree是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。kd树是二叉树,表示对k维空间的一个划分(partition),构造kd Tree相当于奖k维空间划分,构成一系列的k维超矩形区域。

以上是《统计学习方法》对KD Tree的解释,同样简单明了。
KD Tree方法即将空间根据坐标不断划分,划分成k块,如下图所示:

这里写图片描述

  1. 构建 KD-Tree
    这里写图片描述
    这里写图片描述

    以上是《统计学习方法》书中对构造平衡kd Tree的算法描述。我们进行分析。
    1.1 为什么要构造“平衡“kd Tree?
    学过快速排序的读者们都知道,当我们取得的K值刚好保证左右两边的数的个数相等时,快排算法的时间复杂度最低,而当K值取到最大或最小时算法时间复杂度最高。为此提出了随机快速排序算法。以此为基础,我们可以分析知道为什么要构造“平衡“kd Tree了。很简单,为了提高搜索效率。
    为了更加简单明了的解释,我们据如下例子:
    1 2 3 4 5 6 7 8 9
    我们将以上数字分为两部分:
    1) >=5;
    2) <5;
    为了找到数字3,如果我们知道3比5小,那么我们只需要搜索5左边的四个数字即可。
    如果我们将以上数字分成以下两部分:
    1) >=9;
    2) <9;
    我们知道3比9小,那么我们需要搜索9左边的8个数字,这样的效率明显没有上一种的效率高。
    2.2 如何构造平衡kd Tree?
    中位数法。

    中位数:将一组数据按大小顺序排列,处在最中间位置的一个数叫做这组数据的中位数 。

    这就是中位数的定义,初中的知识了。那么如何快速找到中位数呢?先排序后查找?还是找到前n/2个数然后得到中位数呢?显然效率不高。我们利用快速排序 的思想进行求解。先给出代码:

private KDNode findMiD(int begin, int end, int flag) {        if (begin >= end) {return null;}        KDNode lastNode = set.get(end-1);//得到该树节点样本节点集最后一个节点        int dia = flag%(KDNode.dimension);//得到该树节点分割的维数        float keyValue = lastNode.get(dia);//得到分割点//一趟快排算法中的交换部分        int LastSmall = begin-1;        for (int i=begin; i<end-1; i++) {            if (set.get(i).get(dia) < keyValue) {                exchange(++LastSmall, i);            }        }        exchange(end-1, ++LastSmall);//快排算法中的交换部分结束//如果该分割点正好为中位数,则返回该位置的样本节点,如果中位数点小于分割点,则说明中位数在前半部分,故递归搜索前半部分,否则搜索后半部分。        if (midPos == LastSmall)            return set.get(midPos);        else if (midPos < LastSmall)            return findMiD(begin, LastSmall, flag);        else            return findMiD(LastSmall+1, end, flag);    }

这是java实现的代码。我们先不管各种node是什么意思,先来看看思想是什么。
这里写图片描述
根据上图及注释应该很好理解了。
下面我们来说一下我的关于KD tree的设计思想。
结构图:
结构图

样本信息:
KDNode

class KDNode {//每一个数据    public static int dimension = 1;//存储维度信息,每个数据维度相同,故使用static。但我还不知道如何保证一旦设置不需修改。    float[] coordinate;//存储每一维度的数值    KDNode() {        coordinate = new float[1];    }//默认数据为一维    KDNode(int dimension) {        this.dimension = dimension;        coordinate = new float[dimension];    }    void set(int pos, float val) { coordinate[pos] = val; }    float get(int pos) { return coordinate[pos]; }    @Override    public String toString() {        String s = " ";        for (int i=0; i<dimension; i++)            s += ("  " + coordinate[i]);        return s;    }}

比较简单,一目了然。

样本集:
KDnode

class KDNodeSet {//用来存节点的集合    ArrayList<KDNode> set;//样本集    int midPos;//中位数位置    KDNodeSet() {        set = new ArrayList<KDNode>();    }    KDNodeSet(KDNodeSet Nodeset) {        this.set = Nodeset.set;    }//拷贝构造函数    void add(KDNode node) {        set.add(node);    }//添加    KDNode findMiD(int flag) {        midPos = set.size()/2 ;        return findMiD(0, set.size(), flag);    }//提供给外部的找中位数的方法,flag表示维度    private KDNode findMiD(int begin, int end, int flag) {        if (begin >= end) {return null;}        KDNode lastNode = set.get(end-1);        int dia = flag%(KDNode.dimension);        float keyValue = lastNode.get(dia);        int LastSmall = begin-1;        for (int i=begin; i<end-1; i++) {            if (set.get(i).get(dia) < keyValue) {                exchange(++LastSmall, i);            }        }        exchange(end-1, ++LastSmall);        if (midPos == LastSmall)            return set.get(midPos);        else if (midPos < LastSmall)            return findMiD(begin, LastSmall, flag);        else            return findMiD(LastSmall+1, end, flag);    }    KDNodeSet findLeft() {        return getSubSet(0, midPos);    }//返回左子集    KDNodeSet findRight() {        return getSubSet(midPos+1, set.size());    }//返回右子集    KDNodeSet getSubSet(int begin, int end) {        KDNodeSet subSet = new KDNodeSet();        for (int i=begin; i<end; i++)            subSet.add(set.get(i));        return subSet;    }//返回子集

KD Tree的每个节点:
KDTreeNode

class KDTreeNode {//KD树的每一个节点,保存有中位数的值,在该层有的集合和父子节点引用    int flag = 0;    KDNode value;    KDTreeNode father;    KDNodeSet set;    KDTreeNode left;    KDNodeSet leftSet;    KDTreeNode right;    KDNodeSet rightSet;    KDTreeNode(KDTreeNode father, KDNodeSet set) {        this.father = father;        this.set = set;        run();    }    KDTreeNode(KDTreeNode father, KDNodeSet set, int flag) {        this.father = father;        this.set = set;        this.flag = flag;        run();    }    void run() {        if (set.set.size() == 0) return;        value = set.findMiD(flag);        leftSet = set.findLeft();        rightSet = set.findRight();        left = new KDTreeNode(this, leftSet, flag+1);        right = new KDTreeNode(this, rightSet, flag+1);    }    KDTreeNode getFather() {return father;}    KDTreeNode getLeft() {return left;}    KDTreeNode getRight() {return right;}}

KD-Tree:
KDTree

public class KDTree {    KDNodeSet set;    KDTreeNode root;    KDTree() {        set = new KDNodeSet();    }    void addNode(KDNode node) {        set.add(node);    }    void BuildTree() {        root = new KDTreeNode(null, set);//已经建好的KDTree    }    void find(KDNode node) {        KDTreeNode position = find(root,node);    }    private KDTreeNode find(KDTreeNode Tnode, KDNode node) {        int dia = Tnode.flag%KDNode.dimension;        if (Tnode == null) { return Tnode.getFather();//如果找到叶子节点还未找到,那就把他父节点设为最近邻        } else if (node.get(dia) < Tnode.value.get(dia)) {            return find(Tnode.getLeft(), node);        }else            return find(Tnode.getRight(), node);    }

测试用例:

    public static void main(String[] args) {        KDTree tree = new KDTree();        KDNode node1 = new KDNode(3);        node1.set(0,3);        node1.set(1,2);        node1.set(2,5);        tree.addNode(node1);         node1 = new KDNode(3);        node1.set(0,4);        node1.set(1,5);        node1.set(2,1);        tree.addNode(node1);         node1 = new KDNode(3);        node1.set(0,2);        node1.set(1,9);        node1.set(2,0);        tree.addNode(node1);         node1 = new KDNode(3);        node1.set(0,7);        node1.set(1,4);        node1.set(2,5);        tree.addNode(node1);         node1 = new KDNode(3);        node1.set(0,1);        node1.set(1,2);        node1.set(2,5);        tree.addNode(node1);         node1 = new KDNode(3);        node1.set(0,3);        node1.set(1,5);        node1.set(2,5);        tree.addNode(node1);         node1 = new KDNode(3);        node1.set(0,3);        node1.set(1,2);        node1.set(2,8);        tree.addNode(node1);         node1 = new KDNode(3);        node1.set(0,3);        node1.set(1,2);        node1.set(2,1);        tree.addNode(node1);         node1 = new KDNode(3);        node1.set(0,4);        node1.set(1,6);        node1.set(2,5);        tree.addNode(node1);         node1 = new KDNode(3);        node1.set(0,3);        node1.set(1,1);        node1.set(2,14);        tree.addNode(node1);        tree.BuildTree();    }
原创粉丝点击