KNN 源码
来源:互联网 发布:日本娱乐圈黑料 知乎 编辑:程序博客网 时间:2024/05/29 04:56
#!/usr/bin/python#coding=utf-8from numpy import *from os import listdirimport operatorimport matplotlibimport matplotlib.pyplot as plt#KNN分类函数,四个输入:inX,待分类向量;dataSet,训练向量矩阵;labels,训练数据标签;k,k个邻居def classify0(inX, dataSet, labels, k): dataSetSize = dataSet.shape[0]#训练向量的数量 diffMat = tile(inX, (dataSetSize,1)) - dataSet #将待预测向量复制训练向量的个数,变成矩阵,并相减 sqDiffMat = diffMat**2#各项平方 sqDistances = sqDiffMat.sum(axis=1)#并且相加 distances = sqDistances**0.5#然后开方 sortedDistIndicies = distances.argsort() #按距离进行排序,值为在样本中是第几行,下标是最近的第几个 classCount={} for i in range(k):#找k个最近样本 voteIlabel = labels[sortedDistIndicies[i]]#找到最近的第i个样本的标签 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 #相应的结果数量+1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)#对结果进行排序 return sortedClassCount[0][0]#返回投票最高的结果def createDataSet():#创建数据集 group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) labels = ['A','A','B','B'] return group, labelsdef file2matrix(filename):#从文件读取数据 fr = open(filename)#打开文件 numberOfLines = len(fr.readlines()) #get the number of lines in the file returnMat = zeros((numberOfLines,3)) #prepare matrix to return classLabelVector = [] #prepare labels return fr = open(filename) index = 0 for line in fr.readlines(): line = line.strip() listFromLine = line.split('\t')#空格分隔 returnMat[index,:] = listFromLine[0:3]#第1到第三个元素,放到返回矩阵中 if listFromLine[-1] == 'didntLike':#-1代表最后一个元素 classLabelVector.append(1) elif listFromLine[-1] == 'smallDoses': classLabelVector.append(2) elif listFromLine[-1] == 'largeDoses': classLabelVector.append(3) index += 1 return returnMat,classLabelVectordef showPic():#画图函数fig = plt.figure()ax = fig.add_subplot(111)#设置坐标轴ma, la = file2matrix('datingTestSet.txt')ax.scatter(ma[:,1],ma[:,2],15.0*array(la),15.0*array(la))#坐标轴设置plt.show()def autoNorm(dataSet):#自动归一化minVals = dataSet.min(0)#取数据中每一列综合的最小值maxVals = dataSet.max(0)ranges = maxVals-minVals#求差值normDataSet = zeros(shape(dataSet))#初始化归一化矩阵m = dataSet.shape[0]normDataSet = dataSet-tile(minVals,(m,1))#减去最小值normDataSet = normDataSet/tile(ranges, (m,1))#除以差值return normDataSet, ranges, minValsdef datingClassTest():#测试错误率函数hoRatio = 0.10 #取10%的数据作为测试集datingDataMat, datingLabels = file2matrix('datingTestSet.txt')normMat, ranges, minVals = autoNorm(datingDataMat)m = normMat.shape[0]numTestVecs = int(m*hoRatio)#共有多少条数据需要测试errorCount = 0.0for i in range(numTestVecs):classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],4)#分类print "the classifier came back with: %d, the real answer is: %d" %(classifierResult, datingLabels[i])if(classifierResult != datingLabels[i]):#如果结果不同则错误个数加一errorCount = errorCount+1.0print "the total error rate is : %f" %(errorCount/float(numTestVecs))def classifyPerson():#实际的使用resultList = ['not at all', 'in small doses', 'in large doses']percentTats = float(raw_input("percentage of time spent playing video games?"))ffMiles = float(raw_input("frequent fliter miles earned per year?"))iceCream = float(raw_input("liters of ice cream consumed per year?"))datingDataMat, datingLabels = file2matrix("datingTestSet.txt")normMat, ranges, minVals = autoNorm(datingDataMat)#归一化训练数据inArr = array([ffMiles, percentTats, iceCream])#输入数据向量化classifierResult = classify0((inArr-minVals)/ranges, normMat, datingLabels, 3)print "You will probably like this person:", resultList[classifierResult - 1]def img2vector(filename):#从图像转化成向量returnVec = zeros((1,1024))#初始化fr = open(filename)for i in range(32):#第i行lineStr = fr.readline()for j in range(32):#第j列returnVec[0,32*i+j] = int(lineStr[j])return returnVecdef handwritingClassTest():#识别手写数字测试,使用了上面的KNN分类函数hwLabel = []trainingFileList = listdir('trainingDigits')m = len(trainingFileList)#统计文件夹中的文件数量trainingMat = zeros((m,1024))for i in range(m):fileNameStr = trainingFileList[i]#识别文件名fileStr = fileNameStr.split('.')[0]classNumStr = int(fileStr.split('_')[0])hwLabel.append(classNumStr)trainingMat[i,:] = img2vector('trainingDigits/%s' %fileNameStr)testFileList = listdir('testDigits')errorCount = 0.0mTest = len(testFileList)#测试文件的数量for i in range(mTest):fileNameStr = testFileList[i]fileStr = fileNameStr.split('.')[0]classNumStr = int(fileStr.split('_')[0])vectorTest = img2vector('testDigits/%s' %fileNameStr)classifierResult = classify0(vectorTest,trainingMat,hwLabel,3)#0~1无需归一化print "the classifier came back with: %d, the real answer is: %d" %(classifierResult,classNumStr)if(classifierResult != classNumStr):errorCount+=1.0 print "\nthe total number of error is: %d" %errorCountprint "\nthe total error rate is: %f" %(errorCount/float(mTest))
阅读全文
0 0
- KNN 源码
- KNN算法源码解析
- 机器学习实战源码KNN
- 《统计学习方法》-KNN笔记和python源码
- Shark源码分析(十):KNN算法
- knn
- knn
- KNN
- KNN
- KNN
- KNN
- KNN
- knn
- KNN
- knn
- kNN
- KNN
- KNN
- ShaderForge插件分享
- iOS最全面试题及答案
- solr 下载
- Spring XML配置--使用注解装配(@Atutowired、@Inject、@Resource)
- 环境变量配置好了,tomcat启动依然闪退
- KNN 源码
- GitHub上README.md教程
- 如何让ThinkPHP的模板引擎达到最佳效率
- c#中(int)、int.Parse()、int.TryParse、Convert.ToInt32的区别
- Myeclipse破解工具破解方法(二)
- javascript 中的call 和apply
- javascript Object
- Android webView的cookie机制
- Spring之bean