tensorflow学习:mnist图片分类,并打印预测精度

来源:互联网 发布:知乎查看自己的匿名 编辑:程序博客网 时间:2024/04/30 23:13

使用softmax对mnist图片分类,并获取预测的准确度

import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('MNIST_data/', one_hot=True)x_data = tf.placeholder(tf.float32, [None, 28 * 28])y_data = tf.placeholder(tf.float32, [None, 10])#创建一个隐藏层,输入数据:x_data, 输出10个神经元,激励函数使用softmaxprediction = tf.layers.dense(x_data, 10, tf.nn.softmax)# tf.reduce_sum的用法# x is [[1, 1, 1]#       [1, 1, 1]]# tf.reduce_sum(x) => 6# tf.reduce_sum(x, 0) => [2, 2, 2]# tf.reduce_sum(x, 1) => [3, 3]#损失函数,一般softmax和交叉熵损失配合使用cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_data * tf.log(prediction), reduction_indices=[1]))#cross_entropy = tf.reduce_mean(-y_data * tf.log(prediction))train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)def computer_accuracy(x_input, y_input):    ''' 定义训练精度 '''    global prediction    y_pre = sess.run(prediction, feed_dict={x_data: x_input, y_data: y_input})    correct_prediction = tf.equal(tf.arg_max(y_pre, 1), tf.arg_max(y_input, 1))#     correct_prediction = tf.equal(tf.arg_max(y_pre, dimension=1), tf.arg_max(y_input, dimension=1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))    result = sess.run(accuracy, feed_dict={x_data: x_input, y_data: y_input})    return resultinit = tf.global_variables_initializer()with tf.Session() as sess:    sess.run(init)    for step in range(1001):        batch_x_data, batch_y_data = mnist.train.next_batch(1000)        sess.run(train_step, feed_dict={x_data: batch_x_data, y_data: batch_y_data})        if step % 40 == 0:            accuracy = computer_accuracy(mnist.test.images, mnist.test.labels)            print (step, accuracy)                

结果:Extracting MNIST_data/train-images-idx3-ubyte.gzExtracting MNIST_data/train-labels-idx1-ubyte.gzExtracting MNIST_data/t10k-images-idx3-ubyte.gzExtracting MNIST_data/t10k-labels-idx1-ubyte.gz0 0.399440 0.881280 0.8961120 0.9021160 0.9066200 0.9091240 0.9114280 0.9137320 0.9133360 0.9147400 0.9159440 0.9169480 0.9179520 0.9179560 0.9193600 0.9193640 0.9181680 0.9192720 0.9206760 0.9201800 0.9209840 0.92880 0.9205920 0.9211960 0.92031000 0.9201


阅读全文
0 0