K近邻法及手写数字识别系统(二)

来源:互联网 发布:ezzy付强 知乎 编辑:程序博客网 时间:2024/06/04 23:26

         在K近邻法及约会网站预测系统(一),简单的介绍了K近邻法的定义和三要素,并使用K近邻法实现了简单的约会网站预测系统。本文在此基础上,构建识别数字0~9的手写数字识别系统。

        手写数字识别系统的训练集中每个数字大约有200个样本,一共约2000个数据,测试数据集包括大概900个数据。其中每个数据为32*32的二进制图像矩阵,如下图表示数字0。为了复用前面网站约会系统的分类函数,需要将矩阵转换为1*1024的向量。通过img2vector()函数创建一个1*1024的Numpy数组,循环读取给定文件的前32行,并将每行的前32个字符存入数组,返回数组。


def img2vector(filename):          #将32*32的二进制矩阵转换为1*1024的向量       returnVect = zeros((1,1024))   #构建1*1024数组    fr = open(filename)    for i in range(32):            #读取前32行        lineStr = fr.readline()        for j in range(32):        #读取前32个字符            returnVect[0,32*i+j] = int(lineStr[j])    #存入数组    return returnVect


def handwriterClassTest():            #手写数字识别系统       hwLabels = []                     #存储数据类别    trainingFileList = listdir('trainingDigits')    #给定目录下的所有文件名,需要从os模块导入listdir函数,"trainingDigits":训练数据集的文件夹    m = len(trainingFileList)                       #训练数据集长度    trainingMat = zeros((m,1024))                   #训练数据集的矩阵    for i in range(m):                              #遍历所有数据        fileNameStr = trainingFileList[i]           #获得当前文件,文件名为0_0.txt  第一个0代表数字类别,第二个0代表当前类别的序号        fileStr = fileNameStr.split('.')[0]         #截取文件名  如:0_0.txt ->  0_0        classNumStr = int(fileStr.split('_')[0])    #从文件名获取类别           hwLabels.append(classNumStr)                #将当前类别加入数组        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)  #将当前文件中的32*32矩阵转换成1*1024向量    testFileList = listdir('testDigits')            #测试数据集   "testDigits"为测试数据集的文件夹    errorCount = 0.0                                #记录错误数量           mTest = len(testFileList)                       #测试数据集数量    for i in range(mTest):                          #循环遍历测试数据        fileNameStr = testFileList[i]        fileStr = fileNameStr.split('.')[0]        classNumStr = int(fileStr.split('_')[0])            vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)    #这四行代码和上面的训练数据集的处理方法一致        classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)    #使用上一篇中的分类函数进行分类        print('the classify came back with %d,the real answer is:%d' % (classifierResult,classNumStr))   #打印每条数据的实际类别和预测类别        if (classifierResult != classNumStr):       #如果不相等,errorCount加1            errorCount += 1.0    print('the total number of errors is :%d' % errorCount)            #打印错误数    print('the total error rate is:%f' % (errorCount/float(mTest)))    #打印错误率

          可以看出,错误数为11个,错误率约为0.01。

end

原创粉丝点击