KD树算法

来源:互联网 发布:java适配器模式例子 编辑:程序博客网 时间:2024/05/20 09:23

与传统的KNN算法比较我感觉慢很多,我的姿势是不是不对

kd树

import numpy as npfrom numpy import *class KDNode():    """    KDNode    point:该节点的样本点    split:用于判断分割的维度(属性)    left:左节点    right:右节点    """    def __init__(self, point=None, split=None, left=None, right=None):        self.point = point        self.split = split        self.left = left        self.right = rightclass KDTree():    """    KD树    KDNode:kd-tree的节点    dimensions:数据的纬度    right:节点的右子节点    left:节点的左子节点    curr_axis:当前需要切分的纬度    next_axis:下一次需要切分的纬度    """    def __init__(self,data=None):        """        采用递归的方式创建树        """        def createNode(split=None, data_set=None):            """            递归创建节点            input:(1)split:分割的维度(2)data_set:需要分割的样本点集合            output:KDnode            """            if len(data_set) == 0:                return None # 数据集为空,作为递归的停止条件            # 按照split对data_set进行排序,找到split维度中的中位数            data_set = list(data_set)            data_set.sort(key=lambda x: x[split]) # 按照split维的数据大小排序            data_set = np.array(data_set)            median = len(data_set) // 2 # 不用python自带的median函数,我返回的是median的位置所在的索引            # data_set[median]就是这个节点的样本点            # split是这个节点的分割维度            # data_set[:median]样本节点左半部分 data_set[median-1:]            print("------------",median)            print('data_set[:median]',data_set[:median])            print('data_set[median+1:]',data_set[median+1:])            return KDNode(data_set[median],split,createNode(maxVar(data_set[:median]),data_set[:median]),createNode(maxVar(data_set[median+1:]),data_set[median+1:]))        def maxVar(data_set=None):            """            计算样本集的最大方差维度            input:data_set样本集            output:split:最大方差的维度,作为createNode的输入值            """            if len(data_set)==0:                return 0            print("======",len(data_set))            data_mean = np.mean(data_set,0) # 按照列求平均值            print(data_mean)            mean_differ = data_set - data_mean # 求均值差            data_var = np.sum(mean_differ ** 2, axis=0)/len(data_set) # 求方差,差反映数据的分散特征,方差的数值越大,那么数据的分散程度越大            re = np.where(data_var == np.max(data_var)) # 寻找方差最大的位置            print("re:",re)            return re[0][0] # 方差最大的维数        # print(data)        self.root = createNode(maxVar(data),data)def computeDist(pt1,pt2):    """    计算两个点之间的距离    点的类型是N维的    """    sum = 0.0    for i in range(len(pt1)):        sum = sum + (pt1[i] - pt2[i]) ** 2    return np.math.sqrt(sum)def preOrder(root):    """    前序遍历KD树    """    print(root.point)    if root.left:        preOrder(root.left)    if root.right:        preOrder(root.right)def updateNN(min_dist_array=None, tmp_dist=0.0, NN=None, tmp_point=None, k=1):    """    更新近邻点和对应的最小距离的集合    min_dist_array为最小距离的集合    NN为邻近点的集合    tmp_dist和tmp_point分别是需要更新到min_dist_array,NN里的近邻点和距离    """    # 如果距离更小就更新min_dist_array    if tmp_dist <= np.min(min_dist_array):        # 删除最大距离和对应的节点        for i in range(k-1,0,-1):            min_dist_array[i] = min_dist_array[i-1]            NN[i] = NN[i-1]        min_dist_array[0] = tmp_dist        NN[0] = tmp_point        return NN,min_dist_array    for i in range(k) :        if (min_dist_array[i] <= tmp_dist) and (min_dist_array[i+1] >= tmp_dist) :            #tmp_dist在min_dist_array的第i位和第i+1位之间,则插入到i和i+1之间,并把最后一位给剔除掉            for j in range(k-1,i,-1) : #range反向取值                min_dist_array[j] = min_dist_array[j-1]                NN[j] = NN[j-1]            min_dist_array[i+1] = tmp_dist            NN[i+1] = tmp_point            break    return NN,min_dist_arraydef searchKDTree(KDTree=None, target_point=None, k=1):    """    搜索KD树    input:KDTree:kd树;target_point:目标点;k:距离目标点最近的k个点的k值    output:k_arrayList,距离目标点最近的k个点的集合数组    """    if k==0 : return None    tempNode = KDTree.root # 从更节点出发    NN = [tempNode.point] * k #定义最邻近点集合,k个元素,按照距离远近,由近到远。初始化为k个根节点    min_dist_array = [float("inf")] * k#定义近邻点与目标点距离的集合.初始化为无穷大    nodeList = []    def buildSearchPath(tempNode=None, nodeList=None,min_dist_array=None,NN=None,target_point=None):        """        此方法是用来建立以tempNode为根节点,以下所有节点的查找路径,并将它们存放到nodeList中        nodeList为一系列节点的顺序组合,按此先后顺序搜索最邻近点        tempNode为"根节点",即以它为根节点,查找它以下所有的节点(空间)        """        while tempNode:            nodeList.append(tempNode)            split = tempNode.split            point = tempNode.point            tmp_dist = computeDist(point,target_point)            if tmp_dist < np.max(min_dist_array):                NN,min_dist_array = updateNN(min_dist_array,tmp_dist,NN,point,k)# 更新最小距离和最近邻近点            if target_point[split] <= point[split]:#如果目标点当前维的值小于等于切分点的当前维坐标值,移动到左节点                tempNode = tempNode.left            else:                tempNode.right        return NN,min_dist_array    # 建立查找路径    NN,min_dist_array = buildSearchPath(tempNode,nodeList,min_dist_array, NN, target_point)    # 回溯查找    while nodeList:        back_node = nodeList.pop()        split = back_node.split        point = back_node.point        #判断是否需要进入父节点搜素        #如果当前纬度,目标点减实例点大于最小距离,就没必要进入父节点搜素了        #因为目标点到切割超平面的距离很大,那邻近点肯定不在那个切割的空间里,即没必要进入那个空间搜素了        if not abs(target_point[split] - point[split]) >= np.max(min_dist_array):            if target_point[split] <= point[split]: # 在右侧                tempNode = back_node.right            else:                tempNode = back_node.left # 在左侧            if tempNode:                NN,min_dist_array = buildSearchPath(tempNode,nodeList,min_dist_array, NN, target_point)    return NN,min_dist_arraydef classify0(inX, dataSet, labels, k):    '''    k近邻算法的分类器    input:    inX:目标点    dataSet:训练点集合    labels:训练点对应的标签    k:k值    这个方法的目的:已知训练点dataSet和对应的标签labels,确定目标点inX对应的labels    '''     kd = KDTree(dataSet)#构建dataSet的kd树    NN,min_dist_array = searchKDTree(kd, inX, k)#搜索kd树,返回最近的k个点的集合NN,和对应的距离min_dist_array    dataSet = dataSet.tolist()    voteIlabels = []    #多数投票法则确定inX的标签,为防止边界处分类不准的情况,以距离的倒数为权重,即距离越近,权重越大,越该认为inX是属于该类    for i in range(k) :        #找到每个近邻点对应的标签        nni = list(NN[i])        voteIlabels.append(labels[dataSet.index(nni)])#     #开始记数,加权重的方法#     uniques = np.unique(voteIlabels)#     counts = [0.0] * len(uniques)#     for i in range(len(voteIlabels)) :#         for j in range(len(uniques)) :#             if voteIlabels[i] == uniques[j] :#                 counts[j] = counts[j] + uniques[j] / min_dist_array[i] #权重为距离的倒数#                 break    #开始记数,不加权重的方法    uniques, counts = np.unique(voteIlabels, return_counts=True)    return uniques[np.argmax(counts)]# 处理文件数据def file2matrix(filename):    fr = open(filename) # 打开文件    arrayOlines = fr.readlines() #读取文件    numbersOfLines = len(arrayOlines) # 文件有多少行    returnMat = zeros((numbersOfLines,3)) # 创建0矩阵    classLabelVector = [] # 标签集合    index = 0    for line in arrayOlines:        line = line.strip()#移除字符串头尾的空格        listFromLine = line.split('\t')        returnMat[index,:] = listFromLine[0:3] # 取前三个数据然后给切片赋值        classLabelVector.append(int(listFromLine[-1])) # 最后一个是标签        index += 1    return returnMat,classLabelVector# 归一化特征值def autoNorm(dataSet):    minVals = dataSet.min(0)    maxVals = dataSet.max(0)    ranges = maxVals - minVals    m = dataSet.shape[0]    normDataSet = dataSet - tile(minVals,(m,1))    normDataSet = normDataSet/tile(ranges,(m,1))    return normDataSet,ranges,minValsdef datingClassTest():    hoRatio = 0.1 # 测试样本的比例    datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') # 载入数据    normMat,ranges,minVals = autoNorm(datingDataMat) # 归一化处理    m = normMat.shape[0]    numTestVecs = int(m*hoRatio) # 获取测试样本    errorCount = 0.0    print(type(datingDataMat))    print(type(datingLabels))    for i in range(numTestVecs):        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)        print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]))        if (classifierResult != datingLabels[i]): errorCount += 1.0    print("the total error rate is: %f" % (errorCount/float(numTestVecs)))    print(errorCount)if __name__ == "__main__":    # test()    # test2()    datingClassTest()
原创粉丝点击