《机器学习实战》之K—近邻算法实现手写体数字识别

来源:互联网 发布:java接口自动化测试 编辑:程序博客网 时间:2024/06/04 18:12

一、问题描述

  • 主要程序为kNN.py,在主程序中,包含函数:
    • classify0(inX, dataSet, labels, k):实现分类
    • file2matrix(filename):将文本文件转换为矩阵,本例中没有用到
    • autoNum(dataSet):这个函数没有用到,作用是实现均值归一化
    • img2vector(filename):将图片转化为向量,图片大小是32*32,转化后的向量为1*1024,
    • Detect_Test():数字识别和错误率计算函数

二、各函数的代码

  • classify0(inX, dataSet, labels, k)
# -*- coding=utf-8 -*-from os import listdirfrom numpy import *import operatordef classify0(inX, dataSet, labels, k):    dataSetSize = dataSet.shape[0]    diffMat = tile(inX, (dataSetSize,1)) - dataSet    sqDiffMat = diffMat**2    sqDistances = sqDiffMat.sum(axis=1)    distances = sqDistances**0.5    sortedDistIndicies = distances.argsort()    classCount={}    for i in range(k):        voteIlabel = labels[sortedDistIndicies[i]]        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)    return sortedClassCount[0][0]
  • img2vector(filename)
def img2vector(filename):             #将图片转化为向量    returnVect=zeros((1,1024))    fr=open(filename)    for i in range(32):        lineStr=fr.readline()        for j in range(32):            returnVect[0,32*i+j]=int(lineStr[j])    return returnVect
  • Detect_Test()
#手写体识别函数&错误率检测函数def Detect_Test():    hwLabels=[]    trainingFileList=listdir('E:/PythonApplication/kNN/trainingDigits')     #获取目录内容    m=len(trainingFileList)                                                 #获取文件的个数    trainingMat=zeros((m,1024))               #创建一个矩阵,m行1024列,用来存储转化后的数字向量    for i in range(m):        fileNameStr=trainingFileList[i]       #从文件名解析分类数字        fileStr=fileNameStr.split('.')[0]       #用[0]操作符保证操作的是数字矩阵中的每一列而不是每一行        classNumStr=int(fileStr.split('_')[0])  #同上,因为文件名的形式是0_4.txt,所以为了得到分类标签,可以只取‘_’前面的数字        hwLabels.append(classNumStr)        trainingMat[i,:]=img2vector('E:/PythonApplication/kNN/trainingDigits/%s' % fileNameStr)    testFileList=listdir('E:/PythonApplication/kNN/testDigits')    errCount=0          #计算识别错误率    mTest=len(testFileList)    for i in range(mTest):        fileNameStr=testFileList[i]        fileStr=fileNameStr.split('.')[0]        classNumStr=int(fileStr.split('_')[0])        vectorOfTest=img2vector('E:/PythonApplication/kNN/testDigits/%s' % fileNameStr)        result=classify0(vectorOfTest,trainingMat,hwLabels,3)        print 'the classfiler came back with %d the real is:%d' % (result,classNumStr),'\t',i        if  (result!=classNumStr):            errCount += 1.0    print '错误个数:',errCount    print '错误率:',errCount/float(mTest)

三、补充

  • 本次实验使用的IDE为PyCharm社区版;
  • 实验数据中原始图片是以二进制格式存储的,分为训练集和测试集,整体结构图如下:
    数据文件预览结构

其中单独的一张图片存储内容如下:

数字0的存储形式

  • 数据集下载:数据集+源码下载。
  • 写到最后:moulei007@gmail.com
阅读全文
0 0
原创粉丝点击