kNN算法识别手写数字(代码笔记)

来源:互联网 发布:wav音频分割软件 编辑:程序博客网 时间:2024/05/22 15:44

k-近邻算法,属于有监督分类算法。

思想:利用输入数据特征值和训练样本数据特征值之间的距离分类,挑出距离最小的k个训练样本的类别频率,作为预测的分类估计。

'''k-近邻算法是基于实例的学习1 使用时要保存全部的数据集,占存储空间2 要对每个训练数据计算距离值,实际使用时非常耗时'''import numpy as npimport operatordef classify0(x, dataSet, labels, k):    dataSetSize = dataSet.shape[0]    diffMat = np.tile(x, (dataSetSize,1)) - dataSet    sqDiff = diffMat**2    sqDist = sqDiff.sum(axis=1)    distances = sqDist**0.5  # 一行数据的平方根    sortedDistInd = distances.argsort()  # 向量元素从小到大对应的索引号    classCount = {}    for i in range(k):  # 前k个,也就是最近的k个; 统计类出现的频率        vLabel = labels[sortedDistInd[i]]          classCount[vLabel] = classCount.get(vLabel,0)+1    sortedClassCount = sorted(classCount.items(), # 转成dict_items:[(key1,cnt1),(key2,cnt2),..]                       key=operator.itemgetter(1), # 排序,依据tuple第二个元素;reverse,由大到小                       reverse=True)    return sortedClassCount[0][0]    def img2vec(filename):  # 32x32的矩阵数据转成向量    vec = np.zeros((1,1024))    fr = open(filename)  # (如果是txt文件的话)    for i in range(32):        lineStr = fr.readline()        for j in range(32):            vec[0, 32*i+j] = int(lineStr[j])    return vecdef handwritingClassify():    trainLab = []    trainFileList = listdir('trainingDigits')  # 训练数据目录    m = len(trainFileList)    trainMat = zeros((m,1024))  # 训练数据存成一个矩阵    for i in range(m):        filenameStr = trainFileList[i]        fileStr = filenameStr.split('.')[0]        classStr = int(fileStr.split('_')[0])        trainLab.append(classStr)        trainMat[i,:] = img2vec('trainingDigits/%s' % filenameStr)    #------------------------ 测试数据 -------------------------    errorCount = 0.0    testFileList = listdir('testDigits')  # 测试数据目录    n = len(testFileList)    for i in range(n):        filenameStr = testFileList[i]        fileStr = filenameStr.split('.')[0]        classStr = int(fileStr.split('_')[0])        vecTest = img2vec('testDigits/%s' % filenameStr)        classTest = classify0(vecTest, trainMat, trainLab, 3)  # 测试数据的直接分类        print("the classifier predicts : %d, the real is : %d" % (classTest,classStr))        if(classTest!=classStr):            error += 1.0    print("\n the total numbers of errors is: %d" % errorCount)    print("\n the total error rate is: %d" % (error/float(n)))


原创粉丝点击