Keras入门-预训练模型fine-tune(ResNet)
来源:互联网 发布:java开发实战经典 免费 编辑:程序博客网 时间:2024/06/07 18:34
在深度学习的学习过程中,由于计算资源有限或者训练集较小,但我们又想获得较好较稳定的结果,那么一些已经训练好的模型会对我们有很大帮助,比如 Alex Net, google net, VGG net, ResNet等,那我们怎么对这些已经训练好的模型进行fine-tune来提高准确率呢? 在这篇博客中,我们使用已经训练好的ResNet50网络模型,该模型基于imagenet数据集,实现了对1000种物体的分类。
步骤如下:
1. 下载ResNet50不包含全连接层的模型参数到本地(resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5);
2. 定义好ResNet50的网络结构;
3. 将预训练的模型参数加载到我们所定义的网络结构中;
4. 更改全连接层结构,便于对我们的分类任务进行处
5. 或者根据需要解冻最后几个block,然后以很低的学习率开始训练。我们只选择最后一个block进行训练,是因为训练样本很少,而ResNet50模型层数很多,全部训练肯 定不能训练好,会过拟合。 其次fine-tune时由于是在一个已经训练好的模型上进行的,故权值更新应该是一个小范围的,以免破坏预训练好的特征。
Step1:下载权重数据
地址:点击这里
Step2:定义ResNet50的网络结构
def identity_block(X, f, filters, stage, block): # defining name basis conv_name_base = 'res' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch' # Retrieve Filters F1, F2, F3 = filters # Save the input value. You'll need this later to add back to the main path. X_shortcut = X # First component of main path X = Conv2D(filters = F1, kernel_size = (1, 1), strides = (1,1), padding = 'valid', name = conv_name_base + '2a')(X) X = BatchNormalization(axis = 3, name = bn_name_base + '2a')(X) X = Activation('relu')(X) # Second component of main path (≈3 lines) X = Conv2D(filters= F2, kernel_size=(f,f),strides=(1,1),padding='same',name=conv_name_base + '2b')(X) X = BatchNormalization(axis=3,name=bn_name_base+'2b')(X) X = Activation('relu')(X) # Third component of main path (≈2 lines) X = Conv2D(filters=F3,kernel_size=(1,1),strides=(1,1),padding='valid',name=conv_name_base+'2c')(X) X = BatchNormalization(axis=3,name=bn_name_base+'2c')(X) # Final step: Add shortcut value to main path, and pass it through a RELU activation (≈2 lines) X = Add()([X, X_shortcut]) X = Activation('relu')(X) return Xdef convolutional_block(X, f, filters, stage, block, s = 2): # defining name basis conv_name_base = 'res' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch' # Retrieve Filters F1, F2, F3 = filters # Save the input value X_shortcut = X ##### MAIN PATH ##### # First component of main path X = Conv2D(F1, (1, 1), strides = (s,s),padding='valid',name = conv_name_base + '2a')(X) X = BatchNormalization(axis = 3, name = bn_name_base + '2a')(X) X = Activation('relu')(X) # Second component of main path (≈3 lines) X = Conv2D(F2,(f,f),strides=(1,1),padding='same',name=conv_name_base+'2b')(X) X = BatchNormalization(axis=3,name=bn_name_base+'2b')(X) X = Activation('relu')(X) # Third component of main path (≈2 lines) X = Conv2D(F3,(1,1),strides=(1,1),padding='valid',name=conv_name_base+'2c')(X) X = BatchNormalization(axis=3,name=bn_name_base+'2c')(X) ##### SHORTCUT PATH #### (≈2 lines) X_shortcut = Conv2D(F3,(1,1),strides=(s,s),padding='valid',name=conv_name_base+'1')(X_shortcut) X_shortcut = BatchNormalization(axis=3,name =bn_name_base+'1')(X_shortcut) # Final step: Add shortcut value to main path, and pass it through a RELU activation (≈2 lines) X = Add()([X,X_shortcut]) X = Activation('relu')(X) return X # GRADED FUNCTION: ResNet50def ResNet50(input_shape = (64, 64, 3), classes = 30): # Define the input as a tensor with shape input_shape X_input = Input(input_shape) # Zero-Padding X = ZeroPadding2D((3, 3))(X_input) # Stage 1 X = Conv2D(64, (7, 7), strides = (2, 2), name = 'conv1')(X) X = BatchNormalization(axis = 3, name = 'bn_conv1')(X) X = Activation('relu')(X) X = MaxPooling2D((3, 3), strides=(2, 2))(X) # Stage 2 X = convolutional_block(X, f = 3, filters = [64, 64, 256], stage = 2, block='a', s = 1) X = identity_block(X, 3, [64, 64, 256], stage=2, block='b') X = identity_block(X, 3, [64, 64, 256], stage=2, block='c') ### START CODE HERE ### # Stage 3 (≈4 lines) X = convolutional_block(X, f = 3,filters= [128,128,512],stage=3,block='a',s=2) X = identity_block(X,3,[128,128,512],stage=3,block='b') X = identity_block(X,3,[128,128,512],stage=3,block='c') X = identity_block(X,3,[128,128,512],stage=3,block='d') # Stage 4 (≈6 lines) X = convolutional_block(X,f=3,filters=[256,256,1024],stage=4,block='a',s=2) X = identity_block(X,3,[256,256,1024],stage=4,block='b') X = identity_block(X,3,[256,256,1024],stage=4,block='c') X = identity_block(X,3,[256,256,1024],stage=4,block='d') X = identity_block(X,3,[256,256,1024],stage=4,block='e') X = identity_block(X,3,[256,256,1024],stage=4,block='f') # Stage 5 (≈3 lines) X = convolutional_block(X, f = 3,filters= [512,512,2048],stage=5,block='a',s=2) X = identity_block(X,3,[512,512,2048],stage=5,block='b') X = identity_block(X,3,[512,512,2048],stage=5,block='c') # AVGPOOL (≈1 line). Use "X = AveragePooling2D(...)(X)" X = AveragePooling2D((2,2),strides=(2,2))(X) # output layer X = Flatten()(X) model = Model(inputs = X_input, outputs = X, name='ResNet50') return model
Step3:加载模型权重参数
base_model = ResNet50(input_shape=(224,224,3),classes=30) base_model.load_weights('resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
X = base_model.outputpredictions = Dense(30, activation='softmax')(X)model = Model(inputs=pigModel.input, outputs=predictions)
Step5:编译和训练
model.compile(optimizer='Adam', loss='categorical_crossentropy',metrics=['accuracy'])
es = EarlyStopping(monitor='val_loss', patience=1)model.fit(x=X_train,y=Y_train,epochs=20,batch_size=32,validation_data=(X_val, Y_val),callbacks=[es])
- Keras入门-预训练模型fine-tune(ResNet)
- keras入门 ---在预训练好网络模型上进行fine-tune
- pytorch学习笔记(十一):fine-tune 预训练的模型
- tensorflow & keras fine tune
- CNN训练之fine tune
- MXNet的预训练:fine-tune.py源码详解
- Caffe windows 下进行(微调)fine-tune 模型
- SSD Faster-RCNN使用自己的数据fine-tune训练模型
- caffe— 使用模型进行fine tune
- fine-tuning 预训练的模型文件
- keras面向小数据集的图像分类(VGG-16基础上fine-tune)实现(附代码)
- caffe简易上手指南(三)—— 使用模型进行fine tune
- caffe深度学习(一)fine-tune
- 使用caffe fine-tune一个单标签图像分类模型
- 使用caffe fine-tune一个单标签图像分类模型
- 使用caffe fine-tune一个单标签图像分类模型
- Deep Learning_预训练CNN图片分类模型(AlexNet、VGG、GoogLeNet、Resnet.....)
- keras入门 ---用预训练好网络模型的bottleneck特征
- Android小技巧(二)监听EditText是否输入完毕(用于机器扫码自动输入)
- Windows 如何在cmd命令行中查看、修改、删除与添加环境变量
- angular js 循环输出数据 添加数据(隐藏) 点击模糊查询 (循环为死数据)
- 数据结构之二叉树的性质
- centos7 安装docker
- Keras入门-预训练模型fine-tune(ResNet)
- 编程实现有关SMS4的2个程序之——编程实现线性变换模块
- Golang中gzip过滤器的源码分析与解释
- 初学者----Android Greendao+多线程断点续传
- html5 垂直居中 模拟table布局
- 二叉树-C语言
- SSM框架——详细整合教程(Spring+SpringMVC+MyBatis)
- C语言小练习8
- 6.9