tensorflow 使用softmax 分类mnist数据库

来源:互联网 发布:结对编程 编辑:程序博客网 时间:2024/05/22 03:36
import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("/root/data/", one_hot=True)learning_rate = 0.01training_epochs = 25batch_size = 100display_step = 1x = tf.placeholder(tf.float32, [None, 784]) # mnist data image of shape 28*28=784y = tf.placeholder(tf.float32, [None, 10]) # 0-9 digits recognition => 10 classesW = tf.Variable(tf.zeros([784, 10]))b = tf.Variable(tf.zeros([10]))pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmaxcost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)init = tf.initialize_all_variables()with tf.Session() as sess:    sess.run(init)    for epoch in range(training_epochs):        avg_cost = 0.        total_batch = int(mnist.train.num_examples/batch_size)        for i in range(total_batch):            batch_xs, batch_ys = mnist.train.next_batch(batch_size)              _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,                                                          y: batch_ys})              avg_cost += c / total_batch          if (epoch+1) % display_step == 0:            print "Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost)    print "Optimization Finished!"    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))    print "Accuracy:", accuracy.eval({x: mnist.test.images[:3000], y: mnist.test.labels[:3000]})

0 0
原创粉丝点击