PyTorch笔记5-save和load神经网络

来源:互联网 发布:日语扫描翻译软件 编辑:程序博客网 时间:2024/06/05 15:32

本系列笔记为莫烦PyTorch视频教程笔记 github源码

概要

用 PyTorch 训练好神经网络(NN)后,如何保存以便下次要用的时候直接提取使用即可,下面举栗

import torchfrom torch.autograd import Variableimport torch.nn.functional as F      # activation functionimport matplotlib.pyplot as plttorch.manual_seed(1)   # torch seed%matplotlib inline
# fake data# unsqueeze set shape, otherwise (100,)x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # shape(100, 1)y = x.pow(2) + 0.2*torch.rand(x.size())# only the variable can be trainedx, y = Variable(x), Variable(y)# build NNclass Net(torch.nn.Module):    def __init__(self, n_feature, n_hidden, n_output):        super(Net, self).__init__()        self.hidden = torch.nn.Linear(n_feature, n_hidden)     # hidden layer        self.prediction = torch.nn.Linear(n_hidden, n_output)  # output layer    def forward(self, x):        x = F.relu(self.hidden(x))    # activation func for hidden layer        x = self.prediction(x)        return xnet1 = Net(1, 10, 1)print('net1: \n', net1)
net1:  Net (  (hidden): Linear (1 -> 10)  (prediction): Linear (10 -> 1))
loss_func = torch.nn.MSELoss()optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)for epoch in range(100):    prediction = net1(x)    loss = loss_func1(prediction, y)    optimizer.zero_grad()     # clear gradient for next train    loss.backward()    optimizer.step()

保存神经网络

下面用两种方式来保存
net1 为保存整个网络
net2 只保存网络中的参数(速度快,占内存少)

torch.save(net1, './NNPkl/net.pkl')torch.save(net1.state_dict(), './NNPkl/net_parms.pkl')
/Users/yangjiahua/pytorch-test/lib/python3.6/site-packages/torch/serialization.py:147: UserWarning: Couldn't retrieve source code for container of type Net. It won't be checked for correctness upon loading.  "type " + obj.__name__ + ". It won't be checked "

提取神经网络

也有两种方式提取:提取整个网络以及提取网络参数
net2 为提取整个神经网络
net3 为提取神经网络参数,注意,该方式需要先建立一个跟所提取神经网络参数一样的网络架构,然后再赋予参数

net2 = torch.load('./NNPkl/net.pkl')net3 = Net(1, 10, 1)                 # first build same NN as net1net3.load_state_dict(torch.load('./NNPkl/net_parms.pkl'))

可视化比较

画图查看提取保存中的网络跟原来训练的是否一致

plt.figure(1, figsize=(10, 3))plt.subplot(131)plt.title('net1')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.subplot(132)plt.title('net2')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), net2(x).data.numpy(), 'r-', lw=5)plt.subplot(133)plt.title('net3')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), net3(x).data.numpy(), 'r-', lw=5)
[<matplotlib.lines.Line2D at 0x10d9479e8>]

这里写图片描述

原创粉丝点击