tensorflow simple

来源:互联网 发布:立体照片制作软件 编辑:程序博客网 时间:2024/06/05 14:10
import mnist
from mnist import read_data_sets
import tensorflow as tf
from tensorflow import models
from tensorflow import python
from tensorflow import tensorboard
from tensorflow import tools

mnist = read_data_sets('MNIST_data', one_hot=True)

sess = tf.InteractiveSession()

x = tf.placeholder("float",shape=[None,784])
y_ = tf.placeholder("float",shape=[None,10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

sess.run(tf.initialize_all_variables())
y = tf.nn.softmax(tf.matmul(x,W) + b)
cross_entropy = -tf.reduce_sum(y_*tf.log(y))

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    train_step.run(feed_dict={x: batch_xs,y_: batch_ys})

correct_prediction = tf.equal(tf.arg_max(y,1),tf.arg_max(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))

print '\naccuracy:'
assert isinstance(mnist.test, object)
print accuracy.eval(feed_dict={x:mnist.test.images,y_:mnist.test.labels})
原创粉丝点击