Pytorch框架下Finetune注意点
来源:互联网 发布:怎样加速手机网络 编辑:程序博客网 时间:2024/06/05 02:24
最近在参加AI challenger的比赛(虽然九月就开始的比赛,到11月才开始玩。。。)结局无所谓,就希望在过程中能学习一些东西,由于场景识别比赛的finetune模型权重都是torch下的,之前尝试了很多权重转化工具,但是发现基本上都不靠谱,所以比赛要继续做下去,只能转向Pytorch,花了整整两天时间来学习Pytorch,也有了较基础的一些收获,现在记录一下,也和有需要的人一起分享。
1、如何入手
其实入手的方法有很多,阅读Pytorch的toturial,然后手册方面,直接看pytorch的中文文档。然后解决一些基础的MNIST或者CIFAR10的问题,也可以快速上手,我选择的也就是利用项目,也就是比赛来快速学习一个工具的使用,但是比赛也需要的一个基础代码来快速学习对吧,我参考的就是这份代码戳我,这个是场景识别Places365数据库下的一个训练代码,写的很全,各种metrics的显示也写了,比官方toturials要全很多,有兴趣的可以看看。
2、迁移学习方面
(1)数据扩增
迁移学习一直是深度学习的重要技术方向,这次比赛我只要采用的也是迁移学习的方式。在迁移学习过程中,有需要设置不同层对应不同的学习率,这点keras没法实现,caffe则可以做到,但是caffe的使用太过于繁琐。而对于Pytorch而言,迁移学习需要的技术点,它都很完美的吻合了,Pytorch有很多的数据扩增方法,可以很方便的调用:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])train_datasets = datasets.ImageFolder(os.path.join(root_dir, train_dir), transforms.Compose([ transforms.RandomResizedCrop(224), #从原图像随机切割一张(224, 224)的图像 transforms.RandomHorizontalFlip(), #以0.5的概率水平翻转 transforms.RandomVerticalFlip(), #以0.5的概率垂直翻转 transforms.RandomRotation(10), #在(-10, 10)范围内旋转 transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), #HSV以及对比度变化 transforms.ToTensor(), #把PIL.Image或者numpy.ndarray对象转化为tensor,并且是[0,1]范围,主要是除以255 normalize, ]))train_loader = torch.utils.data.DataLoader(train_datasets, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
上述代码中,pytorch首先通过datasets底下的ImageFolder对象设置图像目录路径,以及数据扩增的方式,上述数据扩增的解释如注释所示,更详细的解释可以看源码 ,在图片路径设置好之后再通过torch.utils.data下的Dataloader对象来设置数据的生成器,通过简单两步就可以设置好一个带有数据扩增的图像生成器。
在这里,有一个细节需要注意,ImageFolder对象下的图像路径应该符合以下格式:
dir/cat/0.jpgdir/cat/1.jpgdir/dog/0.jpgdir/dog/1.jpg
但是像我一开始,使用的目录格式是这样的:
dir/0/0.jpgdir/0/1.jpgdir/1/0.jpgdir/1/1.jpg
也就是使用数字代替标签,这个在caffe里也是很正常的写法,但是在之后的predict阶段就出现问题了,因为标签目录名并不是pytorch真正的类标签,而是类名,如何获取类标签呢?
train_datasets.class_to_idx
就可以获取类名和类标签的一个字典映射,然后通过以下代码:
idx_to_class = dict(zip(train_datasets.class_to_idx.values(), train_datasets.class_to_idx.keys()))
就可以通过这个字典来获得输出的top精度的标签所对应的类名。
(2)学习率设置
Pytorch下可以很方便的设置不同层对应不同学习率,比如说对于一个model可以如下设置:
optim.SGD([ {'params': model.base.parameters()}, {'params': model.classifier.parameters(), 'lr': 1e-3} ], lr=1e-2, momentum=0.9)
这里代码的含义是除了classifier以外的层的学习率都是1e-2。
(3)迁移学习模型添加新层
通常在迁移学习中,都是直接将最后一层的全连接层大小换成自己数据集类的大小,然后finetune,在这次比赛中,我发现这样的精度并不能提升到最大,因此,采用迁移的base模型来叠加MLP的获取更高的精度,如何在base模型之后,叠加MLP?代码如下:
class model_bn(nn.Module): def __init__(self, model, feature_size): super(model_bn, self).__init__() self.features = nn.Sequential(*list(model.children())[:-1]) self.num_ftrs = model.fc.in_features self.classifier = nn.Sequential( nn.BatchNorm1d(self.num_ftrs), nn.Dropout(0.5), nn.Linear(self.num_ftrs, feature_size), nn.BatchNorm1d(feature_size), nn.ELU(inplace=True), nn.Dropout(0.5), nn.Linear(feature_size, classes_num), ) def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x
在上述代码中,self.features = nn.Sequential(*list(model.children())[:-1]) 直接将全base模型中的最后一层全连接层去除,*是解包的操作,将list类型以元组的方式传递给nn.Sequential对象,之后添加dropout和bn层来获取更高的精度,这里还要注意view函数相当于numpy中的reshape函数,在这里的作用就是keras中的Flatten层,将输出从二维压成一维,在这过程中发现,貌似pytorch没有直接的全局池化层,需要自己定义操作,可以通过以下方式来进行:
import torch.nn.functional as Foutput = F.average_pool2d(input, kernel_size=input.size()[2:])
(4)杂七杂八的坑
要将模型放入cuda中,才能在gpu中优化运行,之前使用:
model = model.cuda()
无效,但是换成下面代码就可以运行了:
model = torch.nn.DataParallel(model).cuda()
在这么设置之后,又会出现一个问题就是优化器无法获取到对应的层,也就是说model.fc会报错,但是仔细观察不难发现,再将模型放入gpu之后,整体模型被module包裹,所以应该通过model.module.fc.parameters()来获取对应层的参数。还有就是这样的设置之后每个batch数据会被平分到所有的gpu上,如果要限制只在一个gpu上运行,则设置即可:
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
最后就是一个数据类型的转化:
(a)模型的输出是variable类型,通过.data获取到tensor类型变量,如果是cuda.tensor,可以通过.cpu()来移动至cpu上。
(b)tensor类型转化为numpy类型,使用:
b = a.numpy()
反之,使用:
a = torch.from_numpy(a)
最后还有一个将图像的目录划分成train和test的方法,这种用法对使用惯了Keras的人来说很常见,而在Pytorch方法中,这个有两种方式,要不一开始目录就划分好,要不然就是动态的通过Pytorch来划分,具体看这两个链接:
Pytorch Train Test Split
Train, Validation and Test Split for torchvision Datasets
- Pytorch框架下Finetune注意点
- pytorch finetune模型
- Torch7,Pytorch安装注意点
- PyTorch参数初始化和Finetune
- PyTorch参数初始化和Finetune
- ARC 下注意点
- ARC 下注意点
- caffe下的finetune训练
- scrapy框架图片下载注意点
- Xe7下编码注意点
- 操作公司框架的注意点
- Spring Roo框架开发注意点
- xUtils 框架的使用注意点
- ThinkPHP框架知识的注意点
- Django框架使用注意点-笔记小结
- caffe Resnet-50 finetune 所有代码+需要注意的地方
- caffe Resnet-50 finetune 所有代码+需要注意的地方
- 【Pytorch】Windows10下配置Pytorch环境
- WebUtils封装返回值
- 前端面试之模块化-1、模块的写法
- 关于最小生成树的一些性质
- RedisTemplate.java
- 浅谈php接收POST数据的三种方式
- Pytorch框架下Finetune注意点
- 100个网络基础知识Q&A
- 如何区分Android wrap_content和fill_parent的详细说明
- RequestContextListener作用(涨知识了,转载自己保存)
- 登陆注册模块解析
- 安装`whl`文件
- hibernate配置文件属性catalog
- 台大-林轩田老师-机器学习基石学习笔记8
- 判断当前数据库类型是mysql还是oracle