keras/构建卷积神经网络识别mnist
来源:互联网 发布:php服务器配置 编辑:程序博客网 时间:2024/05/16 14:07
环境:Keras 2.04, python 2.7,GPU
使用深度学习框架keras,构建卷积神经网络识别手写数字,keras在构建神经网络方面比Tensorflow简单很多,而且Tensorflow也将keras作为其高级api
#coding:utf-8"""python 2.7keras 2.0.4"""from keras.utils import np_utilsfrom keras.models import Sequentialfrom keras.layers import Dense,Activation,Convolution2D,MaxPooling2D,Flattenfrom keras.optimizers import Adamfrom sklearn.metrics import confusion_matrix,classification_reportimport numpy as npimport input_dataimport datetimestart_time = datetime.datetime.now()#设置随机种子np.random.seed(1000)#数据格式转换#one_hot=False这里故意使y值为如下表示:(0000000000),目的是后面使用keras的np_utilsmnist = input_data.read_data_sets('mnist/',one_hot=False)#样本数,颜色通道,28行28列train_data=mnist.train.images.reshape(mnist.train.images.shape[0],1,28,28)#通过keras的np_utils将y值转为如下表示:(0000000000)train_labels = np_utils.to_categorical(mnist.train.labels,nb_classes=10)test_data = mnist.test.images.reshape(mnist.test.images.shape[0],1,28,28)test_labels = np_utils.to_categorical(mnist.test.labels,nb_classes=10)#构建模型model = Sequential()#卷积层,32个卷积核,每个卷积核大小5*5,采用same_padding的方式model.add(Convolution2D(nb_filter=32,nb_row=5,nb_col=5,border_mode='same',input_shape=(1,28,28)))#pooling层,采用same padding model.add(MaxPooling2D(pool_size=(2,2),border_mode='same'))model.add(Convolution2D(nb_filter=64,nb_row=5,nb_col=5,border_mode='same'))model.add(MaxPooling2D(pool_size=(2,2),border_mode='same'))#将数据展平model.add(Flatten())#全连接层model.add(Dense(1024))model.add(Activation('relu'))model.add(Dense(10))model.add(Activation('softmax'))#编译模型sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9,nesterov=True) model.compile(optimizer=Adam(lr = 0.001),loss='categorical_crossentropy',metrics=['accuracy'])#训练模型#shuffle就是是否把数据随机打乱之后再进行训练 # verbose是屏显进度条 # validation_split就是拿出百分之多少用来做交叉验证 model.fit(train_data,train_labels,nb_epoch=10,batch_size=50,shuffle=True,verbose=1,validation_split=0.3)#测试集结果c,acc = model.evaluate(test_data,test_labels,batch_size=50)#输出预测分类是0,1,2,3,4,5这种类型predictions = model.predict_classes(test_data,batch_size=50)#混淆矩阵print(confusion_matrix(mnist.test.labels,predictions))#reportprint(classification_report(mnist.test.labels,np.array(predictions)))#模型训练了多久end_time = datetime.datetime.now()total_time = (end_time - start_time).secondsprint('total time is:',total_time)
结果:
实验过程是在GPU上测试的,速度比cpu快很多,进行十轮训练用了582s
阅读全文
0 0
- keras/构建卷积神经网络识别mnist
- keras/构建卷积神经网络人脸识别
- keras构建卷积神经网络识别cifar10
- tensorflow笔记:卷积神经网络用于MNIST识别
- 通过mnist数字识别理解卷积神经网络
- tensorflow1.1/构建双向神经网络识别mnist
- keras:1)初体验-MLP神经网络实现MNIST手写识别
- 使用Keras构建神经网络进行Mnist手写字体分类
- tensorflow1.1/构建卷积神经网络识别文本
- keras 识别Mnist
- 卷积神经网络CNN——使用keras识别猫咪
- keras入门 利用卷积神经网络进行手写数字识别
- [TensorFlow]入门学习笔记(2)-卷积神经网络mnist手写识别
- tensorflow 卷积神经网络 LeNet-5模型 MNIST手写体数字识别
- TensorFlow实战-mnist手写数字识别(卷积神经网络)
- Tensorflow MNIST手写体识别多层卷积神经网络程序实现
- 使用tensorflow卷积神经网络实现mnist手写数字识别
- Tensorflow之 CNN卷积神经网络的MNIST手写数字识别
- python [:-1]
- Oracle 学习(三):pl/sql自动保存上次的窗口界面
- JavaScript基础
- chart.js插件生成折线图时数据普遍较大时Y轴数据不从0开始的解决办法
- 我的C程序设计语言学习日记#01
- keras/构建卷积神经网络识别mnist
- 561. Array Partition I
- [bzoj2199][Usaco2011 Jan]奶牛议会 2-sat
- 数据类型和类型转换
- 高仿App--元贝驾考(二)Dialog工具类
- C++10进制转16进制
- Maven入门指南③:坐标和依赖
- linux通过关键字查找
- linux 和window 双系统下无法显示win引导