《统计学习方法》-KNN笔记和python源码
来源:互联网 发布:装修论坛 淘宝 编辑:程序博客网 时间:2024/04/30 20:55
K近邻法
K近邻法(k-nearest neighbor,k-NN)是一种基本分类与回归方法。
k近邻法实际上利用训练数据集对特征向量空间经行划分,并作为其分类的“模型”。
1.算法:
输入:训练数据集T,其中的实例类别已定。
输出:实例x的所属的类y。
分类时,对新的实例,根据k个最近邻的训练实例的类别,通过多数表决等方式经行预测。
(1)根据给定的距离度量,在训练数据集T中找出与x最近的k个点,涵盖这k个点的x的邻域记作N(x)。
(2)在N(x)中根据分类决策规则决定x的类别y。
2.距离度量方法
(1)欧几里得距离:
(2)皮尔逊距离:
3.k值的选择
如果选择较小的k值,就相当于用较小的领域中的训练实例经行预测,“学习”的近似误差会减小,但缺点估计误差会增大,预测实例对近邻的实例点会非常敏感。
反之亦然。
k-NN的实现:kd树
最简单的实现方法是采用线性扫描,计算耗时巨大。
采用kd树,kd树是二叉树,表示对k维空间的一个划分。构造kd树不断地用垂直于坐标轴的超平面将k维空间划分,构造一系列的k维超矩形区域。
1.构造:
输入:k维数据集T={x1,x2,x3,...xn}
输出:kd树
(1)开始:构造根节点,根节点对应于包含T的k维空间的超矩形区域。
选择xl为坐标轴,以T中所有实例的xl坐标的中位数为切分点,将根节点对应的超矩形区域切分为两个子域。切分由通过切分点并与坐标轴xl垂直的超平面实现。
由根节点生成深度为1的左右子结点:左结点对应于坐标xl小于切分点的子区域,右子结点对应于坐标xl大于切分点的区域。
将落在切分超平面的实例点保存在根结点。
(2)重复:对深度为j的结点,选择xl为切分的周坐标,l=j(modk)+1,以该结点的区域中的所有实例的xl坐标的中位数为切分点,将该结点对应的超矩形区域切分为两个子区域,切分由通过切分点并且与坐标轴xl垂直的超平面实现。
由根节点生成深度为1的左右子结点:左结点对应于坐标xl小于切分点的子区域,右子结点对应于坐标xl大于切分点的区域。
将落在切分超平面的实例点保存在该结点。
(3)直到两个子域没有实例存在时停止。从而形成kd树的区域划分。
2.kd树搜索
# coding=utf-8# author=altmanclass BinaryTree(object): ''' 创建结点 ''' class __node(object): def __init__(self, value, k,left=None, right=None): self.value = value self.left = left self.right = right self.s = k def getValue(self): return self.value def setValue(self, value): self.value = value def getLeft(self): return self.left def getRight(self): return self.right def setLeft(self, newLeft): self.left = newLeft def setRight(self, newRight): self.right = newRight def getS(self): return self.s def __iter__(self): if self.left != None: for elem in self.left: yield elem yield self.value if self.right != None: for elem in self.right: yield elem ''' 创建根 ''' def __init__(self,length): self.length = length self.root = None def insert(self, value): k = 0 length = self.length def __insert(k,root, value): index = k%length k +=1 if root == None: return BinaryTree.__node(value,index) if value[index] < root.getValue()[index]: root.setLeft(__insert(k,root.getLeft(), value)) else: root.setRight(__insert(k,root.getRight(), value)) return root self.root = __insert(k,self.root,value) def __iter__(self): if self.root != None: return self.root.__iter__() else: return [].__iter__()def main(): passif __name__ == '__main__': main()构建和查询
import numpy as npimport binarayTree as btimport copy as cpimport stack as stdef sim_distance(item1,item2): diff = (item1-item2)**2 sum_diff = np.sum(diff) sqrt = sum_diff**0.5 return sqrt#递归插入def insertRecursively(k,tree,testArray,length,start,stop): if start>=stop: return middleIndex = (start+stop)//2 count = k%length tmp = testArray[start:stop,count] #排序 sortedId = tmp.argsort() nextArray = cp.deepcopy(testArray) for i,x in enumerate(sortedId): nextArray[i+start] = testArray[x+start] value = (nextArray[middleIndex]) tree.insert(value) k +=1 insertRecursively(k,tree,nextArray,length,start,middleIndex) insertRecursively(k,tree,nextArray,length,middleIndex+1,stop)#创建kd树def makeTree(tree,testArray): k = 0 length = testArray.shape[1] insertRecursively(k,tree,testArray,length,0,len(testArray))#寻找当前最近点def findNode(tree,goal,length): root = tree.root k = 0 value = root.getValue() #最小距离 max_distance = 0.0 min_distance = 0.0 #通过栈保存搜索路径 path = st.Stack() while True: index = k%length value = root.getValue() path.push(root) k +=1 if goal[index]<root.getValue()[index]: if root.getLeft()!=None: root = root.getLeft() else: max_distance = sim_distance(goal,value) nearest = value break else: if root.getRight()!=None: root = root.getRight() else: max_distance = sim_distance(goal,value) nearest = value break min_distance = cp.deepcopy(max_distance) path.pop() while not path.isEmpty(): print(nearest) back_point = path.pop() index = back_point.getS() value = back_point.getValue() tmp_dis = sim_distance(goal[index],value[index]) #判断进入子结点 if tmp_dis <= max_distance: kd_point = None if goal[index] < value[index]: kd_point = back_point.getRight() if kd_point != None: path.push(kd_point) else: kd_point = back_point.getLeft() if kd_point != None: path.push(kd_point) #判断是否与当前结点,距离更近 tmp_dis = sim_distance(goal,value) if min_distance >= tmp_dis: min_distance = tmp_dis nearest = value print(nearest)def main(): testNum = [2,3,5,4,9,6,4,7,8,1,7,2] goal = np.array([7,2]) testArray = np.reshape(testNum,(6,2)) tree = bt.BinaryTree(2) makeTree(tree,testArray) findNode(tree,goal,len(goal))if __name__ == '__main__': main()
- 《统计学习方法》-KNN笔记和python源码
- 《统计学习方法》-感知机笔记和python源码
- 《统计学习方法》-朴素贝叶斯法笔记和python源码
- 《统计学习方法》-逻辑回归笔记和python源码
- 统计学习方法笔记五---KNN(K近邻)
- 《统计学习方法》-支持向量机SVM学习笔记和python源码
- 统计学习方法读书笔记-knn
- 统计学习方法:KNN
- 统计学习方法---KNN(K近邻)
- 李航《统计学习方法》第三章——用Python实现KNN算法(MNIST数据集)
- 统计学习方法(3)——KNN,KD树及其Python实现
- 统计学习方法笔记--第一章统计学习方法概论
- 统计学习方法笔记1--统计学习方法概论
- 《统计学习方法》阅读笔记
- [笔记]统计学习方法
- 《统计学习方法》学习笔记
- 统计学习方法~笔记1
- 统计学习方法笔记
- 秒杀系统架构分析与实战
- JayRock:JSON and JSON_RPC for .Net
- ViewCompat的作用
- 使用Jenkins搭建持续集成服务
- BZOJ 3576: [Hnoi2014]江南乐
- 《统计学习方法》-KNN笔记和python源码
- Eclipse打包时出现export aborted because fatal lint errors
- Hibernate配置文件与映射文件详解
- windows server 2008防火墙阻止局域网不能访问解决方案
- NTFS和FAT
- cookie 和session 的区别详解
- 有block回调的UIButton和Alert
- triangle of success
- 虚拟机安装xp系统