tensorflow & mnist & CNN

来源:互联网 发布:福建农林大学网络教学 编辑:程序博客网 时间:2024/06/07 03:07




import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport cv2import numpy as np# import osmnist = input_data.read_data_sets('./1/', one_hot=True)x = tf.placeholder(tf.float32, [None, 784])W = tf.Variable(tf.zeros([784, 10]))b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x, W) + b)y_ = tf.placeholder(tf.float32, [None, 10])cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))x1 = tf.arg_max(y,1)y1 = tf.arg_max(y_,1)sess = tf.InteractiveSession()tf.global_variables_initializer().run()saver = tf.train.Saver()saver.restore(sess, "./SoftmaxSaver/model.ckpt")  # restore the softmax modelfor i in range(10):    # bat_x, bat_y = mnist.test.next_batch(1)    bat_x, bat_y = mnist.test.next_batch(1)    print len(bat_x[0])    img = np.reshape(bat_x,(28,28))    cv2.imshow("s",img)    cv2.waitKey()    print "forecast %g"%sess.run(x1, feed_dict={x: bat_x})  # return the forecast(predictable) labels    print "accurate %d"%sess.run(y1,feed_dict={y_: bat_y})  # return the accurate labels    # x2 = tf.arg_max(bat_y, 1)    # print x2    # print bat_ycorrect_prediction = tf.equal(tf.arg_max(y, 1), tf.arg_max(y_, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))print sess.run(x1,feed_dict={x:mnist.test.images})print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})