kd树识别压缩有的mnist数据集
来源:互联网 发布:红叶知弦小说 编辑:程序博客网 时间:2024/06/06 09:26
在《一般knn算法识别mnist数据集(代码)》 中用一般的knn方法做了mnist识别。和神经网络方法比起来,knn慢很多,识别1000张图片需要234s。kd树更高效率实现knn的方法,它用二叉树来存储训练集中的样本,搜索k个近邻点时速度更快。具体算法如下:
对mnist,每张图片是28*28的,且灰度值大于3的就算是手写的痕迹了,因此,在一般knn算法识别mnist数据集(代码)中,通过二值化(one hot)灰度值的方式,效果还不错。但是,由于kd树通过各维的中值来切分区域,如果每一维的数值只有0和1,根本没办法切分。因此,将二值化后的矩阵分别按行求和,按列求和,得到新的56维的特征。相当于做了一个降维,同时,增大各维的取值范围。但是,这样处理后,图片原本良好的平移不变形变差了。识别率估计会降低。无妨,这里重要要的是效率的对比。
代码如下,其中,kd树的建立和搜索仿照了http://www.hankcs.com/ml/k-nearest-neighbor-method.html 中的代码。
代码如下:
#coding:utf-8import numpy as npimport gzipimport itertoolsfrom datetime import datetimedef _read32(bytestream): dt = np.dtype(np.uint32).newbyteorder('>') return np.frombuffer(bytestream.read(4), dtype=dt)[0]def extract_images(input_file, is_value_binary, is_matrix): with gzip.open(input_file, 'rb') as zipf: magic = _read32(zipf) if magic !=2051: raise ValueError('Invalid magic number %d in MNIST image file: %s' %(magic, input_file.name)) num_images = _read32(zipf) rows = _read32(zipf) cols = _read32(zipf) print magic, num_images, rows, cols buf = zipf.read(rows * cols * num_images) data = np.frombuffer(buf, dtype=np.uint8) #reshape成二维 data = data.reshape(num_images, rows, cols) #二值化 data_value_binary = np.minimum(data, 1) #按行相加,存到钱28个元素中,按列相加,存入后28个元素中 #如果分类效果不好,可再计算按对角线相加、行列式等 #多加了一列,以便train_x存储标签用。 data_tidy = np.zeros((num_images, rows + cols + 1), dtype=np.uint32) for i in range(num_images): data_tidy[i, :rows] = np.sum(data_value_binary[i], axis=1) data_tidy[i, rows:(rows+cols)] = (np.sum(data_value_binary[i].transpose(), axis=1)) return data_tidy#抽取标签#仿照tensorflow中mnist.py写的def extract_labels(input_file): with gzip.open(input_file, 'rb') as zipf: magic = _read32(zipf) if magic != 2049: raise ValueError('Invalid magic number %d in MNIST label file: %s' % (magic, input_file.name)) num_items = _read32(zipf) buf = zipf.read(num_items) labels = np.frombuffer(buf, dtype=np.uint8) return labelsclass node: def __init__(self, point, label): self.left = None self.right = None self.point = point self.label = label #由于按树存储的时候数据点顺序打乱了,这里将label也存进树里面。 self.parent = None pass def set_left(self, left): if left == None: pass left.parent = self self.left = left def set_right(self, right): if right == None: pass right.parent = self self.right = rightdef median(lst): m = len(lst) / 2 return lst[m], mdef build_kdtree(data, d): data = sorted(data, key=lambda x: x[d.next()]) p, m = median(data) tree = node(p[:-1], p[-1]) del data[m] #递归查询新节点该存放的位置,同时也递归的切分区域 if m > 0: tree.set_left(build_kdtree(data[:m], d)) if len(data) > 1: tree.set_right(build_kdtree(data[m:], d)) return tree#计算距离def distance(a, b): diff = a - b squaredDiff = diff ** 2 return np.sum(squaredDiff)def search_kdtree(tree, d, target, k): den = d.next() #直到搜索到不存在更近的节点才停止。 if target[den] < tree.point[den]: if tree.left != None: return search_kdtree(tree.left, d, target, k) else: if tree.right != None: return search_kdtree(tree.right, d, target, k) #持续更新距离最近的k个节点 def update_best(t, best): if t == None: return label = t.label t = t.point d = distance(t, target) for i in range(k): if d < best[i][1]: for j in range(0, i): best[j][1] = best[j+1][1] best[j][0] = best[j+1][0] best[j][2] = best[j+1][2] best[i][1] = d best[i][0] = t best[i][2] = label best = [] for i in range(k): best.append( [tree.point, 100000.0, 10] ) while (tree.parent != None): update_best(tree.parent.left, best) update_best(tree.parent.right, best) tree = tree.parent return bestdef testHandWritingClass(): ## step 1: load data print "step 1: load data..." train_x = extract_images('data/mnist/train_images', True, True) train_y = extract_labels('data/mnist/train_labels') test_x = extract_images('data/mnist/test_images', True, True) test_y = extract_labels('data/mnist/test_labels') l = min(train_x.shape[0], train_y.shape[0]) rows = train_x.shape[1] #将训练集的标签存到train_x中,一遍一同存储到kd树中。 for i in range(l): train_x[i, -1] = train_y[i] densim = itertools.cycle(range(0, rows-1)) ## step 2: training... print "step 2: build tree..." mnist_tree = build_kdtree(train_x, densim) ## step 3: testing print "step 3: testing..." a = datetime.now() numTestSamples = test_x.shape[0] matchCount = 0 test_num = numTestSamples K = 3 for i in xrange(test_num): best_k = search_kdtree(mnist_tree, densim, test_x[i, :-1], K) #计算数量最大的label。 classCount = {} for j in range(K): voteLabel = best_k[j][2] classCount[voteLabel] = classCount.get(voteLabel, 0) + 1 maxCount = 0 predict = 0 for key, value in classCount.items(): if value > maxCount: maxCount = value predict = key if predict == test_y[i]: matchCount += 1 if i % 100 == 0: print "完成%d张图片"%(i) accuracy = float(matchCount) / test_num b = datetime.now() print "一共运行了%d秒"%((b-a).seconds) ## step 4: show the result print "step 4: show the result..." print 'The classify accuracy is: %.2f%%' % (accuracy * 100)if __name__ == '__main__': testHandWritingClass()
同样,k=3, 识别10000张的结果如下:
速度快了585倍,这其中每张图片从784维变成了56维,速度有加快,但kd树本身比knn快也是确定无疑的。
由于数据的压缩,准确率下降了好多。
阅读全文
0 0
- kd树识别压缩有的mnist数据集
- 手写字体识别 --MNIST数据集
- 经典手写数字mnist数据集识别
- tensorflow实现MNIST数据集识别
- tensorflow mnist数据集手写字识别
- 识别MNIST数据集之(一):读取数据
- 识别MNIST数据集之(一):读取数据
- TensorFlow的MNIST数据识别
- 使用tensorflow对Mnist数据集进行字体识别
- tensorflow下对MNIST数据集进行识别的程序代码
- 识别MNIST数据集:用Python实现神经网络
- 一般knn算法识别mnist数据集(代码)
- 基于MNIST数据集实现车牌识别--初步演示版
- 机器学习算法之KNN识别mnist数据集
- Tensorflow入门-简单神经网络进行MNIST数据集识别
- tensorflow入门之mnist手写数据集识别
- tensorflow 使用CNN 进行mnist数据集识别
- caffe-mnist数据识别loss accuracy曲线
- Java 通过Xml导出Excel文件,Java Excel 导出工具类,Java导出Excel工具类
- 用python简单实现mysql数据同步到ElasticSearch
- Web单选下拉列表与多选下拉列表的清除
- 如何快速转载CSDN中的博客
- Android 中判断为空的简单语句
- kd树识别压缩有的mnist数据集
- k近邻算法(knn) 学习
- jetson tk1开发(1)-开箱
- 最全Pycharm教程(33)——使用Pycharm编写IPython Notebook文件
- SSH框架入门(1)——struts2(1)
- Cookie笔记
- HDOJ1045
- [Mysql] 防御和检查SQL注入攻击的手段
- 机器学习之K-means聚类算法