K-近邻算法实践

来源:互联网 发布:java new date 编辑:程序博客网 时间:2024/05/16 11:41

一、K-近邻算法概述 
k近邻算法(k-nearest neighbor,k-NN)是一种基本的分类与回归方法。简单的说,K-近邻算法采用测量不同特征值之间的距离方法进行分类。 
- 优点:进度高,对异常值不铭感、无数据输入假定。 
- 缺点:计算复杂度高、空间复杂度高。 
- 适用数据范围:数值型和标称型。

工作原理:给定一个训练数据集,对新的输入实例,在训练集中找到与该实例最邻近的K个实例,这k个实例的多数属于某个类,就把该输入实例分为这个类。

  • k-近邻算法的一般流程 
    (1)计算已知类别数据集中的点与当前点之间的距离; 
    (2)按照距离递增次序排序; 
    (3)选取与当前点距离最小的K个点 
    (4)确定前K个点所在类别的出现频率; 
    (5)返回前K个点出现频率最高的类别作为当前点的预测分
二、用python实现算法 
2.1 k-近邻算法
def classify0(inX, dataSet, labels, k):    dataSetSize = dataSet.shape[0] #得到数据集的行数    diffMat = tile(inX, (dataSetSize,1)) - dataSet #tile函数将inX复制dataSetSize份    sqDiffMat = diffMat**2    sqDistances = sqDiffMat.sum(axis=1)    distances = sqDistances**0.5    #计算距离    sortedDistIndicies = distances.argsort() #argsort函数返回的是数组值从小到大的索引值    classCount={}              for i in range(k):         voteIlabel = labels[sortedDistIndicies[i]] # 获取类别 #字典的get方法,查找classCount中是否包含voteIlabel,是则返回该值,不是则返回defValue,这里是0        # 其实这也就是计算K临近点中出现的类别的频率,以次数体现        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 # 对字典中的类别出现次数进行排序,classCount中存储的事 key-value,其中key就是label,value就是出现的次数        # 所以key=operator.itemgetter(1)选中的事value,也就是对次数进行排序    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)#sortedClassCount[0][0]也就是排序后的次数最大的那个label    return sortedClassCount[0][0]

下面用2组数据来简单的测试一下这个算法
def createDataSet():    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])    labels = ['A','A','B','B']    return group, labels

输出的结果为B,大家也可以改变输入[0,0]为其他值,测试程序的运行结果。
现在,我们已经构造出了第一个分类器,从这个分类器出发,我们在进行2个实例的练习。

三、用K-近邻算法改进约会网站配对效果
    1.问题描述
      比如你的朋友经常上约会网站寻找自己的约会对象,你的朋友选定约会对象的时候主要看重三点“每年飞行的旅程数”、“玩游戏所耗时间百分比”、“每个月看书的数目”,你阅人无数的朋友已经约会过很多个对象了,并且把这些对象分为三类“她是我喜欢的类型”、“一般喜欢”,“她不是我喜欢的类型”,经过无数次的约会之后,你的朋友心已经很累了,他想能否输入某人的“每年飞行的旅程数”、“玩游戏所耗时间百分比”、“每个月看书的数目”这三项数据,就能判断她是不是他喜欢的类型呢?
    2.爬下来的数据集(可以在机器学习实战书的官网上,下载源数据)

 
3.解决实际问题的一般步骤


下面按照这个步骤一步一步来完成
  1. 收集数据
   使用提供的txt文件。
      2.数据预处理
  ①必须将待处理的数据的格式改变为分类器可以接受的格式,将txt中的数据放到矩阵中存储。
def file2matrix(filename):    fr = open(filename)    numberOfLines = len(fr.readlines())         #get the number of lines in the file    returnMat = zeros((numberOfLines,3))        #prepare matrix to return    classLabelVector = []                       #prepare labels return       fr = open(filename)    index = 0    for line in fr.readlines():        line = line.strip()        listFromLine = line.split('\t')        returnMat[index,:] = listFromLine[0:3]        classLabelVector.append(int(listFromLine[-1]))        index += 1    return returnMat,classLabelVector
②数据的归一化(平衡各个特征间的权重)
  我们通常采用的方法是将数值归一化,如将取值范围处理到0到1或者-1到1之间。
  newValue=(oldValue-min)/(max-min)
这个公式可以将任意范围的特征值转化为0到1区间的值。 其中min,max分别是数据集中最小的特征值和最大特征值。
归一化函数autoNorm代码如下
def autoNorm(dataSet):    minVals = dataSet.min(0)    maxVals = dataSet.max(0)    ranges = maxVals - minVals    normDataSet = zeros(shape(dataSet))    m = dataSet.shape[0]    normDataSet = dataSet - tile(minVals, (m,1))    normDataSet = normDataSet/tile(ranges, (m,1))   #element wise divide    return normDataSet, ranges, minVals
   3.分析输入数据
   可以使用Matplotlib创建散点图直观感受。
   4.训练算法:此步骤不适合拥有K-近邻算法。
   5.测试算法
    用部分数据作为测试样本,在这个实验中我们用了90%的数据当作训练样本,10%的数据当作测试样本,来检查分类器的正确性。测试的指标是出错率,当预测的分类和实际的分类不一样时,记为一个错误。错误率=错误的总数/测试的样本数。
   测试算法的代码如下:
def datingClassTest():    hoRatio = 0.50      #hold out 10%    datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom file    normMat, ranges, minVals = autoNorm(datingDataMat)    m = normMat.shape[0]    numTestVecs = int(m*hoRatio)    errorCount = 0.0    for i in range(numTestVecs):        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)        print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])        if (classifierResult != datingLabels[i]): errorCount += 1.0    print "the total error rate is: %f" % (errorCount/float(numTestVecs))    print errorCount
分类器出来约会数据集的错误约为5%

  6.使用算法解决实际问题
def classifyPerson():    resultList=['not at all', 'in small doses', 'in large doess']    percenTats=float(raw_input("percentage of time spent playing video games?"))    ffMiles=float(raw_input("frequent flier miles earned per year?"))    iceCream=float(raw_input("liters of ice cream consumed per year?"))    datingDataMat,datingLabels=file2matrix('C:\Users\iris123\Desktop\datingTestSet2.txt')    normMat,ranges,minVals=autoNorm(datingDataMat)    inArr=array([ffMiles,percenTats,iceCream])    classifierResult=classify0((inArr-minVals)/ranges,normMat,datingLabels,3)    print "you will probably like this person:",resultList[classifierResult-1]


根据我们输入的数据,预测对约会对象的兴趣程度。
    
四、使用K-近邻算法识别手写字体
     总体步骤和上文类似,就不在赘述了。展示一下核心代码
def img2vector(filename):    returnVect = zeros((1,1024))    fr = open(filename)    for i in range(32):        lineStr = fr.readline()        for j in range(32):            returnVect[0,32*i+j] = int(lineStr[j])    return returnVectdef handwritingClassTest():    hwLabels = []    trainingFileList = listdir('trainingDigits')           #load the training set    m = len(trainingFileList)    trainingMat = zeros((m,1024))    for i in range(m):        fileNameStr = trainingFileList[i]        fileStr = fileNameStr.split('.')[0]     #take off .txt        classNumStr = int(fileStr.split('_')[0])        hwLabels.append(classNumStr)        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)    testFileList = listdir('testDigits')        #iterate through the test set    errorCount = 0.0    mTest = len(testFileList)    for i in range(mTest):        fileNameStr = testFileList[i]        fileStr = fileNameStr.split('.')[0]     #take off .txt        classNumStr = int(fileStr.split('_')[0])        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)        print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)        if (classifierResult != classNumStr): errorCount += 1.0    print "\nthe total number of errors is: %d" % errorCount    print "\nthe total error rate is: %f" % (errorCount/float(mTest))

img2Vector把图片数据转化为向量数据。