【pytorch】模型的搭建保存加载
来源:互联网 发布:奶茶店如何做网络 编辑:程序博客网 时间:2024/06/07 01:30
使用pytorch进行网络模型的搭建、保存与加载,是非常快速、方便的。
那么,如何进行参数初始化呢?使用 torch.nn.init ,如下:
搭建ConvNet
所有的网络都要继承torch.nn.Module,然后在构造函数中使用torch.nn中的提供的接口定义layer的属性,最后,在forward函数中将各个layer连接起来。
下面,以LeNet为例:
class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16*5*5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = F.relu(self.conv1(x)) x = F.max_pool2d(x, 2) x = F.relu(self.conv2(x)) x = F.max_pool2d(x, 2) x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) out = self.fc3(x) return out
这样一来,我们就搭建好了网络模型,是不是很简洁明了呢?此外,还可以使用torch.nn.Sequential,更方便进行模块化的定义,如下:
class LeNetSeq(nn.Module): def __init__(self): super(LeNetSeq, self).__init__() self.conv = nn.Sequential( nn.Conv2d(3, 6, 5), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(6, 16, 5), nn.ReLU(), nn.MaxPool2d(2), ) self.fc = nn.Sequential( nn.Linear(16*5*5, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, 10) ) def forward(self, x): x = self.conv(x) x = out.view(x.size(0), -1) out = self.fc(x) return out
Module有很多属性,可以查看权重、参数等等;如下:
net = lenet.LeNet()print(net)for param in net.parameters(): print(type(param.data), param.size()) print(list(param.data)) print(net.state_dict().keys())#参数的keysfor key in net.state_dict():#模型参数 print key, 'corresponds to', list(net.state_dict()[key])
那么,如何进行参数初始化呢?使用 torch.nn.init ,如下:
def initNetParams(net): '''Init net parameters.''' for m in net.modules(): if isinstance(m, nn.Conv2d): init.xavier_uniform(m.weight) if m.bias: init.constant(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant(m.weight, 1) init.constant(m.bias, 0) elif isinstance(m, nn.Linear): init.normal(m.weight, std=1e-3) if m.bias: init.constant(m.bias, 0)initNetParams(net)
保存ConvNet
使用torch.save()对网络结构和模型参数的保存,有两种保存方式:
- 保存整个神经网络的的结构信息和模型参数信息,save的对象是网络net;
- 保存神经网络的训练模型参数,save的对象是net.state_dict()。
torch.save(net1, 'net.pkl') # 保存整个神经网络的结构和模型参数 torch.save(net1.state_dict(), 'net_params.pkl') # 只保存神经网络的模型参数
加载ConvNet
对应上面两种保存方式,重载方式也有两种。
- 对应第一种完整网络结构信息,重载的时候通过torch.load(‘.pth’)直接初始化新的神经网络对象即可。
- 对应第二种只保存模型参数信息,需要首先导入对应的网络,通过net.load_state_dict(torch.load('.pth'))完成模型参数的重载。
在网络比较大的时候,第一种方法会花费较多的时间,所占的存储空间也比较大。
# 保存和加载整个模型 torch.save(model_object, 'model.pth') model = torch.load('model.pth') # 仅保存和加载模型参数 torch.save(model_object.state_dict(), 'params.pth') model_object.load_state_dict(torch.load('params.pth'))
相关代码可以查看:tfygg/pytorch-tutorials
阅读全文
0 0
- 【pytorch】模型的搭建保存加载
- pytorch 保存与加载模型
- pytorch 模型的加载
- PyTorch基本用法(七)——模型的保存与加载
- PyTorch(7)——模型的训练和测试、保存和加载
- pytorch学习笔记(五):保存和加载模型
- pytorch学习笔记(五):保存和加载模型
- pytorch 如何加载部分预训练模型
- PyTorch学习系列(十四)——保存训练好的模型
- 170719 Keras 模型的保存与加载
- PyTorch快速搭建神经网络及其保存提取方法
- [DL]基于Pytorch的seq2seq模型
- tensorflow保存 和 加载模型
- mxnet模型保存和加载
- tensorflow 模型保存与加载
- Tensorflow 保存和加载模型
- TensorFlow 模型保存与加载
- tensorflow保存和加载模型
- MongoDB使用笔记
- 苹果IOS开发者账号总结
- python3小项目——爬取智联招聘信息(二)
- MTK 开发调试方法
- 中文编程尝试
- 【pytorch】模型的搭建保存加载
- Apache Commons 工具类介绍及简单使用
- ES6 Promise 用法
- 苹果面试8大难题及答案
- 移动端与PHP服务端接口通信流程设计(基础版)
- Delphi 跨语言环境 乱码问题
- struct hostent *host = gethostbyname2([hostName UTF8String], AF_INET);
- 如何解决GET请求中文乱码问题?
- Java for Web学习笔记(六五):Controller替代Servlet(7)上传和下载(自定义View)