Keras训练辅助工具及优化工具

来源:互联网 发布:python 爬虫多进程 编辑:程序博客网 时间:2024/06/03 23:44

原文:http://blog.csdn.net/zzulp/article/details/76591341

1 Callbacks

Callbacks提供了一系列的类,用于在训练过程中被回调,从而实现对训练过程进行观察和干涉。除了库提供的一些类,用户也可以自定义类。下面列举比较有用的回调类。

类名作用构造函数ModelCheckpoint用于在epoch间保存要模型ModelCheckpoint(filepath, monitor=’val_loss’, save_best_only=False, save_weights_only=False, mode=’auto’, period=1)EarlyStopping当early stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练。EarlyStopping(monitor=’val_loss’, patience=0, mode=’auto’)TensorBoard生成tb需要的日志TensorBoard(log_dir=’./logs’, histogram_freq=0, write_graph=True, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None)ReduceLROnPlateau当指标变化小时,减少学习率ReduceLROnPlateau(monitor=’val_loss’, factor=0.1, patience=10, mode=’auto’, epsilon=0.0001, cooldown=0, min_lr=0)

示例:

from keras.callbacks import ModelCheckpointmodel = Sequential()model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))model.add(Activation('softmax'))model.compile(loss='categorical_crossentropy', optimizer='rmsprop')checkpointer = ModelCheckpoint(filepath="/tmp/weights.h5", save_best_only=True)tensbrd = TensorBoard(logdir='path/of/log')model.fit(X_train, Y_train, batch_size=128, callbacks=[checkpointer,tensbrd])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

PS:加入tensorboard回调类后,就可以使用tensorflow的tensorboard命令行来打开可视化web服务了。

2 Application

本模块提供了基于image-net预训练好的图像模型,方便我们进行迁移学习使用。初次使用时,模型权重数据会下载到~/.keras/models目录下。

图像模型说明构造函数InceptionV3 InceptionV3(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000)ResNet50 ResNet50(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000)VGG19 VGG19(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000)VGG16 VGG16(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000)Xception Xception(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None, classes=1000)

参数说明

参数说明include_top是否保留顶层的全连接网络, False为只要bottleneckweights‘imagenet’代表加载预训练权重, None代表随机初始化input_tensor可填入Keras tensor作为模型的图像输出tensorinput_shape长为3的tuple,指明输入图片的shape,图片的宽高必须大于197pooling特征提取网络的池化方式。None代表不池化,最后一个卷积层的输出为4D张量。‘avg’代表全局平均池化,‘max’代表全局最大值池化classes图片分类的类别数,当include_top=True weight=None时可用

关于迁移学习,可以参考这篇文章:如何在极小数据集上实现图像分类。里面介绍了通过图像变换以及使用已有模型并fine-tune新分类器的过程。

3 模型可视化

utils包中提供了plot_model函数,用来将一个model以图像的形式展现出来。此功能依赖pydot-ng与graphviz。 
pip install pydot-ng graphviz

from keras.utils import plot_modelmodel = keras.applications.InceptionV3()plot_model(model, to_file='model.png')
  • 1
  • 2
  • 3
  • 1
  • 2
  • 3


原创粉丝点击