Keras实现一个简单的CNN的分类例子

来源:互联网 发布:淘宝千里眼数据准确吗 编辑:程序博客网 时间:2024/05/29 10:59

还是将keras样例库中的mnist中的数据集使用CNN进行分类。
注意引包的时候多了一些CNN需要的层。

import numpy as npnp.random.seed(1337)from keras.datasets import  mnistfrom keras.utils import np_utilsfrom keras.models import Sequentialfrom keras.layers import Dense,Activation,Convolution2D,MaxPooling2D,Flattenfrom keras.optimizers import Adam# download the mnist to the path '~/.keras/datasets/' if it is the first time to be called# X shape (60,000 28x28), y shape (10,000, )(X_train, y_train), (X_test, y_test) = mnist.load_data()# data pre-processing,-1 represents the number of samples;1 represents the num of channels,28&28 represents the length,width respectivelyX_train = X_train.reshape(-1,1,28,28)  # normalizeX_test = X_test.reshape(-1,1,28,28)    # normalizey_train = np_utils.to_categorical(y_train,nb_classes=10)y_test = np_utils.to_categorical(y_test, nb_classes=10)#build neural networkmodel=Sequential()model.add(Convolution2D(    nb_filter=32,    nb_col=5,    nb_row=5,    border_mode='same', #padding method    input_shape=(1,     #channels                 28,28) #length and width))model.add(Activation('relu'))model.add(MaxPooling2D(    pool_size=(2,2),    strides=(2,2),    border_mode='same', #padding method))//这是添加第二层神经网络,卷积层,激励函数,池化层model.add(Convolution2D(64,5,5,border_mode='same'))model.add(Activation('relu'))model.add(MaxPooling2D(pool_size=(2,2),border_mode='same'))//将经过池化层之后的三维特征,整理成一维。方便后面建立全链接层model.add(Flatten())//1024像素model.add(Dense(1024))model.add(Activation('relu'))//输出压缩到10维,因为有10个标记model.add(Dense(10))//使用softmax进行分类model.add(Activation('softmax'))# Another way to define your optimizeadam=Adam(lr=1e-4)model.compile(    loss='categorical_crossentropy',    optimizer=adam,    metrics=['accuracy'])print('\nTraining-----------')model.fit(X_train,y_train,nb_epoch=2,batch_size=32)print('\nTesting------------')loss,accuracy=model.evaluate(X_test,y_test)print('test loss: ', loss)print('test accuracy: ', accuracy)

运行结果如下:
这是训练之后的loss&accuracy,经过两轮完全训练
这里写图片描述

这是测试之后的
这里写图片描述

原创粉丝点击