【十四】机器学习之路——K-近邻算法实战

来源:互联网 发布:汕头新天音响淘宝店 编辑:程序博客网 时间:2024/06/05 20:20

使用k-近邻算法识别手写数字
  一个星期没有更新博客了,最近在看K-近邻算法和决策树,学习《机器学习实战》K-近邻算法里的实战问题代码时遇到了些问题,经过几天的硬啃,终于完成了代码。话不多说,下面一起看一下如何用K-近邻算法实现识别手写数字。[例子与代码摘自《机器学习实战》]

  简单起见,该算法只能识别0~9的数字。这里识别算法首先,咱们将数字的图像使用图形处理软件,处理成相同大小:宽高均为32像素的黑白图像。图像存储为文本格式。如下图所示:


  现在咱们手头有训练集TrainingDigits大约2000个例子,0~9中每个数字大约200个左右的例子,测试集TestDigits大约有900个左右的例子,0~9每个数字大约100个;[数据取自《机器学习实战 第2章 k-近邻算法》],每个数据命名格式如下图所示:(后面代码里会根据这个命名格式来读取相应的数字)

  OK,数据准备好了,现在可以大干一场了,我们怎么利用K近邻算法来实现数字识别呢?还记得K-近邻算法的思路吗?如果忘记的同学可以参考上一篇博客机器学习之路——k-近邻算法(KNN)。在这个数字识别问题里同样,我们的处理思路如下:

1. 首先计算测试集里的数据与训练集里数据的距离差。
2. 计算好测试点与训练集里样本点的距离后,将结果从小到大进行排序;
3. 选取距离最近的k个点,确定这k个点数据所在分类的出现频率;
4. 选择频率最高的分类作为预测数据的分类输出;
其实以上四个步骤咱们上一篇博客里已经定义了函数实现了:

def classify0(inX, dataSet, labels, k):

  咱们在这个实战问题里其实重点就是怎么把这些文本里的数据变成classify0()函数可以处理的数据。

  之前classify0()函数里计算距离的思路是将测试数据一个1*2的矩阵数组(x,y)利用tile扩展成m*2的数组(这里m为训练集数据总数),然后同训练集进行相减求平方再开根号得到距离,同样这里我们首先先将每个测试集样本数据先转化为一个1*n的矩阵数组(n代表特征值个数)。先看这段代码如何实现:

#定义数据处理的函数,将训练集转化为1*1024的矩阵def img2vector(filename):    returnVect = zeros((1,1024))#先定义一个空矩阵数组,下面里用for循环将测试样本读入该数组中    fr = open(filename)    for i in range (32):#因为每个数字样本在文本里是32*32的矩阵数组,一行一行来赋值,所以需要两个for循环嵌套        lineStr = fr.readline()#readline()依次读取每一行,readlines()是输出文件共多少行,注意区分        for j in range(32):#读取第一行后,将第一行里的32个数据赋值给returnVect前32个数,依次类推,最终将32*32=1024个数据全部赋值完毕            returnVect[0,32*i+j] = int(lineStr[j])    return returnVect#返回最终的数据

  完成了测试集数据处理后,现在就要处理训练集数据了,思路相同,将所有训练集数据放入一个m*1024的矩阵数组里,m为训练集样本个数,然后对应的将每个样本的分类Labels放入一个1*m的矩阵数组里,【如果这点不太懂的话可以参考我上一篇KNN算法介绍的博客,链接在上面】完成训练集的处理后,就可以利用classify0()函数对测试集数据进行分类了。好了,一起看下这段代码怎么实现。代码有点长,我会一句一句的注释,让大家更容易理解,涉及到相关的python内置函数后面有对应的链接供学习。

#定义手写数字识别函数,并计算其错误率def handwritingClassTest():    hwLabels = []#将训练集的数据对应的Label即数字用一个list容器存储,classify0()函数输入要用到    trainingFileList = listdir('这里填训练集所在文件夹地址') #输出文件夹里所有训练集的文件名称与后缀,用于读取训练集里每个文件对应的数字即Label    m = len(trainingFileList)#计算下训练集共有多少组数据    trainingMat = zeros((m,1024))#构造一个m*1024的矩阵数组存储训练集里的所有数据,每一行是一组数据    for i in range(m):#通过for循环将m组数据对应的label赋值到hwLabels中去        fileNameStr = trainingFileList[i]        fileStr = fileNameStr.split('.')[0]        classNumStr = int(fileStr.split('_')[0])        hwLabels.append(classNumStr)        trainingMat[i,:] = img2vector('这里填训练集所在文件夹地址%s' % fileNameStr)#将训练集里数据赋值到trainingMat    testFileList = listdir('这里填测试集所在文件夹地址')            errorCount = 0.0#识别函数识别结果是错误的个数    mTest = len(testFileList)#计算测试集数据个数    #以下代码将mTest个测试集里的数据依次进行识别,并输出识别的结果    for i in range(mTest):        fileNameStr = testFileList[i]        fileStr = fileNameStr.split('.')[0]        classNumStr = int(fileStr.split('_')[0])#读取测试集数据对应的真实数字是多少即测试集数据的label        vectorUnderTest = img2vector('这里填测试集所在文件夹地址%s' % fileNameStr)#将测试集里数据进行img2vector转换        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)#利用分类函数来进行手写数据识别,并输出识别的结果,和对应的实际结果        print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)        if (classifierResult != classNumStr): errorCount += 1.0    print "\nthe total number of errors is: %d" % errorCount    print "\nthe total error rate is: %f" % (errorCount/float(mTest))#计算错误率并输出

最终代码运行的结果如下所示,由于用了两个for循环嵌套,导致运行的时间有点久。测试集共946组数据,识别错误11个,错误率1.1628%。


  以上介绍的就是利用k-近邻算法实现的手写数字识别,代码里涉及到的一些具体函数如下,如有不懂的同学可以点进链接进行学习。

read()、readline()、readlines()函数区别
listdir()函数用法
split()函数用法

  好了,今天就讲到这里,欢迎大家多多交流,如果这篇博客对你有帮助,请动动手指帮我点个赞,谢谢!