利用keras(tensorflow) 做cnn mnist识别
来源:互联网 发布:sql if else 多条件 编辑:程序博客网 时间:2024/06/05 03:35
keras图像数据处理以及图像识别小例子
1、数据预处理
数据集请自行下载,数据不大,20来兆
数据具体如下所示:
格式为 要识别的数字.序号.jpg
数据预处理代码,我用的是tensorflow做后端的keras,所以输入维度为(样本量,高,宽,通道)
import osfrom PIL import Imageimport numpy as np#读取文件夹mnist下的42000张图片,图片为灰度图,所以为1通道,#如果是将彩色图作为输入,则将1替换为3,图像大小28*28def load_data(): data = np.empty((42000,1,28,28),dtype="float32") label = np.empty((42000,),dtype="uint8") imgs = os.listdir("d:/mnist") num = len(imgs) for i in range(num): img = Image.open("d:/mnist/"+imgs[i]) arr = np.asarray(img,dtype="float32") data[i,:,:,:] = arr label[i] = int(imgs[i].split('.')[0]) data = data.reshape(42000,28,28,1) return data,labeldata , label = load_data()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
2、cnn大法
代码如下,很简单的用了几层卷积和池化,应该都能看懂,最后训练正确率95左右
from keras.preprocessing.image import ImageDataGeneratorfrom keras.models import Sequentialfrom keras.layers.core import Dense, Dropout, Activation, Flattenfrom keras.layers.advanced_activations import PReLUfrom keras.layers.convolutional import Convolution2D, MaxPooling2Dfrom keras.optimizers import SGD, Adadelta, Adagradfrom keras.utils import np_utils, generic_utilsfrom six.moves import range#加载数据data, label = load_data()print(data.shape[0], ' samples')#label为0~9共10个类别,keras要求格式为binary class matrices,转化一下,直接调用keras提供的这个函数label = np_utils.to_categorical(label, 10)train_data = data[:40000]train_labels = label[:40000]validation_labels = label[40000:]validation_data = data[40000:]################开始建立CNN模型################生成一个modelmodel = Sequential()#第一个卷积层,4个卷积核,每个卷积核大小5*5。1表示输入的图片的通道,灰度图为1通道。#border_mode可以是valid或者full,具体看这里说明:http://deeplearning.net/software/theano/library/tensor/nnet/conv.html#theano.tensor.nnet.conv.conv2d#激活函数用tanh#你还可以在model.add(Activation('tanh'))后加上dropout的技巧: model.add(Dropout(0.5))model.add(Convolution2D(4, 5, 5,input_shape=(28, 28,1)))model.add(Activation('relu'))model.add(MaxPooling2D(pool_size=(2, 2)))#第二个卷积层,8个卷积核,每个卷积核大小3*3。4表示输入的特征图个数,等于上一层的卷积核个数#激活函数用tanh#采用maxpooling,poolsize为(2,2)model.add(Convolution2D(8, 3, 3))model.add(Activation('relu'))model.add(MaxPooling2D(pool_size=(2, 2)))#第三个卷积层,16个卷积核,每个卷积核大小3*3#激活函数用tanh#采用maxpooling,poolsize为(2,2)model.add(Convolution2D(16, 3, 3))model.add(Activation('relu'))model.add(MaxPooling2D(pool_size=(2, 2)))#全连接层,先将前一层输出的二维特征图flatten为一维的。#Dense就是隐藏层。16就是上一层输出的特征图个数。4是根据每个卷积层计算出来的:(28-5+1)得到24,(24-3+1)/2得到11,(11-3+1)/2得到4#全连接有128个神经元节点,初始化方式为normalmodel.add(Flatten())model.add(Dense(128))model.add(Activation('relu'))model.add(Dropout(0.5))#Softmax分类,输出是10类别model.add(Dense(10))model.add(Activation('softmax'))##############开始训练模型###############使用SGD + momentum#model.compile里的参数loss就是损失函数(目标函数)model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])model.fit(train_data, train_labels, nb_epoch=10, batch_size=100, validation_data=(validation_data, validation_labels))json_string = model.to_json()open('d:/my_model_architecture.json','w').write(json_string)model.save_weights('d:/firsttry.h5')
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
3、结果
阅读全文
0 0
- 利用keras(tensorflow) 做cnn mnist识别
- 利用keras(tensorflow) 做cnn mnist识别
- Tensorflow学习笔记(二):利用CNN实现手写数字(mnist)识别
- TensorFlow学习笔记(3)----CNN识别MNIST手写数字
- CNN学习(三)—Tensorflow 进行MNIST手写体识别
- Tensorflow训练CNN网络识别mnist
- tensorflow中CNN对mnist识别
- tensorflow进行MNIST手写数字识别-CNN
- Tensorflow实战-CNN网络Mnist识别
- tensorflow学习之---CNN识别MNIST
- tensorflow之用CNN识别MNIST
- TensorFlow的cnn做mnist例子
- 【TensorFlow】MNIST(使用CNN)
- keras mnist cnn example
- Keras入门课2 -- 使用CNN识别mnist手写数字
- TensorFlow在MNIST中的应用 识别手写数字(OpenCV+TensorFlow+CNN)
- TensorFlow学习笔记(十五)TensorFLow 用mnist数据做CNN
- kaggle mnist tensorflow+keras
- AAPT2 error
- keras 识别Mnist
- 基于PyTorch的深度学习入门教程(七)——PyTorch重点综合实践
- 实验:java call so, 传入传出多参数
- shell基础知识梳理一
- 利用keras(tensorflow) 做cnn mnist识别
- 正规化--预防过拟合
- 图麟科技完成 2.5 亿元 A 轮融资,CV应用还有多少细分领域可供分割?
- Google 何时回归中国?这个问题也许根本就不存在
- 为什么说阿里云的云骨干网系云计算第三代技术标志 | 解读
- Mac OS上设置Django开发环境
- 继康宁之后,苹果又砸3.9亿美元投资iPhone X激光芯片厂商Finisar
- 请编程设计一个登陆界面,要求输入账号和密码(不考虑事件)
- 利用Java Swing技术设计一个鼠标点击速度比赛游戏程序。程序显示一个按钮和一个文本框,用户点击按钮,文本框显示鼠标点击次数。