CNTK API文档翻译(13)——CIFAR-10数据准备

来源:互联网 发布:python pyqt5教程 编辑:程序博客网 时间:2024/05/21 02:53

本教程将展示如何为CNTK里的深度学习算法准备图像数据集。CIFAR-10数据集是一个常用从8000万张小图片中标记一部分而成的图像分类数据集,,由Alex Krizhevsky、Vinod Nair和 Geoffrey Hinton收集整理。

CIFAR-10数据集不包含在CNTK中,不过可以非常容易的在互联网上下载并转换成CNTK支持的格式。

本期教程和下期教程完成一下工作:

  • 本期:熟悉CIFAR-10数据集,并将它转换成CNTK支持的格式。
  • 下期:图像理解教程。

如果你对CIFAR-10数据集可以干什么感兴趣,可以看看Rodrigo Benenson的博客,上面有相关领域的最新技术和算法。博客地址:http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html#43494641522d3130

注:下载数据根据网络状况可能花费十到十五分钟,请耐心等待。

# Use a function definition from future version (say 3.x from 2.7 interpreter)from __future__ import print_function from PIL import Imageimport getoptimport numpy as npimport pickle as cpimport osimport shutilimport structimport sysimport tarfileimport xml.etree.cElementTree as etimport xml.dom.minidomtry:     from urllib.request import urlretrieve except ImportError:     from urllib import urlretrieve

数据下载

CIFAR-10数据集包含600000个 32×32的彩色图片,其中50000张是训练图片,10000张是测试图片。图片被被分成10类,每类6000张,这10类分别是:飞机、开车、鸟、猫、鹿、狗、蛙、马、船和卡车。

# CIFAR Image dataimgSize = 32numFeature = imgSize * imgSize * 3

我们先创建一些辅助函数来下载CIFAR数据。下载的文档包含文件data_batch_1、data_batch_2、…data_batch_5以及test_batch,每个文件都是Python cPickle模块持久化保存的对象。要将他们用到CNTK中我们需要做如下三步操作:

readBatch:解包Python持久化文件
loadData:将数据加载成训练或测试对象
saveTxt:如函数名描述,将用来测试和训练的特征值和标记值保存成文本文件

def readBatch(src):    with open(src, 'rb') as f:        if sys.version_info[0] < 3:             d = cp.load(f)         else:            d = cp.load(f, encoding='latin1')        data = d['data']        feat = data    res = np.hstack((feat, np.reshape(d['labels'], (len(d['labels']), 1))))    return res.astype(np.int)def loadData(src):    print ('Downloading ' + src)    fname, h = urlretrieve(src, './delete.me')    print ('Done.')    try:        print ('Extracting files...')        with tarfile.open(fname) as tar:            tar.extractall()        print ('Done.')        print ('Preparing train set...')        trn = np.empty((0, numFeature + 1), dtype=np.int)        for i in range(5):            batchName = './cifar-10-batches-py/data_batch_{0}'.format(i + 1)            trn = np.vstack((trn, readBatch(batchName)))        print ('Done.')        print ('Preparing test set...')        tst = readBatch('./cifar-10-batches-py/test_batch')        print ('Done.')    finally:        os.remove(fname)    return (trn, tst)def saveTxt(filename, ndarray):    with open(filename, 'w') as f:        labels = list(map(' '.join, np.eye(10, dtype=np.uint).astype(str)))        for row in ndarray:            row_str = row.astype(str)            label_str = labels[row[-1]]            feature_str = ' '.join(row_str[:-1])            f.write('|labels {} |features {}\n'.format(label_str, feature_str))

除了将图片保存成文本格式之外,我们还需要将其保存成png格式。我们的任务是图像理解,我们还需要保存每张图片的意思,saveImage和saveMean函数就是做如下工作的。

def saveImage(fname, data, label, mapFile, regrFile, pad, **key_parms):    # data in CIFAR-10 dataset is in CHW format.    pixData = data.reshape((3, imgSize, imgSize))    if ('mean' in key_parms):        key_parms['mean'] += pixData    if pad > 0:        pixData = np.pad(pixData, ((0, 0), (pad, pad), (pad, pad)), mode='constant', constant_values=128)     img = Image.new('RGB', (imgSize + 2 * pad, imgSize + 2 * pad))    pixels = img.load()    for x in range(img.size[0]):        for y in range(img.size[1]):            pixels[x, y] = (pixData[0][y][x], pixData[1][y][x], pixData[2][y][x])    img.save(fname)    mapFile.write("%s\t%d\n" % (fname, label))    # compute per channel mean and store for regression example    channelMean = np.mean(pixData, axis=(1,2))    regrFile.write("|regrLabels\t%f\t%f\t%f\n" % (channelMean[0]/255.0, channelMean[1]/255.0, channelMean[2]/255.0))def saveMean(fname, data):    root = et.Element('opencv_storage')    et.SubElement(root, 'Channel').text = '3'    et.SubElement(root, 'Row').text = str(imgSize)    et.SubElement(root, 'Col').text = str(imgSize)    meanImg = et.SubElement(root, 'MeanImg', type_id='opencv-matrix')    et.SubElement(meanImg, 'rows').text = '1'    et.SubElement(meanImg, 'cols').text = str(imgSize * imgSize * 3)    et.SubElement(meanImg, 'dt').text = 'f'    et.SubElement(meanImg, 'data').text = ' '.join(['%e' % n for n in np.reshape(data, (imgSize * imgSize * 3))])    tree = et.ElementTree(root)    tree.write(fname)    x = xml.dom.minidom.parse(fname)    with open(fname, 'w') as f:        f.write(x.toprettyxml(indent = '  '))

saveTrainImages和saveTestImages函数只是简单的循环使用上面的函数,保存下全部的数据。

def saveTrainImages(filename, foldername):    if not os.path.exists(foldername):        os.makedirs(foldername)    data = {}    # mean is in CHW format.    dataMean = np.zeros((3, imgSize, imgSize))     with open('train_map.txt', 'w') as mapFile:        with open('train_regrLabels.txt', 'w') as regrFile:            for ifile in range(1, 6):                with open(os.path.join('./cifar-10-batches-py', 'data_batch_' + str(ifile)), 'rb') as f:                    if sys.version_info[0] < 3:                         data = cp.load(f)                    else:                         data = cp.load(f, encoding='latin1')                    for i in range(10000):                        fname = os.path.join(os.path.abspath(foldername), ('%05d.png' % (i + (ifile - 1) * 10000)))                        saveImage(fname, data['data'][i, :], data['labels'][i], mapFile, regrFile, 4, mean=dataMean)    dataMean = dataMean / (50 * 1000)    saveMean('CIFAR-10_mean.xml', dataMean)def saveTestImages(filename, foldername):    if not os.path.exists(foldername):      os.makedirs(foldername)    with open('test_map.txt', 'w') as mapFile:        with open('test_regrLabels.txt', 'w') as regrFile:            with open(os.path.join('./cifar-10-batches-py', 'test_batch'), 'rb') as f:                if sys.version_info[0] < 3:                     data = cp.load(f)                else:                     data = cp.load(f, encoding='latin1')                for i in range(10000):                    fname = os.path.join(os.path.abspath(foldername), ('%05d.png' % i))                    saveImage(fname, data['data'][i, :], data['labels'][i], mapFile, regrFile, 0)# URLs for the train image and labels dataurl_cifar_data = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'# Paths for saving the text filesdata_dir = './data/CIFAR-10/'train_filename = data_dir + '/Train_cntk_text.txt'test_filename = data_dir + '/Test_cntk_text.txt'train_img_directory = data_dir + '/Train'test_img_directory = data_dir + '/Test'root_dir = os.getcwd()if not os.path.exists(data_dir):    os.makedirs(data_dir)try:    os.chdir(data_dir)       trn, tst= loadData('http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')    print ('Writing train text file...')    saveTxt(r'./Train_cntk_text.txt', trn)    print ('Done.')    print ('Writing test text file...')    saveTxt(r'./Test_cntk_text.txt', tst)    print ('Done.')    print ('Converting train data to png images...')    saveTrainImages(r'./Train_cntk_text.txt', 'train')    print ('Done.')    print ('Converting test data to png images...')    saveTestImages(r'./Test_cntk_text.txt', 'test')    print ('Done.')finally:    os.chdir("../..")


欢迎扫码关注我的微信公众号获取最新文章
image

阅读全文
1 0
原创粉丝点击