cs231n作业一之 在cifar-10上实现KNN

来源:互联网 发布:服务器端 语言 python 编辑:程序博客网 时间:2024/06/04 23:30
在网上也看了好多代码,但是都运行不了,要不是代码问题,要不就是路径问题,本代码已经成功运行过,主要有两个问题需要注意
1.python2和python3的版本不兼容,有一些语法不同,需要改
2.关于导入cifar-10数据库,我是用的在cs231n熵下载的数据库,我直接吧数据库托到和你写代码存储的根目录的位置就可以了
import pickle as pimport matplotlib.pyplot as pltimport numpy as np# NearestNeighbor classclass NearestNeighbor(object):    def __init__(self):        pass    def train(self, X, y):        """ X is N x D where each row is an example. Y is 1-dimension of size N """        # the nearest neighbor classifier simply remembers all the training data        self.Xtr = X        self.ytr = y    def predict(self, X):        """ X is N x D where each row is an example we wish to predict label for """        num_test = X.shape[0]        # lets make sure that the output type matches the input type        Ypred = np.zeros(num_test, dtype=self.ytr.dtype)        # loop over all test rows        for i in range(num_test):            # find the nearest training image to the i'th test image            # using the L1 distance (sum of absolute value differences)            distances = np.sum(np.abs(self.Xtr - X[i, :]), axis=1)            min_index = np.argmin(distances)  # get the index with smallest distance            Ypred[i] = self.ytr[min_index]  # predict the label of the nearest example        return Ypreddef load_CIFAR_batch(filename):    """ load single batch of cifar """    with open(filename, 'rb')as f:        datadict = p.load(f, encoding='latin1')        X = datadict['data']        Y = datadict['labels']        Y = np.array(Y)  # 字典里载入的Y是list类型,把它变成array类型        return X, Ydef load_CIFAR_Labels(filename):    with open(filename, 'rb') as f:        label_names = p.load(f, encoding='latin1')        names = label_names['label_names']        return names# load datalabel_names = load_CIFAR_Labels("cifar-10-batches-py/batches.meta")imgX1, imgY1 = load_CIFAR_batch("cifar-10-batches-py/data_batch_1")imgX2, imgY2 = load_CIFAR_batch("cifar-10-batches-py/data_batch_2")imgX3, imgY3 = load_CIFAR_batch("cifar-10-batches-py/data_batch_3")imgX4, imgY4 = load_CIFAR_batch("cifar-10-batches-py/data_batch_4")imgX5, imgY5 = load_CIFAR_batch("cifar-10-batches-py/data_batch_5")Xte_rows, Yte = load_CIFAR_batch("cifar-10-batches-py/test_batch")Xtr_rows = np.concatenate((imgX1, imgX2, imgX3, imgX4, imgX5))Ytr_rows = np.concatenate((imgY1, imgY2, imgY3, imgY4, imgY5))nn = NearestNeighbor()  # create a Nearest Neighbor classifier classnn.train(Xtr_rows[:1000,:], Ytr_rows[:1000])  # train the classifier on the training images and labelsYte_predict = nn.predict(Xte_rows[:100,:])  # predict labels on the test images# and now print the classification accuracy, which is the average number# of examples that are correctly predicted (i.e. label matches)print('accuracy: %f' % (np.mean(Yte_predict == Yte[:100])))# show a pictureimage=imgX1[6,0:1024].reshape(32,32)print(image.shape)plt.imshow(image,cmap=plt.cm.gray)plt.axis('off')    #去除图片边上的坐标轴plt.show()image=imgX2[6,0:1024].reshape(32,32)print(image.shape)plt.imshow(image,cmap=plt.cm.gray)plt.axis('off')    #去除图片边上的坐标轴plt.show()image=imgX3[6,0:1024].reshape(32,32)print(image.shape)plt.imshow(image,cmap=plt.cm.gray)plt.axis('off')    #去除图片边上的坐标轴plt.show()image=imgX4[6,0:1024].reshape(32,32)print(image.shape)plt.imshow(image,cmap=plt.cm.gray)plt.axis('off')    #去除图片边上的坐标轴plt.show()image=imgX5[6,0:1024].reshape(32,32)print(image.shape)plt.imshow(image,cmap=plt.cm.gray)plt.axis('off')    #去除图片边上的坐标轴plt.show()image=imgX6[6,0:1024].reshape(32,32)print(image.shape)plt.imshow(image,cmap=plt.cm.gray)plt.axis('off')    #去除图片边上的坐标轴plt.show()
运行结果:可以得到knn这个分类器的准确率是24%,还可以得到每一类最后的最相近的图片

原创粉丝点击