PyTorch学习系列(十五)——如何加载预训练模型?
来源:互联网 发布:androlua源码 编辑:程序博客网 时间:2024/05/21 13:09
PyTorch提供的预训练模型
PyTorch定义了几个常用模型,并且提供了预训练版本:
- AlexNet: AlexNet variant from the “One weird trick” paper.
- VGG: VGG-11, VGG-13, VGG-16, VGG-19 (with and without batch normalization)
- ResNet: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152
- SqueezeNet: SqueezeNet 1.0, and SqueezeNet 1.1
预训练模型可以通过设置pretrained=True来构建:
import torchvision.models as modelsresnet18 = models.resnet18(pretrained=True)vgg16 = models.vgg16(pretrained=True)alexnet = models.alexnet(pretrained=True)squeezenet = models.squeezenet1_0(pretrained=True)
预训练模型期望的输入是RGB图像的mini-batch:(batch_size, 3, H, W),并且H和W不能低于224。图像的像素值必须在范围[0,1]间,并且用均值mean=[0.485, 0.456, 0.406]和方差std=[0.229, 0.224, 0.225]进行归一化。
加载预训练模型
torch.nn.Module对象有函数static_dict()用于返回包含模块所有状态的字典,包括参数和缓存。键是参数名称或者缓存名称。
函数Module::load_state_dict(state_dict)用state_dict中的状态值更新模块的状态值。static_dict中的键应该和函数static_dict()返回的字典中的键完全一样。
下面给出加载预训练的模型的示例:
vgg16 = models.vgg16(pretrained=True)pretrained_dict = vgg16.state_dict()model_dict = model.state_dict()# 1. filter out unnecessary keyspretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}# 2. overwrite entries in the existing state dictmodel_dict.update(pretrained_dict) # 3. load the new state dictmodel.load_state_dict(model_dict)
阅读全文
0 0
- PyTorch学习系列(十五)——如何加载预训练模型?
- pytorch 如何加载部分预训练模型
- PyTorch学习系列(十四)——保存训练好的模型
- PyTorch学习系列(十)——如何在训练时固定一些层?
- PyTorch学习系列(十六)——如何使用cuda进行训练?
- Pytorch学习系列(八)——训练神经网络
- pytorch学习1:如何加载自己的训练数据
- PyTorch(7)——模型的训练和测试、保存和加载
- PyTorch学习之路(level1)——训练一个图像分类模型
- PyTorch学习—PyTorch是什么?
- PyTorch学习系列(一)——加载数据并生成batch数据
- pytorch学习笔记(十一):fine-tune 预训练的模型
- pytorch 模型的加载
- 莫烦PyTorch学习笔记(五)——模型的存取
- PyTorch学习总结(一)——查看模型中间结果
- 使用pytorch预训练模型分类与特征提取
- PyTorch(三)——使用训练好的模型测试自己图片
- pytorch学习笔记(五):保存和加载模型
- java的线程池
- 小甲鱼的一个任务那一课
- 利用spark的随机森林做票房预测
- Qt简易通用开发框架
- 算法作业22
- PyTorch学习系列(十五)——如何加载预训练模型?
- 算法谜题93 击中战舰
- 消息队列的使用场景
- CentOS7 安装SVN
- LR11添加多台负载机配置
- matpltlib画图
- cejs
- clipimg
- mdebug