cs231n assignment(1.1):kNN分类器
来源:互联网 发布:淘宝所有分类怎么去掉 编辑:程序博客网 时间:2024/06/07 05:15
cs231n assignment(1.1):kNN分类器
kNN分类器练习的宗旨在于:了解基本图像分类流程,交叉验证,熟悉书写高效的向量代码
understand the basic Image Classification pipeline, cross-validation, and gain proficiency in writing efficient, vectorized code.
1. 数据准备
数据来源于cifar10图片数据库,该数据库有10种不同的图片类,分别是’plane’, ‘car’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck.每张图片大小为32*32*3
以下是数据集的基本信息,均为numpy.ndarray
Training data shape: (50000, 32, 32, 3)Training labels shape: (50000,)Test data shape: (10000, 32, 32, 3)Test labels shape: (10000,)
部分展示如下:
为了简化流程,训练集取了5000幅图,测试集取了500幅图,并且把训练集和测试集的维数都拉成了一维于是
X_Train.shape:(5000,3072) y_train.shape:(5000,) X_test.shape:(500,3072) y_test.shape(500,)
2. kNN算法的实现
2.1 kNN算法的第一步在于计算两张图的L2距离,也就是欧式距离
其公式如下:
其中
assignment中分别实现了用两层循环,一层循环,和不用循环的向量化运算,这里只写不用循环的向量化计算
L2的代码:
def compute_distances_no_loops(self, X): """ Compute the distance between each test point in X and each training point in self.X_train using no explicit loops. Input / Output: Same as compute_distances_two_loops """ num_test = X.shape[0] num_train = self.X_train.shape[0] dists = np.zeros((num_test, num_train)) M1=np.sum(X**2,axis=1,keepdims = True) M2=np.sum(self.X_train ** 2, axis=1) M3=np.multiply(np.dot(X,self.X_train.T),-2) M=M3+M1+M2 dists=np.sqrt(M) return dists
理解这段代码的方法在于将L2公式展开并且用矩阵的想法带入,公式可类比展开如下:
理解矩阵的维数和整体操作就不难写出上面的代码了。
2.2 预测函数
def predict_labels(self, dists, k=1): """ Given a matrix of distances between test points and training points, predict a label for each test point. Inputs: - dists: A numpy array of shape (num_test, num_train) where dists[i, j] gives the distance betwen the ith test point and the jth training point. Returns: - y: A numpy array of shape (num_test,) containing predicted labels for the test data, where y[i] is the predicted label for the test point X[i]. """ num_test = dists.shape[0] num_train = dists.shape[1] y_pred = np.zeros(num_test) for i in xrange(num_test): closest_y = [] idx = np.argsort(dists[i,:]) closest_y = self.y_train[idx[:k]] count = 0 label = 0 for j in closest_y: tmp = 0 for kk in closest_y: tmp += (kk == j) if tmp > count: count = tmp label = j y_pred[i] = label return y_pred
3.交叉验证
交叉验证的原因在于不能使用测试集来做验证,因为测试集是非常宝贵的,它反映了模型整体的适配能力,所以不应该使用。然而,又需要调试超参数(hyperparameters),所以我们使用训练集来调试超参数。
对不同的K值,操作是一样的:把训练集合分成五份,然后选其中一份为验证数据(ValidationData),其余四份为训练数据(TrainingData),循环五次,得到不同的分类准确度(accuracy)。也就是说,每个k值对应的五个准确度。
num_folds = 5k_choices = [1, 3, 5, 8, 10, 12, 15, 20, 50, 100]X_train_folds = []y_train_folds = []X_train_folds = np.array_split(X_train,num_folds)y_train_folds = np.array_split(y_train,num_folds)k_to_accuracies = {}for k in k_choices: k_to_accuracies[k] = np.zeros(num_folds) for i in xrange(num_folds): Training_Data = np.array(X_train_folds[:i]+X_train_folds[i+1:]) Training_Label = np.array(y_train_folds[:i]+y_train_folds[i+1:]) Validation_Data = X_train_folds[i] Validation_Label = y_train_folds[i] Training_Data = np.reshape(Training_Data,(4000,3072)) Training_Label = np.reshape(Training_Label,4000) classifier.train(Training_Data, Training_Label) yte_Pre = classifier.predict(Validation_Data,k) num_correct = np.sum(yte_Pre == Validation_Label) accuracy = (float)(num_correct)/len(Validation_Label) k_to_accuracies[k][i] = accuracy
得到一张图表
大约在
Got 147 / 500 correct => accuracy: 0.294000
- cs231n assignment(1.1):kNN分类器
- CS231n--assignment 1--KNN
- cs231n assignment(1.2) svm分类器
- cs231n assignment(1.3):softmax分类器
- 20161106#cs231n#1.最近邻分类器 Assignment1-KNN
- cs231n-(1)图像分类和kNN
- [学习笔记]cs231n 图像分类与KNN
- CS231n学习笔记1-图像分类,KNN
- KNN最近邻分类算法 + cs231n assignment1
- cs231n knn
- cs231n-线性分类器
- CS231n课程笔记2.1:图像分类问题简介&KNN
- CS231n 笔记一(lecture 2)(KNN、线性分类)
- cs231n Assignment#1 (1)k-Nearest Neighbor (kNN) exercise 代码理解笔记
- cs231n assignment(1.4):two_layer_net
- cs231n:SVM线性分类器
- KNN分类器
- KNN分类器
- ROS学习第三弹(Services/Parameters rqt_console/rqt_logger_level/roslaunch)
- 精彩的UI设计界面,不看就真的亏了
- 关于C的二级指针的问题
- C语言中的static 详细分析
- can't find file to patch at input line 3 错误原因
- cs231n assignment(1.1):kNN分类器
- 用java语言写一个环形队列
- 计算几何基础之“玻璃球”
- poi导入导出excel
- Android 框架-ImageLoader 图片加载框架
- iOS ping++支付功能实现
- myeclips或eclips开发中,Tomcat遇到8080端口被占用
- Kotlin控制台输入2
- 面试题13:O(1)时间内删除链表节点