Ubuntu下Tensorflow加载MNIST数据集(数据下载和读取)

来源:互联网 发布:剑灵不知火舞捏脸数据 编辑:程序博客网 时间:2024/06/05 04:56

作为小白初入茅庐,根据tensorflow中文版的参考文档,以及自己平时学习的体会,总结了在ubuntu下进行深度学习的过程。

作为每个学习深度学习的人来说,测试手写数字的训练应该是每个人都会经历的过程。这里我先将自己在学习过程中参考的一些paper在这里一一列出,大家可以先参看这些paper以后会有更好的了解。
[ TensorFlow中文社区]

[ MNIST数据库]

[ Softmax函数的基本原理]

[ Keras中文文档社区]

[MNIST函数不同性能分类函数的对比]

好了可以开始啦
此教程是在ubuntu下安装cpu版的tensorflow,用于做简单的深度学习的训练,(当然也可以GPU 加速版的,自己本人曾经配置成功过,现在愿意再提那些伤感的岁月)。
关于如何配置tensorflow有很多相关的教程,这里不再一一叙述,只是想大家推荐一些我自己认为比较好的网址,供大家参考:
[1]:http://blog.csdn.net/zhaoyu106/article/details/52793183
[2]:http://blog.csdn.net/u010789558/article/details/51867648
[3]: http://textminingonline.com/dive-into-tensorflow-part-iii-gtx-1080-ubuntu16-04-cuda8-0-cudnn5-0-tensorflow
[4]:http://m.blog.csdn.net/article/details?id=52658965
[5]:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/g3doc/get_started/os_setup.md#installing-from-sources
[6]:http://www.tensorfly.cn/tfdoc/get_started/os_setup.html
[7]:http://ramhiser.com/2016/01/05/installing-tensorflow-on-an-aws-ec2-instance-with-gpu-support/
[8]:http://blog.csdn.net/u012436149/article/details/52554176
[9]:http://m.blog.csdn.net/article/details?id=51999566

以上的参考文档是笔者在配置tensorflow过程中,所参考的一些博客和官方文档,这里想要提示大家的是在配置的过程中一定要注意版本的兼容性,此处有很多坑,笔者就掉进去过很多次。

笔者这里是基于pip安装的tensorflow,大家可以根据自己的需要来完成安装,也推荐大家基于Anaconda去安装tensorflow,省去了很多附加包的安装。
在linux下测试自己是否安装成功,有时间会更新自己的安装流程。
Tensorflow参考安装流程

在linux控制台下输入python进行测试$ python>>> import tensorflow as tf>>> hello = tf.constant('Hello, TensorFlow!')>>> sess = tf.Session()>>> print sess.run(hello)Hello, TensorFlow!>>> a = tf.constant(10)>>> b = tf.constant(32)>>> print sess.run(a+b)42

linux控制台下测试


一、获取mnist数据

首先可以从前面给出的MNIST数据库来获取手写数字分类识别,为了快速导入数据集,可以翻墙或者直接下载四个数据集包放在自己定义的文件中,tensorflow提供专门用于下载mnist数据的函数input_data.py文件,其直接包含在tensorflow中,无需下载安装。

import tensorflow.examples.tutorials.mnist.input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

执行完以后,会在自己的新建的文件中自动生成一个MNIST_data文件夹
这里写图片描述

之后将下载的四个压缩文件放到该文件夹下

文件 内容 t10k-images-idx3-ubyte.gz 测试集图片 10000张图片 t10k-labels-idx1-ubyte.gz 测试集图片对应的数字标签 train-images-idx3-ubyte.gz 训练集图片 55000张训练图片,其中包含5000张验证图片 train-labels-idx1-ubyte.gz 训练集图片对应的数字标签

input_data文件会调用一个maybe_download函数,确保数据下载成功。这个函数还会判断数据是否已经下载,如果已经下载好了,就不再重复下载。

下载下来的数据集被分三个子集:5.5W行的训练数据集(mnist.train),5千行的验证数据集(mnist.validation)和1W行的测试数据集(mnist.test)。因为每张图片为28x28的黑白图片,所以每行为784维的向量。

每个子集都由两部分组成:图片部分(images)和标签部分(labels), 我们可以用下面的代码来查看 :

print(mnist.train.images.shape)print mnist.train.labels.shapeprint mnist.validation.images.shapeprint mnist.validation.labels.shapeprint mnist.test.images.shapeprint mnist.test.labels.shape

执行的结果是(笔者是在jupyter notebook中实现测试,具体如何使用jupyter notebook以后会进行更新)

Extracting MNIST_data/train-images-idx3-ubyte.gzExtracting MNIST_data/train-labels-idx1-ubyte.gzExtracting MNIST_data/t10k-images-idx3-ubyte.gzExtracting MNIST_data/t10k-labels-idx1-ubyte.gz(55000, 784)(55000, 10)(5000, 784)(5000, 10)(10000, 784)(10000, 10)

同时,在进行深度学习的过程中,也会用到一些公认的测试数据,如CVS数据,可以直接导入:

import tensorflow.contrib.learn.python.learn.datasets.base as baseiris_data,iris_label=base.load_iris()house_data,house_label=base.load_boston()

如cifar10数据:

import tensorflow.models.image.cifar10.cifar10 as cifar10cifar10.maybe_download_and_extract()images, labels = cifar10.distorted_inputs()print imagesprint labels

[1]http://colah.github.io/posts/2014-10-Visualizing-MNIST/
[2]http://colah.github.io/posts/2015-08-Backprop/
[3]http://colah.github.io/posts/2015-09-Visual-Information/
[4] http://math.stackexchange.com/
[5] https://github.com/jmcmanus/pagedown-extra
[6] http://meta.math.stackexchange.com/questions/5020/mathjax-basic-tutorial-and-quick-reference
[7] http://bramp.github.io/js-sequence-diagrams/
[8] http://adrai.github.io/flowchart.js/
[9] https://github.com/benweet/stackedit

原创粉丝点击