《统计学习方法》-KNN笔记和python源码

来源:互联网 发布:装修论坛 淘宝 编辑:程序博客网 时间:2024/04/30 20:55

K近邻法

K近邻法(k-nearest neighbor,k-NN)是一种基本分类与回归方法。

k近邻法实际上利用训练数据集对特征向量空间经行划分,并作为其分类的“模型”。

1.算法:

输入:训练数据集T,其中的实例类别已定。

输出:实例x的所属的类y。

分类时,对新的实例,根据k个最近邻的训练实例的类别,通过多数表决等方式经行预测。

(1)根据给定的距离度量,在训练数据集T中找出与x最近的k个点,涵盖这k个点的x的邻域记作N(x)。

(2)在N(x)中根据分类决策规则决定x的类别y。

2.距离度量方法

(1)欧几里得距离:


(2)皮尔逊距离:



3.k值的选择

如果选择较小的k值,就相当于用较小的领域中的训练实例经行预测,“学习”的近似误差会减小,但缺点估计误差会增大,预测实例对近邻的实例点会非常敏感。

反之亦然。


k-NN的实现:kd树

最简单的实现方法是采用线性扫描,计算耗时巨大。

采用kd树,kd树是二叉树,表示对k维空间的一个划分。构造kd树不断地用垂直于坐标轴的超平面将k维空间划分,构造一系列的k维超矩形区域。

1.构造:

输入:k维数据集T={x1,x2,x3,...xn}

输出:kd树

(1)开始:构造根节点,根节点对应于包含T的k维空间的超矩形区域。

选择xl为坐标轴,以T中所有实例的xl坐标的中位数为切分点,将根节点对应的超矩形区域切分为两个子域。切分由通过切分点并与坐标轴xl垂直的超平面实现。

由根节点生成深度为1的左右子结点:左结点对应于坐标xl小于切分点的子区域,右子结点对应于坐标xl大于切分点的区域。

将落在切分超平面的实例点保存在根结点

(2)重复:对深度为j的结点,选择xl为切分的周坐标,l=j(modk)+1,以该结点的区域中的所有实例的xl坐标的中位数为切分点,将该结点对应的超矩形区域切分为两个子区域,切分由通过切分点并且与坐标轴xl垂直的超平面实现。

由根节点生成深度为1的左右子结点:左结点对应于坐标xl小于切分点的子区域,右子结点对应于坐标xl大于切分点的区域。

将落在切分超平面的实例点保存在该结点。

(3)直到两个子域没有实例存在时停止。从而形成kd树的区域划分。

2.kd树搜索

输入:已构造的kd树:目标点x;(辅助结构,最大堆)
输出:x的k近邻

(1)从根节点出发,递归地向下访问kd树,若目标x当前维的坐标小于切分点的坐标,则移动到左子结点,否则移动到右子结点,知道结点为叶节点为止。
(2)递归的向上回退,在每个节点进行以下操作:
检查该子结点的兄弟结点区域是否有比堆顶元素更近的点或堆容量未满。具体的,检查另一子结点对应的区域是否与以目标点为求心,以目标点与堆顶元素距离为半径的球体相交。
如果相交或容量未满,以另一子结点为根节点执行(1)。
(4)当回退到根节点时,搜索结束,堆中实例即为所求实例。
可以用堆栈保存搜索时的路径,回退时逐一访问。

二叉树
# coding=utf-8# author=altmanclass BinaryTree(object):    '''    创建结点    '''    class __node(object):        def __init__(self, value, k,left=None, right=None):            self.value = value            self.left = left            self.right = right            self.s = k        def getValue(self):            return self.value        def setValue(self, value):            self.value = value        def getLeft(self):            return self.left        def getRight(self):            return self.right        def setLeft(self, newLeft):            self.left = newLeft        def setRight(self, newRight):            self.right = newRight        def getS(self):            return self.s        def __iter__(self):            if self.left != None:                for elem in self.left:                    yield elem            yield self.value            if self.right != None:                for elem in self.right:                    yield elem    '''    创建根    '''    def __init__(self,length):        self.length = length        self.root = None    def insert(self, value):        k = 0        length = self.length        def __insert(k,root, value):            index = k%length            k +=1            if root == None:                return BinaryTree.__node(value,index)            if value[index] < root.getValue()[index]:                root.setLeft(__insert(k,root.getLeft(), value))            else:                root.setRight(__insert(k,root.getRight(), value))            return root        self.root = __insert(k,self.root,value)    def __iter__(self):        if self.root != None:            return self.root.__iter__()        else:            return [].__iter__()def main():    passif __name__ == '__main__':    main()
构建和查询
import numpy as npimport binarayTree as btimport copy as cpimport stack as stdef sim_distance(item1,item2):    diff = (item1-item2)**2    sum_diff = np.sum(diff)    sqrt = sum_diff**0.5    return sqrt#递归插入def insertRecursively(k,tree,testArray,length,start,stop):    if start>=stop:        return    middleIndex = (start+stop)//2    count = k%length    tmp = testArray[start:stop,count]    #排序    sortedId = tmp.argsort()    nextArray = cp.deepcopy(testArray)    for i,x in enumerate(sortedId):        nextArray[i+start] = testArray[x+start]    value = (nextArray[middleIndex])    tree.insert(value)    k +=1    insertRecursively(k,tree,nextArray,length,start,middleIndex)    insertRecursively(k,tree,nextArray,length,middleIndex+1,stop)#创建kd树def makeTree(tree,testArray):    k = 0    length = testArray.shape[1]    insertRecursively(k,tree,testArray,length,0,len(testArray))#寻找当前最近点def findNode(tree,goal,length):    root = tree.root    k = 0    value = root.getValue()    #最小距离    max_distance = 0.0    min_distance = 0.0    #通过栈保存搜索路径    path = st.Stack()    while True:        index = k%length        value = root.getValue()        path.push(root)        k +=1        if goal[index]<root.getValue()[index]:            if root.getLeft()!=None:                root = root.getLeft()            else:                max_distance = sim_distance(goal,value)                nearest = value                break        else:            if root.getRight()!=None:                root = root.getRight()            else:                max_distance = sim_distance(goal,value)                nearest = value                break    min_distance = cp.deepcopy(max_distance)    path.pop()    while not path.isEmpty():        print(nearest)        back_point = path.pop()        index = back_point.getS()        value = back_point.getValue()        tmp_dis = sim_distance(goal[index],value[index])        #判断进入子结点        if tmp_dis <= max_distance:            kd_point = None            if goal[index] < value[index]:                kd_point = back_point.getRight()                if kd_point != None:                    path.push(kd_point)            else:                kd_point = back_point.getLeft()                if kd_point != None:                    path.push(kd_point)        #判断是否与当前结点,距离更近        tmp_dis = sim_distance(goal,value)        if min_distance >= tmp_dis:            min_distance = tmp_dis            nearest = value    print(nearest)def main():    testNum = [2,3,5,4,9,6,4,7,8,1,7,2]    goal = np.array([7,2])    testArray = np.reshape(testNum,(6,2))    tree = bt.BinaryTree(2)    makeTree(tree,testArray)    findNode(tree,goal,len(goal))if __name__ == '__main__':    main()



0 0
原创粉丝点击