Tensorflow深度学习之七:再谈mnist手写数字识别程序
来源:互联网 发布:怎么写销售数据分析表? 编辑:程序博客网 时间:2024/06/08 01:28
之前学习的第一个深度学习的程序就是mnist手写字体的识别,那个时候对于很多概念不是很理解,现在回过头再看当时的代码,理解了很多,现将加了注释的代码贴上,与大家分享。(本人还是在学习Tensorflow的初始阶段,如果有什么地方理解有误,还请大家不吝指出。)
from tensorflow.examples.tutorials.mnist import input_data# 下载mnist数据集至当前目录下的MNIST_data文件夹,并读取数据。mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)# 输出训练集,测试集,验证集图片的shape和标签的shape。print(mnist.train.images.shape, mnist.train.labels.shape)print(mnist.test.images.shape, mnist.test.labels.shape)print(mnist.validation.images.shape, mnist.validation.labels.shape)import tensorflow as tf# 默认会话。sess = tf.InteractiveSession()# 定义一个placeholder,用于保存输入的图片的信息。# 由于图片中的数值是0~1之间的浮点数,所以x的数据类型也应是tf.float32。# 第二个参数表示x的维度,其中None表示不限制输入的数量,之后的参数便是输入的数据的维度,# 这里的784表示输入的是一个长度为784的一维向量。x = tf.placeholder(tf.float32, [None, 784])# 定义权重变量,该变量是一个784x10的矩阵,这里将初始权重全部赋值为0。W = tf.Variable(tf.zeros([784, 10]))# 定义偏置值变量,该变量是一个由10个元素组成的向量,同样这里的偏置值变量的初始值也全部被赋值为0。b = tf.Variable(tf.zeros([10]))# 定义softmax层,输入是一个10个元素组成的向量。# tf.matmul是用于矩阵相乘的函数。y = tf.nn.softmax(tf.matmul(x, W)+b)# 再次定义一个placeholder,用于保存真实的图片标签。y_ = tf.placeholder(tf.float32, [None, 10])# 定义交叉熵,这是本程序需要使用的loss函数,我们的目的是使得这个loss函数尽可能的小。cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))# 我们使用梯度下降的方法来优化参数,这里把学习率设置为0.5,我们需要优化的函数是cross_entropy,即我们的loss函数。train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)# 初始化所有的全局变量。tf.global_variables_initializer().run()# 从这里开始训练我们建立好的模型。# 从上面的公式中可以看出,我们建立的是一个全连接的模型,本质上是对矩阵乘法的优化。# 我们训练1000次。for i in range(1000): # 使用mnist自带的方法随机产生100个数据。 batch_xs, batch_ys = mnist.train.next_batch(100) # 将这100个数据分别feed给上面我们定义的两个placeholder,由于训练模型。 train_step.run({x:batch_xs,y_:batch_ys})# 建立评估模型正确预测的Graph。correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))# 定义准确率的计算公式。accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))# 将测试集数据传递给两个placeholder,然后执行上述定义的准确率的公式,最后输出准确率的值。print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
运行结果如下:(因为使用了梯度下降的方法,因此每一次运行的结果或有不同,一般结果在0.92左右)
Extracting MNIST_data/train-images-idx3-ubyte.gzExtracting MNIST_data/train-labels-idx1-ubyte.gzExtracting MNIST_data/t10k-images-idx3-ubyte.gzExtracting MNIST_data/t10k-labels-idx1-ubyte.gz(55000, 784) (55000, 10)(10000, 784) (10000, 10)(5000, 784) (5000, 10)0.9184
阅读全文
0 0
- Tensorflow深度学习之七:再谈mnist手写数字识别程序
- Tensorflow深度学习之八:再探CNN解决mnist手写数字识别问题
- TensorFlow学习---实现mnist手写数字识别
- tensorflow-mnist手写数字识别
- Tensorflow深度学习笔记(五)--手写数字识别-MNIST数据测试
- TensorFlow学习笔记(3)----CNN识别MNIST手写数字
- 《学习Tensorflow》——MNIST手写数字识别
- TensorFlow学习笔记(二)MNIST手写数字识别
- Tensorflow之 CNN卷积神经网络的MNIST手写数字识别
- TensorFlow MNIST 手写数字识别之过拟合
- TensorFlow 卷积神经网络之MNIST 手写数字识别
- TensorFlow学习笔记之源码分析(2)----手写数字识别mnist example
- 深度学习- 用Torch实现MNIST手写数字识别
- 深度学习笔记5torch实现mnist手写数字识别
- 基于tensorflow的MNIST手写数字识别
- 基于tensorflow的MNIST手写数字识别
- Tensorflow 实现 MNIST 手写数字识别
- 神经网络-tensorflow实现mnist手写数字识别
- 第一次的博客
- AbstractFeatureExtractor API
- c++动态内存管理
- android 水波纹
- 立下个flag,从今天开始我要写有深度的文章了,也要好好学习,考研去了
- Tensorflow深度学习之七:再谈mnist手写数字识别程序
- Linux学习-服务器B挂载服务器A的文件夹
- Hotel POJ 3667
- mysql数据库的备份
- ocp Oracle.1z0-053 711 ---19
- Heap Sort | out-of-place and in-place implementation
- php的一些易错点
- 《强化学习》学习笔记整理与提炼
- 1162-------C语言实验——保留字母(字符串)