TensorFlow实战:手写数字识别之K近邻
来源:互联网 发布:sql数据库管理工具 编辑:程序博客网 时间:2024/05/29 13:20
导语
自从Google发布了TensorFlow之后,作为一款开源的深度学习框架,在全世界范围内产生了巨大的影响力,如今在GitHub上深度学习框架居于第一名,且远远领先其他深度学习开源项目,并且也在工业界被大量运用。学习TensorFlow不仅可以加深对深度学习的理解,而且可以知道如何将深度学习这一门高深的学问用于实践当中。
TensorFlow就像其名字一样,由“tensor”和“flow”组成,“tensor”即“张量”的意思。框架的主要思想是先构建需要的计算图,图中每个定点表示一个操作,边表示张量之间的流向或依赖关系。当整个计算图构建完之后,启动计算图,系统会自动按照节点之间的依赖关系计算节点值,就能在需要的节点上获取数据。
本文并不打算详细介绍TensorFlow的原理,想要看原理的可以直接去官网。本文主要内容是用TensorFlow写一个入门级的算法K近邻实现手写数字识别MNIST。
加载数据
Keras提供了实现深度学习所需要的绝大部分函数库,可实现多种神经网络模型,并可加载多种数据集来评价模型的效果。下面的代码会自动加载数据,如果是第一次调用,数据会保存在你的hone目录下~/.keras/datasets/mnist.pkl.gz,大约15MB。
from keras.datasets import mnist(X_train, y_train), (X_test, y_test) = mnist.load_data()
对数据的维度进行reshape,原数据是28*28大小的图片,要将其展开成784长度的向量,便于计算样本间的距离。
num_pixels = X_train.shape[1] * X_train.shape[2]X_train = X_train.reshape(X_train.shape[0], num_pixels).astype('float32')X_test = X_test.reshape(X_test.shape[0], num_pixels).astype('float32')## 取一部分作为训练数据Xtr, Ytr = X_train[:5000],y_train[:5000]Xte, Yte = X_test,y_test
K近邻实现
计算训练集中的样本距离,采用L1距离,取距离最近的一个样本,将其标签赋值给测试样本
distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)pred = tf.arg_min(distance, 0)
对测试集进行预测
accuracy = 0.# 初始化init = tf.global_variables_initializer()## 启动sessionwith tf.Session() as sess: sess.run(init) # 对每个测试样本进行预测 for i in range(len(Xte)): # 得到最近距离的样本 nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i, :]}) # 输出预测结果 print("Test", i, "Prediction:", np.argmax(Ytr[nn_index]), "True Class:", np.argmax(Yte[i])) # 计算准确率 if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]): accuracy += 1./len(Xte)print("测试完成!")print("Accuracy:", accuracy)
小结
用最近邻处理MNIST问题可以取得较好的效果,需要调节的参数主要是近邻的数目(K),模型的效果相当程度上依赖K的取值。虽然过程很简单,但对于了解和熟悉TensorFlow也很有帮助,同时也可以用TensorFlow实现逻辑回归、线性回归等模型,后面会一一将其实现。
- TensorFlow实战:手写数字识别之K近邻
- 《机器学习实战二》K近邻学习之手写数字识别及检测识别错误率
- k近邻 - 手写数字识别
- 机器学习实战k近邻算法(kNN)应用之手写数字识别代码解读
- Python3:《机器学习实战》之k近邻算法(3)识别手写数字
- 【机器学习】k-近邻算法应用之手写数字识别
- 机器学习实战之k-近邻算法(6)---手写数字识别系统(0-9识别)
- Tensorflow实战之用softmax Regression识别手写数字
- K近邻算法(一) python实现,手写数字识别(from机器学习实战)
- 机器学习实战 --应用实例(k-近邻算法)-- 手写数字识别
- 《机器学习实战》第二章:k-近邻算法(3)手写数字识别
- 机器学习实战——使用K-近邻算法识别手写数字
- 『机器学习实战』使用 k-近邻算法识别手写数字
- TensorFlow实战—mnist手写数字识别
- TensorFlow实战(一)手写数字识别
- OpenCV手写数字字符识别(基于k近邻算法)
- 基于K-近邻算法识别手写数字的实现
- 基于k近邻(KNN)的手写数字识别
- Hibernate笔记
- 简单制作进度条,圆弧
- Java实现机器人的运动范围
- 关于Android的一些杂项
- jpg/png/psd/tiff图片格式详解
- TensorFlow实战:手写数字识别之K近邻
- 2017上半年总结
- volatile关键字解析
- PHP中String类型
- HTML:a标签中href属性总结
- #学志#项目进度04
- js数组去重的三种常用方法总结
- for(;;)和while(true)的区别
- 外观模式