python读取MNIST

来源:互联网 发布:google tensorflow 编辑:程序博客网 时间:2024/05/22 17:11

MNIST的数据结构:
MNIST的数据结构

import numpy as npimport matplotlib.pyplot as pltimport structimport osdef read_MNIST_train():    #load images    with open(os.curdir+'/MNIST_data/train-images.idx3-ubyte', 'rb') as train_images:        buf = train_images.read()        index = 0        magic, numImgs, numRows, numCols = struct.unpack_from('>IIII', buf, offset=index)        # print(magic, numImgs, numRows, numCols)        index += struct.calcsize('>IIII')        imgs = struct.unpack_from('>'+str(numImgs*numRows*numCols)+'B', buf, offset=index)        imgs = np.array(imgs)        X = imgs.reshape((numImgs, numRows*numCols))    # show image    #     x0 = X[2,:]    #     im = np.array(x0)    #     im = im.reshape(28, 28)    #     fig = plt.figure()    #     plotwindow = fig.add_subplot(111)    #     plt.imshow(im, cmap='gray')    #     plt.show()    #load labels    with open(os.curdir+'/MNIST_data/train-labels.idx1-ubyte', 'rb') as train_labels:        buf = train_labels.read()        index = 0        magic, numLabs = struct.unpack_from('>II', buf, offset=index)        # print(magic, numLabs)        index += struct.calcsize('>II')        labs = struct.unpack_from('>'+str(numLabs)+'B', buf, offset=index)        y = np.zeros((numLabs, 10))        y[range(numLabs), labs] = 1    return X, yif __name__ == '__main__':    X, y = read_MNIST_train()    #输出第n副图像的标签及图像    n = 100    l = y[n,:]    l = np.where(l==1)[0]    print(l)    x = X[n,:]    im = np.array(x)    im = im.reshape(28, 28)    fig = plt.figure()    plotwindow = fig.add_subplot(111)    plt.imshow(im, cmap='gray')    plt.show()

最终得到了60000*784的图像矩阵X,以及对应的60000*10的标签矩阵y

0 0
原创粉丝点击