Keras学习之六:训练辅助及优化工具

来源:互联网 发布:Windows无响应 编辑:程序博客网 时间:2024/06/06 00:05

1 Callbacks

Callbacks提供了一系列的类,用于在训练过程中被回调,从而实现对训练过程进行观察和干涉。除了库提供的一些类,用户也可以自定义类。下面列举比较有用的回调类。

类名 作用 构造函数 ModelCheckpoint 用于在epoch间保存要模型 ModelCheckpoint(filepath, monitor=’val_loss’, save_best_only=False, save_weights_only=False, mode=’auto’, period=1) EarlyStopping 当early stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练。 EarlyStopping(monitor=’val_loss’, patience=0, mode=’auto’) TensorBoard 生成tb需要的日志 TensorBoard(log_dir=’./logs’, histogram_freq=0, write_graph=True, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None) ReduceLROnPlateau 当指标变化小时,减少学习率 ReduceLROnPlateau(monitor=’val_loss’, factor=0.1, patience=10, mode=’auto’, epsilon=0.0001, cooldown=0, min_lr=0)

示例:

from keras.callbacks import ModelCheckpointmodel = Sequential()model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))model.add(Activation('softmax'))model.compile(loss='categorical_crossentropy', optimizer='rmsprop')checkpointer = ModelCheckpoint(filepath="/tmp/weights.h5", save_best_only=True)tensbrd = TensorBoard(logdir='path/of/log')model.fit(X_train, Y_train, batch_size=128, callbacks=[checkpointer,tensbrd])

PS:加入tensorboard回调类后,就可以使用tensorflow的tensorboard命令行来打开可视化web服务了。

2 Application

本模块提供了基于image-net预训练好的图像模型,方便我们进行迁移学习使用。初次使用时,模型权重数据会下载到~/.keras/models目录下。

图像模型 说明 构造函数 InceptionV3 InceptionV3(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000) ResNet50 ResNet50(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000) VGG19 VGG19(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000) VGG16 VGG16(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000) Xception Xception(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None, classes=1000)

参数说明

参数 说明 include_top 是否保留顶层的全连接网络, False为只要bottleneck weights ‘imagenet’代表加载预训练权重, None代表随机初始化 input_tensor 可填入Keras tensor作为模型的图像输出tensor input_shape 长为3的tuple,指明输入图片的shape,图片的宽高必须大于197 pooling 特征提取网络的池化方式。None代表不池化,最后一个卷积层的输出为4D张量。‘avg’代表全局平均池化,‘max’代表全局最大值池化 classes 图片分类的类别数,当include_top=True weight=None时可用

关于迁移学习,可以参考这篇文章:如何在极小数据集上实现图像分类。里面介绍了通过图像变换以及使用已有模型并fine-tune新分类器的过程。

3 模型可视化

utils包中提供了plot_model函数,用来将一个model以图像的形式展现出来。此功能依赖pydot-ng与graphviz。
pip install pydot-ng graphviz

from keras.utils import plot_modelmodel = keras.applications.InceptionV3()plot_model(model, to_file='model.png')
原创粉丝点击