基于k近邻(KNN)的手写数字识别
来源:互联网 发布:淘宝订单关闭的原因 编辑:程序博客网 时间:2024/05/06 18:56
作者:faaronzheng 转载请注明出处!
最近再看Machine Learning in Action. k近邻算法这一章节提供了不少例子,本着Talk is cheap的原则,我们用手写数字识别来实际测试一下。 简单的介绍一下k近邻算法(KNN):给定测试样本,基于某种距离度量找出训练集中与其最靠近的k个训练样本,然后基于这k个“邻居”的信息来进行预测。如下图所示:
x为测试样本,小黑点是一类样本,小红点是另一类样本。在测试样本x的周围画一个圈,这个圈就是依据某种距离度量画出的,可以看到我们选择的是5近邻。现在我们要做出一个预测,就是这个测试样本x是属于小黑点那一类还是小红点那一类呢?很简单,我们只要看看选中的近邻中哪一类样本多就把这类样本的标签赋给测试样本就可以了。图中自然就是小黑点,所以我们预测x是小黑点。
正文:
第一步:准备实验数据。Machine Learning in Action书中的数据使用的是“手写数字数据集的光学识别”一文中的数据。具体可以参考书中的相关介绍。所有的数据是以Txt形式保存的,由32行32列的0/1元素组成。下图就是一个手写数字0的保存数据。可以看出,数字所在的位置用1表示,空白的用0表示。
除此之外,为了能识别自己手写的数字,我们在原来实验的基础上添加画板的功能,使其能采集自己手写的数字并按照相同的格式保存下来。如下图所示,当点击CustomizeTestData后会出现一个画板,当我们在画板上写上数字后,按下ESC键保存图片并退出,接下来将保存的图片处理成我们想要的格式,就可以用算法对其进行预测了。画板的实现使用了pygame。
下面是画板功能的具体实现:
import pygamefrom pygame.locals import *import mathfrom sys import exit#向sys模块借一个exit函数用来退出程序pygame.init()#初始化pygame,为使用硬件做准备 class Brush(): def __init__(self, screen): self.screen = screen self.color = (0, 0, 0) self.size = 4 self.drawing = False self.last_pos = None self.space = 1 # if style is True, normal solid brush # if style is False, png brush self.style = False # load brush style png self.brush = pygame.image.load("brush.png").convert_alpha() # set the current brush depends on size self.brush_now = self.brush.subsurface((0,0), (1, 1)) def start_draw(self, pos): self.drawing = True self.last_pos = pos def end_draw(self): self.drawing = False def set_brush_style(self, style): print "* set brush style to", style self.style = style def get_brush_style(self): return self.style def get_current_brush(self): return self.brush_now def set_size(self, size): if size < 0.5: size = 0.5 elif size > 32: size = 32 print "* set brush size to", size self.size = size self.brush_now = self.brush.subsurface((0,0), (size*2, size*2)) def get_size(self): return self.size def set_color(self, color): self.color = color for i in xrange(self.brush.get_width()): for j in xrange(self.brush.get_height()): self.brush.set_at((i, j), color + (self.brush.get_at((i, j)).a,)) def get_color(self): return self.color def draw(self, pos): if self.drawing: for p in self._get_points(pos): # draw eveypoint between them if self.style == False: pygame.draw.circle(self.screen, self.color, p, self.size) else: self.screen.blit(self.brush_now, p) self.last_pos = pos def _get_points(self, pos): """ Get all points between last_point ~ now_point. """ points = [ (self.last_pos[0], self.last_pos[1]) ] len_x = pos[0] - self.last_pos[0] len_y = pos[1] - self.last_pos[1] length = math.sqrt(len_x ** 2 + len_y ** 2) step_x = len_x / length step_y = len_y / length for i in xrange(int(length)): points.append( (points[-1][0] + step_x, points[-1][1] + step_y)) points = map(lambda x:(int(0.5+x[0]), int(0.5+x[1])), points) # return light-weight, uniq integer point list return list(set(points)) class Menu(): def __init__(self, screen): self.screen = screen self.brush = None def set_brush(self, brush): self.brush = brush class Painter(): def __init__(self): self.screen = pygame.display.set_mode((100, 100)) # self.menu = pygame.display.set_mode((80, 600)) pygame.display.set_caption("Painter") self.clock = pygame.time.Clock() self.brush = Brush(self.screen) self.menu = Menu(self.screen) self.menu.set_brush(self.brush) def run(self): self.screen.fill((255, 255, 255)) while True: # max fps limit self.clock.tick(30) for event in pygame.event.get(): if event.type == QUIT: pygame.quit() # break elif event.type == KEYDOWN: # press esc to clear screen if event.key == K_ESCAPE: fname = "test.png" pygame.image.save(self.screen, fname) pygame.quit() #break elif event.type == MOUSEBUTTONDOWN: # <= 74, coarse judge here can save much time if ((event.pos)[0] <= 74 and self.menu.click_button(event.pos)): # if not click on a functional button, do drawing pass else: self.brush.start_draw(event.pos) elif event.type == MOUSEMOTION: self.brush.draw(event.pos) elif event.type == MOUSEBUTTONUP: self.brush.end_draw() self.menu.draw() pygame.display.update()
KNN算法--KNN的关键在我看来是距离度量的选择。不同的距离度量会对最终的结果产生比较大的影响。首先将手写数字变化为一个一维的向量,通过计算测试样例(向量)和每个训练样本(向量)之间的距离然后进行排序。最后选最近的k个进行投票产生对测试样例的预测。
import pygamefrom numpy import *import operatorfrom os import listdirfrom Board import *import Tkinterimport tkFileDialogimport tkMessageBoximport Image from KNN import dotpygame.init()def classify0(inX, dataSet, labels, k): #k控制选取最近的k个近邻然后投票 dataSetSize = dataSet.shape[0] #计算欧式距离(其实比较的是两个向量之间的距离) diffMat = tile(inX, (dataSetSize,1)) - dataSet sqDiffMat = diffMat**2 sqDistances = sqDiffMat.sum(axis=1) distances = sqDistances**0.5 sortedDistIndicies = distances.argsort() classCount={} #投票 for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0]def classify1(inX,dataSet,labels, k): dataSetSize = dataSet.shape[0] diffMat = tile(inX, (dataSetSize,1)) - dataSet diffMatT=(diffMat.T) sqDiffMat = dot(diffMat,diffMat.T) distances = sqrt(sqDiffMat) sortedDistIndicies=distances.argsort() classCount={} #投票 for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0]# 将文件转化为向量def img2vector(filename): returnVect = zeros((1,1024)) fr = open(filename) for i in range(32): lineStr = fr.readline() for j in range(32): returnVect[0,32*i+j] = int(lineStr[j]) return returnVectdef handwritingClassTest(TrainDataPath): hwLabels = [] trainingFileList = listdir(TrainDataPath) #load the training set m = len(trainingFileList) trainingMat = zeros((m,1024)) for i in range(m): fileNameStr = trainingFileList[i] fileStr = fileNameStr.split('.')[0] #take off .txt classNumStr = int(fileStr.split('_')[0]) hwLabels.append(classNumStr) trainingMat[i,:] = img2vector(TrainDataPath+'/%s' % fileNameStr) testFileList = listdir('C:/Users/HP/Desktop/MLiA_SourceCode/machinelearninginaction/Ch02/testDigits') #iterate through the test set errorCount = 0.0 mTest = len(testFileList) for i in range(mTest): fileNameStr = testFileList[i] fileStr = fileNameStr.split('.')[0] #take off .txt classNumStr = int(fileStr.split('_')[0]) vectorUnderTest = img2vector('C:/Users/HP/Desktop/MLiA_SourceCode/machinelearninginaction/Ch02/testDigits/%s' % fileNameStr) classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr) if (classifierResult != classNumStr): errorCount += 1.0 print "\nthe total number of errors is: %d" % errorCount print "\nthe total error rate is: %f" % (errorCount/float(mTest))top = Tkinter.Tk()def TrainDataCallBack(): TrainDataPath=tkFileDialog.askdirectory() handwritingClassTest(TrainDataPath)def CustomizeTestDataCallBack(): board = Painter() board.run() def TestingCustomizeTestDataCallBack(): ResizePic() TransformArray()TrainDataButton = Tkinter.Button(top, text ="TrainData", command = TrainDataCallBack)CustomizeTestDataButton = Tkinter.Button(top, text ="CustomizeTestData", command = CustomizeTestDataCallBack)TestingButton = Tkinter.Button(top, text ="TestingCustomizeTestData", command = TestingCustomizeTestDataCallBack)def ResizePic(): im = Image.open("test.png") w,h = im.size im_ss = im.resize((int(32), int(32))) im_ss.save("test.png") def TransformArray(): TestArray = zeros((1,1024)) im = Image.open("test.png") width,height = im.size for h in range(0, height): for w in range(0, width): pixel = im.getpixel((w, h)) if pixel!=(255,255,255): TestArray[0,32*h+w]=int(1) handwritingTesting(TestArray)def handwritingTesting(TestArray): # TrainDataPath=tkFileDialog.askdirectory() TrainDataPath="C:/Users/HP/Desktop/MLiA_SourceCode/machinelearninginaction/Ch02/trainingDigits" hwLabels = [] trainingFileList = listdir(TrainDataPath) #load the training set m = len(trainingFileList) trainingMat = zeros((m,1024)) for i in range(m): fileNameStr = trainingFileList[i] fileStr = fileNameStr.split('.')[0] #take off .txt classNumStr = int(fileStr.split('_')[0]) hwLabels.append(classNumStr) trainingMat[i,:] = img2vector(TrainDataPath+'/%s' % fileNameStr) classifierResult = classify0(TestArray, trainingMat, hwLabels, 100) classifierResult1 = classify1(TestArray, trainingMat, hwLabels, 100) print "the classifier came back with: %d" % classifierResult print "the classifier came back with: %d" % classifierResult1 TrainDataButton.pack()CustomizeTestDataButton.pack()TestingButton.pack()top.mainloop()
源代码下载:faaron-KNN手写字识别
未完待续。。。
- 基于k近邻(KNN)的手写数字识别
- 基于K-近邻算法识别手写数字的实现
- 基于K-近邻算法的手写数字识别研究
- K近邻分类器(KNN)手写数字(MNIST)识别
- k近邻 - 手写数字识别
- 基于KNN的手写数字识别
- 基于python的手写数字识别(KNN算法)
- OpenCV手写数字字符识别(基于k近邻算法)
- OpenCV手写数字字符识别(基于k近邻算法)
- OpenCV手写数字字符识别(基于k近邻算法)
- OpenCV手写数字字符识别(基于k近邻算法)
- 机器学习实战k近邻算法(kNN)应用之手写数字识别代码解读
- 基于K近邻法的手写数字图像识别
- opencv 基于KNN的手写数字字符识别
- 基于SVM和KNN的手写数字的识别(分类)——小试牛刀篇
- KNN手写数字识别
- 手写识别系统(k-近邻算法)
- 银行卡号识别(三) --- 基于k最近邻的数字识别测
- 数据库性能优化
- maven 项目出现 java.lang.ClassNotFoundException: org.springframework.web.context.ContextLoaderListener
- 关于二进制,八进制,十进制,十六进制的转变
- 在Mac OS环境安装Composer
- 2016.3.16__CSS3渐变_倒影_过渡_2D变形_3D变形__第十天
- 基于k近邻(KNN)的手写数字识别
- [Chromium中文文档]跨平台开发的约定与模式
- 人的惯性思维
- Problem C: 判断字符串是否为回文
- URAL1017Staircases DP
- 指针常量
- 【LeetCode】88. Merge Sorted Array
- word 2013 参考文献插入及交叉引用的实现方法(转自百度经验)
- maven install时报错Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:2.3.2:compile