kNN分类原理以及python实现手写数字分类

来源:互联网 发布:php final 编辑:程序博客网 时间:2024/05/22 14:04

kNN算法分类原理

K最近邻(kNN,k-Nearest Neighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。

kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 kNN方法在类别决策时,只与极少量的相邻样本有关。由于kNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,kNN方法较其他方法更为适合。
这里写图片描述
比如上面这个图,我们有两类数据,分别是蓝色方块和红色三角形,他们分布在一个上图的二维中间中。那么假如我们有一个绿色圆圈这个数据,需要判断这个数据是属于蓝色方块这一类,还是与红色三角形同类。怎么做呢?我们先把离这个绿色圆圈最近的几个点找到,因为我们觉得离绿色圆圈最近的才对它的类别有判断的帮助。那到底要用多少个来判断呢?这个个数就是k了。如果k=3,就表示我们选择离绿色圆圈最近的3个点来判断,由于红色三角形所占比例为2/3,所以我们认为绿色圆是和红色三角形同类。如果k=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。从这里可以看到,k的值还是很重要的。
其算法描述如下:

1)计算已知类别数据集中的点与当前点之间的距离;

2)按照距离递增次序排序;

3)选取与当前点距离最小的k个点;

4)确定前k个点所在类别的出现频率;

5)返回前k个点出现频率最高的类别作为当前点的预测分类。

Python实现
这里我们首先实现一个简单一点的数据集,这个数据集很简单,是由二维空间上的4个点构成的一个矩阵,如表2.1所示。
这里写图片描述
其中前两个点构成一个类别A,后两个点构成一个类别B。我们用Python把这4个点在坐标系中绘制出来。代码如下:

# -*- coding: utf-8 -*-###实现在坐标系中把数据点展现出来from numpy import *import numpy as npimport matplotlib.pyplot as plt###产生训练数据集,共4个点def createDataSet():    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])    labels = ['A','A','B','B'] #这4个点所对应的类别,即贴上标签    return group,labels##testdata = [0.2,0.2]  #测试数据集,一个点###绘制图形dataSet,labels = createDataSet()fig = plt.figure()ax = fig.add_subplot(111)indx = 0for point in dataSet:    if labels[indx] == 'A':        ax.scatter(point[0],point[1],c = 'blue',marker = 'o',linewidths=0,s=300)        plt.annotate("("+str(point[0])+","+str(point[1])+")",xy = (point[0],point[1]))    else:        ax.scatter(point[0],point[1],c = 'red',marker = '^',linewidths=0,s=300)        plt.annotate("("+str(point[0])+","+str(point[1])+")",xy = (point[0],point[1]))    indx +=1###将测试数据集(需要分类的那个点)也绘制出来#ax.scatter(testdata[0],testdata[1],c = 'green',marker = '^',linewidths=0,s=300)#plt.annotate("("+str(testdata[0])+","+str(testdata[1])+")",xy = (testdata[0],testdata[1]))plt.show()

绘制的图形如下图所示:
这里写图片描述
从图中可以明显的看见由4个点构成的训练集。A类—蓝色圆圈,B类—红色三角形。

下面给出测试集,只有一个点。我们把它加入刚才的矩阵中去,如表2.2所示:
这里写图片描述

接下来也将这个点加入到之前的图中,代码上边已经给出,只需把14,28,29行签的注释去掉就行了。这样画出的图如下所示:
这里写图片描述
新加入的点为绿色三角形,从距离上来看,它应该属于红色的那一类,即B类。

接下来以欧氏距离度量,任意给定一个点以判定它属于哪一类,完整代码如下:
我们新建一个kNN.py文件:

# -*- coding: utf-8 -*-from numpy import *import operator#创建测试数据集以及所属类别def createDataSet():    group = array([[1.0,0.9],[1.0,1.0],[0.1,0.2],[0.0,0.1]])    labels = ['A','A','B','B']    return group,labels#分类函数def kNNClassify(newInput,dataSet,labels,k):    #计算给定的数据与训练数据集中所有数据的欧式距离    numSamples = dataSet.shape[0] #得到训练数据集的行数,即计算有多少个训练样本,这里为4    diff = tile(newInput,(numSamples,1)) - dataSet #将测试数据在行方向扩展并计算其与训练集之差    squaredDiff = diff ** 2                        #平方    squaredDist = sum(squaredDiff,axis = 1)        #求和    distance = squaredDist ** 0.5                 #开方    sortedDistIndices = argsort(distance) #以升序对得到的距离进行排序    #对于给定的k,计算距离最近的前k个训练样本属于哪个类    classCount = {}    for i in range(k):        voteLabel = labels[sortedDistIndices[i]]        classCount[voteLabel] = classCount.get(voteLabel,0) + 1    #计算这k个训练样本所属类别最大的那一类并返回结果    maxCount = 0    for key,value in classCount.items():        if value > maxCount:            maxCount = value            maxIndex = key    return maxIndex

再建一个测试testkNN.py文件

# -*- coding: utf-8 -*-import kNN0from numpy import *dataSet,labels = kNN0.createDataSet()testX = array([0.2,0.2])k=3outputLabel = kNN0.kNNClassify(testX,dataSet,labels,3)print("你的输入是:",testX,"被分类为:",outputLabel,"类")

运行testkNN.py文件,得到输出为:

你的输入是: [0.2 0.2] 被分类为:B

有了以上的知识最为基础,接下来我们用kNN来分类一个大点的数据库,包括数据维度比较大和样本数比较多的数据库。这里我们用到一个手写数字的数据库,可以到这里下载。这个数据库包括数字0-9的手写体。每个数字大约有200个样本。每个样本保持在一个txt文件中。手写体图像本身的大小是32x32的二值图,转换到txt文件保存后,内容也是32x32个数字,0或者1,如下:
这里写图片描述
这里我们还是新建一个kNN.py脚本文件,文件里面包含四个函数,一个用来生成将每个样本的txt文件转换为对应的一个向量,一个用来加载整个数据库,一个实现kNN分类算法。最后就是实现这个加载,测试的函数。

# -*- coding: utf-8 -*-from numpy import *  import operator  import os  # kNN分类  def kNNClassify(newInput, dataSet, labels, k):      numSamples = dataSet.shape[0]     ## step 1: 计算欧式距离      diff = tile(newInput, (numSamples, 1)) - dataSet     squaredDiff = diff ** 2     squaredDist = sum(squaredDiff, axis = 1)     distance = squaredDist ** 0.5      ## step 2: 对得到的欧式距离进行排序      sortedDistIndices = argsort(distance)      classCount = {} # 创建一个字典储存分类结果      for i in range(k):          ## step 3: 选择距离最小的前k个          voteLabel = labels[sortedDistIndices[i]]          ## step 4: 计算标签出现的次数,即类别        classCount[voteLabel] = classCount.get(voteLabel, 0) + 1      ## step 5: 返回最大的类别,即返回结果      maxCount = 0      for key, value in classCount.items():          if value > maxCount:              maxCount = value              maxIndex = key      return maxIndex   # 将图像转换为一维矩阵  def  img2vector(filename):      rows = 32      cols = 32      imgVector = zeros((1, rows * cols))       fileIn = open(filename)      for row in range(rows):          lineStr = fileIn.readline()          for col in range(cols):              imgVector[0, row * 32 + col] = int(lineStr[col])      return imgVector  # 加载训练数据集  def loadDataSet():      ## step 1: 获取训练数据集      print ("获取训练数据集...")      dataSetDir = './digits/'      trainingFileList = os.listdir(dataSetDir + 'trainingDigits')     numSamples = len(trainingFileList)      train_x = zeros((numSamples, 1024))      train_y = []      for i in range(numSamples):        filename = trainingFileList[i]        # train_x 为所有的训练数据集的样本构成的矩阵,这里为1934*1024,即1934个数据,每个数据1024个点        train_x[i, :] = img2vector(dataSetDir + 'trainingDigits/%s' % filename)           # 根据样本名得到标签,如样本"1_18.txt"          label = int(filename.split('_')[0]) # 则这里返回 1,说明这个样本属于标签(类别)1        train_y.append(label)      ## step 2: 获取测试数据集      print ("获取测试数据集..." )     testingFileList = os.listdir(dataSetDir + 'testDigits')     numSamples = len(testingFileList)      test_x = zeros((numSamples, 1024))      test_y = []      for i in range(numSamples):        filename = testingFileList[i]          # 这里test_x与上边的train_x一样的,这里是946*1024,共946个测试数据        test_x[i, :] = img2vector(dataSetDir + 'testDigits/%s' % filename)           # 根据样本名得到其所属标签,如1_18.txt        label = int(filename.split('_')[0]) # 得到 1          test_y.append(label)     return train_x, train_y, test_x, test_y  # 测试手写数字类  def testHandWritingClass():      ## step 1: 加载数据      print ("step 1: 加载数据..." )     train_x, train_y, test_x, test_y = loadDataSet()      ## step 2: 训练...      print ("step 2: 训练..." )     pass      ## step 3: 测试...      print ("step 3: 测试..."  )    numTestSamples = test_x.shape[0]      matchCount = 0      for i in range(numTestSamples):          predict = kNNClassify(test_x[i], train_x, train_y, 3)          if predict == test_y[i]:              matchCount += 1      accuracy = float(matchCount) / numTestSamples      ## step 4: 显示结果      print ("step 4: 显示结果..." )     print ('分类准确率为: %.2f%%' % (accuracy * 100)  )

然后在命令行中输入以下代码测试

import kNN  kNN.testHandWritingClass()  

输出结果如下:

step 1: 加载数据...获取训练数据集...获取测试数据集...step 2: 训练...step 3: 测试...step 4: 显示结果...分类准确率为: 98.94%

本文代码来自郑捷著的机器学习算法原理与编程实践以及机器学习算法与Python实践之(一)k近邻(KNN) 并做了相应的修改。
工具:anaconda(Python3.6),IDE: spyder
参考:百度百科knn

原创粉丝点击