CIFAR-10驱动的KNN分类器

来源:互联网 发布:淘宝卖特产没证可以吗 编辑:程序博客网 时间:2024/06/05 22:42

先读取CIFAR-10的数据集,CIFAR的数据字典包含了50000张图片,每张图片是32x32的的三通道彩色图像,所以CIFAR-10的训练集是有50000个32x32x3=3072的向量组成。 (50000,3072)的矩阵构成了训练图片,训练集中有包含了50000个label。测试集是10000张图片,10000个label。训练集分为5个batch,在读取数据时,将5个batch数据读入到一个50000X3072的训练矩阵中,将对应的标签读入到1X10000的数组中。

这里的近邻算法实际上就是将所有的训练数据都保存下来,然后在预测时让所有的训练数据和测试的数据求L1距离(曼哈顿距离,绝对值之和),将差异最小的标签记录下来作为预测到图片的标签。
算法效率低下,预测时间太长。


import numpy as npimport pickle'''输入训练集及测试集'''file_path = "E:/cifar-10-python/cifar-10-batches-py/"'''拆包数据集'''import numpy as npclass NearestNeighbor(object):  def __init__(self):    pass  def train(self, X, y):    """ X is N x D where each row is an example. Y is 1-dimension of size N """    # the nearest neighbor classifier simply remembers all the training data    self.Xtr = X    self.ytr = y  def predict(self, X):    """ X is N x D where each row is an example we wish to predict label for """    num_test = X.shape[0]    # lets make sure that the output type matches the input type    Ypred = np.zeros(num_test, dtype = self.ytr.dtype)    # loop over all test rows    for i in range(num_test):      # find the nearest training image to the i'th test image      # using the L1 distance (sum of absolute value differences)      distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)      min_index = np.argmin(distances) # get the index with smallest distance      Ypred[i] = self.ytr[min_index] # predict the label of the nearest example    return Ypreddef unpickle(file):    import pickle    with open(file, 'rb') as fo:        dict = pickle.load(fo, encoding='latin1')    return dict'''加载数据集'''def load_CIFAR10(file):    # dictTrain1 = unpickle(file + "data_batch_1")    # dataTrain1 = dictTrain1['data']    # labelTrain1 = dictTrain1['labels']    #    # dictTrain2 = unpickle(file + "data_batch_2")    # dataTrain2 = dictTrain2['data']    # labelTrain2 = dictTrain2['labels']    #    # dictTrain3 = unpickle(file + "data_batch_3")    # dataTrain3 = dictTrain3['data']    # labelTrain3 = dictTrain3['labels']    #    # dictTrain4 = unpickle(file + "data_batch_4")    # dataTrain4 = dictTrain4['data']    # labelTrain4 = dictTrain4['labels']    #    # dictTrain5 = unpickle(file + "data_batch_5")    # dataTrain5 = dictTrain5['data']    # labelTrain5 = dictTrain5['labels']    # dataTrain = np.vstack([dataTrain1, dataTrain2, dataTrain3, dataTrain4, dataTrain5])    # labelTrain = np.concatenate([labelTrain1, labelTrain2, labelTrain3, labelTrain4, labelTrain5])    dictTrain = unpickle(file + "data_batch_1")    dataTrain = dictTrain['data']    labelTrain = dictTrain['labels']    for i in range(2,6):        dictTrain = unpickle(file+"data_batch_"+str(i))        dataTrain = np.vstack([dataTrain, dictTrain['data']])        labelTrain = np.hstack([labelTrain, dictTrain['labels']])    dictTest = unpickle(file + "test_batch")    dataTest = dictTest['data']    labelTest = dictTest['labels']    labelTest = np.array(labelTest)    return dataTrain, labelTrain, dataTest, labelTestdataTrain, labelTrain, dataTest, labelTest = load_CIFAR10(file_path)print(dataTrain.shape)print(type(labelTrain))print(dataTest.shape)print(len(labelTest))nn = NearestNeighbor() # create a Nearest Neighbor classifier classnn.train(dataTrain[:50000, :], labelTrain[:50000]) # train the classifier on the training images and labelslabelTest_Predict = nn.predict(dataTest[:10000, :]) # predict labels on the test images# and now print the classification accuracy, which is the average number# of examples that are correctly predicted (i.e. label matches)print ('accuracy: %f' % ( np.mean(labelTest_Predict == labelTest[:10000]) ))

程序运行结果图

准确率为24.92%相当低,预测时间及其长