KNN算法(基于KD-Tree)

来源:互联网 发布:hive group 数据倾斜 编辑:程序博客网 时间:2024/06/14 18:54

输入数据可以自己调,设置要查的point和k值后返回最邻近的k个point

import numpy as npclass KD_Node:    def __init__(self, point=None, split=None, leftNode=None, rightNode=None):        """        :param point:数据点        :param split:分割维度        :param LL:左儿子        :param RR:右儿子        """        self.point = point        self.split = split        self.leftNode = leftNode        self.rightNode = rightNodedef setData():    group = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]    return groupdef createKDTree(group):    """    构造KD树    :param group: 数据集    :return:    """    length = len(group)    if length == 0:        return    dimension = len(group[0])    # 维度    split = 0    # 最大方差    max_var = 0    # 获得方差最大的维度    for i in range(dimension):        temp = []        for data in group:            temp.append(data[i])        var = cal_variance(temp)        if var > max_var:            max_var = var            split = i    group.sort(key=lambda x: x[split])    point = group[length // 2]    root = KD_Node(point, split)    root.leftNode = createKDTree(group[0: (length // 2)])    # 注意右儿子要加一    root.rightNode = createKDTree(group[(length // 2 + 1): length])    return rootdef cal_variance(dataList):    """    计算数据集的方差,D(x) = E(X^2) - [E(X)]^2    :param dataList:某一维度下的所有数据    :return:    """    length = len(dataList)    array = np.array(dataList)    sum1 = array.sum()    e1 = sum1 / length    e2 = (array * array).sum() / length    return e2 - e1 * e1def findKN(point, root, k):    """    kd树搜索函数,搜索最近的点    :param point:要查找的点    :param root:kd树的树根    :param k:返回邻近点的个数    :return:    """    if k < 1:        return    ret = [root.point]    max_dist = cal_dist(point, ret[-1])    nodeList = []    temp = root    while temp:        nodeList.append(temp)        dist = cal_dist(point, temp.point)        if len(ret) < k:            ret.append(temp.point)            sortP(ret, point)            if dist > max_dist:                max_dist = dist        elif dist < max_dist and len(ret) >= k:            # 去除最大            ret.pop()            ret.append(temp.point)            sortP(ret, point)            max_dist = cal_dist(point, ret[-1])        ss = temp.split        if point[ss] < temp.point[ss]:            temp = temp.leftNode        else:            temp = temp.rightNode    while nodeList:        bac_point = nodeList.pop()        ss = bac_point.split        # 判断是否要进入父节点的另一个子空间进行搜索        # 并不是判断距离就要进去,只要空间中有那个圈就要进去        if abs(point[ss] - bac_point.point[ss]) < max_dist:            if point[ss] >= bac_point.point[ss]:                temp = bac_point.leftNode            else:                temp = bac_point.rightNode            if temp:                nodeList.append(temp)                dist = cal_dist(temp.point, point)                if len(ret) < k:                    ret.append(temp.point)                    sortP(ret, point)                    if dist > max_dist:                        max_dist = dist                elif dist < max_dist and len(ret) >= k:                    ret.pop()                    ret.append(temp.point)                    sortP(ret, point)                    max_dist = cal_dist(point, ret[-1])    return retdef sortP(group, point):    for i in range(len(group) - 1):        for j in range(i + 1, len(group)):            if cal_dist(group[j], point) < cal_dist(group[i], point):                temp = group[j]                group[j] = group[i]                group[i] = temp    return groupdef cal_dist(point1, point2):    """    计算两个点的欧氏距离    :param point1:    :param point2:    :return:    """    ret = 0    for i in range(len(point1)):        ret += (point1[i] - point2[i]) ** 2    return ret ** (1 / 2)
原创粉丝点击