PyTorch笔记5-save和load神经网络
来源:互联网 发布:日语扫描翻译软件 编辑:程序博客网 时间:2024/06/05 15:32
本系列笔记为莫烦PyTorch视频教程笔记 github源码
概要
用 PyTorch 训练好神经网络(NN)后,如何保存以便下次要用的时候直接提取使用即可,下面举栗
import torchfrom torch.autograd import Variableimport torch.nn.functional as F # activation functionimport matplotlib.pyplot as plttorch.manual_seed(1) # torch seed%matplotlib inline
# fake data# unsqueeze set shape, otherwise (100,)x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # shape(100, 1)y = x.pow(2) + 0.2*torch.rand(x.size())# only the variable can be trainedx, y = Variable(x), Variable(y)# build NNclass Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer self.prediction = torch.nn.Linear(n_hidden, n_output) # output layer def forward(self, x): x = F.relu(self.hidden(x)) # activation func for hidden layer x = self.prediction(x) return xnet1 = Net(1, 10, 1)print('net1: \n', net1)
net1: Net ( (hidden): Linear (1 -> 10) (prediction): Linear (10 -> 1))
loss_func = torch.nn.MSELoss()optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)for epoch in range(100): prediction = net1(x) loss = loss_func1(prediction, y) optimizer.zero_grad() # clear gradient for next train loss.backward() optimizer.step()
保存神经网络
下面用两种方式来保存
net1 为保存整个网络
net2 只保存网络中的参数(速度快,占内存少)
torch.save(net1, './NNPkl/net.pkl')torch.save(net1.state_dict(), './NNPkl/net_parms.pkl')
/Users/yangjiahua/pytorch-test/lib/python3.6/site-packages/torch/serialization.py:147: UserWarning: Couldn't retrieve source code for container of type Net. It won't be checked for correctness upon loading. "type " + obj.__name__ + ". It won't be checked "
提取神经网络
也有两种方式提取:提取整个网络以及提取网络参数
net2 为提取整个神经网络
net3 为提取神经网络参数,注意,该方式需要先建立一个跟所提取神经网络参数一样的网络架构,然后再赋予参数
net2 = torch.load('./NNPkl/net.pkl')net3 = Net(1, 10, 1) # first build same NN as net1net3.load_state_dict(torch.load('./NNPkl/net_parms.pkl'))
可视化比较
画图查看提取保存中的网络跟原来训练的是否一致
plt.figure(1, figsize=(10, 3))plt.subplot(131)plt.title('net1')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.subplot(132)plt.title('net2')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), net2(x).data.numpy(), 'r-', lw=5)plt.subplot(133)plt.title('net3')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), net3(x).data.numpy(), 'r-', lw=5)
[<matplotlib.lines.Line2D at 0x10d9479e8>]
阅读全文
0 0
- PyTorch笔记5-save和load神经网络
- pytorch-save and load models
- save和load
- matlab中save和load
- matlab 的load和save
- numpy.load和numpy.save
- Tensorflow的save和load
- Julia : 再谈HDF5 的save 和 load
- Docker image批量save和load
- sklearn 模型持久化,save和load
- PyTorch笔记4-快速构建神经网络(NN)
- PyTorch上搭建简单神经网络实现回归和分类
- pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解
- pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解
- [初学笔记] fopen fclose fprintf fileparts, load & save,whos & struct
- BMP格式以及用纯C实现Load和Save
- .NET下枚举类型的Save和Load分析
- BMP格式以及用纯C实现Load和Save
- VS2015 MATLAB混合编程之COM组件
- 共享内存
- 文章标题
- KindEditor编辑器使用
- CSdn测试
- PyTorch笔记5-save和load神经网络
- 作业帮-将json数组里面的每一个对象的value取出生成与之对应的二维数组
- 串口 SWD Jtag
- 约瑟夫环(约瑟夫问题) 采用循环单链表实现
- [ARC066F]Contest with Drinks Hard
- JavaWeb开发模式一:JSP+JavaBean
- python3.5调用face++
- on、where、having的区别
- Java并发编程:Lock