tensorflow学习笔记二---k近邻分类器
来源:互联网 发布:淘宝手机收藏链接地址 编辑:程序博客网 时间:2024/06/10 01:34
使用Tensorflow实现k近邻分类器模型
1.k近邻模型的基本原理
- 距离度量
2.k值的选择
3 .分类决策规则
2.Tensorflow实现k近邻分类代码
- inference()-构建学习器模型前向预测过程(从输入到输出的计算图路径)
- evaluate()-在测试集数据上对模型的预测性能进行评估
- 此模型没有添加loss也没有train
3.计算步骤
- 算距离:给定测试样本的特征向量,计算他与训练集中每个样本特征向量的距离,得到一个一维张量
关于缩减求和:http://blog.csdn.net/flyfish1986/article/details/54646236
- 找近邻:圈定最近的k个训练样本作为测试样本近邻
- 作分类:根据k个近邻的归属主要类别,来对测试做主要分类
4.总结
- tensorflow实现k近邻算法主要有以下几个步骤
- 算距离:计算测试样本与每一个训练样本的距离,缩减求和后得到一个一维数组存储
- 找近邻:划定k值的大小,选取k个训练样本做为测试样本的近邻
- 做分类:根据k个近邻,对测试样本做分类(距离最小的索引)
- 做评估:与真实的标签进行比较,计算准确率
- 核心代码
import numpy as npimport osimport tensorflow as tf# 防止意外报错os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'# 导入mnist数据集from tensorflow.examples.tutorials.mnist import input_data# onehot标签标识一个长度为n的数组,只有一个元素是1,其他的都是0,用来表示mnist中标签数据mnist = input_data.read_data_sets("mnist_data/", one_hot=True)# 对mnist数据集做一个数量限制Xtrain,Ytrain=mnist.train.next_batch(5000)#使用5000个训练数据Xtest,Ytest=mnist.train.next_batch(200) # 使用200个测试数据print('Xtrain.shape: ', Xtrain.shape, ', Xtest.shape: ',Xtest.shape)print('Ytrain.shape: ', Ytrain.shape, ', Ytest.shape: ',Ytest.shape)# 计算图输入占位符#train 使用全部样本,test 逐个样本进行测试xtrain=tf.placeholder("float",[None,784])#图片训练集xtest=tf.placeholder("float",[784])#测试集#使用L1距离进行最近邻计算#计算L1距离distance=tf.reduce_sum(tf.abs(tf.add(xtrain,tf.negative(xtest))),axis=1)# 预测: 获得最小距离的索引 (根据最近邻的类标签进行判断)pred = tf.arg_min(distance, 0)#评估:判断给定的一条测试样本是否预测正确#评估正确率accuracy=0# 初始化节点init = tf.global_variables_initializer()#启动会话with tf.Session() as sess: sess.run(init) Ntest=len(Xtest)#测试样本的数量 for i in range(Ntest): # 获取当前测试样本的最近邻 nn_index = sess.run(pred, feed_dict={xtrain: Xtrain, xtest: Xtest[i, :]})#一个样本一个样本的输入 # 获得最近邻预测标签,然后与真实的类标签比较,由于是 one_hot 编码,所以要用 argmax 将类标取出 pred_class_label = np.argmax(Ytrain[nn_index]) true_class_label = np.argmax(Ytest[i]) print("Test", i, "Predicted Class Label:", pred_class_label, "True Class Label:", true_class_label) # 计算准确率 if pred_class_label == true_class_label: accuracy += 1 print("Done!") accuracy /= Ntest print("Accuracy:", accuracy)
训练结果准确率是0.925,使用的数据是mnist5000的训练数据和200的测试数据
关于这部分也就介绍到这,我只是代码的搬运工,嘿嘿嘿,把学到的东西分享出来,温故而知新。
总觉得很快乐。
持续更新,大家可以一起讨论哦。
阅读全文
0 0
- tensorflow学习笔记二---k近邻分类器
- 19、TensorFlow 实现最近邻分类器(K=1)
- 机器学习,k近邻分类器,python,
- 机器学习(二):分类算法之k-近邻算法
- 机器学习(二)k-近邻分类算法(kNN)
- 《机器学习实战》学习笔记-[1]-K近邻_第一个分类器
- 《机器学习实战》学习笔记(二、k-近邻算法)
- python K-近邻分类器
- K最近邻分类器
- 机器学习python,k近邻分类器,三维作图
- 机器学习笔记(二)——k-近邻算法
- 机器学习实战笔记之二(k-近邻算法)
- 机器学习笔记--K-近邻算法(二)
- sklearn学习笔记(二)——最近邻分类
- 图片分类-K近邻分类器
- 机器学习实战-k近邻分类
- Halcon学习之K最近邻分类
- 《机器学习实战》K近邻(KNN)分类
- 基于梯度下降法实现线性回归算法
- Javascript学习心得一
- easyUI_执行tab窗口中的方法,sources源码中查看不到js代码
- c#--实例选号器--实现打印、序列化方式保存
- 使用OB缓存实现静态化
- tensorflow学习笔记二---k近邻分类器
- ElementaryOS 0.4快速配置工具
- 菊花台
- Flume的安装及简单的使用(二)
- 写一个makefile
- Linux下安装网络软件的步骤
- 3,单例模式
- make (;区别 + 目标变量)
- Linux上安装wine qq的方法