简单的Tensorflow(6):MNIST数据集的简单应用

来源:互联网 发布:网络翻译成英文 编辑:程序博客网 时间:2024/06/08 19:39

The MNIST database的全称是Mixed National Institute of Standards and Technology database是一个手写数字数据库,它有60000个训练样本集和10000个测试样本集。它是NIST数据库的一个子集,可以用来做手写数字识别的训练和测试数据集。可以到官网下载http://yann.lecun.com/exdb/mnist/这些文件并不是标准的图像格式。这些图像数据都保存在二进制文件中。每个样本图像的宽高为28*28。


使用之前导入数据集mnist = input_data.read_data_sets('MNIST_data',one_hot=True),one_hot最早是数字电路中的一种编码方式,这里可以理解成一个拥有10个元素的行向量,但是只有其中一个元素为1,其余全都是0。

导入之后需要划分数据集的大小,batch_size = 100,n_batch = mnist.train.num_examples // batch_size

然后定义一个简单的网络即可对数据集进行训练和检测:


全部代码:

import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data#输入数据集mnist = input_data.read_data_sets('MNIST_data',one_hot=True)#输入数据集batch_size = 100n_batch = mnist.train.num_examples // batch_size#定义两个占位符x = tf.placeholder(tf.float32,[None,784])y = tf.placeholder(tf.float32,[None,10])#创建一个简单的神经网络W = tf.Variable(tf.zeros([784,10]))B = tf.Variable(tf.zeros([10]))Result = tf.matmul(x,W) + Bprediction = tf.nn.softmax(Result)#设置损失函数loss = tf.reduce_mean(tf.square(y - prediction))#设置优化器optimizer = tf.train.GradientDescentOptimizer(0.1)#最小化代价函数train = optimizer.minimize(loss)#初始化变量init = tf.global_variables_initializer()#存放预测结果correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#求准确率accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#设置会话with tf.Session() as sess:    sess.run(init)    for step in range(51):        for batch in range(n_batch):            batch_x, batch_y = mnist.train.next_batch(batch_size)            sess.run(train,feed_dict={x:batch_x,y:batch_y})        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})        print("Iter "+str(step)+" ,Testing Accuarcy "+str(acc))


结果是:(可以看出如果继续训练准确率会继续慢慢提升)

Iter 0 ,Testing Accuarcy 0.7426Iter 1 ,Testing Accuarcy 0.8336Iter 2 ,Testing Accuarcy 0.8603Iter 3 ,Testing Accuarcy 0.871Iter 4 ,Testing Accuarcy 0.8765Iter 5 ,Testing Accuarcy 0.8821Iter 6 ,Testing Accuarcy 0.8848Iter 7 ,Testing Accuarcy 0.8894Iter 8 ,Testing Accuarcy 0.8921Iter 9 ,Testing Accuarcy 0.8949Iter 10 ,Testing Accuarcy 0.8958Iter 11 ,Testing Accuarcy 0.8976Iter 12 ,Testing Accuarcy 0.8984Iter 13 ,Testing Accuarcy 0.8997Iter 14 ,Testing Accuarcy 0.9015Iter 15 ,Testing Accuarcy 0.9013Iter 16 ,Testing Accuarcy 0.9026Iter 17 ,Testing Accuarcy 0.9029Iter 18 ,Testing Accuarcy 0.9043Iter 19 ,Testing Accuarcy 0.9048Iter 20 ,Testing Accuarcy 0.9059Iter 21 ,Testing Accuarcy 0.9063Iter 22 ,Testing Accuarcy 0.9067Iter 23 ,Testing Accuarcy 0.9073Iter 24 ,Testing Accuarcy 0.9074Iter 25 ,Testing Accuarcy 0.9078Iter 26 ,Testing Accuarcy 0.9089Iter 27 ,Testing Accuarcy 0.9091Iter 28 ,Testing Accuarcy 0.9092Iter 29 ,Testing Accuarcy 0.9096Iter 30 ,Testing Accuarcy 0.9106Iter 31 ,Testing Accuarcy 0.911Iter 32 ,Testing Accuarcy 0.9111Iter 33 ,Testing Accuarcy 0.9115Iter 34 ,Testing Accuarcy 0.9122Iter 35 ,Testing Accuarcy 0.9125Iter 36 ,Testing Accuarcy 0.9124Iter 37 ,Testing Accuarcy 0.9134Iter 38 ,Testing Accuarcy 0.9127Iter 39 ,Testing Accuarcy 0.9135Iter 40 ,Testing Accuarcy 0.9133Iter 41 ,Testing Accuarcy 0.9139Iter 42 ,Testing Accuarcy 0.9141Iter 43 ,Testing Accuarcy 0.9143Iter 44 ,Testing Accuarcy 0.9147Iter 45 ,Testing Accuarcy 0.9152Iter 46 ,Testing Accuarcy 0.9154Iter 47 ,Testing Accuarcy 0.9157Iter 48 ,Testing Accuarcy 0.9157Iter 49 ,Testing Accuarcy 0.9158Iter 50 ,Testing Accuarcy 0.9159


阅读全文
0 0