keras下基于mnist数据集的cnn
来源:互联网 发布:外汇复盘软件 编辑:程序博客网 时间:2024/06/05 05:33
keras是一个支持theano和thsorflow为后端的深度学习框架,本实例以theano为后端,实现一个简单的cnn网络,通过这个我们也可以体会到cnn的强大之处,
首先要安装keras1.02,python2.7,下载mnist数据集于本地(由于在线下载一直失败)。
主程序如下:
import numpy as npnp.random.seed(1337) # for reproducibilityimport osfrom keras.datasets import mnistfrom keras.models import Sequentialfrom keras.layers.core import Dense, Dropout, Activation, Flattenfrom keras.layers.convolutional import Convolution2D, MaxPooling2Dfrom keras.utils import np_utilsbatch_size = 128nb_classes = 10nb_epoch = 12# input image dimensionsimg_rows, img_cols = 28, 28# number of convolutional filters to usenb_filters = 32# size of pooling area for max poolingnb_pool = 2# convolution kernel sizenb_conv = 3# the data, shuffled and split between train and test sets(X_train, y_train), (X_val, y_val), (X_test, y_test) = mnist.load_data()# Add the depth in the input. Only grayscale so depth is only one# see http://cs231n.github.io/convolutional-networks/#overviewX_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)# Make the value floats in [0;1] instead of int in [0;255]X_train = X_train.astype('float32')X_test = X_test.astype('float32')X_train /= 255X_test /= 255# Display the shapes to check if everything's okprint('X_train shape:', X_train.shape)print(X_train.shape[0], 'train samples')print(X_test.shape[0], 'test samples')# convert class vectors to binary class matrices (ie one-hot vectors)Y_train = np_utils.to_categorical(y_train, nb_classes)Y_test = np_utils.to_categorical(y_test, nb_classes)##############################################################################################model = Sequential()# For an explanation on conv layers see http://cs231n.github.io/convolutional-networks/#conv# By default the stride/subsample is 1# border_mode "valid" means no zero-padding.# If you want zero-padding add a ZeroPadding layer or, if stride is 1 use border_mode="same"model.add(Convolution2D(nb_filters, nb_conv, nb_conv,border_mode = 'valid',input_shape = (1,img_rows, img_cols),dim_ordering='th'))model.add(Activation('relu'))model.add(Convolution2D(nb_filters, nb_conv, nb_conv))model.add(Activation('relu'))# For an explanation on pooling layers see http://cs231n.github.io/convolutional-networks/#poolmodel.add(MaxPooling2D(pool_size=(nb_pool, nb_pool)))model.add(Dropout(0.25))# Flatten the 3D output to 1D tensor for a fully connected layer to accept the inputmodel.add(Flatten())model.add(Dense(128))model.add(Activation('relu'))model.add(Dropout(0.5))model.add(Dense(nb_classes)) # Last layer with one output per classmodel.add(Activation('softmax')) # We want a score simlar to a probability for each class################################################################################################ The function to optimize is the cross entropy between the true label and the output (softmax) of the model# We will use adadelta to do the gradient descent see http://cs231n.github.io/neural-networks-3/#adamodel.compile(loss='categorical_crossentropy', optimizer='adadelta', metrics=["accuracy"])# Make the model learnmodel.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, verbose=1, validation_data=(X_test, Y_test))# Evaluate how the model does on the test setscore = model.evaluate(X_test, Y_test, verbose=0)print('Test score:', score[0])print('Test accuracy:', score[1])
还要改一个地方,就是修改mnist.load_data()函数,改变数据集的打开方式,同时设置mnist数据集路径
import gzipfrom ..utils.data_utils import get_filefrom six.moves import cPickleimport sysdef load_data(path='C:/Users/123/Desktop/mnist.pkl'): # path = get_file(path, origin='https://s3.amazonaws.com/img-datasets/mnist.pkl.gz') path = r'C:/Users/123/Desktop/mnist.pkl' if path.endswith('.gz'): f = gzip.open(path, 'rb') else: f = open(path, 'rb') f = open(path, 'rb') data = cPickle.load(f) f.close() return data
第三次时正确率已经达到80%多,设置的12次。由于程序比较占内存,我只运行了7次。结果正确率到92%左右。
阅读全文
0 0
- keras下基于mnist数据集的cnn
- 基于深度学习框架Keras的CNN分类Mnist
- Keras基于Cifar-10数据集的CNN实现
- keras mnist cnn example
- 基于MNIST数据集的深度学习库keras的学习
- 基于Keras的CNN框架
- 03-Keras之用MNIST数据集训练一个CNN
- 使用Keras搭建一个CNN处理MNIST数据
- TensorFlow系列(4)——基于MNIST数据集的CNN实现
- Keras-4 mnist With CNN
- 基于keras建立简单的CNN
- keras加载MNIST数据集方法
- tensorflow手册实现mnist数据集的CNN
- MNIST(二):基于CNN的mnist识别
- 基于Keras实现CNN
- 用keras实验mnist数据
- 【keras】解决 example 案例中 MNIST 数据集下载不了的问题
- DeepLearning (五) 基于Keras的CNN 训练cifar-10 数据库
- 51 Nod1428活动安排
- linux v4l2编程
- 我与python约个会:08.程序编程基础2~基本数据类型
- spring简介
- BOM.window对象
- keras下基于mnist数据集的cnn
- 首触树链剖分
- LeetCode[338]Counting Bits
- android 特色输入输出
- oracle 通用函数
- 我与python约个会:09.程序编程基础3~组合数据类型
- linux tail
- 快节奏多人在线游戏网络入门系列教程(4):爆头!滞后补偿
- Spark编译