机器学习(1)-K-近邻算法(KNN)

来源:互联网 发布:条形码 连接数据库 编辑:程序博客网 时间:2024/06/01 23:32

优缺点和适用范围

  • 优点:精度高、对异常值不敏感、无数据输入假定。
  • 缺点:计算复杂度高、空间复杂度高。
  • 适用数据范围:数值型和标称型(离散型数据,变量的结果只在有限目标集中取值)。

原理/数学推理过程

  • 存在数据集,且每个数据存在标签,输入没有标签的数据后,计算该数据到所有其他已知类别数据的距离,排序,并取最近的k个(k<20),选择k个数据中出现次数最多的类别作为输入数据的分类

过程代码实现

  • 收集数据:可以使用任何方法。
  • 准备数据:距离计算所需要的数值,最好是结构化的数据格式。
  • 分析数据:可以使用任何方法。
  • 训练算法:此步骤不适用于k近邻算法。
  • 测试算法:计算错误率。
  • 使用算法:首先需要输入样本数据和结构化的输出结果,然后运行k近邻算法判定输入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理。

数据和源码

  • 最简单的例子
import numpy as np import operatordef createDataSet():    ground = np.array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])    labels = ['A', 'A', 'B', 'B']    return ground, labels# 分类算法def classify0(inX,dataSet,labels,k):    # 数据长度    dataSetSize = dataSet.shape[0]    # 计算inX点与其他所有点的距离,tile方法把inX点修改为矩阵    diffMat = np.tile(inX, (dataSetSize,1)) - dataSet    # 计算平方    sqDiffMat = diffMat ** 2    # 求和    sqDistances = sqDiffMat.sum(axis = 1)    # 求根号    distance = sqDistances **0.5    # 排序,并取其index存值    sortedDistIndicies = distance.argsort()    classCount = {}    for i in range(k):        voteLabel = labels[sortedDistIndicies[i]]        # 有则+1,无则生成一个        classCount[voteLabel] = classCount.get(voteLabel,0) + 1    # 字典的排序    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)    # python2下使用classCount.iteritems()代替classCount.items()    return sortedClassCount[0][0]if __name__=='__main__':    group, labels = createDataSet()    print(classify0([3,3] ,group, labels, 3))
  • 改进约会网站配对效果
import operatorimport numpy as npimport matplotlibimport matplotlib.pyplot as pltdef createDataSet():    ground = np.array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])    labels = ['A', 'A', 'B', 'B']    return ground, labels# 分类算法def classify0(inX, dataSet, labels, k):      # 数据长度    dataSetSize = dataSet.shape[0]    # 计算inX点与其他所有点的激励,tile方法把inX点修改为矩阵    diffMat = np.tile(inX, (dataSetSize,1)) - dataSet    # 计算平方    sqDiffMat = diffMat ** 2    # 求和    sqDistances = sqDiffMat.sum(axis = 1)    # 求根号    distance = sqDistances **0.5    # 排序,并取其index存值    sortedDistIndicies = distance.argsort()    classCount = {}    for i in range(k):        voteLabel = labels[sortedDistIndicies[i]]        # 有则+1,无则生成一个        classCount[voteLabel] = classCount.get(voteLabel,0) + 1    # 字典的排序    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)    return sortedClassCount[0][0]# 文本转换为数据def file2matrix(filename):    fr = open(filename)    arrayOlines = fr.readlines()    numberOfLines = len(arrayOlines)    # 初始化0矩阵    returnMat = np.zeros((numberOfLines,3))    classLabelVector = []    index = 0     for line in arrayOlines:        # strip() 方法用于移除字符串头尾指定的字符(默认为空格)。        line = line.strip()        listFromLine = line.split('\t')        returnMat[index,:] = listFromLine[0:3]        # 取倒数第一个数        classLabelVector.append(int(listFromLine[-1]))        index += 1    return returnMat, classLabelVector# 归一化操作def autoNorm(dataSet):    minVals = dataSet.min(0)    maxVals = dataSet.max(0)    ranges = maxVals-minVals    normDataSet = np.zeros(np.shape(dataSet))    m = dataSet.shape[0]    normDataSet = dataSet-np.tile(minVals,(m,1))    normDataSet = normDataSet/np.tile(ranges,(m,1))    return normDataSet,ranges,minVals# 测试函数def datingClassTest():    # 取前面10%的数据用作测试数据    hoRatio =0.10    datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')    normMat, ranges, minVals = autoNorm(datingDataMat)    m = normMat.shape[0]    numTestVecs = int(m*hoRatio)    errorCount = 0.0    for i in range(numTestVecs):        classifirerResult = classify0(normMat[i,:], normMat[numTestVecs:m,:], datingLabels[numTestVecs:m],3)        print('判断类别是%d,正确答案是%d' % (classifirerResult, datingLabels[i]))        if(classifirerResult != datingLabels[i]):            errorCount +=1.0        pass    pass    print('错误的个数是%d,错误率是%f' % (errorCount,errorCount/numTestVecs))if __name__=='__main__':    # group, labels = createDataSet()    # classify0([3,3] ,group, labels, 3)    # datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')    # datingDataMat,ranges,minVals = autoNorm(datingDataMat)    # fig = plt.figure()    # ax = fig.add_subplot(111)    # ax.scatter(datingDataMat[:,0],datingDataMat[:,1],5.0*np.array(datingLabels),15.0*np.array(datingLabels))    # plt.show()    datingClassTest()
  • 数据
    链接: https://pan.baidu.com/s/1pLVzIcF 密码: ay45