Keras-2 Keras Mnist

来源:互联网 发布:windows资源管理器解锁 编辑:程序博客网 时间:2024/05/21 16:39

Keras Mnist

在这里,我们将利用Keras搭建一个深度学习网络对mnist数据集进行识别。

  • 本文参考 keras-mnist-tutorial
  • 整个代码分为三个部分:
    1. 数据准备
    2. 模型搭建
    3. 训练优化

让我们开始吧

首先先导入一些模块

%matplotlib inlineimport numpy as npimport matplotlib.pyplot as pltfrom keras.datasets import mnistfrom keras.models import Sequentialfrom keras.layers.core import Dense, Activation, Dropoutfrom keras.utils import np_utils

数据准备

我们通过keras自带的数据集mnist进行导入数据,然后对其归一化处理,并且将原二维数据变成一维数据,作为网络的输入。

读入mnist数据集。可以看到每条样本是一个28*28的矩阵,共有60000个训练数据,10000个测试数据。

(X_train, y_train), (X_test, y_test) = mnist.load_data();print(x_train.shape)print(x_test.shape)
(60000, 28, 28)(10000, 28, 28)

将一些样本图像打印出来看看

for i in range(9):    plt.subplot(3,3,i+1)    plt.imshow(X_train[i], cmap='gray', interpolation='none')    plt.title("Class {}".format(y_train[i]))

这里写图片描述

将二维数据变成一维数据

X_train = X_train.reshape(len(X_train), -1)X_test = X_test.reshape(len(X_test), -1)

接下来对数据进行归一化。原来的数据范围是[0,255],我们通过归一化时靠近0附近。归一化的方式有很多,大家随意。

# uint不能有负数,我们先转为float类型X_train = X_train.astype('float32')X_test = X_test.astype('float32')X_train = (X_train - 127) / 127X_test = (X_test - 127) / 127

接下来 One-hot encoding

nb_classes = 10y_train = np_utils.to_categorical(y_train, nb_classes)y_test = np_utils.to_categorical(y_test, nb_classes)

搭建网络

数据已经准备好了,接下来我们进行网络的搭建,我们的网络有三层,都是全连接网络,大概长的像这样
网络模型

这里新遇到一个Dropout,这是一种防止过拟合(overfitting)的方法,详见Dropout层

model = Sequential()model.add(Dense(512, input_shape=(784,), kernel_initializer='he_normal'))model.add(Activation('relu'))model.add(Dropout(0.2)) model.add(Dense(512, kernel_initializer='he_normal'))model.add(Activation('relu'))model.add(Dropout(0.2)) model.add(Dense(nb_classes))model.add(Activation('softmax'))

OK!模型搭建好了,我们通过编译对学习过程进行配置

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

那么我们进行训练吧

model.fit(X_train, y_train, epochs=20, batch_size=64, verbose=1, validation_split=0.05)
Train on 57000 samples, validate on 3000 samplesEpoch 1/2057000/57000 [==============================] - 19s 327us/step - loss: 0.0811 - acc: 0.9761 - val_loss: 0.0702 - val_acc: 0.9810Epoch 2/2057000/57000 [==============================] - 19s 328us/step - loss: 0.0752 - acc: 0.9772 - val_loss: 0.0720 - val_acc: 0.9813Epoch 3/2057000/57000 [==============================] - 19s 331us/step - loss: 0.0687 - acc: 0.9788 - val_loss: 0.0670 - val_acc: 0.9830Epoch 4/2057000/57000 [==============================] - 20s 350us/step - loss: 0.0667 - acc: 0.9794 - val_loss: 0.0755 - val_acc: 0.9810Epoch 5/2057000/57000 [==============================] - 20s 353us/step - loss: 0.0688 - acc: 0.9794 - val_loss: 0.0671 - val_acc: 0.9820Epoch 6/2057000/57000 [==============================] - 20s 346us/step - loss: 0.0639 - acc: 0.9807 - val_loss: 0.0744 - val_acc: 0.9790Epoch 7/2057000/57000 [==============================] - 20s 342us/step - loss: 0.0626 - acc: 0.9805 - val_loss: 0.0685 - val_acc: 0.9837Epoch 8/2057000/57000 [==============================] - 21s 365us/step - loss: 0.0669 - acc: 0.9796 - val_loss: 0.0988 - val_acc: 0.9757Epoch 9/2057000/57000 [==============================] - 20s 345us/step - loss: 0.0605 - acc: 0.9819 - val_loss: 0.0769 - val_acc: 0.9833Epoch 10/2057000/57000 [==============================] - 19s 338us/step - loss: 0.0592 - acc: 0.9820 - val_loss: 0.0576 - val_acc: 0.9870Epoch 11/2057000/57000 [==============================] - 19s 336us/step - loss: 0.0600 - acc: 0.9822 - val_loss: 0.0689 - val_acc: 0.9847Epoch 12/2057000/57000 [==============================] - 20s 345us/step - loss: 0.0625 - acc: 0.9813 - val_loss: 0.0689 - val_acc: 0.9843Epoch 13/2057000/57000 [==============================] - 20s 346us/step - loss: 0.0573 - acc: 0.9829 - val_loss: 0.0679 - val_acc: 0.9853Epoch 14/2057000/57000 [==============================] - 19s 342us/step - loss: 0.0555 - acc: 0.9833 - val_loss: 0.0642 - val_acc: 0.9850Epoch 15/2057000/57000 [==============================] - 20s 359us/step - loss: 0.0571 - acc: 0.9831 - val_loss: 0.0779 - val_acc: 0.9833Epoch 16/2057000/57000 [==============================] - 21s 361us/step - loss: 0.0564 - acc: 0.9831 - val_loss: 0.0610 - val_acc: 0.9867Epoch 17/2057000/57000 [==============================] - 20s 354us/step - loss: 0.0574 - acc: 0.9834 - val_loss: 0.0669 - val_acc: 0.9867Epoch 18/2057000/57000 [==============================] - 20s 353us/step - loss: 0.0526 - acc: 0.9848 - val_loss: 0.0863 - val_acc: 0.9830Epoch 19/2057000/57000 [==============================] - 20s 349us/step - loss: 0.0548 - acc: 0.9832 - val_loss: 0.0726 - val_acc: 0.9847Epoch 20/2057000/57000 [==============================] - 20s 352us/step - loss: 0.0512 - acc: 0.9845 - val_loss: 0.0735 - val_acc: 0.9860<keras.callbacks.History at 0x2904822cd30>

训练完毕,测试测试

loss, accuracy = model.evaluate(X_test, y_test)print('Test loss:', loss)print('Accuracy:', accuracy)
10000/10000 [==============================] - 1s 80us/stepTest loss: 0.0864374790877Accuracy: 0.9817
原创粉丝点击