图片分类-最近邻分类(1)

来源:互联网 发布:软件测试周期算法举例 编辑:程序博客网 时间:2024/05/06 11:31

作为我们学习的第一个方法,我们会开发一个叫最近邻分类器的东西。这个分类器与卷积神经网络(CNN)没有什么关系并且在实际中很少用到。但是他会让我们对图片分类问题的基础算法有一个初步的了解。
如图,分类数据集为:CIFAR-10。一个非常有趣的图片分类数据集是CIFAR-10数据集。这个数据集包含60000个32*32像素的小图片。每个图片都被标记成十种类别中的一种(例如“airplane, automobile, bird, etc”)。这60000个图片被分为训练集和测试集。训练集有50000个图片,测试集有10000个图片。在如下的图片中我们可以看到10个从十种类型中选取的随机图片样本:
这里写图片描述
左图:从CIFAR-10数据集中获取的样本图片。右图:第一行展示了一些测试图片并且他们最相近邻的类型由于像素不同。
假如现在我们有50000个CIFAR-10数据集的训练集(每个种类5000张图片),并且我们希望编辑剩余的10000张图片。最近邻算法就会拿一个测试图片,然后让他与训练集中的每一张图片对比。并且预测和他最相近的训练集。在如上的图片中右图你可以看到一个样本经过这样处理后的结果。请注意,在10个示例中,只有3个示例中检索了同一个类的图像,而在其他7个示例中,情况并非如此。例如,在第八列中最近邻马头的图片是一个红色的车,大概是因为黑色背景吧。在这种情况下,这个马的图片将被归在汽车类中。
你可能注意到了我们留了很多不明确的细节例如我们怎么比较这两个有32*32*3的图片数据。最简单的比较方式就是一个像素一个像素的叠加并且添加所有不同。其他方面,有两个图片并且表现他们的向量I1,I2,一个很合理的选择就是比较他们的L1距离

这里总和可以包含所有元素。以下是过程的可视化:
这里写图片描述
一个两个图片像素比较转化为L1距离的比较(对于每一个颜色来说)。这两个图片用减法运算并且把所有的不同都相加到一个数上。如果两个图片是相等的那么他们的结果会是0。但是如果不同他们的值会非常大。
让我们来看看我们怎么实现分类器的代码。首先,我们把CIFAR-10数据以4个数组到内存中:训练集/训练标记和测试集/测试标记。在以下的代码中,Xtr(大小:50000*32*32*3)把所有的图片都保存着并且会有一个一维的数组Ytr分配(长度50000)保存训练集(从0到9):

import pickleimport  NearestNeighbor as NNimport numpy as npdef unpickle(file):    with open(file, 'rb') as fo:        dict = pickle.load(fo, encoding='bytes')        #print (dict)        X = dict[b'data']        Y = dict[b'labels']        Y = np.array(Y)    return X,Ydef load_CIFAR10(dir):    Xtr = []    Ytr = []    for i in range(1, 6):        x, y = unpickle("%sdata_batch_%d" % (dir,i))        Xtr.extend(x)        Ytr.extend(y.tolist())    Xtr = np.array(Xtr)    Ytr = np.array(Ytr)    Xte, Yte = unpickle("%stest_batch"%(dir))    return Xtr, Ytr, Xte, Yteif __name__ == '__main__':    Xtr, Ytr, Xte, Yte = load_CIFAR10('data/')    # flatten out all images to be one-dimensional    Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3)  # Xtr_rows becomes 50000 x 3072    Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3)  # Xte_rows becomes 10000 x 3072

现在我们把全部图片伸展成列,这里是我们怎么训练和评估一个分类器:

nn = NearestNeighbor() # create a Nearest Neighbor classifier classnn.train(Xtr_rows, Ytr) # train the classifier on the training images and labelsYte_predict = nn.predict(Xte_rows) # predict labels on the test images# and now print the classification accuracy, which is the average number# of examples that are correctly predicted (i.e. label matches)print 'accuracy: %f' % ( np.mean(Yte_predict == Yte) )

注意到作为一个评价标准,用精确度(通过记录错误的预测所占比例)是很常见的。注意所有的飞雷奇我们都会创建成满足这个通用API:他们有一个函数train(X,y)函数来学习数据和标记。本质上来说,这个类应该建立一些种类的标记模型并且提供他们如何从数据中预测。当然,我们遗忘了一些事情,例如真实的分类器。这里有一个肩带的基于L1距离的最近邻算法分类器的代码:

import numpy as npclass NearestNeighbor(object):  def __init__(self):    pass  def train(self, X, y):    """ X is N x D where each row is an example. Y is 1-dimension of size N """    # the nearest neighbor classifier simply remembers all the training data    self.Xtr = X    self.ytr = y  def predict(self, X):    """ X is N x D where each row is an example we wish to predict label for """    num_test = X.shape[0]    # lets make sure that the output type matches the input type    Ypred = np.zeros(num_test, dtype = self.ytr.dtype)    # loop over all test rows    for i in xrange(num_test):      # find the nearest training image to the i'th test image      # using the L1 distance (sum of absolute value differences)      distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)      min_index = np.argmin(distances) # get the index with smallest distance      Ypred[i] = self.ytr[min_index] # predict the label of the nearest example    return Ypred

如果你可以跑起来这个代码,你会发现这个分类器在CIFAR-10上只有38.6%的准确度。这比随机选取好的到了(随机选取只有10%的准确度由于只有十类),但是离人类还差得远(评估的有94%)或者通过状态集艺术卷积神经网络与95%的准确率(最近kaggle的leaderboard)
距离选择.
这里有很多其他的计算两个向量之间的距离函数。一个比较公认的就是L2距离。它在集合论中解释维欧氏距离。距离函数为

换一句话说,目前我们可以计算像素之间的差异,但是这次是以平方的方式,把他们加起来再求他们的平方根。再numpy中,用代码可以表示如下:

distances = np.sqrt(np.sum(np.square(self.Xtr - X[i,:]), axis = 1))

注意:这里我们用了np.sqrt函数,但是再实际中最近邻函数我们会省略平方根操作,因为它是一个单调函数。这样的话,它只是计算了绝对距离并且保留了排序,所以有没有用平方根都一样。如果你用这个距离再数据集上跑可以获得35.4%的准确率(略低于L1距离)。
L1 vs L2.
考虑两种方法的对比还是很有趣。特别的,L2距离相比L1距离来说再计算两个向量距离方面更加的斤斤计较。因此,L2距离再中等规模的分期上表现的比大规模上更好。L1和L2距离(L1/L2规则再比较图片上有相同的效应)是最常用的p-norm特殊方式。

原创粉丝点击