cs231n作业一之 在cifar-10上实现KNN
来源:互联网 发布:服务器端 语言 python 编辑:程序博客网 时间:2024/06/04 23:30
在网上也看了好多代码,但是都运行不了,要不是代码问题,要不就是路径问题,本代码已经成功运行过,主要有两个问题需要注意
1.python2和python3的版本不兼容,有一些语法不同,需要改
2.关于导入cifar-10数据库,我是用的在cs231n熵下载的数据库,我直接吧数据库托到和你写代码存储的根目录的位置就可以了
import pickle as pimport matplotlib.pyplot as pltimport numpy as np# NearestNeighbor classclass NearestNeighbor(object): def __init__(self): pass def train(self, X, y): """ X is N x D where each row is an example. Y is 1-dimension of size N """ # the nearest neighbor classifier simply remembers all the training data self.Xtr = X self.ytr = y def predict(self, X): """ X is N x D where each row is an example we wish to predict label for """ num_test = X.shape[0] # lets make sure that the output type matches the input type Ypred = np.zeros(num_test, dtype=self.ytr.dtype) # loop over all test rows for i in range(num_test): # find the nearest training image to the i'th test image # using the L1 distance (sum of absolute value differences) distances = np.sum(np.abs(self.Xtr - X[i, :]), axis=1) min_index = np.argmin(distances) # get the index with smallest distance Ypred[i] = self.ytr[min_index] # predict the label of the nearest example return Ypreddef load_CIFAR_batch(filename): """ load single batch of cifar """ with open(filename, 'rb')as f: datadict = p.load(f, encoding='latin1') X = datadict['data'] Y = datadict['labels'] Y = np.array(Y) # 字典里载入的Y是list类型,把它变成array类型 return X, Ydef load_CIFAR_Labels(filename): with open(filename, 'rb') as f: label_names = p.load(f, encoding='latin1') names = label_names['label_names'] return names# load datalabel_names = load_CIFAR_Labels("cifar-10-batches-py/batches.meta")imgX1, imgY1 = load_CIFAR_batch("cifar-10-batches-py/data_batch_1")imgX2, imgY2 = load_CIFAR_batch("cifar-10-batches-py/data_batch_2")imgX3, imgY3 = load_CIFAR_batch("cifar-10-batches-py/data_batch_3")imgX4, imgY4 = load_CIFAR_batch("cifar-10-batches-py/data_batch_4")imgX5, imgY5 = load_CIFAR_batch("cifar-10-batches-py/data_batch_5")Xte_rows, Yte = load_CIFAR_batch("cifar-10-batches-py/test_batch")Xtr_rows = np.concatenate((imgX1, imgX2, imgX3, imgX4, imgX5))Ytr_rows = np.concatenate((imgY1, imgY2, imgY3, imgY4, imgY5))nn = NearestNeighbor() # create a Nearest Neighbor classifier classnn.train(Xtr_rows[:1000,:], Ytr_rows[:1000]) # train the classifier on the training images and labelsYte_predict = nn.predict(Xte_rows[:100,:]) # predict labels on the test images# and now print the classification accuracy, which is the average number# of examples that are correctly predicted (i.e. label matches)print('accuracy: %f' % (np.mean(Yte_predict == Yte[:100])))# show a pictureimage=imgX1[6,0:1024].reshape(32,32)print(image.shape)plt.imshow(image,cmap=plt.cm.gray)plt.axis('off') #去除图片边上的坐标轴plt.show()image=imgX2[6,0:1024].reshape(32,32)print(image.shape)plt.imshow(image,cmap=plt.cm.gray)plt.axis('off') #去除图片边上的坐标轴plt.show()image=imgX3[6,0:1024].reshape(32,32)print(image.shape)plt.imshow(image,cmap=plt.cm.gray)plt.axis('off') #去除图片边上的坐标轴plt.show()image=imgX4[6,0:1024].reshape(32,32)print(image.shape)plt.imshow(image,cmap=plt.cm.gray)plt.axis('off') #去除图片边上的坐标轴plt.show()image=imgX5[6,0:1024].reshape(32,32)print(image.shape)plt.imshow(image,cmap=plt.cm.gray)plt.axis('off') #去除图片边上的坐标轴plt.show()image=imgX6[6,0:1024].reshape(32,32)print(image.shape)plt.imshow(image,cmap=plt.cm.gray)plt.axis('off') #去除图片边上的坐标轴plt.show()
运行结果:可以得到knn这个分类器的准确率是24%,还可以得到每一类最后的最相近的图片
阅读全文
0 0
- cs231n作业一之 在cifar-10上实现KNN
- [CS231n-assignment2] Python从零实现的CNN在CIFAR-10上的实验报告
- cs231n作业一之实现SVM
- cs231n作业一之实现softmax
- cs231n作业1--KNN
- CS231n-assignment1(作业1)-knn
- caffe示例实现之1在CIFAR-10数据集上训练与测试Caffe
- cs231n课程作业assignment1(KNN)
- cs231n的第一次作业knn的问题
- CIFAR-10 在Caffe上训练学习
- CIFAR-10在caffe上进行训练
- CS231n作业笔记1.1: KNN中的距离矩阵vectorize的实现方法(无循环)
- 薛开宇学习笔记一之总结笔记(CIFAR-10 在 在 caffe 上进行训练与学习)--Linux语法总结
- TensorFlow应用之进阶版卷积神经网络CNN在CIFAR-10数据集上分类
- CIFAR-10驱动的KNN分类器
- cs231n knn
- CS231n+assignment1(作业一)
- [CS231n@Stanford] Assignment1-Q1 (python) KNN实现
- 文章标题
- Android 上传library到jcenter
- 第八届福建大学生程序设计竞赛-L Tic-Tac-Toe
- 太多switch case ,if else if
- cpu cache 学习记录
- cs231n作业一之 在cifar-10上实现KNN
- ThinkPHP 3.1.2 视图
- 使用ajax异步请求数据,并展示在html中
- 怎样退出终止App
- HDU 1576-A/B(扩展欧几里得算法)
- 详解js闭包
- gradle学习笔记(下)
- JSP--整合SiteMesh02
- 魔法工会