keras 多分类一些函数参数设置
来源:互联网 发布:db2分页查询sql 编辑:程序博客网 时间:2024/06/01 10:54
用Lenet-5 识别Mnist数据集为例子:
采用下载好的Mnist数据压缩包转换成PNG图片数据集,加载图片采用keras图像预处理模块中的ImageDataGenerator。
首先import所需要的模块
from keras.preprocessing.image import ImageDataGeneratorfrom keras.models import Modelfrom keras.layers import MaxPooling2D,Input,Convolution2Dfrom keras.layers import Dropout, Flatten, Densefrom keras import backend as K
定义图像数据信息及训练参数
img_width, img_height = 28, 28 train_data_dir = 'dataMnist/train' #train data directoryvalidation_data_dir = 'dataMnist/validation'# validation data directorynb_train_samples = 60000 nb_validation_samples = 10000epochs = 50 batch_size = 32
判断使用的后台
if K.image_dim_ordering() == 'th': input_shape = (3, img_width, img_height)else: input_shape = (img_width, img_height, 3)
网络模型定义
主要注意最后的输出层定义
比如Mnist数据集是要对0~9这10种手写字符进行分类,那么网络的输出层就应该输出一个10维的向量,10维向量的每一维代表该类别的预测概率,所以此处输出层的定义为:
x = Dense(10,activation=’softmax’)(x)
此处因为是多分类问题,Dense()的第一个参数代表输出层节点数,要输出10类则此项值为10,激活函数采用softmax,如果是二分类问题第一个参数可以是1,激活函数可选sigmoid
img_input=Input(shape=input_shape)x=Convolution2D(32, 3, 3, activation='relu', border_mode='same')(img_input)x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)x=Convolution2D(32,3,3,activation='relu',border_mode='same')(x)x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)x=Convolution2D(64,3,3,activation='relu',border_mode='same')(x)x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)x = Flatten(name='flatten')(x)x = Dense(64, activation='relu')(x)x= Dropout(0.5)(x)x = Dense(10,activation='softmax')(x)model=Model(img_input,x)model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy'])model.summary()
利用ImageDataGenerator传入图像数据集
注意用ImageDataGenerator的方法.flow_from_directory()加载图片数据流时,参数class_mode要设为‘categorical’,如果是二分类问题该值可设为‘binary’,另外要设置classes参数为10种类别数字所在文件夹的名字,以列表的形式传入。
train_datagen = ImageDataGenerator( rescale=1. / 255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)# this is the augmentation configuration we will use for testing:# only rescalingtest_datagen = ImageDataGenerator(rescale=1. / 255)train_generator = train_datagen.flow_from_directory( train_data_dir, target_size=(img_width, img_height), batch_size=batch_size, class_mode='categorical', #多分类问题设为'categorical' classes=['0','1','2','3','4','5','6','7','8','9'] #十种数字图片所在文件夹的名字 )validation_generator = test_datagen.flow_from_directory( validation_data_dir, target_size=(img_width, img_height), batch_size=batch_size, class_mode='categorical' )
训练和保存模型及权值
model.fit_generator( train_generator, samples_per_epoch=nb_train_samples, nb_epoch=epochs, validation_data=validation_generator, nb_val_samples=nb_validation_samples )model.save_weights('Mnist123weight.h5')model.save('Mnist123model.h5')
至此训练结束
图片预测
注意model.save()可以将模型以及权值一起保存,而model.save_weights()只保存了网络权值,此时如果要进行预测,必须定义有和训练出该权值所用的网络结构一模一样的一个网络。
此处利用keras.models中的load_model方法加载model.save()所保存的模型,以恢复网络结构和参数。
from keras.models import load_modelfrom keras.preprocessing.image import img_to_array, load_imgimport numpy as npclasses=['0','1','2','3','4','5','6','7','8','9']model=load_model('Mnist123model.h5')while True: img_addr=input('Please input your image address:') if img_addr=="exit": break else: img = load_img(img_addr, False, target_size=(28, 28)) x = img_to_array(img) / 255.0 x = np.expand_dims(x, axis=0) result = model.predict(x) ind=np.argmax(result,1) print('this is a ', classes[ind])
本文原创,转载请注明出处
阅读全文
0 0
- keras 多分类一些函数参数设置
- Keras中的多分类损失函数categorical_crossentropy
- keras + lstm 情感分类
- 利用keras进行分类
- keras -- 实现cifar10分类
- 使用Keras做猫狗分类
- Keras一些基本概念
- 使用Keras进行图像分类
- Keras:2.2搭建分类神经网络
- Keras classifier分类(二)
- keras搬砖系列-分类
- LCD一些参数设置
- LCD一些参数设置
- poi excel一些参数设置
- Hive一些参数设置
- fusioncharts 中的一些参数设置
- MySQLday01(语言分类 一些基本函数)
- Keras backens函数
- maven pom 文件报错
- 原生JS实现瀑布流
- 《Effective Java》(1~2)阅读笔记
- 【pandas使用遇到的问题】 have mixed types. Specify dtype option on import or set low_memory=False.
- win10安装redis及redis客户端使用方法
- keras 多分类一些函数参数设置
- 姚期智:中国金融科技发展的真正挑战是什么?如何解决? 本文作者:温晓桦2017-09-17 18:31 导语:“在金融科技里面,计算机科学的用途已经从台后走到了台中,对核心金融体系的运作上产生一定的
- android NFC通信初探
- 电子护照OCR识别并支持读取功能
- 购物车
- Python进阶----生成器@.@
- Java虚拟机详解——JVM常见问题总结
- VMware 12.5 pro下载+激活密钥 转载 2016年11月28日 00:06:02 270300 热门虚拟机软件VMware Workstation Pro现已更新至v12.5.2。12.0
- 一个菜鸟处理关于ajax向后台传值的问题