CNTK API文档翻译(13)——CIFAR-10数据准备
来源:互联网 发布:python pyqt5教程 编辑:程序博客网 时间:2024/05/21 02:53
本教程将展示如何为CNTK里的深度学习算法准备图像数据集。CIFAR-10数据集是一个常用从8000万张小图片中标记一部分而成的图像分类数据集,,由Alex Krizhevsky、Vinod Nair和 Geoffrey Hinton收集整理。
CIFAR-10数据集不包含在CNTK中,不过可以非常容易的在互联网上下载并转换成CNTK支持的格式。
本期教程和下期教程完成一下工作:
- 本期:熟悉CIFAR-10数据集,并将它转换成CNTK支持的格式。
- 下期:图像理解教程。
如果你对CIFAR-10数据集可以干什么感兴趣,可以看看Rodrigo Benenson的博客,上面有相关领域的最新技术和算法。博客地址:http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html#43494641522d3130
注:下载数据根据网络状况可能花费十到十五分钟,请耐心等待。
# Use a function definition from future version (say 3.x from 2.7 interpreter)from __future__ import print_function from PIL import Imageimport getoptimport numpy as npimport pickle as cpimport osimport shutilimport structimport sysimport tarfileimport xml.etree.cElementTree as etimport xml.dom.minidomtry: from urllib.request import urlretrieve except ImportError: from urllib import urlretrieve
数据下载
CIFAR-10数据集包含600000个 32×32的彩色图片,其中50000张是训练图片,10000张是测试图片。图片被被分成10类,每类6000张,这10类分别是:飞机、开车、鸟、猫、鹿、狗、蛙、马、船和卡车。
# CIFAR Image dataimgSize = 32numFeature = imgSize * imgSize * 3
我们先创建一些辅助函数来下载CIFAR数据。下载的文档包含文件data_batch_1、data_batch_2、…data_batch_5以及test_batch,每个文件都是Python cPickle模块持久化保存的对象。要将他们用到CNTK中我们需要做如下三步操作:
readBatch:解包Python持久化文件
loadData:将数据加载成训练或测试对象
saveTxt:如函数名描述,将用来测试和训练的特征值和标记值保存成文本文件
def readBatch(src): with open(src, 'rb') as f: if sys.version_info[0] < 3: d = cp.load(f) else: d = cp.load(f, encoding='latin1') data = d['data'] feat = data res = np.hstack((feat, np.reshape(d['labels'], (len(d['labels']), 1)))) return res.astype(np.int)def loadData(src): print ('Downloading ' + src) fname, h = urlretrieve(src, './delete.me') print ('Done.') try: print ('Extracting files...') with tarfile.open(fname) as tar: tar.extractall() print ('Done.') print ('Preparing train set...') trn = np.empty((0, numFeature + 1), dtype=np.int) for i in range(5): batchName = './cifar-10-batches-py/data_batch_{0}'.format(i + 1) trn = np.vstack((trn, readBatch(batchName))) print ('Done.') print ('Preparing test set...') tst = readBatch('./cifar-10-batches-py/test_batch') print ('Done.') finally: os.remove(fname) return (trn, tst)def saveTxt(filename, ndarray): with open(filename, 'w') as f: labels = list(map(' '.join, np.eye(10, dtype=np.uint).astype(str))) for row in ndarray: row_str = row.astype(str) label_str = labels[row[-1]] feature_str = ' '.join(row_str[:-1]) f.write('|labels {} |features {}\n'.format(label_str, feature_str))
除了将图片保存成文本格式之外,我们还需要将其保存成png格式。我们的任务是图像理解,我们还需要保存每张图片的意思,saveImage和saveMean函数就是做如下工作的。
def saveImage(fname, data, label, mapFile, regrFile, pad, **key_parms): # data in CIFAR-10 dataset is in CHW format. pixData = data.reshape((3, imgSize, imgSize)) if ('mean' in key_parms): key_parms['mean'] += pixData if pad > 0: pixData = np.pad(pixData, ((0, 0), (pad, pad), (pad, pad)), mode='constant', constant_values=128) img = Image.new('RGB', (imgSize + 2 * pad, imgSize + 2 * pad)) pixels = img.load() for x in range(img.size[0]): for y in range(img.size[1]): pixels[x, y] = (pixData[0][y][x], pixData[1][y][x], pixData[2][y][x]) img.save(fname) mapFile.write("%s\t%d\n" % (fname, label)) # compute per channel mean and store for regression example channelMean = np.mean(pixData, axis=(1,2)) regrFile.write("|regrLabels\t%f\t%f\t%f\n" % (channelMean[0]/255.0, channelMean[1]/255.0, channelMean[2]/255.0))def saveMean(fname, data): root = et.Element('opencv_storage') et.SubElement(root, 'Channel').text = '3' et.SubElement(root, 'Row').text = str(imgSize) et.SubElement(root, 'Col').text = str(imgSize) meanImg = et.SubElement(root, 'MeanImg', type_id='opencv-matrix') et.SubElement(meanImg, 'rows').text = '1' et.SubElement(meanImg, 'cols').text = str(imgSize * imgSize * 3) et.SubElement(meanImg, 'dt').text = 'f' et.SubElement(meanImg, 'data').text = ' '.join(['%e' % n for n in np.reshape(data, (imgSize * imgSize * 3))]) tree = et.ElementTree(root) tree.write(fname) x = xml.dom.minidom.parse(fname) with open(fname, 'w') as f: f.write(x.toprettyxml(indent = ' '))
saveTrainImages和saveTestImages函数只是简单的循环使用上面的函数,保存下全部的数据。
def saveTrainImages(filename, foldername): if not os.path.exists(foldername): os.makedirs(foldername) data = {} # mean is in CHW format. dataMean = np.zeros((3, imgSize, imgSize)) with open('train_map.txt', 'w') as mapFile: with open('train_regrLabels.txt', 'w') as regrFile: for ifile in range(1, 6): with open(os.path.join('./cifar-10-batches-py', 'data_batch_' + str(ifile)), 'rb') as f: if sys.version_info[0] < 3: data = cp.load(f) else: data = cp.load(f, encoding='latin1') for i in range(10000): fname = os.path.join(os.path.abspath(foldername), ('%05d.png' % (i + (ifile - 1) * 10000))) saveImage(fname, data['data'][i, :], data['labels'][i], mapFile, regrFile, 4, mean=dataMean) dataMean = dataMean / (50 * 1000) saveMean('CIFAR-10_mean.xml', dataMean)def saveTestImages(filename, foldername): if not os.path.exists(foldername): os.makedirs(foldername) with open('test_map.txt', 'w') as mapFile: with open('test_regrLabels.txt', 'w') as regrFile: with open(os.path.join('./cifar-10-batches-py', 'test_batch'), 'rb') as f: if sys.version_info[0] < 3: data = cp.load(f) else: data = cp.load(f, encoding='latin1') for i in range(10000): fname = os.path.join(os.path.abspath(foldername), ('%05d.png' % i)) saveImage(fname, data['data'][i, :], data['labels'][i], mapFile, regrFile, 0)# URLs for the train image and labels dataurl_cifar_data = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'# Paths for saving the text filesdata_dir = './data/CIFAR-10/'train_filename = data_dir + '/Train_cntk_text.txt'test_filename = data_dir + '/Test_cntk_text.txt'train_img_directory = data_dir + '/Train'test_img_directory = data_dir + '/Test'root_dir = os.getcwd()if not os.path.exists(data_dir): os.makedirs(data_dir)try: os.chdir(data_dir) trn, tst= loadData('http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz') print ('Writing train text file...') saveTxt(r'./Train_cntk_text.txt', trn) print ('Done.') print ('Writing test text file...') saveTxt(r'./Test_cntk_text.txt', tst) print ('Done.') print ('Converting train data to png images...') saveTrainImages(r'./Train_cntk_text.txt', 'train') print ('Done.') print ('Converting test data to png images...') saveTestImages(r'./Test_cntk_text.txt', 'test') print ('Done.')finally: os.chdir("../..")
欢迎扫码关注我的微信公众号获取最新文章
- CNTK API文档翻译(13)——CIFAR-10数据准备
- CNTK API文档翻译(4)——MNIST数据加载
- CNTK API文档翻译(12)——CNTK进阶
- CNTK API文档翻译(10)——使用LSTM预测时间序列数据
- CNTK API文档翻译(5)——对MNIST数据使用逻辑回归
- CNTK API文档翻译(6)——对MNIST数据使用多层感知机
- CNTK API文档翻译(7)——对MNIST数据使用卷积神经网络
- CNTK API文档翻译(9)——使用自编码器压缩MNIST数据
- CNTK API文档翻译(17)——多对多神经网络处理文本数据(1)
- CNTK API文档翻译(18)——多对多神经网络处理文本数据(2)
- CNTK API文档翻译(20)——GAN处理MSIST数据基础
- CNTK API文档翻译(21)——深度卷积GAN处理MSIST数据基础
- CNTK API文档翻译(17)——多对多神经网络处理文本数据(1)
- CNTK API文档翻译(1)——使用数列
- CNTK API文档翻译(2)——逻辑回归
- CNTK API文档翻译(3)——前馈神经网络
- CNTK API文档翻译(14)——实验图像识别
- CNTK API文档翻译(15)——自然语言理解
- okHttp 添加动态的 超时时间 处理
- 从一百个数中找不存在的数
- 【转】angularJS的兄弟controller之间如何正确的通信
- 问题:删除volume报错
- 多态&多态对象模型
- CNTK API文档翻译(13)——CIFAR-10数据准备
- JavaScript函数
- 颜色特征识别—识别红色,黄色,绿色,蓝色排针的数量
- Excel在统计分析中的应用—第二章—描述性统计-Part3-偏度(偏斜度和矩偏度系数)
- [USACO3.1]丑数 Humble Numbers
- git pull提示 not-fast-forward
- 东北往事:黑道风云20年-有声全集
- jsp开发中cannot resolve taglib with uri的解决方法
- 设计模式(9)装饰模式--结构型