机器学习实战:KNN算法讲解

来源:互联网 发布:斗鱼鱼丸能用淘宝买 编辑:程序博客网 时间:2024/06/08 06:17

机器学习实战:KNN算法讲解

    KNN算法本章内容来至于《统计学习与方法》李航,《机器学习》周志华,以及《机器学习实战》Peter HarringTon,相互学习,不足之处请大家多多指教

    1.1 KNN算法的优缺点

    1.2 KNN算法的工作机制

    1.3 KNN算法的python实现

    1.4 我对KNN算法的理解

1.1 KNN算法的优缺点

      优点:KNN算法是分类数据最简单的算法,具有精度高,对异常值不明显,无数据输入假定的特点。

      缺点:KNN算法必须保存全部的数据,如果训练的数据集比较大,必须使用大量的存储空间,而且对每个数据距离计算,可能会比较耗时,KNN算法的另一个缺陷是无法给出任何数据的基础结构信息,无法知道实例样本和典型样本具有什么特征。

1.2 KNN算法的工作机制

【1】KNN算法:给定测试样本,基于某种距离度量找到训练集中最靠近的K个训练样本,然后基于这K个邻居的信息来进行预测,通常在分类任务重可以使用“投票法”,即选择这K个样本中出现最多的类别标记作为预测结果,在回归任务中可以使用平均分,将k个样本的实值输出标记的平均值作为预测值,或者是积极与距离远近进行加权平均或者加权投票,距离越近的样本权重越大。-----周志华 《机器学习P225页》

输入:训练数据T = {(x1,y1),(x2,y2),(x3,y3),(x4,y4),……(xn,yn)},

实例的类别y={c1,c2,c3……,cn},以及实例向量x。

输出:实例x所属的类别y。

算法过程:

(1)根据给定的距离度量,在训练集T中找出与实例x最近的K个点,涵盖这k个点的x的领域记为Nk(x)

(2)在Nk(x)中,根据分类决策规则,如多数表决,决定x的类别

(3)K近邻算法的特殊情况是K=1的情况,称为最近邻算法,对于输入的实例点,最近邻算法将训练数据集中与x最近点的类作为x的类

 

【2】距离度量包括LP距离,欧氏距离,曼哈顿距离,《统计学习与方法》

其中xi,xj的LP距离定义为:

 

当P=2时候,称为欧氏距离:

 

当P= 1时候,称为曼哈顿距离:

 

当P=无穷大时候,他是坐标的最大值:

  

【3】关于K值的选择对KNN算法的影响

如果选择较小的K值,就相当于用较小的领域中的训练实例进行预测,学习的近似误差会减小,只有输入实力和相似点的训练实例较近时候,才会对预测起结果,但学习的误差估计会增大,预测结果对对近邻的实例点非常敏感,如果近邻是噪声点就会出错,换句话说K值变小,会使得整体模型变得复杂,容易发生过拟合

如果K值比较大,就相当于用较大的领域中的训练实例进行预测,其优点是会减少学习的估计误差,但是缺点是学习的近似误差会增大,这时候与输入实例较远的点也会对训练实例起预测作用,使得预测发生错误K值的增大会使得整个模型变得更加简单。

在训练过程中,K值通常比较小,通常采用交叉验证法来选取合适的K值

    

1.3 KNN算法的python实现

参照机器学习实战的例子,使用KNN算法改进约会网站的配对效果

#!/usr/bin/python#-*- encoding:utf-8 -*-from numpy import  *import  numpy as npimport operatorimport matplotlib as mplimport matplotlib.pyplot as plt#添加Linux黑体字库,避免matplotlib显示中文乱码mpl.rcParams['font.sans-serif'] = [u'SimHei']mpl.rcParams['axes.unicode_minus'] = Falsedef createDataSet():    #训练的数据T    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])    #分类的标签label    labels = ['A','A','B','B']    return  group,labelsdef classify0(inx,dataSet,labels,k):    dataSetSize = dataSet.shape[0]    print dataSet    diffMat = tile(inx,(dataSetSize,1)) - dataSet    print  diffMat    sqDiffMat = diffMat **2    sqDistances = sqDiffMat.sum(axis=1)    print 'sqDistances =',sqDistances    distances = sqDistances**0.5    sortedDistIndicies = distances.argsort()    classCount = {}    for i in range(k):        voteIlabel = labels[sortedDistIndicies[i]]        classCount[voteIlabel] = classCount.get(voteIlabel,0)+1    sortedClassCount = sorted(classCount.iteritems(),key = operator.itemgetter(1),reverse = True)    return sortedClassCount[0][0]#从文本数据中获得数据def file2matrix(filename):    fr = open(filename)    arrayOlines = fr.readlines();    numberOfLines = len(arrayOlines)    returnMat = zeros((numberOfLines,3))    classLabelVector = []    index = 0    for line in arrayOlines:        line = line.strip()        listFromLine = line.split('\t')        returnMat[index,:]=listFromLine[0:3]        classLabelVector.append(int(listFromLine[-1]))        index +=1    return returnMat,classLabelVector#归一化空间def autoNorm(dataSet):    minVals = dataSet.min(0)    maxVals = dataSet.max(0)    ranges = maxVals -minVals    normDataSet = zeros(shape(dataSet))    m = dataSet.shape[0]    normDataSet = dataSet-tile(minVals,(m,1))    normDataSet = normDataSet/tile(ranges,(m,1))    return normDataSet,ranges,minValsdef datingClassTest():    hoRatio = 0.10    datingDataMat,datingLabels = file2matrix('datingTestSet.txt')    normMat,ranges,minVals = autoNorm(datingDataMat)    m=normMat.shape[0]    numTestVecs = int(m*hoRatio)    errorCount = 0.0    for i in range(numTestVecs):        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)        print 'the classifier came back with: %d ,the real answer is :%d'%(classifierResult,datingLabels[i])        if(classifierResult!= datingLabels[i]):errorCount+=1.0    print "the total error rate is :%f"%(errorCount/float(numTestVecs))if __name__ == "__main__":    # datingClassTest()    group,labels = createDataSet();    datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')    fig =plt.figure(facecolor='w')    ax = fig.add_subplot(111)    ax.scatter(datingDataMat[:,1],datingDataMat[:,2],15.0*array(datingLabels),15.0*array(datingLabels))    plt.xlabel(u"玩游戏所耗时间百分比",fontsize=14)    plt.ylabel(u"每周消耗的冰淇淋公斤升数",fontsize=14)    plt.title(u"约会网站KNN算法预测")    plt.show()


实验的结果:


 

代码技巧

1:归一化特征

 

2:使用多通道颜色显示不同的类别

 

代码调试过程中出现的bug

 

代码下载: