TensorFlow实战—mnist手写数字识别
来源:互联网 发布:儿童电脑编程 编辑:程序博客网 时间:2024/05/16 13:54
# -*- coding: utf-8 -*-"""Created on Fri Aug 18 13:19:20 2017@author: zx"""from tensorflow.examples.tutorials.mnist import input_dataimport tensorflow as tfmnist = input_data.read_data_sets("MNIST_data", one_hot = True)x = tf.placeholder("float32",[None,784])w = tf.Variable(tf.zeros([784,10]))b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x,w) + b)y_ = tf.placeholder("float32",[None,10])cross_entropy = -tf.reduce_sum(y_*tf.log(y))train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)init = tf.initialize_all_variables()sess = tf.Session()sess.run(init)for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))print (sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
它也包含每一张图片对应的标签,告诉我们这个是数字几。比如,上面这四张图片的标签分别是5,0,4,1。
在此教程中,我们将训练一个机器学习模型用于预测图片里面的数字。我们的目的不是要设计一个世界一流的复杂模型 -- 尽管我们会在之后给你源代码去实现一流的预测模型 -- 而是要介绍下如何使用TensorFlow。所以,我们这里会从一个很简单的数学模型开始,它叫做Softmax Regression。
对应这个教程的实现代码很短,而且真正有意思的内容只包含在三行代码里面。但是,去理解包含在这些代码里面的设计思想是非常重要的:TensorFlow工作流程和机器学习的基本概念。因此,这个教程会很详细地介绍这些代码的实现原理。
不多说,上代码。。。。
# -*- coding: utf-8 -*-"""Created on Fri Aug 18 13:19:20 2017@author: zx"""from tensorflow.examples.tutorials.mnist import input_dataimport tensorflow as tfmnist = input_data.read_data_sets("MNIST_data", one_hot = True)x = tf.placeholder("float32",[None,784])w = tf.Variable(tf.zeros([784,10]))b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x,w) + b)y_ = tf.placeholder("float32",[None,10])cross_entropy = -tf.reduce_sum(y_*tf.log(y))train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)init = tf.initialize_all_variables()sess = tf.Session()sess.run(init)for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))print (sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
这个最终结果值应该大约是91%。
阅读全文
0 0
- TensorFlow实战—mnist手写数字识别
- tensorflow-mnist手写数字识别
- TensorFlow实战-mnist手写数字识别(卷积神经网络)
- 《学习Tensorflow》——MNIST手写数字识别
- tensorflow入门实践例子—MNIST手写数字识别
- 基于tensorflow的MNIST手写数字识别
- 基于tensorflow的MNIST手写数字识别
- Tensorflow 实现 MNIST 手写数字识别
- 神经网络-tensorflow实现mnist手写数字识别
- tensorflow中mnist手写数字识别
- tensorflow中logistic识别mnist手写数字
- tensorflow中MLP识别mnist手写数字
- tensorflow构建RNN识别mnist手写数字
- TensorFlow学习---实现mnist手写数字识别
- tensorflow进行MNIST手写数字识别-CNN
- tensorflow进行MNIST手写数字识别-LSTM
- 训练Tensorflow识别手写数字 mnist
- TensorFlow笔记之一:MNIST手写数字识别
- Android UI 自动化测试之UiSelector
- pdf转jpg格式转换器怎么用
- Windows 生物统计框架结构简介(WBF) (指纹识别技术)
- [SMOJ1894]战争
- 博弈--类似Bash--hdu1517 A Multiplication Game
- TensorFlow实战—mnist手写数字识别
- Java网络与线程
- 转的 侵立删 java内存模型
- BFS DFS算法,和动态规划
- 黑苹果的日记 ---缓存和一些图片视频的处理
- hdu 6143 Killer Names
- int main(int argc, char* argv[])问题(2)-传字符串
- [codevs2959]阶乘质因数分解
- Android SharedPreferencesHelper简单封装