Udacity Deep Learning 任务 1: notMNIST

来源:互联网 发布:怎样输入单变量数据 编辑:程序博客网 时间:2024/06/06 07:05

预处理 notMNIST 数据,在此基础上训练一个简单的逻辑回归模型
这里写图片描述

导入python库:

# These are all the modules we'll be using later. Make sure you can import them# before proceeding further.from __future__ import print_functionimport matplotlib.pyplot as pltimport numpy as npimport osimport sysimport tarfilefrom IPython.display import display, Imagefrom scipy import ndimagefrom sklearn.linear_model import LogisticRegressionfrom six.moves.urllib.request import urlretrievefrom six.moves import cPickle as pickle# Config the matplotlib backend as plotting inline in IPython%matplotlib inline

若导入出错,可下载相应的python库并安装。windows下可参考:Windows下各种机器学习python库安装

数据集下载:

url = 'https://commondatastorage.googleapis.com/books1000/'last_percent_reported = Nonedata_root = '.' # Change me to store data elsewheredef download_progress_hook(count, blockSize, totalSize):  """A hook to report the progress of a download. This is mostly intended for users with  slow internet connections. Reports every 5% change in download progress.  """  global last_percent_reported  percent = int(count * blockSize * 100 / totalSize)  if last_percent_reported != percent:    if percent % 5 == 0:      sys.stdout.write("%s%%" % percent)      sys.stdout.flush()    else:      sys.stdout.write(".")      sys.stdout.flush()    last_percent_reported = percentdef maybe_download(filename, expected_bytes, force=False):  """Download a file if not present, and make sure it's the right size."""  dest_filename = os.path.join(data_root, filename)  if force or not os.path.exists(dest_filename):    print('Attempting to download:', filename)     filename, _ = urlretrieve(url + filename, dest_filename, reporthook=download_progress_hook)    print('\nDownload Complete!')  statinfo = os.stat(dest_filename)  if statinfo.st_size == expected_bytes:    print('Found and verified', dest_filename)  else:    raise Exception(      'Failed to verify ' + dest_filename + '. Can you get to it with a browser?')  return dest_filenametrain_filename = maybe_download('notMNIST_large.tar.gz', 247336696)test_filename = maybe_download('notMNIST_small.tar.gz', 8458043)

Found and verified notMNIST_large.tar.gz
Found and verified notMNIST_small.tar.gz
也可以复制网址直接下载。

解压下载的文件:

num_classes = 10np.random.seed(133)def maybe_extract(filename, force=False):  root = os.path.splitext(os.path.splitext(filename)[0])[0]  # remove .tar.gz  if os.path.isdir(root) and not force:    # You may override by setting force=True.    print('%s already present - Skipping extraction of %s.' % (root, filename))  else:    print('Extracting data for %s. This may take a while. Please wait.' % root)    tar = tarfile.open(filename)    sys.stdout.flush()    tar.extractall(data_root)    tar.close()  data_folders = [    os.path.join(root, d) for d in sorted(os.listdir(root))    if os.path.isdir(os.path.join(root, d))]  if len(data_folders) != num_classes:    raise Exception(      'Expected %d folders, one per class. Found %d instead.' % (        num_classes, len(data_folders)))  print(data_folders)  return data_folderstrain_folders = maybe_extract(train_filename)test_folders = maybe_extract(test_filename)

[‘notMNIST_large/A’, ‘notMNIST_large/B’, ‘notMNIST_large/C’, ‘notMNIST_large/D’, ‘notMNIST_large/E’, ‘notMNIST_large/F’, ‘notMNIST_large/G’, ‘notMNIST_large/H’, ‘notMNIST_large/I’, ‘notMNIST_large/J’]
[‘notMNIST_small/A’, ‘notMNIST_small/B’, ‘notMNIST_small/C’, ‘notMNIST_small/D’, ‘notMNIST_small/E’, ‘notMNIST_small/F’, ‘notMNIST_small/G’, ‘notMNIST_small/H’, ‘notMNIST_small/I’, ‘notMNIST_small/J’]

问题一: 显示图片数据:

import randomimport matplotlib.image as mpimgdef plot_samples(data_folders, sample_size, title=None):    fig = plt.figure()    if title: fig.suptitle(title, fontsize=16, fontweight='bold')    for folder in data_folders:        image_files = os.listdir(folder)        image_sample = random.sample(image_files, sample_size)        for image in image_sample:            image_file = os.path.join(folder, image)            ax = fig.add_subplot(len(data_folders), sample_size, sample_size * data_folders.index(folder) +                                 image_sample.index(image) + 1)            image = mpimg.imread(image_file)            ax.imshow(image)            ax.set_axis_off()    plt.show()plot_samples(train_folders, 10, 'Train Folders')plot_samples(test_folders, 10, 'Test Folders')

这里写图片描述

这里写图片描述

数据处理:

image_size = 28  # Pixel width and height.pixel_depth = 255.0  # Number of levels per pixel.def load_letter(folder, min_num_images):  """Load the data for a single letter label."""  image_files = os.listdir(folder)  dataset = np.ndarray(shape=(len(image_files), image_size, image_size),                         dtype=np.float32)  print(folder)  num_images = 0  for image in image_files:    image_file = os.path.join(folder, image)    try:      image_data = (ndimage.imread(image_file).astype(float) -                     pixel_depth / 2) / pixel_depth      if image_data.shape != (image_size, image_size):        raise Exception('Unexpected image shape: %s' % str(image_data.shape))      dataset[num_images, :, :] = image_data      num_images = num_images + 1    except IOError as e:      print('Could not read:', image_file, ':', e, '- it\'s ok, skipping.')  dataset = dataset[0:num_images, :, :]  if num_images < min_num_images:    raise Exception('Many fewer images than expected: %d < %d' %                    (num_images, min_num_images))  print('Full dataset tensor:', dataset.shape)  print('Mean:', np.mean(dataset))  print('Standard deviation:', np.std(dataset))  return datasetdef maybe_pickle(data_folders, min_num_images_per_class, force=False):  dataset_names = []  for folder in data_folders:    set_filename = folder + '.pickle'    dataset_names.append(set_filename)    if os.path.exists(set_filename) and not force:      # You may override by setting force=True.      print('%s already present - Skipping pickling.' % set_filename)    else:      print('Pickling %s.' % set_filename)      dataset = load_letter(folder, min_num_images_per_class)      try:        with open(set_filename, 'wb') as f:          pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)      except Exception as e:        print('Unable to save data to', set_filename, ':', e)  return dataset_namestrain_datasets = maybe_pickle(train_folders, 45000)test_datasets = maybe_pickle(test_folders, 1800)

notMNIST_large/A
Could not read: notMNIST_large/A/Um9tYW5hIEJvbGQucGZi.png : cannot identify image file - it’s ok, skipping.
Could not read: notMNIST_large/A/RnJlaWdodERpc3BCb29rSXRhbGljLnR0Zg==.png : cannot identify image file - it’s ok, skipping.
Could not read: notMNIST_large/A/SG90IE11c3RhcmQgQlROIFBvc3Rlci50dGY=.png : cannot identify image file - it’s ok, skipping.
Full dataset tensor: (52909, 28, 28)
Mean: -0.12848
Standard deviation: 0.425576
notMNIST_large/B
Could not read: notMNIST_large/B/TmlraXNFRi1TZW1pQm9sZEl0YWxpYy5vdGY=.png : cannot identify image file - it’s ok, skipping.
Full dataset tensor: (52911, 28, 28)
Mean: -0.00755947
Standard deviation: 0.417272
notMNIST_large/C
Full dataset tensor: (52912, 28, 28)
Mean: -0.142321
Standard deviation: 0.421305
notMNIST_large/D
Could not read: notMNIST_large/D/VHJhbnNpdCBCb2xkLnR0Zg==.png : cannot identify image file - it’s ok, skipping.
Full dataset tensor: (52911, 28, 28)
Mean: -0.0574553
Standard deviation: 0.434072
notMNIST_large/E
Full dataset tensor: (52912, 28, 28)
Mean: -0.0701406
Standard deviation: 0.42882
notMNIST_large/F
Full dataset tensor: (52912, 28, 28)
Mean: -0.125914
Standard deviation: 0.429645
notMNIST_large/G
Full dataset tensor: (52912, 28, 28)
Mean: -0.0947771
Standard deviation: 0.421674
notMNIST_large/H
Full dataset tensor: (52912, 28, 28)
Mean: -0.0687667
Standard deviation: 0.430344
notMNIST_large/I
Full dataset tensor: (52912, 28, 28)
Mean: 0.0307405
Standard deviation: 0.449686
notMNIST_large/J
Full dataset tensor: (52911, 28, 28)
Mean: -0.153479
Standard deviation: 0.397169
notMNIST_small/A
Could not read: notMNIST_small/A/RGVtb2NyYXRpY2FCb2xkT2xkc3R5bGUgQm9sZC50dGY=.png : cannot identify image file - it’s ok, skipping.
Full dataset tensor: (1872, 28, 28)
Mean: -0.132588
Standard deviation: 0.445923
notMNIST_small/B
Full dataset tensor: (1873, 28, 28)
Mean: 0.00535619
Standard deviation: 0.457054
notMNIST_small/C
Full dataset tensor: (1873, 28, 28)
Mean: -0.141489
Standard deviation: 0.441056
notMNIST_small/D
Full dataset tensor: (1873, 28, 28)
Mean: -0.0492094
Standard deviation: 0.460477
notMNIST_small/E
Full dataset tensor: (1873, 28, 28)
Mean: -0.0598952
Standard deviation: 0.456146
notMNIST_small/F
Could not read: notMNIST_small/F/Q3Jvc3NvdmVyIEJvbGRPYmxpcXVlLnR0Zg==.png : cannot identify image file - it’s ok, skipping.
Full dataset tensor: (1872, 28, 28)
Mean: -0.118148
Standard deviation: 0.451134
notMNIST_small/G
Full dataset tensor: (1872, 28, 28)
Mean: -0.092519
Standard deviation: 0.448468
notMNIST_small/H
Full dataset tensor: (1872, 28, 28)
Mean: -0.0586729
Standard deviation: 0.457387
notMNIST_small/I
Full dataset tensor: (1872, 28, 28)
Mean: 0.0526481
Standard deviation: 0.472657
notMNIST_small/J
Full dataset tensor: (1872, 28, 28)
Mean: -0.15167
Standard deviation: 0.449521

问题二:显示处理后的图片数据

def load_and_display_pickle(datasets, sample_size, title=None):    fig = plt.figure()    if title: fig.suptitle(title, fontsize=16, fontweight='bold')    num_of_images = []    for pickle_file in datasets:        with open(pickle_file, 'rb') as f:            data = pickle.load(f)            print('Total images in', pickle_file, ':', len(data))            for index, image in enumerate(data):                if index == sample_size: break                ax = fig.add_subplot(len(datasets), sample_size, sample_size * datasets.index(pickle_file) +                                     index + 1)                ax.imshow(image)                ax.set_axis_off()                ax.imshow(image)            num_of_images.append(len(data))    balance_check(num_of_images)    plt.show()    return num_of_images

问题三:均衡校验

def generate_fake_label(sizes):    labels = np.ndarray(sum(sizes), dtype=np.int32)    start = 0    end = 0    for label, size in enumerate(sizes):        start = end        end += size        for j in range(start, end):            labels[j] = label    return labelsdef plot_balance():    fig, ax = plt.subplots(1, 2)    bins = np.arange(train_labels.min(), train_labels.max() + 2)    ax[0].hist(train_labels, bins=bins)    ax[0].set_xticks((bins[:-1] + bins[1:]) / 2, [chr(k) for k in range(ord("A"), ord("J") + 1)])    ax[0].set_title("Training data")    bins = np.arange(test_labels.min(), test_labels.max() + 2)    ax[1].hist(test_labels, bins=bins)    ax[1].set_xticks((bins[:-1] + bins[1:]) / 2, [chr(k) for k in range(ord("A"), ord("J") + 1)])    ax[1].set_title("Test data")    plt.show()def mean(numbers):    return float(sum(numbers)) / max(len(numbers), 1)def balance_check(sizes):    mean_val = mean(sizes)    print('mean of # images :', mean_val)    for i in sizes:        if abs(i - mean_val) > 0.1 * mean_val:            print("Too much or less images")        else:            print("Well balanced", i)test_labels = generate_fake_label(load_and_display_pickle(test_datasets, 10, 'Test Datasets'))train_labels = generate_fake_label(load_and_display_pickle(train_datasets, 10, 'Train Datasets'))plot_balance()

未完