pytorch 保存与加载模型

来源:互联网 发布:中标麒麟linux 编辑:程序博客网 时间:2024/06/15 20:21

懒得吐槽自己,折腾了半天。
需要 finetune vgg19_bn, 直接使用 model_zoo.

from torchvision import modelsimport torchmodel = models.vgg19_bn(pretrained=True)torch.save(model.state_dict(), 'vgg19_bn.pkl')

没什么问题,demo看起来一切正常。

由于离线,所以在使用的时候需要加载。

model = torch.load('vgg19_bn.pkl')for param in model.features.parameters():    param.requires_grad = False

好的,开始报错,

orderedDict没有 features属性。

由于pkl文件是序列化文件,把后缀名从.pkl换成.pt,还是一样。最后发现是保存的文件错了,与后缀名无关

from torchvision import modelsimport torchmodel = models.vgg19_bn(pretrained=True)torch.save(model, '1.pkl')torch.save(model, '2.pt')torch.save(model.state_dict(), '3.pkl')torch.save(model.state_dict(), '4.pt')model1 = torch.load('1.pkl')model2 = torch.load('2.pt')model3 = torch.load('3.pkl')model4 = torch.load('4.pt')print('model1 type is ',type(model1))print('model2 type is ',type(model2))print('model3 type is ',type(model3))print('model4 type is ',type(model4))

输出结果为:

model1 type is  <class 'torchvision.models.vgg.VGG'>model2 type is  <class 'torchvision.models.vgg.VGG'>model3 type is  <class 'collections.OrderedDict'>model4 type is  <class 'collections.OrderedDict'>

在加载的时候需要的是VGG文件,保存的内容错了。

原创粉丝点击