kNN在CIFAR10上的应用

来源:互联网 发布:centos 6.8 docker 编辑:程序博客网 时间:2024/06/05 06:13

1. 获取CIFAR10

CIFAR10是一个10分类的图片数据集,主页在这里,作者使用python版本的数据集。

2. 加载数据集

在主页上已有加载数据集的代码,数据集分成了5个训练用的batch和1个test batch,每个batch有10000张32x32x3的图片,还有一个batches.meta文件装着label对应的名字。


不妨贴出我的代码:

def load_data(root, batch):    ''' @brief: There are 5 batches and a test-batch        in ../datasets/cifar-10. 每个batch打开有key:['data,        labels, batch_label, filenames']        @param batch: batch-n/test-batch    '''    batch_path = os.path.join(root, batch)    with open(batch_path, 'rb') as f:        dataset = pickle.load(f)    return datasetdef load_label_names(root):    ''' @brief: 装载batches.meta,包含了label_names '''    meta_path = os.path.join(root, 'batches.meta')    with open(meta_path, 'rb') as f:        meta = pickle.load(f)    return meta['label_names']


在dataset这个dict里最有用的是data和labels两个key,分别对应10000x3072的图像数据和10000个标签。


3. kNN

kNN的思想是对需要确定类别的数据,在已知类别的数据集上找到与它距离最近的k个数据,根据这k个数据各自属于的类别对新数据的类别进行投票,少数服从多数。就像这样(百度百科贴过来的):


写成代码就像这样:

class kNN:    ''' 实现kNN分类器 '''    def __init__(self):        self.Xtr = None        self.Ytr = None    def __init__(self, X, Y):        self.Xtr = X        self.Ytr = Y    def train(self, X, Y):        self.Xtr = X        self.Ytr = Y    def predict(self, x, k=1):        distances = np.sum((self.Xtr - x)**2, axis=1)        k_labels = [self.Ytr[x] for x in np.argsort(distances)][:k]        u, counts = np.unique(k_labels, return_counts=True)        return u[np.argmax(counts)]


嘛,其实主要的东西都在predict里,这样写只是个套路。

kNN能够设置的参数就是两个,距离度量和k值,作者写的距离是欧几里得距离,就是相减平方加和,也可以用其他距离试试。

k值可以通过实验确定,作者先在一个batch上玩玩,将batch分为训练集、验证集和测试集,分割比例是7:2:1,先通过验证集确定一个较好的k值。

代码:

    acc_vs_k = []    knn = kNN(train_data, train_labels)    k_list = range(1,11)    acc_list = []    for k in k_list:        correct_num = 0        now = time.time()        for i in xrange(val_size):            pred_label = knn.predict(val_data[i], k)            true_label = val_labels[i]            if pred_label == true_label:                correct_num += 1        acc = correct_num * 1.0 / val_size        acc_list.append(acc)

把k值和accuracy对应的图显示出来


在1~10这个范围内最优的k值是5