基于KNN的手写数字识别

来源:互联网 发布:数据完整性保护 编辑:程序博客网 时间:2024/05/18 12:01

代码如下:

#coding=utf-8from numpy import *import operatorimport osimport timedef createDataSet():    group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])    labels = ['A','A','B','B']    return group,labels#inputX表示输入向量(也就是我们要判断它属于哪一类的)#dataSet表示训练样本#label表示训练样本的标签#k是最近邻的参数,选最近k个def kNNclassify(inputX, dataSet, labels, k):    dataSetSize = dataSet.shape[0]#计算有几个训练数据    #开始计算欧几里得距离    diffMat = tile(inputX, (dataSetSize,1)) - dataSet    #diffMat = inputX.repeat(dataSetSize, aixs=1) - dataSet    sqDiffMat = diffMat ** 2    sqDistances = sqDiffMat.sum(axis=1)#矩阵每一行向量相加    distances = sqDistances ** 0.5    #欧几里得距离计算完毕    sortedDistance = distances.argsort()    classCount = {}    for i in xrange(k):        voteLabel = labels[sortedDistance[i]]        classCount[voteLabel] = classCount.get(voteLabel,0) + 1    res = max(classCount)    return resdef img2vec(filename):    returnVec = zeros((1,1024))    fr = open(filename)    for i in range(32):        lineStr = fr.readline()        for j in range(32):            returnVec[0,32*i+j] = int(lineStr[j])    return returnVecdef handwritingClassTest(trainingFloder,testFloder,K):    hwLabels = []    trainingFileList = os.listdir(trainingFloder)    m = len(trainingFileList)    trainingMat = zeros((m,1024))    for i in range(m):        fileName = trainingFileList[i]        fileStr = fileName.split('.')[0]        #获取真实的值        classNumStr = int(fileStr.split('_')[0])        hwLabels.append(classNumStr)        trainingMat[i,:] = img2vec(trainingFloder+'/'+fileName)    testFileList = os.listdir(testFloder)    errorCount = 0.0    mTest = len(testFileList)    for i in range(mTest):        fileName = testFileList[i]        fileStr = fileName.split('.')[0]        classNumStr = int(fileStr.split('_')[0])        vectorUnderTest = img2vec(testFloder+'/'+fileName)        classifierResult = kNNclassify(vectorUnderTest, trainingMat, hwLabels, K)        #输出测试结果        #print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)        if classifierResult != classNumStr:            errorCount +=1    print "\nthe total number of tests is: %d" % mTest  # 输出测试总样本数    print "the total number of errors is: %d" % errorCount  # 输出测试错误样本数    print 'the correct rate is; ',1-errorCount/mTest    print 'the error rate is; ',errorCount/mTestdef main():    t1 = time.clock()    handwritingClassTest('trainingDigits','testDigits',3)    t2 = time.clock()    print 'execute time:%ds' %(t2-t1)if __name__=='__main__':    main()
#k=1 98.62%

#k=2 98.52%

#k=3 97.56%

#k=4 96.30%

0 0
原创粉丝点击