tensorflow tutorials(八):手写数字数据集MNIST介绍

来源:互联网 发布:是linux添加环境变量 编辑:程序博客网 时间:2024/06/08 08:58


声明:版权所有,转载请联系作者并注明出处:  http://blog.csdn.net/u013719780?viewmode=contents


在做机器学习相关实验的时候,首先我们就是需要一份通用的数据集,以便与其他的算法得到的实验结果进行比较。在图像分类领域MNIST数据集就是这样一个通用的数据集,前面几篇博文都用到了MNIST数据集,本文对其进行一些简单的介绍!


MNIST

In [1]:
import numpy as npimport tensorflow as tfimport matplotlib.pyplot as pltfrom tensorflow.examples.tutorials.mnist import input_data%matplotlib inline  print ("packs loaded")
packs loaded

Download and Extract MNIST dataset

In [2]:
print ("Download and Extract MNIST dataset")mnist = input_data.read_data_sets('/tmp/data/', one_hot=True)printprint (" tpye of 'mnist' is %s" % (type(mnist)))print (" number of trian data is %d" % (mnist.train.num_examples))print (" number of test data is %d" % (mnist.test.num_examples))
Download and Extract MNIST datasetExtracting /tmp/data/train-images-idx3-ubyte.gzExtracting /tmp/data/train-labels-idx1-ubyte.gzExtracting /tmp/data/t10k-images-idx3-ubyte.gzExtracting /tmp/data/t10k-labels-idx1-ubyte.gz tpye of 'mnist' is <class 'collections.Datasets'> number of trian data is 55000 number of test data is 10000
In [3]:
# What does the data of MNIST look like? print ("What does the data of MNIST look like?")trainimg   = mnist.train.imagestrainlabel = mnist.train.labelstestimg    = mnist.test.imagestestlabel  = mnist.test.labelsprintprint (" type of 'trainimg' is %s"    % (type(trainimg)))print (" type of 'trainlabel' is %s"  % (type(trainlabel)))print (" type of 'testimg' is %s"     % (type(testimg)))print (" type of 'testlabel' is %s"   % (type(testlabel)))print (" shape of 'trainimg' is %s"   % (trainimg.shape,))print (" shape of 'trainlabel' is %s" % (trainlabel.shape,))print (" shape of 'testimg' is %s"    % (testimg.shape,))print (" shape of 'testlabel' is %s"  % (testlabel.shape,))
What does the data of MNIST look like? type of 'trainimg' is <type 'numpy.ndarray'> type of 'trainlabel' is <type 'numpy.ndarray'> type of 'testimg' is <type 'numpy.ndarray'> type of 'testlabel' is <type 'numpy.ndarray'> shape of 'trainimg' is (55000, 784) shape of 'trainlabel' is (55000, 10) shape of 'testimg' is (10000, 784) shape of 'testlabel' is (10000, 10)
In [4]:
# How does the training data look like?print ("How does the training data look like?")nsample = 5randidx = np.random.randint(trainimg.shape[0], size=nsample)for i in randidx:    curr_img   = np.reshape(trainimg[i, :], (28, 28)) # 28 by 28 matrix     curr_label = np.argmax(trainlabel[i, :] ) # Label    plt.matshow(curr_img, cmap=plt.get_cmap('gray'))    plt.title("" + str(i) + "th Training Data "               + "Label is " + str(curr_label))    print ("" + str(i) + "th Training Data "            + "Label is " + str(curr_label))
How does the training data look like?12118th Training Data Label is 546324th Training Data Label is 833th Training Data Label is 436491th Training Data Label is 36910th Training Data Label is 3
<img src=""" style="box-sizing: border-box; border: 0px; vertical-align: middle; max-width: 100%; height: auto;" alt="">
In [5]:
# Batch Learning? print ("Batch Learning? ")batch_size = 100batch_xs, batch_ys = mnist.train.next_batch(batch_size)print ("type of 'batch_xs' is %s" % (type(batch_xs)))print ("type of 'batch_ys' is %s" % (type(batch_ys)))print ("shape of 'batch_xs' is %s" % (batch_xs.shape,))print ("shape of 'batch_ys' is %s" % (batch_ys.shape,))
Batch Learning? type of 'batch_xs' is <type 'numpy.ndarray'>type of 'batch_ys' is <type 'numpy.ndarray'>shape of 'batch_xs' is (100, 784)shape of 'batch_ys' is (100, 10)
In [6]:
# Get Random Batch with 'np.random.randint'print ("5. Get Random Batch with 'np.random.randint'")randidx   = np.random.randint(trainimg.shape[0], size=batch_size)batch_xs2 = trainimg[randidx, :]batch_ys2 = trainlabel[randidx, :]print ("type of 'batch_xs2' is %s" % (type(batch_xs2)))print ("type of 'batch_ys2' is %s" % (type(batch_ys2)))print ("shape of 'batch_xs2' is %s" % (batch_xs2.shape,))print ("shape of 'batch_ys2' is %s" % (batch_ys2.shape,))
5. Get Random Batch with 'np.random.randint'type of 'batch_xs2' is <type 'numpy.ndarray'>type of 'batch_ys2' is <type 'numpy.ndarray'>shape of 'batch_xs2' is (100, 784)shape of 'batch_ys2' is (100, 10)
In [7]:
randidx
Out[7]:
array([51472, 13751, 33562, 23281,  8489, 48481,  7799, 30307, 37366,       25312, 46149, 49712,  5083, 52853, 29819, 36444, 34829,  8769,       39518, 54911,  6720, 43675, 41703, 35594,  9300, 14474, 33318,       14808, 53456, 41978,  8047, 34524, 30978, 53455, 42119, 22660,       30329, 27169, 53798,  2125, 41759, 38951,  1438, 33511, 38784,       15822, 16785,  9229,  1216, 19569,  3116, 22172, 14766, 16153,        1707, 20899,  9087, 21263, 24853, 27784, 38324, 29287, 21828,       34511, 26340, 39194, 38272, 34238, 28050, 29294, 42672, 18696,       17796, 48147, 41841, 47077,  5925, 48237, 30605,  9169, 11260,        9155, 39346, 41049, 11342,   536,  5927, 11155, 40424, 33583,       38991, 16569, 34801,   870, 20546, 25061, 17601,  4521, 24359,  4613])

1 0
原创粉丝点击