一般knn算法识别mnist数据集(代码)
来源:互联网 发布:c语言1到100的素数 编辑:程序博客网 时间:2024/06/09 20:58
本来是想弄个kd tree来玩玩knn的,但是mnist这样的数据集真是不好按维切分。把数据打印出来看了下,貌似灰度值大于3的都算是手写的印迹,着实不能取中值。既然这样,先拿一般的knn方法识别一下,看看效果和执行效率,再想办法这算一下mnist,玩玩kd tree吧。knn的基本原理在k-means、GMM聚类、KNN原理概述 有介绍,比较全的原理介绍在http://www.hankcs.com/ml/k-nearest-neighbor-method.html
下面是用knn 识别mnist数据集的代码,代码包括mnist的下载和抽取,以及knn测试,并计算了测试1000张图片所花费的时间。
#coding:utf-8import numpy as npimport osimport gzipfrom six.moves import urllibimport operatorfrom datetime import datetimeSOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'TEST_IMAGES = 't10k-images-idx3-ubyte.gz'TEST_LABELS = 't10k-labels-idx1-ubyte.gz'#下载mnist数据集,仿照tensorflow的base.py中的写法。def maybe_download(filename, path, source_url): if not os.path.exists(path): os.makedirs(path) filepath = os.path.join(path, filename) if not os.path.exists(filepath): urllib.request.urlretrieve(source_url, filepath) return filepath#按32位读取,主要为读校验码、图片数量、尺寸准备的#仿照tensorflow的mnist.py写的。def _read32(bytestream): dt = np.dtype(np.uint32).newbyteorder('>') return np.frombuffer(bytestream.read(4), dtype=dt)[0]#抽取图片,并按照需求,可将图片中的灰度值二值化,按照需求,可将二值化后的数据存成矩阵或者张量#仿照tensorflow中mnist.py写的def extract_images(input_file, is_value_binary, is_matrix): with gzip.open(input_file, 'rb') as zipf: magic = _read32(zipf) if magic !=2051: raise ValueError('Invalid magic number %d in MNIST image file: %s' %(magic, input_file.name)) num_images = _read32(zipf) rows = _read32(zipf) cols = _read32(zipf) print magic, num_images, rows, cols buf = zipf.read(rows * cols * num_images) data = np.frombuffer(buf, dtype=np.uint8) if is_matrix: data = data.reshape(num_images, rows*cols) else: data = data.reshape(num_images, rows, cols) if is_value_binary: return np.minimum(data, 1) else: return data#抽取标签#仿照tensorflow中mnist.py写的def extract_labels(input_file): with gzip.open(input_file, 'rb') as zipf: magic = _read32(zipf) if magic != 2049: raise ValueError('Invalid magic number %d in MNIST label file: %s' % (magic, input_file.name)) num_items = _read32(zipf) buf = zipf.read(num_items) labels = np.frombuffer(buf, dtype=np.uint8) return labels# 一般的knn分类,跟全部数据同时计算一般距离,然后找出最小距离的k张图,并找出这k张图片的标签,标签占比最大的为newInput的label#copy大神http://blog.csdn.net/zouxy09/article/details/16955347的def kNNClassify(newInput, dataSet, labels, k): numSamples = dataSet.shape[0] # shape[0] stands for the num of row init_shape = newInput.shape[0] newInput = newInput.reshape(1, init_shape) #np.tile(A,B):重复A B次,相当于重复[A]*B #print np.tile(newInput, (numSamples, 1)).shape diff = np.tile(newInput, (numSamples, 1)) - dataSet # Subtract element-wise squaredDiff = diff ** 2 # squared for the subtract squaredDist = np.sum(squaredDiff, axis = 1) # sum is performed by row distance = squaredDist ** 0.5 sortedDistIndices = np.argsort(distance) classCount = {} # define a dictionary (can be append element) for i in xrange(k): ## step 3: choose the min k distance voteLabel = labels[sortedDistIndices[i]] ## step 4: count the times labels occur # when the key voteLabel is not in dictionary classCount, get() # will return 0 classCount[voteLabel] = classCount.get(voteLabel, 0) + 1 ## step 5: the max voted class will return maxCount = 0 maxIndex = 0 for key, value in classCount.items(): if value > maxCount: maxCount = value maxIndex = key return maxIndexmaybe_download('train_images', 'data/mnist', SOURCE_URL+TRAIN_IMAGES)maybe_download('train_labels', 'data/mnist', SOURCE_URL+TRAIN_LABELS)maybe_download('test_images', 'data/mnist', SOURCE_URL+TEST_IMAGES)maybe_download('test_labels', 'data/mnist', SOURCE_URL+TEST_LABELS)# 主函数,先读图片,然后用于测试手写数字#copy大神http://blog.csdn.net/zouxy09/article/details/16955347的def testHandWritingClass(): ## step 1: load data print "step 1: load data..." train_x = extract_images('data/mnist/train_images', True, True) train_y = extract_labels('data/mnist/train_labels') test_x = extract_images('data/mnist/test_images', True, True) test_y = extract_labels('data/mnist/test_labels') ## step 2: training... print "step 2: training..." pass ## step 3: testing print "step 3: testing..." a = datetime.now() numTestSamples = test_x.shape[0] matchCount = 0 test_num = numTestSamples/10 for i in xrange(test_num): predict = kNNClassify(test_x[i], train_x, train_y, 3) if predict == test_y[i]: matchCount += 1 if i % 100 == 0: print "完成%d张图片"%(i) accuracy = float(matchCount) / test_num b = datetime.now() print "一共运行了%d秒"%((b-a).seconds) ## step 4: show the result print "step 4: show the result..." print 'The classify accuracy is: %.2f%%' % (accuracy * 100)if __name__ == '__main__': testHandWritingClass()
执行后的结果如下:
step 1: load data...2051 60000 28 282051 10000 28 28step 2: training...step 3: testing...完成0张图片完成100张图片完成200张图片完成300张图片完成400张图片完成500张图片完成600张图片完成700张图片完成800张图片完成900张图片一共运行了234秒step 4: show the result...The classify accuracy is: 96.20%
1000张图片运行时间234秒,时间开销大于简单的cnn,识别率高于96.2%,仅高于softmax回归,后者只有92%,多层感知机能达到98%的识别率,且训练速度快,测试更快。
阅读全文
0 0
- 一般knn算法识别mnist数据集(代码)
- 机器学习算法之KNN识别mnist数据集
- KNN算法实现(Mnist数据集)
- 李航《统计学习方法》第三章——用Python实现KNN算法(MNIST数据集)
- 使用PCA + KNN对MNIST数据集进行手写数字识别 python
- tensorflow实现KNN识别MNIST
- kNN算法识别手写数字(代码笔记)
- 识别MNIST数据集之(一):读取数据
- 识别MNIST数据集之(一):读取数据
- 用KNN做手写数字识别(mnist)
- 数据挖掘算法---KNN(附python代码)
- 使用KNN对MNIST数据集进行实验
- MNIST数字识别代码
- 手写字体识别 --MNIST数据集
- 经典手写数字mnist数据集识别
- tensorflow实现MNIST数据集识别
- tensorflow mnist数据集手写字识别
- 【KNN近邻算法】实现识别简单数字验证码(算法原理+代码笔记)
- React性能优化 PureComponent 使用指南
- c++的动态类
- 浅谈Velocity.js
- Android Studio SVN 使用方法
- spring cloud-zuul的Filter详解
- 一般knn算法识别mnist数据集(代码)
- Unity3D 学习笔记一
- 清帝之惑之顺治
- Android 监听手机键盘打开和关闭状态
- Google-C++-style-headerfile
- JSON
- Centos服务器安全配置SSH使用Google Authenticator二次验证
- 判断网络是否可用
- Unity3D-调用Android系统邮件发送邮件,或调用系统分享去分享内容