KNN 源码

来源:互联网 发布:日本娱乐圈黑料 知乎 编辑:程序博客网 时间:2024/05/29 04:56
#!/usr/bin/python#coding=utf-8from numpy import *from os import listdirimport operatorimport matplotlibimport matplotlib.pyplot as plt#KNN分类函数,四个输入:inX,待分类向量;dataSet,训练向量矩阵;labels,训练数据标签;k,k个邻居def classify0(inX, dataSet, labels, k):    dataSetSize = dataSet.shape[0]#训练向量的数量    diffMat = tile(inX, (dataSetSize,1)) - dataSet #将待预测向量复制训练向量的个数,变成矩阵,并相减    sqDiffMat = diffMat**2#各项平方    sqDistances = sqDiffMat.sum(axis=1)#并且相加    distances = sqDistances**0.5#然后开方    sortedDistIndicies = distances.argsort()     #按距离进行排序,值为在样本中是第几行,下标是最近的第几个    classCount={}              for i in range(k):#找k个最近样本        voteIlabel = labels[sortedDistIndicies[i]]#找到最近的第i个样本的标签        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 #相应的结果数量+1    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)#对结果进行排序    return sortedClassCount[0][0]#返回投票最高的结果def createDataSet():#创建数据集    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])    labels = ['A','A','B','B']    return group, labelsdef file2matrix(filename):#从文件读取数据    fr = open(filename)#打开文件    numberOfLines = len(fr.readlines())         #get the number of lines in the file    returnMat = zeros((numberOfLines,3))        #prepare matrix to return    classLabelVector = []                       #prepare labels return       fr = open(filename)    index = 0    for line in fr.readlines():        line = line.strip()        listFromLine = line.split('\t')#空格分隔        returnMat[index,:] = listFromLine[0:3]#第1到第三个元素,放到返回矩阵中        if listFromLine[-1] == 'didntLike':#-1代表最后一个元素            classLabelVector.append(1)        elif listFromLine[-1] == 'smallDoses':            classLabelVector.append(2)        elif listFromLine[-1] == 'largeDoses':            classLabelVector.append(3)        index += 1    return returnMat,classLabelVectordef showPic():#画图函数fig = plt.figure()ax = fig.add_subplot(111)#设置坐标轴ma, la = file2matrix('datingTestSet.txt')ax.scatter(ma[:,1],ma[:,2],15.0*array(la),15.0*array(la))#坐标轴设置plt.show()def autoNorm(dataSet):#自动归一化minVals = dataSet.min(0)#取数据中每一列综合的最小值maxVals = dataSet.max(0)ranges = maxVals-minVals#求差值normDataSet = zeros(shape(dataSet))#初始化归一化矩阵m = dataSet.shape[0]normDataSet = dataSet-tile(minVals,(m,1))#减去最小值normDataSet = normDataSet/tile(ranges, (m,1))#除以差值return normDataSet, ranges, minValsdef datingClassTest():#测试错误率函数hoRatio = 0.10 #取10%的数据作为测试集datingDataMat, datingLabels = file2matrix('datingTestSet.txt')normMat, ranges, minVals = autoNorm(datingDataMat)m = normMat.shape[0]numTestVecs = int(m*hoRatio)#共有多少条数据需要测试errorCount = 0.0for i in range(numTestVecs):classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],4)#分类print "the classifier came back with: %d, the real answer is: %d" %(classifierResult, datingLabels[i])if(classifierResult != datingLabels[i]):#如果结果不同则错误个数加一errorCount = errorCount+1.0print "the total error rate is : %f" %(errorCount/float(numTestVecs))def classifyPerson():#实际的使用resultList = ['not at all', 'in small doses', 'in large doses']percentTats = float(raw_input("percentage of time spent playing video games?"))ffMiles = float(raw_input("frequent fliter miles earned per year?"))iceCream = float(raw_input("liters of ice cream consumed per year?"))datingDataMat, datingLabels = file2matrix("datingTestSet.txt")normMat, ranges, minVals = autoNorm(datingDataMat)#归一化训练数据inArr = array([ffMiles, percentTats, iceCream])#输入数据向量化classifierResult = classify0((inArr-minVals)/ranges, normMat, datingLabels, 3)print "You will probably like this person:", resultList[classifierResult - 1]def img2vector(filename):#从图像转化成向量returnVec = zeros((1,1024))#初始化fr = open(filename)for i in range(32):#第i行lineStr = fr.readline()for j in range(32):#第j列returnVec[0,32*i+j] = int(lineStr[j])return returnVecdef handwritingClassTest():#识别手写数字测试,使用了上面的KNN分类函数hwLabel = []trainingFileList = listdir('trainingDigits')m = len(trainingFileList)#统计文件夹中的文件数量trainingMat = zeros((m,1024))for i in range(m):fileNameStr = trainingFileList[i]#识别文件名fileStr = fileNameStr.split('.')[0]classNumStr = int(fileStr.split('_')[0])hwLabel.append(classNumStr)trainingMat[i,:] = img2vector('trainingDigits/%s' %fileNameStr)testFileList = listdir('testDigits')errorCount = 0.0mTest = len(testFileList)#测试文件的数量for i in range(mTest):fileNameStr = testFileList[i]fileStr = fileNameStr.split('.')[0]classNumStr = int(fileStr.split('_')[0])vectorTest = img2vector('testDigits/%s' %fileNameStr)classifierResult = classify0(vectorTest,trainingMat,hwLabel,3)#0~1无需归一化print "the classifier came back with: %d, the real answer is: %d" %(classifierResult,classNumStr)if(classifierResult != classNumStr):errorCount+=1.0 print "\nthe total number of error is: %d" %errorCountprint "\nthe total error rate is: %f" %(errorCount/float(mTest))

原创粉丝点击