《统计学习方法》——kd树python实现

来源:互联网 发布:苏州淘宝店铺装修 编辑:程序博客网 时间:2024/06/17 13:05

kd树原理

之前看KNN时,确实发现这个计算量很大。因此有人提出了kd树算法,其作用是,当你需要求得与预测点最近的K个点时,这个算法可以达到O(logN)的时间复杂度(相当于搜索一颗二叉树的时间耗损).原理有一篇博文讲的十分精彩[这里写链接内容](http://blog.csdn.net/u010551621/article/details/44813299)

kd树python实现

这里给出的是kd树的建树、对预测点求得最近邻的k个点的python代码。本博文的代码是在(http://blog.csdn.net/u010551621/article/details/44813299)的基础上进行的修改,感谢其清晰的原理和代码表达。

kd树节点结构
一个树节点包括:

  1. 节点信息
  2. 被分割的维度
  3. 左孩子
  4. 右孩子

python代码如下

class KD_node(object):    #定义的kd树节点    def __init__(self, point = None, split = None, LL = None, RR = None):        #节点值        self.point = point;        #节点分割维度        self.split = split;        #节点左孩子        self.left = LL;        #节点右孩子        self.right = RR;

kd树建树
首先给出伪代码:

  1. 历遍所有维度,找到方差最大的维度
  2. 以这个维度上的点的数值进行排序,找到其中间点
  3. 以这个点为划分,递归建立左子树
  4. 以这个点为划分,递归建立右子树
  5. 当数据集内没有点时,退出函数

这里给出两个重要概念:

  1. 以方差最大维度为划分的维度:方差越大,代表着这个维度上的数据波动越大,代表着以这个维度划分数据,可以最广泛的把数据集分开
  2. 取中位点为划分点,有助有构造一个平衡二叉树,不至于出现二叉树有时候会出现的极端,即是一个父节点只有一个孩子节点,使树的深度大大加深,增加搜索的复杂度。

这里给出代码实现

def createKDTree(root, data_list):    length = len(data_list);    if length == 0:        return ;    dimension = len(data_list[0]);    max_var = 0;    split = 0;    for i in range(dimension):        ll = [];        for t in data_list:            ll.append(t[i]);        var = computerVariance(ll);        if var > max_var:            max_var = var;            split = i;    #以最大方差的点为维度,进行划分    data_list = sorted(data_list, key = lambda x : x[split]);    point = data_list[int(length / 2)];    root = KD_node(point,split);    #递归建立左子树    root.left = createKDTree(root.left, data_list[0:int(length / 2)]);    #递归建立右子树    root.right = createKDTree(root.right, data_list[int(length / 2) + 1 : length]);    return root;#计算方差def computerVariance(arraylist):    arraylist = array(arraylist);    for i in range(len(arraylist)):        arraylist[i] = float(arraylist[i]);    length = len(arraylist);    sum1 = arraylist.sum();    array2 = arraylist * arraylist;    sum2 = array2.sum();    mean = sum1 / length;    variance = sum2 / length - mean ** 2;    return variance;

查找K个最小值

具体思想如下:给定一个待预测节点,则历遍到最靠近该节点的kd树中的叶子节点。那如何找到最靠近该树的叶子节点呢:方法如下
  1. 若该节点是叶子节点,则返回
  2. 若不是叶子节点,则比较待预测节点与该节点被划分的维度上的值,若小于,则去其左子树
  3. 若不是叶子节点,则比较待预测节点与该节点被划分的维度上的值,若大于,则去其右子树

大致的思想和查找排序二叉树的节点类似。

接下来我们就要去找最小的K各节点了,具体思想如下:

我们用一个K大小的优先队列来存储K个节点的值

  1. 若队列的长度不满K个,则把当前节点入队,并且去该父节点的另外一个子节点比较。
  2. 若已经满了K个,则取距离最长的节点,计算其距离,设为K。在计算预测结点到该节点的父节点的所划分的维度的距离,设为d。如K>d,则去改父节点的另一个子节点查找。否则,继续回退到该节点的父节点的父节点

具体python代码如下:

#用于计算维度距离def computerDistance(pt1, pt2):    sum = 0.0;    for i in range(len(pt1)):        sum = sum + (pt1[i] - pt2[i]) ** 2;    return sum ** 0.5;#query中保存着最近k节点def findNN(root, query,k):    min_dist = computerDistance(query,root.point);    node_K = [];    nodeList = [];    temp_root = root;    #为了方便,在找到叶子节点同时,把所走过的父节点的距离都保存下来,下一次回溯访问就只需要访问子节点,不需要再访问一遍父节点。    while temp_root:        nodeList.append(temp_root);        dd = computerDistance(query,temp_root.point);        if len(node_K) < k:            node_K.append(dd);        else :            max_dist = max(node_K);            if dd < max_dist:                index = node_K.index(max_dist);                del(node_K[index]);                node_K.append(dd);        ss = temp_root.split;        #找到最靠近的叶子节点        if query[ss] <= temp_root.point[ss]:              temp_root = temp_root.left;         else:            temp_root = temp_root.right;    print('node_k :',node_K);    #回溯访问父节点    while nodeList:        back_point = nodeList.pop();        ss = back_point.split;        print('父亲节点 : ',back_point.point,'维度 :',back_point.split);        max_dist = max(node_K);        print(max_dist);        #若满足进入该父节点的另外一个子节点的条件        if  len(node_K) < k or abs(query[ss] - back_point.point[ss]) < max_dist :            #进入另外一个子节点            if query[ss] <= back_point.point[ss]:                temp_root = back_point.right;            else:                temp_root = back_point.left;            if temp_root:                nodeList.append(temp_root);                curDist = computerDistance(temp_root.point,query);                print('curDist :',curDist);                if max_dist > curDist and len(node_K) == k:                    index = node_K.index(max_dist);                    del(node_K[index]);                    node_K.append(curDist);                elif len(node_K) < k:                    node_K.append(curDist);    return node_K;
0 0
原创粉丝点击