tenserflow实例之最近邻算法

来源:互联网 发布:怎样增加淘宝信誉度 编辑:程序博客网 时间:2024/06/13 06:19

下面是利用MNIST data做的一个最近邻分类

'''A nearest neighbor learning algorithm example using TensorFlow library.This example is using the MNIST database of handwritten digits(http://yann.lecun.com/exdb/mnist/)Author: Aymeric DamienProject: https://github.com/aymericdamien/TensorFlow-Examples/'''from __future__ import print_functionimport numpy as npimport tensorflow as tf# 导入 MNIST datafrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("./MNIST_data/", one_hot=True)# 这里我们训练集选取5000,测试集选取200Xtr, Ytr = mnist.train.next_batch(5000) #5000 for training (nn candidates)Xte, Yte = mnist.test.next_batch(200) #200 for testing# tf设置占位符xtr = tf.placeholder("float", [None, 784])xte = tf.placeholder("float", [784])# 利用L1距离进行最近邻# 计算 L1距离distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)# 预测: 取得距离最小的图片的指针pred = tf.arg_min(distance, 0)accuracy = 0.# 初始化所有变量init = tf.global_variables_initializer()# 创建对话with tf.Session() as sess:    sess.run(init)    # 循环测试数据    for i in range(len(Xte)):        # 得到最近邻        nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i, :]})        # 读取最近邻的标签,比较与真实值是否一致        print("Test", i, "Prediction:", np.argmax(Ytr[nn_index]), \            "True Class:", np.argmax(Yte[i]))        # 计算精确度        if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):            accuracy += 1./len(Xte)    print("Done!")    print("Accuracy:", accuracy)
0 0