tensorflow 学习笔记(4)-basic_example

来源:互联网 发布:淘宝天天特价十元包邮 编辑:程序博客网 时间:2024/06/06 01:56

basic_example : nearest neighbor algorithm

# -*- coding: utf-8 -*-"""Created on Tue Jun 20 19:26:25 2017@author: wu"""# 引入模块from __future__ import print_functionimport tensorflow as tfimport numpy as np#下载mnist数据集from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot = True)#获取训练数据和测试数据Xtr, Ytr = mnist.train.next_batch(5000)Xte, Yte = mnist.train.next_batch(200)#TensorFlow的数据图的输入xtr = tf.placeholder("float", [None, 784])xte = tf.placeholder("float", [784])#use L1 distancedistance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices = 1)pred  = tf.arg_min(distance, 0)accuracy  = 0.init = tf.global_variables_initializer()#Launch graphwith tf.Session() as sess:    sess.run(init)    for i in range(len(Xte)):#测试集图片迭代测试        nn_index = sess.run(pred, feed_dict = {xtr: Xtr, xte: Xte[i, :]})        print("Test", i, "Prediction: ", np.argmax(Ytr[nn_index]),"True class: ",np.argmax(Yte[i]))        #如果测试集图片和训练集图片一样,计算精确率        if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):            accuracy += 1./len(Xte)    print("Done!")    print("Accuracy: ",accuracy)

部分运行结果截图

这里写图片描述

原创粉丝点击