TensorFlow学习(一)MNIST数据库例程执行
来源:互联网 发布:淘宝国内lolita品牌 编辑:程序博客网 时间:2024/06/07 05:19
来自tensorflow的中文文档官方例程。重新跑了一次并做些注解,希望能帮到一些人~
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
#数据输入
##不要怕麻烦,所有的问题都是可以解决;找找其他途径,不要在一颗书上吊死
##比如这里,网络上下载不了,就直接拷贝,怕路径不对,就看清楚代码,改好路径
mnist = input_data.read_data_sets('/home/wwb/文档/tensorflow/MNIST-data', one_hot=True)
mx = mnist.train.images
my_ = mnist.train.labels
#tensorflow 的一般流程,创建图,然后在 session 启动
#如果使用InteractiveSession,在运行图的时候,插入一些计算图;
#如果没使用此函数,需要在启动session前,构建整个计算图,然后启动计算图
#TensorFlow也是在Python外部完成其主要工作,但是进行了改进以避免这种开销。
#其并没有采用在Python外部独立运行某个耗时操作的方式,
#而是先让我们描述一个交互操作图,然后完全将其运行在Python外部。
sess = tf.InteractiveSession()
#构建具有一个线性层的softmaxi模型
##创建占位符 784为长×宽,10为分类
x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])
##定义w 和 b,一般用variable(使用前必须初始化),,这里初始化为0向量
##w 784个特征,10个输出(10个分类)
##b 是10维向量
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
##varibale初始化后,分配给变量,可一次性分配所有,如下
sess.run(tf.global_variables_initializer())
#类别预测和损失函数
##实现回归模型,把向量化后的图片 x 和权重w相乘,然后加b;在计算softmax值
y = tf.nn.softmax(tf.matmul(x,W) + b)
##使用最小化误差函数,目标类别和预测类别的交叉熵(计算的是整个minibatch)
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
#模型训练
##多种优化算法供选择,这里采用最速下降法
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
##返回的train_step为操作对象,可以反复运行 train_step
##每一步迭代,我们都会加载50个训练样本,然后执行一次train_step;
##并通过feed_dict将x 和 y_张量占位符用训练训练数据替代。
for i in range(1000):
batch = mnist.train.next_batch(50)
train_step.run(feed_dict={x: batch[0], y_: batch[1]})
#评估模型
##找到预测正确的标签
###tf.argmax 是一个非常有用的函数,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值
###tf.argmax(y,1)分别代表tf.argmax(y_,1)预测值与正确值
###下面的函数,返回布尔
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
###将布尔值转化为浮点值,然后取平均,提高准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels})
阅读全文
0 0
- TensorFlow学习(一)MNIST数据库例程执行
- caffe学习笔记(一):MNIST例程
- TensorFlow学习笔记(一)---MNIST
- TensorFLow学习(一),Mnist入门
- Tensorflow学习笔记(对MNIST经典例程的)的代码运行
- TensorFlow学习笔记(一)MNIST手写字识别
- Tensorflow入门(MNIST学习)
- TensorFLow学习(一)---原生Windows安装TensorFlow,进行MNIST机器学习
- tensorflow mnist实战笔记(一)了解官方mnist数据格式
- Tensorflow入门二 mnist识别(一)
- tensorflow入门 ubuntu mnist cnn例程
- TensorFlow笔记之MNIST例程详解
- TensorFlow学习(二),深入MNIST
- tensorflow学习笔记(四):mnist
- TensorFlow个人学习(训练 MNIST 数据 )
- TensorFlow学习笔记(二)MNIST入门
- Tensorflow学习:MNIST 识别
- TensorFlow-mnist-学习
- CodeForces
- php 数组创建方法
- java多线程
- 【csapp】【微软面试题】有符号数到无符号数隐式转换
- Spring mvc 注解使用
- TensorFlow学习(一)MNIST数据库例程执行
- MIT18.06线性代数课程笔记5:矩阵转置,vector space以及subspace
- 树形结构的查找(二叉排序树-创建、查找、插入、删除)
- python 利用random生成验证码与MD5码加密过程
- Leetcode121.+Leetcode53. Kadane算法解决最大子数组问题
- java工程结构管理
- Python 线程,独立的线程空间(threading.local())
- Android移动开发-使用HttpClient访问被保护资源的实现
- java设计模式之单例模式