Python实现KNN算法手写识别数字

来源:互联网 发布:python 画流程图 编辑:程序博客网 时间:2024/04/27 12:40

本文实现用KNN算法实现手写识别数字功能。
语言:Python
训练材料:手写数字素材32*32像素

from numpy import *import osfrom os import listdirimport operator#将文件32*32转成1*1024def img2vector(filename):    vect=zeros((1,1024))    f=open(filename)    for i in range(32):        line=f.readline()        for j in range(32):            vect[0,32*i+j]=int(line[j])    return vectdef dict2list(dic:dict):    #''' 将字典转化为列表 '''    keys = dic.keys()    vals = dic.values()    lst = [(key, val) for key, val in zip(keys, vals)]#zip是一个可迭代对象    return lst#inputvector:输入的用于测试的向量#trainDataSet:训练的样本集#labels:标签#k:k邻近的个数def knntest(inputvector,trainDataSet,labels,k):    datasetsize=trainDataSet.shape[0]    #tile(a,[2,3]) ([a a a],[a,a,a])用第一个参数来构造    #这里用输入向量来构造一个1024行 1列的矩阵,刚好和训练矩阵同样大小    diffmat=tile(inputvector,(datasetsize,1))-trainDataSet    #求平方和    #每个元素都平方    sqdiffmat=diffmat**2    #按行求和    sqdistance=sqdiffmat.sum(axis=1)    #平方根,得到的是一个一维的矩阵    distance=sqdistance**0.5    #按照从低到高排序    #argsort函数排列后得到的是按下标进行排列的矩阵,    #在原先distance中的下标按距离最近排列 argsort函数返回的是数组值从小到大的索引值    sortdistance=distance.argsort()    classcout={}#用来存储key(标签)value(标签出现的次数,选取次数最大的前几个数,找到其标签)    #依次取出最近的样本数据    for i in range(k):        #记样本的类别        votelabel=labels[sortdistance[i]]        #统计每个标签的次数        classcout[votelabel]=classcout.get(votelabel,0)+1#获取votelabel键对应的值,无返回默认    #print("*************")    #print(classcout)    #classcout.iteritems()在Python3中取消了,key=lambda x:x[0](按第0个元素排序)字典排序,按照value来排序,返回键    sortclasscount=sorted(dict2list(classcout),key=operator.itemgetter(1),reverse=True)    #返回出现频次最高的类别    return sortclasscount[0][0]#手写识别def handwritingClassTest():    print(os.getcwd())    #将训练数据存储到一个矩阵中1024维,并存储对应的标签    handlabel=[]    trainName=listdir(r'digits\trainingDigits')    trainNum=len(trainName)    trainNumpy = zeros((trainNum,1024))    #print("trainNum=%d"%trainNum)    #对文件名进行分析,训练文本对应的标签    for i in range(trainNum):        filename=trainName[i]#文件名        filestr=filename.split('.')[0]#不带后缀的文件名        filelabel=int(filestr.split('_')[0])#文件的标签        #将标签添加至handlabel中        handlabel.append(filelabel)        trainNumpy[i,:]=img2vector(r'digits\trainingDigits\%s'%filename)#转成1024    #print(handlabel[:20])    testfilelist=listdir(r'digits\testDigits')    errornum=0    testnum=len(testfilelist)    errfile=[]    #将每一个测试样本放入训练集中使用KNN进行测试    for i in range(testnum):        testfilename=testfilelist[i]        testfilestr=testfilename.split('.')[0]        testfilelabel=int(testfilestr.split('_')[0])#实际的数字标签        #将测试样本1024        testvector=img2vector(r'digits\testDigits\%s'%testfilename)        #进行测试        #print("-----------")        result=knntest(testvector,trainNumpy,handlabel,3)        print("test value is %d, real value is %d"%(result,testfilelabel))        if(result!=testfilelabel):            errornum+=1            errfile.append(testfilename)    print("the num of error is %d"%errornum)    print("the right rate of test is %f "%(1-errornum/float(testnum)))    print("the error of file are ")    count=0    for i in range(len(errfile)):        if(count==9):            print()        print(errfile[i]+' ',end="")        count+=1def main():    #path=os.getcwd()    handwritingClassTest()if __name__=='__main__':    main();

转载自k-近邻算法实现手写数字识别系统
并自身进行了测试。

0 0
原创粉丝点击