图片格式mnist数据集转换

来源:互联网 发布:黑帽seo劫持跳转代码 编辑:程序博客网 时间:2024/06/08 21:07

        mnist源数据集下载自http://yann.lecun.com/exdb/mnist/。原数据共有四个文件:train-images-idx3-ubyte为训练图片,共60000张,每张灰度图片存为一个1×784维向量;train-labels-idx1-ubyte为训练图片的标签,共60000个;t10k-images-idx3-ubyte为测试图片,共10000张,每张亦为1×784维向量;t10k-labels-idx1-ubyte为测试图片的标签。

        为进行图片卷积操作准备,将原数据中1×784的向量转为分辨率为28×28的图片。代码如下:

        生成train和test两个文件夹,每个文件夹下为0~9共10共十个子文件夹,子文件夹下是相应的图片。


import numpy as npimport osimport cv2import structdef load_mnist(path, kind='train'):    """Load MNIST data from `path`"""    labels_path = os.path.join(path, '%s-labels-idx1-ubyte' % kind)    images_path = os.path.join(path, '%s-images-idx3-ubyte' % kind)    with open(labels_path, 'rb') as lbpath:        magic, n = struct.unpack('>II', lbpath.read(8))        labels = np.fromfile(lbpath, dtype=np.uint8)    with open(images_path, 'rb') as imgpath:        magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))        images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)    return images, labelsX_train, y_train = load_mnist('', kind='train')print('Rows: %d, columns: %d' % (X_train.shape[0], X_train.shape[1]))X_test, y_test = load_mnist('', kind='t10k')print('Rows: %d, columns: %d' % (X_test.shape[0], X_test.shape[1]))count = np.zeros(10)nTrain = len(X_train)for i in xrange(nTrain):    label = y_train[i]    count[label] += 1    filename = './train/' + str(label) + '/' + str(label) + '_' + str(int(count[label])) + '.png'    img = X_train[i].reshape(28,28)    cv2.imwrite(filename, img)count = np.zeros(10)nTest = len(X_test)for i in xrange(nTest):    label = y_test[i]    count[label] += 1    filename = './test/' + str(label) + '/' + str(label) + '_' + str(int(count[label])) + '.png'    img = X_test[i].reshape(28,28)    cv2.imwrite(filename, img)






0 0
原创粉丝点击