KNN之手写数字识别

来源:互联网 发布:java小项目开发实例 编辑:程序博客网 时间:2024/05/04 15:06

代码

  1. '''
  2. Created on Sep 16, 2010
  3. kNN: kNearest Neighbors
  4. Input:     inX: vector to compare to existing dataset (1xN)
  5.             dataSet: size m data set of known vectors (NxM)
  6.            labels: data set labels (1xM vector)
  7.             k: number of neighbors to use for comparison (should be an odd number)
  8.             
  9. Output:    the most popular class label
  10. @author: pbharrin
  11. '''
  12. from numpy import *
  13. import operator
  14. from os import listdir
  15. def classify0(inX, dataSet, labels, k):
  16.     dataSetSize = dataSet.shape[0]
  17.     diffMat = tile(inX, (dataSetSize,1)) - dataSet
  18.     sqDiffMat = diffMat**2
  19.     sqDistances = sqDiffMat.sum(axis=1)
  20.     distances = sqDistances**0.5
  21.     sortedDistIndicies = distances.argsort()     
  22.     classCount={}          
  23.     for i in range(k):
  24.         voteIlabel = labels[sortedDistIndicies[i]]
  25.         classCount[voteIlabel] =classCount.get(voteIlabel,0) + 1
  26.     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
  27.     return sortedClassCount[0][0]
  28. def createDataSet():
  29.     group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
  30.     labels = ['A','A','B','B']
  31.     return group, labels
  32. def file2matrix(filename):
  33.     fr = open(filename)
  34.     numberOfLines = len(fr.readlines())          #get the number of lines in the file
  35.     returnMat = zeros((numberOfLines,3))         #prepare matrix to return
  36.     classLabelVector = []                       #prepare labels return   
  37.     fr = open(filename)
  38.     index = 0
  39.     for line in fr.readlines():
  40.         line = line.strip()
  41.         listFromLine = line.split('\t')
  42.         returnMat[index,:] =listFromLine[0:3]
  43.         classLabelVector.append(int(listFromLine[-1]))
  44.         index += 1
  45.     return returnMat,classLabelVector
  46.     
  47. def autoNorm(dataSet):
  48.     minVals = dataSet.min(0)
  49.     maxVals = dataSet.max(0)
  50.     ranges = maxVals - minVals
  51.     normDataSet = zeros(shape(dataSet))
  52.     m = dataSet.shape[0]
  53.     normDataSet = dataSet - tile(minVals, (m,1))
  54.     normDataSet = normDataSet/tile(ranges, (m,1))   #element wise divide
  55.     return normDataSet, ranges, minVals
  56.    
  57. def datingClassTest():
  58.     hoRatio = 0.50      #hold out 10%
  59.     datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')        #load data setfrom file
  60.     normMat, ranges, minVals =autoNorm(datingDataMat)
  61.     m = normMat.shape[0]
  62.     numTestVecs = int(m*hoRatio)
  63.     errorCount = 0.0
  64.     for i in range(numTestVecs):
  65.         classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
  66.         print "the classifier came back with: %d,the real answer is: %d" % (classifierResult,datingLabels[i])
  67.         if (classifierResult !=datingLabels[i]): errorCount += 1.0
  68.     print "the total error rate is: %f" % (errorCount/float(numTestVecs))
  69.     print errorCount
  70.     
  71. def img2vector(filename):
  72.     returnVect = zeros((1,1024))
  73.     fr = open(filename)
  74.     for i in range(32):
  75.         lineStr = fr.readline()
  76.         for j in range(32):
  77.             returnVect[0,32*i+j] = int(lineStr[j])
  78.     return returnVect
  79. def handwritingClassTest():
  80.     hwLabels = []
  81.     trainingFileList = listdir('trainingDigits')            #load the training set
  82.     m = len(trainingFileList)
  83.     trainingMat = zeros((m,1024))
  84.     for i in range(m):
  85.         fileNameStr = trainingFileList[i]
  86.         fileStr = fileNameStr.split('.')[0]     #take off .txt
  87.         classNumStr = int(fileStr.split('_')[0])
  88.         hwLabels.append(classNumStr)
  89.         trainingMat[i,:] =img2vector('trainingDigits/%s' %fileNameStr)
  90.     testFileList = listdir('testDigits')        #iterate through the test set
  91.     errorCount = 0.0
  92.     mTest = len(testFileList)
  93.     for i in range(mTest):
  94.         fileNameStr = testFileList[i]
  95.         fileStr = fileNameStr.split('.')[0]     #take off .txt
  96.         classNumStr = int(fileStr.split('_')[0])
  97.         vectorUnderTest = img2vector('testDigits/%s' %fileNameStr)
  98.         classifierResult = classify0(vectorUnderTest,trainingMat, hwLabels, 3)
  99.         print "the classifier came back with: %d,the real answer is: %d" % (classifierResult,classNumStr)
  100. #        if (classifierResult != classNumStr):
  101. #            print "the classifier came back with: %s, the real answer is: %d" %(fileNameStr, classNumStr)
  102.         if (classifierResult !=classNumStr): errorCount += 1.0
  103.     print "\nthe total number of errorsis: %d" % errorCount
  104.     print "\nthe total error rate is:%f" % (errorCount/float(mTest))
  105.     
  106. handwritingClassTest()

    结果

    1. thetotal number of errorsis: 11
    2. the total error rate is:0.011628