K-近邻kNN算法简介

来源:互联网 发布:阿里云腾讯云对比报告 编辑:程序博客网 时间:2024/04/29 19:49

1.K-近邻算法:存在一个带标签的训练数据集,对于待分类的样例,从训练数据集中选取与待分类样例距离最近的K个样例,由这K个样例投票表决待分类样例的类别。

2.优点:精度高、对异常值不敏感、无数据输入假定。缺点:计算复杂度高、空间复杂度高、无法给出任何数据的基础结构信息。

3.有k值的选择、距离度量及分类决策规则等三个基本要素。

4.计算样本距离前需要归一化数据,消除数据量级不一致问题。

5.k值的选择反映了对近似误差与估计误差之间的权衡,通常由交叉验证选择最优的k。


'''Created on 2017年2月7日kNN算法@author: Jakin Liu'''import numpy as npimport operatorimport matplotlib.pyplot as pltfrom os import listdir# K-近邻算法def classify0(inX, dataSet, labels, k):    # 获取训练样例数目,对待测样例构造差分矩阵    dataSetSize = dataSet.shape[0]    diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet        # 根据差分矩阵计算待测样例到所有训练样例的距离,并排序    sqDiffMat = diffMat ** 2    sqDistances = sqDiffMat.sum(axis = 1)    distances = sqDistances ** 0.5    sortedDistIndicies = distances.argsort()        # 对距离最近的前k个样例的类别进行计数    classCount = {}    for i in range(k):        voteIlabel = labels[sortedDistIndicies[i]]        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1        # 对类别计数进行排序,将计数最大的类别赋予待测样例    sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)    return sortedClassCount[0][0]# 从文本文件中读取并解析原始数据def file2matrix(filename):    # 读取文本    fr = open(filename)    arrayOLines = fr.readlines()        # 解析和构造数据    numberOfLines = len(arrayOLines)    returnMat = np.zeros((numberOfLines, 3))    classLabelVector = []        index = 0    for line in arrayOLines:        line = line.strip()        listFromLine = line.split('\t')        returnMat[index, :] = listFromLine[0:3]        classLabelVector.append(int(listFromLine[-1]))        index += 1            return returnMat, classLabelVector# 对输入数据集进行归一化def autoNorm(dataSet):    minVals = dataSet.min(0)    maxVals = dataSet.max(0)    ranges = maxVals - minVals    normDataSet = dataSet    m = dataSet.shape[0]    normDataSet = normDataSet - np.tile(minVals, (m, 1))    normDataSet = normDataSet / np.tile(ranges, (m, 1))    return normDataSet, ranges, minVals# 针对约会数据集进行分类的测试def datingClassTest():    # 读入约会数据,并归一化    datingDataMat, datingLabels = file2matrix('Datas/datingTestSet2.txt')    normMat, ranges, minVals = autoNorm(datingDataMat)        # 分配训练集和测试集样例数目    hoRatio = 0.10    m = normMat.shape[0]        numTestVecs = int(m * hoRatio)        # 对测试集进行分类,并记录错分的样例数,计算和打印错分率    errorCount = 0.0    for i in range(numTestVecs):        classifierResult = classify0(normMat[i,:], normMat[numTestVecs:m,:],                                     datingLabels[numTestVecs:m], 3)        print('the classifier came back with: %d, the real answer is: %d'               % (classifierResult, datingLabels[i]))        if (classifierResult != datingLabels[i]):            errorCount += 1.0                print('the total error rate is: %f' % (errorCount / float(numTestVecs)))# 约会网站预测函数,交互输入测试样例def classifyPerson():    # 交互输入测试样例    percentTats = float(input('percentage of time spent playing video games?'))    ffMiles = float(input('frequent flier miles earned per year?'))    iceCream = float(input('liters of ice cream consumed per year?'))    inArr = np.array([ffMiles, percentTats, iceCream])        # 输入训练数据集,并归一化数据    datingDataMat, datingLabels = file2matrix('Datas/datingTestSet2.txt')    normMat, ranges, minVals = autoNorm(datingDataMat)        # 对输入样例进行分类测试    classifierResult = classify0((inArr - minVals) / ranges, normMat, datingLabels, 3)        # 打印输入样例的类别    resultList = ['not at all', 'in small doses', 'in large doses']    print('You will probably like this person: ', resultList[classifierResult - 1])# 读取手写数字图像的文本,并解析成一维行向量def img2vector(filename):    returnVect = np.zeros((1, 1024))    fr = open(filename)    for i in range(32):        lineStr = fr.readline()        for j in range(32):            returnVect[0, 32*i+j] = int(lineStr[j])    return returnVect# 手写数字识别系统的测试def handwritingClassTest():    # 读入并解析手写数字的训练数据    trainingFileList = listdir('Datas/trainingDigits')    m = len(trainingFileList)    trainingMat = np.zeros((m, 1024))    hwLabels = []    for i in range(m):        fileNameStr = trainingFileList[i]        fileStr = fileNameStr.split('.')[0]        classNumStr = int(fileStr.split('_')[0])        hwLabels.append(classNumStr)        trainingMat[i,:] = img2vector('Datas/trainingDigits/%s' % fileNameStr)        # 读入并解析手写数字的测试数据,并kNN分类        testFileList = listdir('Datas/testDigits')    errorCount = 0.0    mTest = len(testFileList)    for i in range(mTest):        fileNameStr = testFileList[i]        fileStr = fileNameStr.split('.')[0]        classNumStr = int(fileStr.split('_')[0])        vectorUnderTest = img2vector('Datas/testDigits/%s' % fileNameStr)                # kNN分类        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)        print('the classifier came back with: %d, the real answer is: %d'              % (classifierResult, classNumStr))                # 如果分类错误,计数加一        if (classifierResult != classNumStr):            errorCount += 1.0                # 打印分类错误数目和错分率    print('\nthe total number of errors is: %d' % errorCount)    print('\nthe total error rate is: %f' % (errorCount/float(mTest)))if __name__ == '__main__':    # 读取datingTestSet2.txt数据    datingDataMat, datingLabels = file2matrix('Datas/datingTestSet2.txt')    print(datingDataMat)    print(datingLabels)        # 展示约会数据集    fig = plt.figure()    ax = fig.add_subplot(111)    ax.scatter(datingDataMat[:,0], datingDataMat[:,1],               15.0 * np.array(datingLabels), 15.0 * np.array(datingLabels))    plt.show()        # 针对约会数据集进行分类的测试    datingClassTest()        # 约会网站预测函数,交互输入测试样例,并打印测试样例的类别    classifyPerson()        # 手写数字识别系统的测试    handwritingClassTest()    



参考:《机器学习实战》、《机器学习》

0 0