CIFAR-10驱动的KNN分类器
来源:互联网 发布:淘宝卖特产没证可以吗 编辑:程序博客网 时间:2024/06/05 22:42
先读取CIFAR-10的数据集,CIFAR的数据字典包含了50000张图片,每张图片是32x32的的三通道彩色图像,所以CIFAR-10的训练集是有50000个32x32x3=3072的向量组成。 (50000,3072)的矩阵构成了训练图片,训练集中有包含了50000个label。测试集是10000张图片,10000个label。训练集分为5个batch,在读取数据时,将5个batch数据读入到一个50000X3072的训练矩阵中,将对应的标签读入到1X10000的数组中。
这里的近邻算法实际上就是将所有的训练数据都保存下来,然后在预测时让所有的训练数据和测试的数据求L1距离(曼哈顿距离,绝对值之和),将差异最小的标签记录下来作为预测到图片的标签。
算法效率低下,预测时间太长。
import numpy as npimport pickle'''输入训练集及测试集'''file_path = "E:/cifar-10-python/cifar-10-batches-py/"'''拆包数据集'''import numpy as npclass NearestNeighbor(object): def __init__(self): pass def train(self, X, y): """ X is N x D where each row is an example. Y is 1-dimension of size N """ # the nearest neighbor classifier simply remembers all the training data self.Xtr = X self.ytr = y def predict(self, X): """ X is N x D where each row is an example we wish to predict label for """ num_test = X.shape[0] # lets make sure that the output type matches the input type Ypred = np.zeros(num_test, dtype = self.ytr.dtype) # loop over all test rows for i in range(num_test): # find the nearest training image to the i'th test image # using the L1 distance (sum of absolute value differences) distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1) min_index = np.argmin(distances) # get the index with smallest distance Ypred[i] = self.ytr[min_index] # predict the label of the nearest example return Ypreddef unpickle(file): import pickle with open(file, 'rb') as fo: dict = pickle.load(fo, encoding='latin1') return dict'''加载数据集'''def load_CIFAR10(file): # dictTrain1 = unpickle(file + "data_batch_1") # dataTrain1 = dictTrain1['data'] # labelTrain1 = dictTrain1['labels'] # # dictTrain2 = unpickle(file + "data_batch_2") # dataTrain2 = dictTrain2['data'] # labelTrain2 = dictTrain2['labels'] # # dictTrain3 = unpickle(file + "data_batch_3") # dataTrain3 = dictTrain3['data'] # labelTrain3 = dictTrain3['labels'] # # dictTrain4 = unpickle(file + "data_batch_4") # dataTrain4 = dictTrain4['data'] # labelTrain4 = dictTrain4['labels'] # # dictTrain5 = unpickle(file + "data_batch_5") # dataTrain5 = dictTrain5['data'] # labelTrain5 = dictTrain5['labels'] # dataTrain = np.vstack([dataTrain1, dataTrain2, dataTrain3, dataTrain4, dataTrain5]) # labelTrain = np.concatenate([labelTrain1, labelTrain2, labelTrain3, labelTrain4, labelTrain5]) dictTrain = unpickle(file + "data_batch_1") dataTrain = dictTrain['data'] labelTrain = dictTrain['labels'] for i in range(2,6): dictTrain = unpickle(file+"data_batch_"+str(i)) dataTrain = np.vstack([dataTrain, dictTrain['data']]) labelTrain = np.hstack([labelTrain, dictTrain['labels']]) dictTest = unpickle(file + "test_batch") dataTest = dictTest['data'] labelTest = dictTest['labels'] labelTest = np.array(labelTest) return dataTrain, labelTrain, dataTest, labelTestdataTrain, labelTrain, dataTest, labelTest = load_CIFAR10(file_path)print(dataTrain.shape)print(type(labelTrain))print(dataTest.shape)print(len(labelTest))nn = NearestNeighbor() # create a Nearest Neighbor classifier classnn.train(dataTrain[:50000, :], labelTrain[:50000]) # train the classifier on the training images and labelslabelTest_Predict = nn.predict(dataTest[:10000, :]) # predict labels on the test images# and now print the classification accuracy, which is the average number# of examples that are correctly predicted (i.e. label matches)print ('accuracy: %f' % ( np.mean(labelTest_Predict == labelTest[:10000]) ))
准确率为24.92%相当低,预测时间及其长
阅读全文
0 0
- CIFAR-10驱动的KNN分类器
- 从CIFAR-10手工分类中学到的经验教训Lessons learned from manually classifying CIFAR-10
- cifar-10 分类 tensorflow 代码
- Python实现的KNN分类器
- cifar分类可能遇到的错误更正
- KNN分类器
- KNN分类器
- KNN分类器
- KNN分类器
- 图像分类器(KNN)
- KNN 分类器原理
- KNN分类器
- cs231n作业一之 在cifar-10上实现KNN
- KNN cifar-10 L1 L2距离 交叉验证
- KNN分类和估计分类器的精度
- NearestNeighnor 实现cifar-10图像分类
- KNN分类算法的实现
- KNN-分类算法的实现
- 通过RSRP和SINR判断LTE信号质量
- 训练营第五天实训
- 阿里巴巴Java开发手册学习小结7-注释规约
- web前后台乱码总结
- Effective Java之避免创建不必要的对象(五)
- CIFAR-10驱动的KNN分类器
- libux创建逻辑卷及扩展
- Jsoup通过URL获取文档,获取href属性内容
- caffe移植到mxnet
- Windows套接字I/O模型(3) -- WSAAsyncSelect模型
- 利用JS实现复选框一键全选/全不选
- 网络流(Dinic && ISAP)
- 谷歌浏览器js调试初体验
- classpath、path、JAVA_HOME的作用及JAVA环境变量配置