机器学习实战kNN之手写识别

来源:互联网 发布:西游记原著版本知乎 编辑:程序博客网 时间:2024/05/18 12:32

机器学习实战kNN之手写识别

转自 http://blog.csdn.net/wyb_009/article/details/9165371


kNN算法算是机器学习入门级绝佳的素材。书上是这样诠释的:“存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都有标签,即我们知道样本集中每一条数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征比较,算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前K个最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数。最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类”。

优点:精度高、对异常值不敏感、无数据输入假定。

缺点:计算复杂度高、空间复杂度高。

适用数据范围:数值型或标称型。

算法的python实现:

[python] view plaincopyprint?
  1. def kNN(data, dataSet, dataLabel, k=3, similarity=sim_distance):    
  2.     scores = [(sim_distance(data, dataSet[i]), dataLabel[i]) for i in range(len(dataSet))]  
  3.     sortedScore = sorted(scores, key=lambda d: d[0], reverse=False)   
  4.     scores = sortedScore[0:k]  
  5.       
  6.     classCount = {}   
  7.     for score in scores:  
  8.         classCount[score[1]] = classCount.get(score[1], 0) + 1  
  9.       
  10.     sortedClassCount = sorted(classCount.items(), key=lambda d: d[1], reverse=True)  
  11.     return sortedClassCount[0][0]  
  12.           

下面分为几步骤来学习这个算法:

(1)准备数据

(2)测试算法

先介绍一个这个手写识别系统,简单起见,该系统只能识别数字0---9,需要识别的数字已经使用图形处理软件,处理成具有相同色彩和大小:32*32像素的黑白照片。目录trainingDigits中包含了大约2000个训练样本,目录testDigits中大约有900个测试样本。

第一步,准备数据:将图片数据转换成测试向量。这一步就是把我们32*32的二进制图像矩阵转换成1*1024的向量。

[python] view plaincopyprint?
  1. def img2vector(filename):  
  2.     vec = []  
  3.     file = open(filename)  
  4.     for i in range(32):  
  5.         line = file.readline()  
  6.         for j in range(32):  
  7.             vec.append(int(line[j]))  
  8.     return vec  

第二步,测试算法准确率,我们用trainingDigits目录下的样本做训练,来测试testDigits目录下的样本,来计算准确率。

[python] view plaincopyprint?
  1. def test():  
  2.     trainData, trainLabel = [], []  
  3.     trainFileList = os.listdir('digits/trainingDigits/')  
  4.     for filename in trainFileList:  
  5.         trainData.append(img2vector('digits/trainingDigits/%s' % filename))  
  6.         trainLabel.append(int(filename.split('_')[0]))  
  7.           
  8.     succCnt, failCnt = 00  
  9.     testFileList = os.listdir('digits/testDigits')  
  10.     for filename in testFileList:  
  11.         data = img2vector('digits/testDigits/%s' % filename)  
  12.         num = kNN(data, trainData, trainLabel)  
  13.         if num == int(filename.split('_')[0]):  
  14.             succCnt += 1  
  15.             print 'succ'  
  16.         else:  
  17.             failCnt += 1  
  18.             print 'fail'  
  19.               
  20.     print "error rate is : %f " % (failCnt/float(failCnt+succCnt))  

我这里测试,K取默认值3,错误率是0.013742,


不会上传文件,所以把代码贴在下面,测试数据在http://download.csdn.net/detail/wyb_009/5649337第二章下面

[python] view plaincopyprint?
  1. import os, math  
  2. def sim_distance(a, b):  
  3.     sum_of_squares = sum([pow(a[i]-b[i], 2for i in range(len(a))])    
  4.     return sum_of_squares   
  5.   
  6. def kNN(data, dataSet, dataLabel, k=3, similarity=sim_distance):    
  7.     scores = [(sim_distance(data, dataSet[i]), dataLabel[i]) for i in range(len(dataSet))]  
  8.     sortedScore = sorted(scores, key=lambda d: d[0], reverse=False)   
  9.     scores = sortedScore[0:k]  
  10.       
  11.     classCount = {}   
  12.     for score in scores:  
  13.         classCount[score[1]] = classCount.get(score[1], 0) + 1  
  14.       
  15.     sortedClassCount = sorted(classCount.items(), key=lambda d: d[1], reverse=True)  
  16.     return sortedClassCount[0][0]  
  17.           
  18. def img2vector(filename):  
  19.     vec = []  
  20.     file = open(filename)  
  21.     for i in range(32):  
  22.         line = file.readline()  
  23.         for j in range(32):  
  24.             vec.append(int(line[j]))  
  25.     return vec  
  26.           
  27. def test():  
  28.     trainData, trainLabel = [], []  
  29.     trainFileList = os.listdir('digits/trainingDigits/')  
  30.     for filename in trainFileList:  
  31.         trainData.append(img2vector('digits/trainingDigits/%s' % filename))  
  32.         trainLabel.append(int(filename.split('_')[0]))  
  33.     print "load train data ok"  
  34.       
  35.     succCnt, failCnt = 00  
  36.     testFileList = os.listdir('digits/testDigits')  
  37.     for filename in testFileList:  
  38.         data = img2vector('digits/testDigits/%s' % filename)  
  39.         num = kNN(data, trainData, trainLabel)  
  40.         if num == int(filename.split('_')[0]):  
  41.             succCnt += 1  
  42.             print 'succ'  
  43.         else:  
  44.             failCnt += 1  
  45.             print 'fail: kNN get %ld, real is %ls' %(num, int(filename.split('_')[0]))  
  46.               
  47.     print "error rate is : %f " % (failCnt/float(failCnt+succCnt))  
  48.       
  49. if __name__ == "__main__":  
  50.     test()