Udacity Deep Learning 任务 1: notMNIST
来源:互联网 发布:怎样输入单变量数据 编辑:程序博客网 时间:2024/06/06 07:05
预处理 notMNIST 数据,在此基础上训练一个简单的逻辑回归模型
导入python库:
# These are all the modules we'll be using later. Make sure you can import them# before proceeding further.from __future__ import print_functionimport matplotlib.pyplot as pltimport numpy as npimport osimport sysimport tarfilefrom IPython.display import display, Imagefrom scipy import ndimagefrom sklearn.linear_model import LogisticRegressionfrom six.moves.urllib.request import urlretrievefrom six.moves import cPickle as pickle# Config the matplotlib backend as plotting inline in IPython%matplotlib inline
若导入出错,可下载相应的python库并安装。windows下可参考:Windows下各种机器学习python库安装
数据集下载:
url = 'https://commondatastorage.googleapis.com/books1000/'last_percent_reported = Nonedata_root = '.' # Change me to store data elsewheredef download_progress_hook(count, blockSize, totalSize): """A hook to report the progress of a download. This is mostly intended for users with slow internet connections. Reports every 5% change in download progress. """ global last_percent_reported percent = int(count * blockSize * 100 / totalSize) if last_percent_reported != percent: if percent % 5 == 0: sys.stdout.write("%s%%" % percent) sys.stdout.flush() else: sys.stdout.write(".") sys.stdout.flush() last_percent_reported = percentdef maybe_download(filename, expected_bytes, force=False): """Download a file if not present, and make sure it's the right size.""" dest_filename = os.path.join(data_root, filename) if force or not os.path.exists(dest_filename): print('Attempting to download:', filename) filename, _ = urlretrieve(url + filename, dest_filename, reporthook=download_progress_hook) print('\nDownload Complete!') statinfo = os.stat(dest_filename) if statinfo.st_size == expected_bytes: print('Found and verified', dest_filename) else: raise Exception( 'Failed to verify ' + dest_filename + '. Can you get to it with a browser?') return dest_filenametrain_filename = maybe_download('notMNIST_large.tar.gz', 247336696)test_filename = maybe_download('notMNIST_small.tar.gz', 8458043)
Found and verified notMNIST_large.tar.gz
Found and verified notMNIST_small.tar.gz
也可以复制网址直接下载。
解压下载的文件:
num_classes = 10np.random.seed(133)def maybe_extract(filename, force=False): root = os.path.splitext(os.path.splitext(filename)[0])[0] # remove .tar.gz if os.path.isdir(root) and not force: # You may override by setting force=True. print('%s already present - Skipping extraction of %s.' % (root, filename)) else: print('Extracting data for %s. This may take a while. Please wait.' % root) tar = tarfile.open(filename) sys.stdout.flush() tar.extractall(data_root) tar.close() data_folders = [ os.path.join(root, d) for d in sorted(os.listdir(root)) if os.path.isdir(os.path.join(root, d))] if len(data_folders) != num_classes: raise Exception( 'Expected %d folders, one per class. Found %d instead.' % ( num_classes, len(data_folders))) print(data_folders) return data_folderstrain_folders = maybe_extract(train_filename)test_folders = maybe_extract(test_filename)
[‘notMNIST_large/A’, ‘notMNIST_large/B’, ‘notMNIST_large/C’, ‘notMNIST_large/D’, ‘notMNIST_large/E’, ‘notMNIST_large/F’, ‘notMNIST_large/G’, ‘notMNIST_large/H’, ‘notMNIST_large/I’, ‘notMNIST_large/J’]
[‘notMNIST_small/A’, ‘notMNIST_small/B’, ‘notMNIST_small/C’, ‘notMNIST_small/D’, ‘notMNIST_small/E’, ‘notMNIST_small/F’, ‘notMNIST_small/G’, ‘notMNIST_small/H’, ‘notMNIST_small/I’, ‘notMNIST_small/J’]
问题一: 显示图片数据:
import randomimport matplotlib.image as mpimgdef plot_samples(data_folders, sample_size, title=None): fig = plt.figure() if title: fig.suptitle(title, fontsize=16, fontweight='bold') for folder in data_folders: image_files = os.listdir(folder) image_sample = random.sample(image_files, sample_size) for image in image_sample: image_file = os.path.join(folder, image) ax = fig.add_subplot(len(data_folders), sample_size, sample_size * data_folders.index(folder) + image_sample.index(image) + 1) image = mpimg.imread(image_file) ax.imshow(image) ax.set_axis_off() plt.show()plot_samples(train_folders, 10, 'Train Folders')plot_samples(test_folders, 10, 'Test Folders')
数据处理:
image_size = 28 # Pixel width and height.pixel_depth = 255.0 # Number of levels per pixel.def load_letter(folder, min_num_images): """Load the data for a single letter label.""" image_files = os.listdir(folder) dataset = np.ndarray(shape=(len(image_files), image_size, image_size), dtype=np.float32) print(folder) num_images = 0 for image in image_files: image_file = os.path.join(folder, image) try: image_data = (ndimage.imread(image_file).astype(float) - pixel_depth / 2) / pixel_depth if image_data.shape != (image_size, image_size): raise Exception('Unexpected image shape: %s' % str(image_data.shape)) dataset[num_images, :, :] = image_data num_images = num_images + 1 except IOError as e: print('Could not read:', image_file, ':', e, '- it\'s ok, skipping.') dataset = dataset[0:num_images, :, :] if num_images < min_num_images: raise Exception('Many fewer images than expected: %d < %d' % (num_images, min_num_images)) print('Full dataset tensor:', dataset.shape) print('Mean:', np.mean(dataset)) print('Standard deviation:', np.std(dataset)) return datasetdef maybe_pickle(data_folders, min_num_images_per_class, force=False): dataset_names = [] for folder in data_folders: set_filename = folder + '.pickle' dataset_names.append(set_filename) if os.path.exists(set_filename) and not force: # You may override by setting force=True. print('%s already present - Skipping pickling.' % set_filename) else: print('Pickling %s.' % set_filename) dataset = load_letter(folder, min_num_images_per_class) try: with open(set_filename, 'wb') as f: pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL) except Exception as e: print('Unable to save data to', set_filename, ':', e) return dataset_namestrain_datasets = maybe_pickle(train_folders, 45000)test_datasets = maybe_pickle(test_folders, 1800)
notMNIST_large/A
Could not read: notMNIST_large/A/Um9tYW5hIEJvbGQucGZi.png : cannot identify image file - it’s ok, skipping.
Could not read: notMNIST_large/A/RnJlaWdodERpc3BCb29rSXRhbGljLnR0Zg==.png : cannot identify image file - it’s ok, skipping.
Could not read: notMNIST_large/A/SG90IE11c3RhcmQgQlROIFBvc3Rlci50dGY=.png : cannot identify image file - it’s ok, skipping.
Full dataset tensor: (52909, 28, 28)
Mean: -0.12848
Standard deviation: 0.425576
notMNIST_large/B
Could not read: notMNIST_large/B/TmlraXNFRi1TZW1pQm9sZEl0YWxpYy5vdGY=.png : cannot identify image file - it’s ok, skipping.
Full dataset tensor: (52911, 28, 28)
Mean: -0.00755947
Standard deviation: 0.417272
notMNIST_large/C
Full dataset tensor: (52912, 28, 28)
Mean: -0.142321
Standard deviation: 0.421305
notMNIST_large/D
Could not read: notMNIST_large/D/VHJhbnNpdCBCb2xkLnR0Zg==.png : cannot identify image file - it’s ok, skipping.
Full dataset tensor: (52911, 28, 28)
Mean: -0.0574553
Standard deviation: 0.434072
notMNIST_large/E
Full dataset tensor: (52912, 28, 28)
Mean: -0.0701406
Standard deviation: 0.42882
notMNIST_large/F
Full dataset tensor: (52912, 28, 28)
Mean: -0.125914
Standard deviation: 0.429645
notMNIST_large/G
Full dataset tensor: (52912, 28, 28)
Mean: -0.0947771
Standard deviation: 0.421674
notMNIST_large/H
Full dataset tensor: (52912, 28, 28)
Mean: -0.0687667
Standard deviation: 0.430344
notMNIST_large/I
Full dataset tensor: (52912, 28, 28)
Mean: 0.0307405
Standard deviation: 0.449686
notMNIST_large/J
Full dataset tensor: (52911, 28, 28)
Mean: -0.153479
Standard deviation: 0.397169
notMNIST_small/A
Could not read: notMNIST_small/A/RGVtb2NyYXRpY2FCb2xkT2xkc3R5bGUgQm9sZC50dGY=.png : cannot identify image file - it’s ok, skipping.
Full dataset tensor: (1872, 28, 28)
Mean: -0.132588
Standard deviation: 0.445923
notMNIST_small/B
Full dataset tensor: (1873, 28, 28)
Mean: 0.00535619
Standard deviation: 0.457054
notMNIST_small/C
Full dataset tensor: (1873, 28, 28)
Mean: -0.141489
Standard deviation: 0.441056
notMNIST_small/D
Full dataset tensor: (1873, 28, 28)
Mean: -0.0492094
Standard deviation: 0.460477
notMNIST_small/E
Full dataset tensor: (1873, 28, 28)
Mean: -0.0598952
Standard deviation: 0.456146
notMNIST_small/F
Could not read: notMNIST_small/F/Q3Jvc3NvdmVyIEJvbGRPYmxpcXVlLnR0Zg==.png : cannot identify image file - it’s ok, skipping.
Full dataset tensor: (1872, 28, 28)
Mean: -0.118148
Standard deviation: 0.451134
notMNIST_small/G
Full dataset tensor: (1872, 28, 28)
Mean: -0.092519
Standard deviation: 0.448468
notMNIST_small/H
Full dataset tensor: (1872, 28, 28)
Mean: -0.0586729
Standard deviation: 0.457387
notMNIST_small/I
Full dataset tensor: (1872, 28, 28)
Mean: 0.0526481
Standard deviation: 0.472657
notMNIST_small/J
Full dataset tensor: (1872, 28, 28)
Mean: -0.15167
Standard deviation: 0.449521
问题二:显示处理后的图片数据
def load_and_display_pickle(datasets, sample_size, title=None): fig = plt.figure() if title: fig.suptitle(title, fontsize=16, fontweight='bold') num_of_images = [] for pickle_file in datasets: with open(pickle_file, 'rb') as f: data = pickle.load(f) print('Total images in', pickle_file, ':', len(data)) for index, image in enumerate(data): if index == sample_size: break ax = fig.add_subplot(len(datasets), sample_size, sample_size * datasets.index(pickle_file) + index + 1) ax.imshow(image) ax.set_axis_off() ax.imshow(image) num_of_images.append(len(data)) balance_check(num_of_images) plt.show() return num_of_images
问题三:均衡校验
def generate_fake_label(sizes): labels = np.ndarray(sum(sizes), dtype=np.int32) start = 0 end = 0 for label, size in enumerate(sizes): start = end end += size for j in range(start, end): labels[j] = label return labelsdef plot_balance(): fig, ax = plt.subplots(1, 2) bins = np.arange(train_labels.min(), train_labels.max() + 2) ax[0].hist(train_labels, bins=bins) ax[0].set_xticks((bins[:-1] + bins[1:]) / 2, [chr(k) for k in range(ord("A"), ord("J") + 1)]) ax[0].set_title("Training data") bins = np.arange(test_labels.min(), test_labels.max() + 2) ax[1].hist(test_labels, bins=bins) ax[1].set_xticks((bins[:-1] + bins[1:]) / 2, [chr(k) for k in range(ord("A"), ord("J") + 1)]) ax[1].set_title("Test data") plt.show()def mean(numbers): return float(sum(numbers)) / max(len(numbers), 1)def balance_check(sizes): mean_val = mean(sizes) print('mean of # images :', mean_val) for i in sizes: if abs(i - mean_val) > 0.1 * mean_val: print("Too much or less images") else: print("Well balanced", i)test_labels = generate_fake_label(load_and_display_pickle(test_datasets, 10, 'Test Datasets'))train_labels = generate_fake_label(load_and_display_pickle(train_datasets, 10, 'Train Datasets'))plot_balance()
未完
- Udacity Deep Learning 任务 1: notMNIST
- udacity deep learning lesson中notMNIST数据集下载错误的解决
- Udacity深度学习DeepLearning课程作业1—notMnist
- Udacity Deep Learning课程作业(一)
- Udacity Deep Learning课程作业(二)
- Udacity Deep Learning课程作业(三)
- Udacity Deep Learning课程作业(四)
- Udacity Deep Learning课程作业(五)
- Udacity Deep Learning课程作业(六)
- Udacity深度学习(google)笔记(1)——notmnist
- 初次接触Deep Learning 任务1的解读
- Udacity作业——TensorFlow notMNIST代码及输出结果——Udacity学习笔记
- 【Deep Learning】1、AutoEncoder
- Deep Learning(1)
- Deep Learning笔记1
- Deep learning ( 1 )
- Deep Learning Notes 1
- 《Deep Learning》(1)-介绍
- 用Eclipse 开发Dynamic Web Project应用程序
- VS2010出现无法打开源文件 "stdafx.h"问题
- 调用腾讯优图开放平台进行人脸识别-Java调用API实现
- Gson User Guide
- 抽象类和接口
- Udacity Deep Learning 任务 1: notMNIST
- scala学习-Linux命令行运行jar包传入main方法参数
- hdu6143-多校8&三种方法-组合数|递推|容斥-Killer Names
- Apache Struts 2.3.x Showcase
- 【HTCVR】VRTK插件模块功能分析之传送移动(二)
- Ambari——大数据平台的搭建利器之进阶篇
- BIM可视化的作用_国建融科合创
- 为什么离不开stackoverflow
- TCP/IP详解--TCP连接中TIME_WAIT状态过多