kd树识别压缩有的mnist数据集

来源:互联网 发布:红叶知弦小说 编辑:程序博客网 时间:2024/06/06 09:26

在《一般knn算法识别mnist数据集(代码)》 中用一般的knn方法做了mnist识别。和神经网络方法比起来,knn慢很多,识别1000张图片需要234s。kd树更高效率实现knn的方法,它用二叉树来存储训练集中的样本,搜索k个近邻点时速度更快。具体算法如下:
这里写图片描述
这里写图片描述

对mnist,每张图片是28*28的,且灰度值大于3的就算是手写的痕迹了,因此,在一般knn算法识别mnist数据集(代码)中,通过二值化(one hot)灰度值的方式,效果还不错。但是,由于kd树通过各维的中值来切分区域,如果每一维的数值只有0和1,根本没办法切分。因此,将二值化后的矩阵分别按行求和,按列求和,得到新的56维的特征。相当于做了一个降维,同时,增大各维的取值范围。但是,这样处理后,图片原本良好的平移不变形变差了。识别率估计会降低。无妨,这里重要要的是效率的对比。
代码如下,其中,kd树的建立和搜索仿照了http://www.hankcs.com/ml/k-nearest-neighbor-method.html 中的代码。
代码如下:

#coding:utf-8import numpy as npimport gzipimport itertoolsfrom datetime import datetimedef _read32(bytestream):    dt = np.dtype(np.uint32).newbyteorder('>')    return np.frombuffer(bytestream.read(4), dtype=dt)[0]def extract_images(input_file, is_value_binary, is_matrix):    with gzip.open(input_file, 'rb') as zipf:        magic = _read32(zipf)        if magic !=2051:            raise ValueError('Invalid magic number %d in MNIST image file: %s' %(magic, input_file.name))        num_images = _read32(zipf)        rows = _read32(zipf)        cols = _read32(zipf)        print magic, num_images, rows, cols        buf = zipf.read(rows * cols * num_images)        data = np.frombuffer(buf, dtype=np.uint8)        #reshape成二维        data = data.reshape(num_images, rows, cols)        #二值化        data_value_binary = np.minimum(data, 1)        #按行相加,存到钱28个元素中,按列相加,存入后28个元素中        #如果分类效果不好,可再计算按对角线相加、行列式等        #多加了一列,以便train_x存储标签用。        data_tidy = np.zeros((num_images, rows + cols + 1), dtype=np.uint32)        for i in range(num_images):            data_tidy[i, :rows] = np.sum(data_value_binary[i], axis=1)            data_tidy[i, rows:(rows+cols)] = (np.sum(data_value_binary[i].transpose(), axis=1))        return data_tidy#抽取标签#仿照tensorflow中mnist.py写的def extract_labels(input_file):    with gzip.open(input_file, 'rb') as zipf:        magic = _read32(zipf)        if magic != 2049:            raise ValueError('Invalid magic number %d in MNIST label file: %s' % (magic, input_file.name))        num_items = _read32(zipf)        buf = zipf.read(num_items)        labels = np.frombuffer(buf, dtype=np.uint8)        return labelsclass node:    def __init__(self, point, label):        self.left = None        self.right = None        self.point = point        self.label = label  #由于按树存储的时候数据点顺序打乱了,这里将label也存进树里面。        self.parent = None        pass    def set_left(self, left):        if left == None: pass        left.parent = self        self.left = left    def set_right(self, right):        if right == None: pass        right.parent = self        self.right = rightdef median(lst):    m = len(lst) / 2    return lst[m], mdef build_kdtree(data, d):    data = sorted(data, key=lambda x: x[d.next()])     p, m = median(data)    tree = node(p[:-1], p[-1])    del data[m]    #递归查询新节点该存放的位置,同时也递归的切分区域    if m > 0: tree.set_left(build_kdtree(data[:m], d))    if len(data) > 1: tree.set_right(build_kdtree(data[m:], d))    return tree#计算距离def distance(a, b):    diff = a - b    squaredDiff = diff ** 2    return np.sum(squaredDiff)def search_kdtree(tree, d, target, k):    den = d.next()    #直到搜索到不存在更近的节点才停止。    if target[den] < tree.point[den]:        if tree.left != None:            return search_kdtree(tree.left, d, target, k)    else:        if tree.right != None:            return search_kdtree(tree.right, d, target, k)    #持续更新距离最近的k个节点    def update_best(t, best):        if t == None: return        label = t.label        t = t.point        d = distance(t, target)        for i in range(k):            if d < best[i][1]:                for j in range(0, i):                    best[j][1] = best[j+1][1]                    best[j][0] = best[j+1][0]                    best[j][2] = best[j+1][2]                best[i][1] = d                best[i][0] = t                best[i][2] = label    best = []    for i in range(k):        best.append( [tree.point, 100000.0, 10] )    while (tree.parent != None):        update_best(tree.parent.left, best)        update_best(tree.parent.right, best)        tree = tree.parent    return bestdef testHandWritingClass():    ## step 1: load data    print "step 1: load data..."    train_x = extract_images('data/mnist/train_images', True, True)    train_y = extract_labels('data/mnist/train_labels')    test_x = extract_images('data/mnist/test_images', True, True)    test_y = extract_labels('data/mnist/test_labels')    l = min(train_x.shape[0], train_y.shape[0])    rows = train_x.shape[1]    #将训练集的标签存到train_x中,一遍一同存储到kd树中。    for i in range(l):        train_x[i, -1] = train_y[i]    densim = itertools.cycle(range(0, rows-1))    ## step 2: training...    print "step 2: build tree..."    mnist_tree = build_kdtree(train_x, densim)    ## step 3: testing    print "step 3: testing..."    a = datetime.now()    numTestSamples = test_x.shape[0]    matchCount = 0    test_num = numTestSamples    K = 3    for i in xrange(test_num):        best_k = search_kdtree(mnist_tree, densim, test_x[i, :-1], K)        #计算数量最大的label。        classCount = {}        for j in range(K):            voteLabel = best_k[j][2]            classCount[voteLabel] = classCount.get(voteLabel, 0) + 1        maxCount = 0        predict = 0        for key, value in classCount.items():            if value > maxCount:                maxCount = value                predict = key        if predict == test_y[i]:            matchCount += 1        if i % 100 == 0:            print "完成%d张图片"%(i)    accuracy = float(matchCount) / test_num    b = datetime.now()    print "一共运行了%d秒"%((b-a).seconds)    ## step 4: show the result    print "step 4: show the result..."    print 'The classify accuracy is: %.2f%%' % (accuracy * 100)if __name__ == '__main__':    testHandWritingClass()

同样,k=3, 识别10000张的结果如下:
这里写图片描述

速度快了585倍,这其中每张图片从784维变成了56维,速度有加快,但kd树本身比knn快也是确定无疑的。
由于数据的压缩,准确率下降了好多。

原创粉丝点击