TF-day1 MINIST识别数字
来源:互联网 发布:男士双肩背包推荐 知乎 编辑:程序博客网 时间:2024/06/16 13:43
当我们开始学习编程的时候,第一件事往往是学习打印”Hello World”。就好比编程入门有Hello World,机器学习入门有MNIST
主要步骤
- 获取数据
- 建立模型
- 定义 tensor,variable:X,W,b
- 定义损失函数,优化器:cross-entropy,gradient descent
- 训练模型:loop,batch
- 评价:准确率
一.获取MINIST 数据集
- 数据来自于http://yann.lecun.com/exdb/mnist/
- 数据分为train, validate, test三部分
from tensorflow.examples.tutorials.mnist import input_dataminist = input_data.read_data_sets("MINIST_data/",one_hot=True)print(minist)##Datasets(train=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x7f02cbc0cd30>, validation=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x7f02df872518>, test=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x7f02e1188a90>)
print(np.shape(minist.test.images))print(np.shape(minist.test.labels))print(np.shape(minist.train.images))print(np.shape(minist.train.labels))print(np.shape(minist.validation.images))print(np.shape(minist.validation.labels))###对应数据集(10000, 784)(10000, 10)(55000, 784)(55000, 10)(5000, 784)(5000, 10)
- 可以看到训练数据是(10000,784) ,对应的标签是(10000,10)
每一张图片包含28*28个像素点,故用一个数字数组来表示这张图片,我们把这个数组展开成一个向量,长度是 28x28 = 784. - 相对应的MNIST数据集的标签是介于0到9的数字,标签数据是”one-hot vectors”。 一个one-hot向量除了某一位的数字是1以外其余各维度数字都是0。比如,标签0将表示成([1,0,0,0,0,0,0,0,0,0,0])
目标:给了 X 后,预测它的 label 是属于 0~9 类中的哪一类
如果想要看数据属于多类中的哪一类,首先可以想到用 softmax 来做。
二.建立模型
softmax regression 有两步:
- 把 input 转化为某类的 evidence
把 evidence 转化为 probabilities
1. 把 input 转化为某类的 evidence
- 某一类的 evidence 就是像素强度的加权求和,再加上此类的 bias。
- 如果某个 pixel 可以作为一个 evidence证明图片不属于此类,则 weight 为负,否则的话 weight 为正。 下图中,红色代表负值,蓝色代表正值:
2. 把 evidence 转化为 probabilities
简单看,softmax 就是把 input 先做指数,再做一下归一:
归一的作用:好理解,就是转化成概率的性质
为什么要取指数:《常用激活函数比较》http://www.jianshu.com/p/22d9720dbf1a
第一个原因是要模拟 max 的行为,所以要让大的更大。
第二个原因是需要一个可导的函数。
用图片表示:
用公式表示:
代码表示:
##实现回归模型y = tf.nn.softmax(tf.matmul(x,w) + b)
三.定义 tensor 和 variable:
x = tf.placeholder(tf.float32,[None,784])w = tf.Variable(tf.zeros([784,10]))b = tf.Variable(tf.zeros([10]))
三.定义损失函数\优化器
这里采用成本函数是“交叉熵”(cross-entropy)。
y 是预测的概率分布, y’ 是实际的分布(我们输入的one-hot vector)。
##训练模型y_ = tf.placeholder("float",[None,10])cross_entropy = -tf.reduce_sum(y_*tf.log(y))
然后用 backpropagation, 且 gradient descent 作为优化器,来训练模型,使得 loss 达到最小:
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
四.训练模型
for _ in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
五.评价模型
###评估模型correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))print(sess.run(accuracy,feed_dict={x:minist.test.images,y_: minist.test.labels}))
完整代码:
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Functions for downloading and reading MNIST data."""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport gzipimport osimport tempfileimport numpy as npfrom six.moves import urllibfrom six.moves import xrange # pylint: disable=redefined-builtinimport tensorflow as tffrom tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets##导入数据from tensorflow.examples.tutorials.mnist import input_dataminist = input_data.read_data_sets("MINIST_data/",one_hot=True)##实现回归模型x = tf.placeholder(tf.float32,[None,784])w = tf.Variable(tf.zeros([784,10]))b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x,w) + b)##训练模型y_ = tf.placeholder("float",[None,10])cross_entropy = -tf.reduce_sum(y_*tf.log(y))train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)init = tf.initialize_all_variables()sess = tf.Session()sess.run(init)for i in range(1000): batch_xs, batch_ys = minist.train.next_batch(100) sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})###评估模型correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))print(sess.run(accuracy,feed_dict={x:minist.test.images,y_: minist.test.labels}))
- TF-day1 MINIST识别数字
- python tensorflow 使用minist数据集实现手写数字识别
- TF-day3 mnist识别数字
- keras---minist手写识别
- tensorflow08 《TensorFlow实战Google深度学习框架》笔记-05-01minist数字识别问题code
- libsvm Minist Hog 手写体识别
- libsvm Minist Hog 手写体识别
- KNN(k-NearestNeighbor)识别minist数据集
- deeplearning----利用逻辑回归分类MINIST数字
- 用CNN及MLP等方法识别minist数据集
- 使用CNN实现手写体识别(minist库)
- TF卡经常不能识别
- tf-Mnist手写字体识别
- 数字识别
- 数字识别
- 数字识别
- 数字识别
- Tensorflow MINIST数据模型的训练,保存,恢复和手写字体识别
- Android Wi-Fi 5G Only时Wi-Fi不可用
- [GNU/Linux] 自己实现ls
- 822C
- Storm初识
- python学习笔记(十三)标准库heapq
- TF-day1 MINIST识别数字
- 单点登录-CAS介绍
- IO流学习小结
- 分页的快捷方式
- 【JS】【个人学习小记】输出当天为星期几的快速方法
- react 如何获取节点内容
- CCS中如何新建Platform以及调用
- 低功耗蓝牙cc2541学习笔记之UART-3-协议栈uart发送 实验
- 分享一个Linux性能调优/诊断网站