《机器学习实战》--KNN

来源:互联网 发布:菜鸟网络总部 编辑:程序博客网 时间:2024/05/18 01:48

一、瞎扯

   先拉一下仇恨,这篇文章是在喝着走亲戚时带来的饮料,单曲循环着成龙版本的《拯救》的情况下完成的,哈哈,过年难免有些活的太潇洒,于是还是需要写些代码,看些书来收收心,另外新的一年开始了,也该对“懒”开刀了,准备养成写博客这一及其装逼的技能,祝各位同行新年快乐!(开始写的的时候还是大年初一,发布的时候过期了可别怪我)KNN是机器学习中最简单,最基础的算法之一,算法实现起来没什么难度,但是它的使用范围依然十分广泛,比如书中提到的电影的分类,婚姻网站的配偶分类,手写识别系统,和轨迹预处理中选择候选轨迹等。

二、KNN介绍

   2.1监督学习和非监督学习    监督学习:利用一组已知类别的样本调整分类器的参数,使其达到所要求性能的过程,也称为监督训练或有教师学习    举个例子,教婴儿学习的时候,过来一只鸡,就让他叫鸡,过来一只鸭子的时候就让他叫鸭子,教会了一会,随便找一些鸡和鸭子他就能分辨了。这就是监督学习    无监督学习:其中很重要的一类叫聚类    举个例子,过来两种动物(假设还是鸡和鸭子,但是婴儿不知道),然后他根据动物的相像程度把动物分成两群   2.2欧几里得距离   这个在论文中经常能见到,它是一个通常采用的距离定义,指在m维空间中两个点之间的真实距离,或者向量的自然长度(即该点到原点的距离)。在二维和三维空间中的欧氏距离就是两点之间的实际距离。   n维下的两个点的欧氏距离:    两个点 A = (a[1],a[2],…,a[n]) 和 B = (b[1],b[2],…,b[n]) 之间的距离 ρ(A,B)      定义为下面的公式:ρ(A,B) =√ [ ∑( a[i] - b[i] )^2 ] (i = 1,2,…,n)   2维下就是最常见的距离公式了 sqrt( (x1-x2)^2+(y1-y2)^2 )了   2.3KNN算法介绍   KNN属于监督学习,他需要一个训练集来训练,然后才能对后面给出的东西(测试集)进行分类。      通俗的过程:   1.通过训练集训练它,如上图,红色和蓝色就是训练集,他们的类别是已知的,于是现在来了一个未知的东西   (标为绿色),希望通过knn来给它分类,看应该是属于蓝色的还是红色的。   2.计算所有红色点和蓝色点到绿点的距离   3.排序找到最近的k个,上图中实线表示k取3(即取距离最近的三个点),同理虚线表示k取5   4.这k个里面哪个类别的东西多,就判断这个绿色点是属于哪一类的。如果上图看实线,红的比蓝的多那么   绿点画为红的,如果是虚线,则标为蓝的

三、代码

代码我自己用了C++风格的python写了一遍,也就是说不用矩阵计算,那些矩阵被我当作多维数组来用了,代码写的很粗糙,未优化,望见谅。

另外还要再说明一个知识点–归一化 
这个很常见,数学课上也讲过就是把数值缩小到0-1之间,这里是把训练集中的样本数据归一化,使用的公式是newValue=(oldValue-min)/(max-min)

  1  # -*- coding: UTF-8 -*-     2 from numpy import *  3 from math import *  4   5   6 #定义类  7 class Student(object):  8     9     def __init__(self, distance, label): 10         self.distance = distance 11         self.label = label 12  13 #KNN算法 14 def classify0(inX, dataSet, labels, k): 15     size=shape(dataSet) 16     line=size[0] 17     column=size[1] 18     # print (inX) 19     if(len(inX)!=column): 20         print("unequal!!") 21         return 22  23 #计算当前项inX与其余训练集的欧氏距离 24     sum=0.0000000000 25     disList=[] 26     for i in range(line): 27         for j in range(column): 28             sum+=(inX[j]-dataSet[i,j])**2 29         tmp=Student(sqrt(sum),labels[i]) 30         disList.append(tmp) 31         sum=0.0000000000 32 #排序:欧氏距离从小到大 33     disList.sort(lambda x,y:cmp(x.distance,y.distance))  34  35 #取k项判断分类 36     dict={} 37     index=0 38     for item in disList: 39         if(index==k): 40             break 41         index+=1 42         if(dict.has_key(item.label)): 43             dict[item.label]+=1 44         else: 45             dict[item.label]=1 46  47     dict=sorted(dict.iteritems(),key=lambda d:d[1],reverse=True) 48     #返回最可能的值 49     return dict[0][0] 50  51 # def classify0(inX, dataSet, labels, k): 52 #     dataSetSize = dataSet.shape[0] 53 #     diffMat = tile(inX, (dataSetSize,1)) - dataSet 54 #     sqDiffMat = diffMat**2 55 #     sqDistances = sqDiffMat.sum(axis=1) 56 #     distances = sqDistances**0.5 57 #     sortedDistIndicies = distances.argsort()      58 #     classCount={}           59 #     for i in range(k): 60 #         voteIlabel = labels[sortedDistIndicies[i]] 61 #         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 62 #     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) 63 #     print sortedClassCount[0][0] 64 #     return sortedClassCount[0][0] 65  66  67 #读取数据 68 def file2matrix (filename): 69     fr=open(filename) 70     lines=fr.readlines() 71     totalCnt=len(lines) 72     totalCol=len(lines[0].strip().split('\t')) 73  74     resultMatrix=zeros((totalCnt,totalCol-1)) 75     labelMatrix=[] 76     index=0 77     for line in lines: 78         tmp=line.strip().split('\t') 79         resultMatrix[index,:]=tmp[0:totalCol-1] 80         labelMatrix.append((tmp[-1])) 81         index+=1 82     return resultMatrix,labelMatrix 83  84  85  86 # def file2matrix(filename): 87 #     fr=open(filename) 88 #     arrayOlines=fr.readlines() 89 #     numberOfLines=len(arrayOlines) 90 #     returnMat=zeros((numberOfLines,3)) 91 #     classLabelVector=[] 92 #     index=0 93 #     for line in arrayOlines: 94 #         line=line.strip() 95 #         listFromLine=line.split('\t') 96 #         returnMat[index,:]=listFromLine[0:3]; 97 #         classLabelVector.append((listFromLine[-1])) 98 #         index+=1 99 #     return returnMat,classLabelVector100 101 102 #归一化处理  newValue=(oldValue-min)/(max-min)103 def autoNorm(dataSet):104     size=shape(dataSet)105     line=size[0]106     column=size[1]107 108     min=zeros((1,column))109     max=zeros((1,column))110     # print min111     index=0112     for value in dataSet[0,:]:113         min[0,index]=value114         max[0,index]=value115         index+=1116 #求每一列的最小值和最大值117     for i in range(line):118         for j in range(column):119             if(i==0 and j==0): 120                 continue        121             if(dataSet[i,j]>max[0,j]):122                 max[0,j]=dataSet[i,j]123             if(dataSet[i,j]<min[0,j]):124                 min[0,j]=dataSet[i,j]125 126     ranges = max-min127     # print ranges128     result=zeros((line,column))129 130     for i in range(line):131         for j in range(column):132             result[i,j]=(dataSet[i,j]-min[0,j])/ranges[0,j]133 134     return result,ranges,min135     136 137 # def autoNorm(dataSet):138 #     minVals = dataSet.min(0)139 #     maxVals = dataSet.max(0)140 #     ranges = maxVals - minVals141 #     print ranges142 #     normDataSet = zeros(shape(dataSet))143 #     m = dataSet.shape[0]144 #     normDataSet = dataSet - tile(minVals, (m,1))145 #     normDataSet = normDataSet/tile(ranges, (m,1))   #element wise divide146 #     return normDataSet, ranges, minVals147 148 149 150 def datingClassTest():151     hoRatio = 0.50      #hold out 10%152     datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom file153     normMat, ranges, minVals = autoNorm(datingDataMat)154     m = normMat.shape[0]155     numTestVecs = int(m*hoRatio)156     errorCount = 0.0157     for i in range(numTestVecs):158         classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)159         print "the classifier came back with: %s, the real answer is: %s" % (classifierResult, datingLabels[i])160         if (classifierResult != datingLabels[i]): errorCount += 1.0161     print "the total error rate is: %f" % (errorCount/float(numTestVecs))162     print errorCount163 164 165 datingClassTest()

 

四、代码以及测试数据的下载

github:   https://github.com/wlmnzf/Machine-Learning-train/tree/master/KNN

五、感谢

1.《机器学习实战》这本书写的不错,值得学习2.百度百科提供的图片和解释也不能忘3.[《什么是无监督学习》 知乎](http://www.zhihu.com/question/23194489)4.感谢 网易云音乐 《拯救》-成龙  深夜相伴

 

  扫码或者搜索 “会打代码的扫地王大爷” 关注公众号


0 0
原创粉丝点击