tensorflow tutorials(八):手写数字数据集MNIST介绍

来源:互联网 发布:是linux添加环境变量 编辑:程序博客网 时间:2024/06/08 08:58

声明:版权所有,转载请联系作者并注明出处:  http://blog.csdn.net/u013719780?viewmode=contents



In [1]:
import numpy as npimport tensorflow as tfimport matplotlib.pyplot as pltfrom tensorflow.examples.tutorials.mnist import input_data%matplotlib inline  print ("packs loaded")
packs loaded

Download and Extract MNIST dataset

In [2]:
print ("Download and Extract MNIST dataset")mnist = input_data.read_data_sets('/tmp/data/', one_hot=True)printprint (" tpye of 'mnist' is %s" % (type(mnist)))print (" number of trian data is %d" % (mnist.train.num_examples))print (" number of test data is %d" % (mnist.test.num_examples))
Download and Extract MNIST datasetExtracting /tmp/data/train-images-idx3-ubyte.gzExtracting /tmp/data/train-labels-idx1-ubyte.gzExtracting /tmp/data/t10k-images-idx3-ubyte.gzExtracting /tmp/data/t10k-labels-idx1-ubyte.gz tpye of 'mnist' is <class 'collections.Datasets'> number of trian data is 55000 number of test data is 10000
In [3]:
# What does the data of MNIST look like? print ("What does the data of MNIST look like?")trainimg   = mnist.train.imagestrainlabel = mnist.train.labelstestimg    = mnist.test.imagestestlabel  = mnist.test.labelsprintprint (" type of 'trainimg' is %s"    % (type(trainimg)))print (" type of 'trainlabel' is %s"  % (type(trainlabel)))print (" type of 'testimg' is %s"     % (type(testimg)))print (" type of 'testlabel' is %s"   % (type(testlabel)))print (" shape of 'trainimg' is %s"   % (trainimg.shape,))print (" shape of 'trainlabel' is %s" % (trainlabel.shape,))print (" shape of 'testimg' is %s"    % (testimg.shape,))print (" shape of 'testlabel' is %s"  % (testlabel.shape,))
What does the data of MNIST look like? type of 'trainimg' is <type 'numpy.ndarray'> type of 'trainlabel' is <type 'numpy.ndarray'> type of 'testimg' is <type 'numpy.ndarray'> type of 'testlabel' is <type 'numpy.ndarray'> shape of 'trainimg' is (55000, 784) shape of 'trainlabel' is (55000, 10) shape of 'testimg' is (10000, 784) shape of 'testlabel' is (10000, 10)
In [4]:
# How does the training data look like?print ("How does the training data look like?")nsample = 5randidx = np.random.randint(trainimg.shape[0], size=nsample)for i in randidx:    curr_img   = np.reshape(trainimg[i, :], (28, 28)) # 28 by 28 matrix     curr_label = np.argmax(trainlabel[i, :] ) # Label    plt.matshow(curr_img, cmap=plt.get_cmap('gray'))    plt.title("" + str(i) + "th Training Data "               + "Label is " + str(curr_label))    print ("" + str(i) + "th Training Data "            + "Label is " + str(curr_label))
How does the training data look like?12118th Training Data Label is 546324th Training Data Label is 833th Training Data Label is 436491th Training Data Label is 36910th Training Data Label is 3
<img src=""" style="box-sizing: border-box; border: 0px; vertical-align: middle; max-width: 100%; height: auto;" alt="">
In [5]:
# Batch Learning? print ("Batch Learning? ")batch_size = 100batch_xs, batch_ys = mnist.train.next_batch(batch_size)print ("type of 'batch_xs' is %s" % (type(batch_xs)))print ("type of 'batch_ys' is %s" % (type(batch_ys)))print ("shape of 'batch_xs' is %s" % (batch_xs.shape,))print ("shape of 'batch_ys' is %s" % (batch_ys.shape,))
Batch Learning? type of 'batch_xs' is <type 'numpy.ndarray'>type of 'batch_ys' is <type 'numpy.ndarray'>shape of 'batch_xs' is (100, 784)shape of 'batch_ys' is (100, 10)
In [6]:
# Get Random Batch with 'np.random.randint'print ("5. Get Random Batch with 'np.random.randint'")randidx   = np.random.randint(trainimg.shape[0], size=batch_size)batch_xs2 = trainimg[randidx, :]batch_ys2 = trainlabel[randidx, :]print ("type of 'batch_xs2' is %s" % (type(batch_xs2)))print ("type of 'batch_ys2' is %s" % (type(batch_ys2)))print ("shape of 'batch_xs2' is %s" % (batch_xs2.shape,))print ("shape of 'batch_ys2' is %s" % (batch_ys2.shape,))
5. Get Random Batch with 'np.random.randint'type of 'batch_xs2' is <type 'numpy.ndarray'>type of 'batch_ys2' is <type 'numpy.ndarray'>shape of 'batch_xs2' is (100, 784)shape of 'batch_ys2' is (100, 10)
In [7]:
array([51472, 13751, 33562, 23281,  8489, 48481,  7799, 30307, 37366,       25312, 46149, 49712,  5083, 52853, 29819, 36444, 34829,  8769,       39518, 54911,  6720, 43675, 41703, 35594,  9300, 14474, 33318,       14808, 53456, 41978,  8047, 34524, 30978, 53455, 42119, 22660,       30329, 27169, 53798,  2125, 41759, 38951,  1438, 33511, 38784,       15822, 16785,  9229,  1216, 19569,  3116, 22172, 14766, 16153,        1707, 20899,  9087, 21263, 24853, 27784, 38324, 29287, 21828,       34511, 26340, 39194, 38272, 34238, 28050, 29294, 42672, 18696,       17796, 48147, 41841, 47077,  5925, 48237, 30605,  9169, 11260,        9155, 39346, 41049, 11342,   536,  5927, 11155, 40424, 33583,       38991, 16569, 34801,   870, 20546, 25061, 17601,  4521, 24359,  4613])

1 0