用tensorflow的slim模块快速实现mnist手写体识别分类

来源:互联网 发布:淘宝网上开店需要多少钱 编辑:程序博客网 时间:2024/04/29 08:30
import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datafrom tensorflow.examples.tutorials.mnist import mnistimport tensorflow.contrib.slim as slimmnist=input_data.read_data_sets('../share/MNIST_DATA',one_hot=True)x=tf.placeholder("float",shape=[None,784])y_=tf.placeholder("float",shape=[None,10])#cast x to 3Dx_image=tf.reshape(x,[-1,28,28,1])#shape of x is [N,28,28,1]#conv layer1net=slim.conv2d(x_image,32,[5,5],scope='conv1')#shape of net is [N,28,28,32]net=slim.max_pool2d(net,[2,2],scope='pool1')#shape of net is [N,14,14,32]#conv layer2net=slim.conv2d(net,64,[5,5],scope='conv2')#shape of net is [N,14,14,64]net=slim.max_pool2d(net,[2,2],scope='pool2')#shape of net is [N,7,7,64]#reshape for full connectionnet=tf.reshape(net,[-1,7*7*64])#[N,7*7*64]#fc1net=slim.fully_connected(net,1024,scope='fc1')#shape of net is [N,1024]#dropout layerkeep_prob=tf.placeholder('float')net=tf.nn.dropout(net,keep_prob)#fc2net=slim.fully_connected(net,10,scope='fc2')#[N,10]#softmaxy=tf.nn.softmax(net)#[N,10]cross_entropy=-tf.reduce_sum(tf.multiply(y_,tf.log(y)))#y and _y have same shape.train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)correct_prediction=tf.equal(tf.argmax(y,axis=1),tf.argmax(y_,axis=1))#shape of correct_prediction is [N]accuracy=tf.reduce_mean(tf.cast(correct_prediction,'float'))init=tf.global_variables_initializer()with tf.Session() as sess:    sess.run(init)    for i in range(10000):        batch=mnist.train.next_batch(50)        if i%100==0:            train_accuracy=sess.run(accuracy,feed_dict={x:batch[0],y_:batch[1],keep_prob:1.0})            print('step %d,training accuracy  %g !!!!!!!'%(i,train_accuracy))        sess.run(train_step,feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})    total_accuracy=sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0})    print('test_accuracy  %s!!!!!!!'%(total_accuracy))

直接贴代码,代码没什么好说的了,我都做了注释了。主要是参考TensorFlow官网上的教程点击打开链接,但是使用了slim模块(slim介绍参考我的这个博客点击打开链接),于是大大缩小了代码量,也提高了代码的可读性,强烈推荐slim模块。当然如果对上述代码中的函数不熟悉的可直接去TensorFlow官网查看API手册,里面介绍得非常详尽。当然在调试代码时最重要的还是关注tensor的shape,于是我在每个tensor变量后都注释了shape,方便调试,也能提高程序的可读性。

后面得到的结果展示


阅读全文
0 0