CNTK API文档翻译(4)——MNIST数据加载
来源:互联网 发布:女神很少网络 编辑:程序博客网 时间:2024/05/29 03:22
本教程针对机器学习和CNTK新手,本教程的前提是你已经完成了本系列的第一个案例和第二个案例。在本教程中,我们将下载和预处理MNIST图像,以便用于建立不同的手书数字图像识别模型。在之后的三期教程中,我们会把第一期和第二期的方法用于本数据集,还会引入卷积神经网络来获取更好的表现。这是我们使用真实的数据进行训练和评估的第一个例子。
本小系列被分成了四个部分:
- 第一部分熟悉本教程中会被用到的MNIST数据集(MNIST数据集详情请看我的Python与人工神经网络第三期)
- 后面三个部分会使用不同类型的神经网络来处理MNIST数据
# Import the relevant modules to be used laterfrom __future__ import print_functionimport gzipimport matplotlib.image as mpimgimport matplotlib.pyplot as pltimport numpy as npimport osimport shutilimport structimport systry: from urllib.request import urlretrieve except ImportError: from urllib import urlretrieve# Config matplotlib for inline plotting%matplotlib inline
数据下载
我们需要把MNIST数据下载到本机。MNIST数据集是一个标准的手书图片,他被广泛用于训练和测试机器学习算法。数据集中包含60000个训练图片和10000个测试图片,每个图片大小是28*28像素,这个数据集能够很方便的在各种电脑上查看和训练。
# Functions to load MNIST images and unpack into train and test set.# - loadData reads image data and formats into a 28x28 long array# - loadLabels reads the corresponding labels data, 1 for each image# - load packs the downloaded image and labels data into a combined format to be read later by # CNTK text reader def loadData(src, cimg): print ('Downloading ' + src) gzfname, h = urlretrieve(src, './delete.me') print ('Done.') try: with gzip.open(gzfname) as gz: n = struct.unpack('I', gz.read(4)) # Read magic number. if n[0] != 0x3080000: raise Exception('Invalid file: unexpected magic number.') # Read number of entries. n = struct.unpack('>I', gz.read(4))[0] if n != cimg: raise Exception('Invalid file: expected {0} entries.'.format(cimg)) crow = struct.unpack('>I', gz.read(4))[0] ccol = struct.unpack('>I', gz.read(4))[0] if crow != 28 or ccol != 28: raise Exception('Invalid file: expected 28 rows/cols per image.') # Read data. res = np.fromstring(gz.read(cimg * crow * ccol), dtype = np.uint8) finally: os.remove(gzfname) return res.reshape((cimg, crow * ccol))def loadLabels(src, cimg): print ('Downloading ' + src) gzfname, h = urlretrieve(src, './delete.me') print ('Done.') try: with gzip.open(gzfname) as gz: n = struct.unpack('I', gz.read(4)) # Read magic number. if n[0] != 0x1080000: raise Exception('Invalid file: unexpected magic number.') # Read number of entries. n = struct.unpack('>I', gz.read(4)) if n[0] != cimg: raise Exception('Invalid file: expected {0} rows.'.format(cimg)) # Read labels. res = np.fromstring(gz.read(cimg), dtype = np.uint8) finally: os.remove(gzfname) return res.reshape((cimg, 1))def try_download(dataSrc, labelsSrc, cimg): data = loadData(dataSrc, cimg) labels = loadLabels(labelsSrc, cimg) return np.hstack((data, labels))
- 下载
# URLs for the train image and labels dataurl_train_image = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'url_train_labels = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'num_train_samples = 60000print("Downloading train data")train = try_download(url_train_image, url_train_labels, num_train_samples)url_test_image = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'url_test_labels = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'num_test_samples = 10000print("Downloading test data")test = try_download(url_test_image, url_test_labels, num_test_samples)
- 查看/可视化
# Plot a random imagesample_number = 5001plt.imshow(train[sample_number,:-1].reshape(28,28), cmap="gray_r")plt.axis('off')print("Image Label: ", train[sample_number,-1])
保存图片
在本地文件夹中保存图片:保存图片时我们把图片降为成一个矢量(28*28的图片变成一个长784的数组)
标签使用一位有效编码(One-Hot Encoding,上图是3,会被标记成0001000000,守卫表示0,最后一位表示9).
# Save the data files into a format compatible with CNTK text readerdef savetxt(filename, ndarray): dir = os.path.dirname(filename) if not os.path.exists(dir): os.makedirs(dir) if not os.path.isfile(filename): print("Saving", filename ) with open(filename, 'w') as f: labels = list(map(' '.join, np.eye(10, dtype=np.uint).astype(str))) for row in ndarray: row_str = row.astype(str) label_str = labels[row[-1]] feature_str = ' '.join(row_str[:-1]) f.write('|labels {} |features {}\n'.format(label_str, feature_str)) else: print("File already exists", filename)
# Save the train and test files (prefer our default path for the data)data_dir = os.path.join("..", "Examples", "Image", "DataSets", "MNIST")if not os.path.exists(data_dir): data_dir = os.path.join("data", "MNIST")print ('Writing train text file...')savetxt(os.path.join(data_dir, "Train-28x28_cntk_text.txt"), train)print ('Writing test text file...')savetxt(os.path.join(data_dir, "Test-28x28_cntk_text.txt"), test)print('Done')
欢迎扫码关注我的微信公众号获取最新文章
阅读全文
0 0
- CNTK API文档翻译(4)——MNIST数据加载
- CNTK API文档翻译(5)——对MNIST数据使用逻辑回归
- CNTK API文档翻译(6)——对MNIST数据使用多层感知机
- CNTK API文档翻译(7)——对MNIST数据使用卷积神经网络
- CNTK API文档翻译(9)——使用自编码器压缩MNIST数据
- CNTK API文档翻译(12)——CNTK进阶
- CNTK API文档翻译(10)——使用LSTM预测时间序列数据
- CNTK API文档翻译(13)——CIFAR-10数据准备
- CNTK API文档翻译(17)——多对多神经网络处理文本数据(1)
- CNTK API文档翻译(18)——多对多神经网络处理文本数据(2)
- CNTK API文档翻译(20)——GAN处理MSIST数据基础
- CNTK API文档翻译(21)——深度卷积GAN处理MSIST数据基础
- CNTK API文档翻译(17)——多对多神经网络处理文本数据(1)
- CNTK API文档翻译(1)——使用数列
- CNTK API文档翻译(2)——逻辑回归
- CNTK API文档翻译(3)——前馈神经网络
- CNTK API文档翻译(14)——实验图像识别
- CNTK API文档翻译(15)——自然语言理解
- hdu 1400 (插头)
- 手机射频架构解析
- 继承,实现,关联,聚合,组合,依赖几种关系的介绍
- CPropertySheet的基本用法
- Leetcode 207 Course Schedule
- CNTK API文档翻译(4)——MNIST数据加载
- Android——广播初介绍
- 使用Gallery和ImageSwitcher制作图片浏览器
- 服务调用框架DataStrom
- String、StringBuffer与StringBuilder之间区别
- WMS系统开发总结-自定义菜单或者显示的表格名字
- A New Start
- POJ3686_The Windy's_最小费用流::最小权匹配
- 多台云服务器+Docker部署Ceph存储系统