机器学习入门--MNIST(一)
来源:互联网 发布:java try的用法 编辑:程序博客网 时间:2024/05/29 18:26
最近在看机器学习TensorFlow,就像其他任何一门语言(当然机器学习不仅仅是语言)都有一个"hello world",可以说MNIST是机器学习的"hello world"。
极客学院有TensorFlow的官方文档(中文版),里面对MNIST做了详细的介绍,比如模型的建立、原理等等(网址:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html)。
我在根据这个文档做成Demo的时候,遇到了一些坑,特整理记录一下,为后面掉坑里同仁提供一些参考。
其实,总结一下,MNIST Demo做成需要如下几个步骤:
1.MINST数据集下载
下载网址:http://yann.lecun.com/exdb/mnist/
下载四个压缩包:
train-images-idx3-ubyte.gz: training set images (9912422 bytes) train-labels-idx1-ubyte.gz: training set labels (28881 bytes) t10k-images-idx3-ubyte.gz: test set images (1648877 bytes) t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
我上传了下载好的压缩包,地址:http://download.csdn.net/download/sarsscofy/10130439
不知道为啥资源分不能设置成0 T_T,如果可以上官网还是去官网下吧.
2.下载自动安装和下载MINST数据集的Python代码,并存储成文件
1)网址:https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/mnist/input_data.py
具体Python代码如下:
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Functions for downloading and reading MNIST data."""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport gzipimport osimport tempfileimport numpyfrom six.moves import urllibfrom six.moves import xrange # pylint: disable=redefined-builtinimport tensorflow as tffrom tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
2)新建一个文件,直接拷贝代码上述路径下打开的Python代码,并存储成input_data.py,注意存储的格式为utf-8。
3.由于后续文件中还需要引用读取下载的文件和代码,所以上述内容最好放在一个目录下(即可以看成是一个Module),并分类。
下图是我的结构:
其中:
MNIST_data文件夹 存放的是第一步中下载的四个MNIST数据集文件。
__init__.py 是空文件。
input_data.py 是自动安装和下载MINST数据集的Python代码
mnist_softmax1.py 是主程序。
具体的下面说。
4.新建文件,我命名为mnist_softmax1.py,即MNIST用Softmax 回归来进行训练的代码,具体如下:
#!/usr/bin/env python3 # -*- coding: utf-8 -*- import input_dataimport tensorflow as tf#x不是一个特定的值,而是一个占位符placeholder,我们在TensorFlow运行计算时输入这个值。#我们希望能够输入任意数量的MNIST图像,每一张图展平成784维的向量。#我们用2维的浮点数张量来表示这些图,这个张量的形状是[None,784 ]。#(这里的None表示此张量的第一个维度可以是任何长度的。)mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)#权重值,初始值全为0x = tf.placeholder(tf.float32, [None, 784])#偏置量,初始值全为0W = tf.Variable(tf.zeros([784,10]))b = tf.Variable(tf.zeros([10]))#建立模型,y是匹配的概率#tf.matmul(x,W)表示x乘以W#y是预测,y_是实际y = tf.nn.softmax(tf.matmul(x,W) + b)#为计算交叉熵,添加的placeholdery_ = tf.placeholder("float", [None,10])#交叉熵cross_entropy = -tf.reduce_sum(y_*tf.log(y))#用梯度下降算法(gradient descent algorithm)以0.01的学习速率最小化交叉熵train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)#初始化我们创建的变量init = tf.global_variables_initializer()#在Session里面启动模型sess = tf.Session()sess.run(init)#训练模型#循环的每个步骤中,都会随机抓取训练数据中的100个批处理数据点,然后用这些数据点作为参数替换之前的占位符来运行train_step#即:使用的随机梯度下降训练方法for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})#-------------------模型评估----------------------#判断预测标签和实际标签是否匹配 #tf.argmax 找出某个tensor对象在某一维上的其数据最大值所在的索引值correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))#计算所学习到的模型在测试数据集上面的正确率 print (sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
我是根据官方文档的思路一行代码一行代码拷贝的,建议新学的XDJM也这样,可以理解更深刻些。
需要注意的是:
1)读取MNIST数据集相对路径要跟自己实际定义的存放路径一致:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
”MNIST_data“是存放MNIST数据集的文件夹名称
2)初始化方法不同,我用的TesorFlow是0.12版本,官方文档中的方法”tf.initialize_all_variables()“已废弃,换成”tf.global_variables_initializer()“。
5.cmd打开命令行窗口,进入demo路径,执行主程序即可看到结果。
当然这个结果不是既定的,多运行几次得到的结果可能不一致:
6.FAQ
1)错误 UnicodeDecodeError: 'utf-8' codec can'tdecode byte 0xb2 in position 168: invalidstart byte
具体如下:
Traceback (most recent call last): File "mnist_softmax1.py", line 4, in <module> File "E:\CInfos\tensortflow\testcode\mnist\input_data.py", line 28, in <module> import tensorflow as tf.......UnicodeDecodeError: 'utf-8' codec can't decode byte 0xb2 in position 168: invalid start byte
input_data.py 、mnist_softmax1.py 存储格式为utf-8即可。
(其他的想起来再追加)
7.参考资料
http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html
http://blog.csdn.net/willduan1/article/details/52024254
-----转载请说明出处,谢谢~
- MNIST机器学习入门(一)
- 机器学习入门--MNIST(一)
- MNIST 机器学习入门 (一)
- MNIST机器学习入门
- MNIST机器学习入门
- MNIST机器学习入门
- MNIST机器学习入门
- MNIST机器学习入门
- MNIST机器学习入门
- MNIST机器学习入门
- MNIST机器学习入门
- 机器学习入门--MNIST(二)
- tensorflow MNIST机器学习入门
- tensorflow- MNIST机器学习入门
- tensorflow-MNIST机器学习入门
- TensorFLow学习(一),Mnist入门
- <二>、TensorFlow之MNIST机器学习入门(1)
- Tensorflow MNIST机器学习入门 分类学习
- sql常用语句
- MyBatis之自查询使用递归实现 N级联动效果(两种实现方式)
- Python学习_我该怎么使用字典
- 浅谈Spring注解
- LeetCode.55 Jump Game
- 机器学习入门--MNIST(一)
- IBM服务器引导盘serverguide 下载
- Somethings about the coding in Python
- UBoot:ENTRY等宏的展开,CPSR寄存器的设置(stat.S)
- sql排序问题
- Android音频子系统,音频流的回放(四)
- 【Java】【线程同步】sleep,join,yield,synchronized,wait,notify
- HDU 1029
- elasticsearch开发学习