MachineLearn
来源:互联网 发布:淘宝宝贝手机端连接 编辑:程序博客网 时间:2024/06/02 21:26
初步涉猎机器学习领域,经过昨天一天的调研,了解到机器学习分为:监督学习、半监督学习、不监督学习
其中监督学习(supervised learning):不仅把训练数据丢给计算机,而且还把分类的结果(数据具有的标签)也一并丢给计算机分析。 由于计算机在学习的过程中不仅有训练数据,而且有训练结果(标签),因此训练的效果通常不错。训练结束之后进行测试。
无监督学习(unsupervised learning):只给计算机训练数据,不给结果(标签),因此计算机无法准确地知道哪些数据具有哪些标签,只能凭借强大的计算能力分析数据的特征,从而得到一定的成果,通常是得到一些集合,集合内的数据在某些特征上相同或相似。
半监督学习(semi-supervised learning):给计算机大量训练数据与少量的分类结果(具有同一标签的集合)。
今天先来窥探下监督学习中最简单的邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法。
简单介绍下KNN算法的步骤:
1. 准备数据,对数据进行预处理
2. 选用合适的数据结构存储训练数据和测试元组
3. 设定参数,如k
4.维护一个大小为k的的按距离由大到小的优先级队列,用于存储最近邻训练元组。随机从训练元组中选取k个元组作为初始的最近邻元组,分别计算测试元组到这k个元组的距离,将训练元组标号和距离存入优先级队列
5. 遍历训练元组集,计算当前训练元组与测试元组的距离,将所得距离L 与优先级队列中的最大距离Lmax
6. 进行比较。若L>=Lmax,则舍弃该元组,遍历下一个元组。若L < Lmax,删除优先级队列中最大距离的元组,将当前训练元组存入优先级队列。
7. 遍历完毕,计算优先级队列中k 个元组的多数类,并将其作为测试元组的类别。
8. 测试元组集测试完毕后计算误差率,继续设定不同的k值重新进行训练,最后取误差率最小的k 值
这里给出一个代码实现#coding=utf-8from numpy import *import operator#创建数据集和标签def createDataSet(): group = array([[1.0,0.9],[1.0,1.0],[0.1,0.2],[0.0,0.1]]) labels = ['A','A','B','B'] return group,labels#KNN 分类器def kNNClassify(newInput,dataSet,labels,k): #step1.求newInput 和 数据集的dataSet之间的欧式距离,进行排序 #shape[0] 是取第一维度,即行的数量 numSamples = dataSet.shape[0] #tile的作用是扩展newInput数组,即把newInput在行上扩展numSamples次,列扩展1次 diff = tile(newInput,(numSamples,1)) - dataSet #对于数组中的每个数进行求平方的运算 squaredDiff = diff ** 2 #把数组的每一行进行求合 squaredDist = sum(squaredDiff,axis = 1) #对数组中的每个数进行开方 distance = squaredDist ** 0.5 #对数组进行排序,返回的是每个数的index的从小到大排列 sortedDistIndices = argsort(distance) #step2.取排序的前最小k个数,寻找对应的label标签,存放在字典中 classCount = {} for i in range(k): voteLabel = labels[sortedDistIndices[i]] #字典中的get函数是获取当前key的value,如果不存在则设置默认值0 classCount[voteLabel] = classCount.get(voteLabel,0) + 1 print('classCount = ',classCount) maxCount = 0 #step3.在字典中找最大的分类 for key,value in classCount.items(): if value>maxCount: maxCount = value maxIndex = key return maxIndexif __name__ == '__main__': dataSet,labels = createDataSet() textX = array([1.2,1.0]) k = 3 outputLabel = kNNClassify(textX,dataSet,labels,3) print("Your input is:", textX, "and classified to class: ", outputLabel)