Keras-4 mnist With CNN
来源:互联网 发布:宜昌网络辣妈高云 编辑:程序博客网 时间:2024/06/01 08:47
Keras mnist With CNN
这次,我们将在Keras下利用卷积神经网络(CNN)对mnist进行训练和预测
- 关于卷积神经网络,强烈推荐零基础入门深度学习(4) - 卷积神经网络,有详细解释和公式推导以及代码实现
- Keras中CNN的使用方法,推荐deep-learning-keras-tensorflow
- 完整代码下载: mnist_CNN
OK,废话不多说,让我们开始吧
from keras.models import Sequentialfrom keras.layers.core import Dense, Dropout, Flatten, Activationfrom keras.layers.convolutional import Conv2Dfrom keras.layers.pooling import MaxPooling2Dfrom keras.utils import np_utilsfrom keras.datasets import mnistimport numpy as npimport matplotlib.pyplot as plt
数据准备
导入数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()
观察数据。训练数据共60000个,测试数据10000个,每个样本都是28*28的图像
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
(60000, 28, 28) (60000,) (10000, 28, 28) (10000,)
展示数据
%matplotlib inlinedef plot_sample(X): plt.figure() plt.imshow(X, cmap='gray')
plot_sample(x_train[20])
注意!下面的内容非常重要
在图像的表示上,Theano和TensorFlow发生了分歧。Theano将100张大小为16*32的RGB图像,表示为
那我们在数据准备阶段就要将数据转换成相应的格式
from keras import backend as K
img_rows, img_cols = 28, 28if K.image_data_format() == 'channels_first': shape_ord = (1, img_rows, img_cols)else: shape_ord = (img_rows, img_cols, 1)
预处理
def preprocess_data(X): return X/255
x_train = x_train.reshape((x_train.shape[0],)+shape_ord)x_test = x_test.reshape((x_test.shape[0],)+shape_ord)x_train = x_train.astype('float')x_test = x_test.astype('float')x_train = preprocess_data(x_train)x_test = preprocess_data(x_test)
One-hoe 编码
nb_classes = 10y_train = np_utils.to_categorical(y_train, nb_classes)y_test = np_utils.to_categorical(y_test, nb_classes)
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
(60000, 28, 28, 1) (60000, 10) (10000, 28, 28, 1) (10000, 10)
搭起我们的网络来
设置参数
kernel_size = (3,3)pool_size = (2,2)epochs = 3batch_size = 128nb_filters = 32
设置网络结构
def build_model(): model = Sequential() model.add(Conv2D(nb_filters, kernel_size=kernel_size, input_shape=shape_ord)) model.add(Activation('relu')) model.add(Conv2D(nb_filters//2, kernel_size=kernel_size)) model.add(Activation('relu')) model.add(MaxPooling2D(pool_size=pool_size)) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(128)) model.add(Activation('relu')) model.add(Dropout(0.5)) model.add(Dense(nb_classes)) model.add(Activation('softmax')) return model
编译和训练
model = build_model()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_split=0.05)
Train on 57000 samples, validate on 3000 samplesEpoch 1/357000/57000 [==============================] - 86s 2ms/step - loss: 0.3246 - acc: 0.9001 - val_loss: 0.0599 - val_acc: 0.9843Epoch 2/357000/57000 [==============================] - 84s 1ms/step - loss: 0.1161 - acc: 0.9649 - val_loss: 0.0461 - val_acc: 0.9887Epoch 3/357000/57000 [==============================] - 87s 2ms/step - loss: 0.0865 - acc: 0.9737 - val_loss: 0.0408 - val_acc: 0.9897<keras.callbacks.History at 0x1ecbce457b8>
测试结果
loss, acc = model.evaluate(x_test, y_test, verbose=0)print('Loss :', loss)print('Accuracy :', acc)
Loss : 0.0391287969164Accuracy : 0.9874
显示预测结果
x_test_org = x_test.reshape(x_test.shape[0], img_rows, img_cols) #为了显示图像而进行reshape
nb_predict = 10x_pred = x_test[:nb_predict]prediction = model.predict(x_pred)prediction = prediction.argmax(axis=1)plt.figure(figsize=(16,8))for i in range(nb_predict): plt.subplot(1, nb_predict, i+1) plt.imshow(x_test_org[i]) plt.text(0,-3,prediction[i], color='black') plt.axis('off')
噢耶!一个简单CNN网络就能有98%以上的准确率,CNN真棒。好的,以上就是Keras中如何使用卷积神经网络。
阅读全文
0 0
- Keras-4 mnist With CNN
- keras mnist cnn example
- Keras with R (CNN)
- Keras 深度学习框架Python Example:CNN/mnist
- 基于深度学习框架Keras的CNN分类Mnist
- 03-Keras之用MNIST数据集训练一个CNN
- keras下基于mnist数据集的cnn
- 利用keras(tensorflow) 做cnn mnist识别
- 使用Keras搭建一个CNN处理MNIST数据
- [深度学习框架] Keras上使用CNN进行mnist分类
- 利用keras(tensorflow) 做cnn mnist识别
- Keras入门课2 -- 使用CNN识别mnist手写数字
- Keras MNIST
- Keras-2 Keras Mnist
- kaggle mnist tensorflow+keras
- DCGAN+keras生成mnist
- keras 识别Mnist
- 基于Keras实现CNN
- mysql 进入shell模式
- C/C++求职宝典重点笔记整理
- ESB企业服务总线
- servlet
- Hadoop 新旧API对比
- Keras-4 mnist With CNN
- 在win10系统下安装ubuntu17.10以及基本配置
- 数据库管理,你值得拥有
- 大数定律与中心极限定律
- MS SQL Server 安装错误集合
- 关于v-for无法及时更新到页面上的解决方法
- Java--11-20
- 【java】Proactor模式
- PAT 1018