Tensorflow深度学习之七:再谈mnist手写数字识别程序

来源:互联网 发布:怎么写销售数据分析表? 编辑:程序博客网 时间:2024/06/08 01:28

之前学习的第一个深度学习的程序就是mnist手写字体的识别,那个时候对于很多概念不是很理解,现在回过头再看当时的代码,理解了很多,现将加了注释的代码贴上,与大家分享。(本人还是在学习Tensorflow的初始阶段,如果有什么地方理解有误,还请大家不吝指出。)

from tensorflow.examples.tutorials.mnist import input_data# 下载mnist数据集至当前目录下的MNIST_data文件夹,并读取数据。mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)# 输出训练集,测试集,验证集图片的shape和标签的shape。print(mnist.train.images.shape, mnist.train.labels.shape)print(mnist.test.images.shape, mnist.test.labels.shape)print(mnist.validation.images.shape, mnist.validation.labels.shape)import tensorflow as tf# 默认会话。sess = tf.InteractiveSession()# 定义一个placeholder,用于保存输入的图片的信息。# 由于图片中的数值是0~1之间的浮点数,所以x的数据类型也应是tf.float32。# 第二个参数表示x的维度,其中None表示不限制输入的数量,之后的参数便是输入的数据的维度,# 这里的784表示输入的是一个长度为784的一维向量。x = tf.placeholder(tf.float32, [None, 784])# 定义权重变量,该变量是一个784x10的矩阵,这里将初始权重全部赋值为0。W = tf.Variable(tf.zeros([784, 10]))# 定义偏置值变量,该变量是一个由10个元素组成的向量,同样这里的偏置值变量的初始值也全部被赋值为0。b = tf.Variable(tf.zeros([10]))# 定义softmax层,输入是一个10个元素组成的向量。# tf.matmul是用于矩阵相乘的函数。y = tf.nn.softmax(tf.matmul(x, W)+b)# 再次定义一个placeholder,用于保存真实的图片标签。y_ = tf.placeholder(tf.float32, [None, 10])# 定义交叉熵,这是本程序需要使用的loss函数,我们的目的是使得这个loss函数尽可能的小。cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))# 我们使用梯度下降的方法来优化参数,这里把学习率设置为0.5,我们需要优化的函数是cross_entropy,即我们的loss函数。train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)# 初始化所有的全局变量。tf.global_variables_initializer().run()# 从这里开始训练我们建立好的模型。# 从上面的公式中可以看出,我们建立的是一个全连接的模型,本质上是对矩阵乘法的优化。# 我们训练1000次。for i in range(1000):    # 使用mnist自带的方法随机产生100个数据。    batch_xs, batch_ys = mnist.train.next_batch(100)    # 将这100个数据分别feed给上面我们定义的两个placeholder,由于训练模型。    train_step.run({x:batch_xs,y_:batch_ys})# 建立评估模型正确预测的Graph。correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))# 定义准确率的计算公式。accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))# 将测试集数据传递给两个placeholder,然后执行上述定义的准确率的公式,最后输出准确率的值。print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

运行结果如下:(因为使用了梯度下降的方法,因此每一次运行的结果或有不同,一般结果在0.92左右)

Extracting MNIST_data/train-images-idx3-ubyte.gzExtracting MNIST_data/train-labels-idx1-ubyte.gzExtracting MNIST_data/t10k-images-idx3-ubyte.gzExtracting MNIST_data/t10k-labels-idx1-ubyte.gz(55000, 784) (55000, 10)(10000, 784) (10000, 10)(5000, 784) (5000, 10)0.9184
阅读全文
0 0
原创粉丝点击