PyTorch基本用法(七)——模型的保存与加载
来源:互联网 发布:java九九乘法表上三角 编辑:程序博客网 时间:2024/06/16 16:42
文章作者:Tyan
博客:noahsnail.com | CSDN | 简书
本文主要是关于PyTorch的一些用法。
import torchimport matplotlib.pyplot as pltfrom torch.autograd import Variable# 生成数据x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)y = x.pow(2) + 0.2 * torch.rand(x.size())# 变为Variablex, y = Variable(x), Variable(y)# 定义网络net = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1))print net
Sequential ( (0): Linear (1 -> 10) (1): ReLU () (2): Linear (10 -> 1))
# 选择优化方法optimizer = torch.optim.SGD(net.parameters(), lr = 0.5)# 选择损失函数loss_func = torch.nn.MSELoss()# 训练网络for i in xrange(1000): # 对x进行预测 prediction = net(x) # 计算损失 loss = loss_func(prediction, y) # 每次迭代清空上一次的梯度 optimizer.zero_grad() # 反向传播 loss.backward() # 更新梯度 optimizer.step()plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw = 5)plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 10, 'color': 'red'})plt.show()
# 保存训练的模型# 保存整个网络和参数torch.save(net, 'net.pkl')# 重新加载模型net = torch.load('net.pkl')# 用新加载的模型进行预测prediction = net(x)plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw = 5)plt.show()
# 只保存网络的参数, 官方推荐的方式torch.save(net.state_dict(), 'net_params.pkl')# 定义网络net = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1))# 加载网络参数net.load_state_dict(torch.load('net_params.pkl'))# 用新加载的参数进行预测prediction = net(x)plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw = 5)plt.show()
阅读全文
0 0
- PyTorch基本用法(七)——模型的保存与加载
- pytorch 保存与加载模型
- scikit-learn的基本用法(八)——模型保存与加载
- scikit-learn的基本用法——模型保存与加载
- 【pytorch】模型的搭建保存加载
- PyTorch(7)——模型的训练和测试、保存和加载
- pytorch 模型的加载
- PyTorch基本用法(二)——Variable
- PyTorch基本用法(四)——回归
- PyTorch基本用法(五)——分类
- PyTorch学习系列(十四)——保存训练好的模型
- pytorch学习笔记(五):保存和加载模型
- pytorch学习笔记(五):保存和加载模型
- 170719 Keras 模型的保存与加载
- tensorflow的基本用法(十)——保存神经网络参数和加载神经网络参数
- PyTorch基本用法(一)——Numpy,Torch对比
- PyTorch基本用法(三)——激活函数
- PyTorch基本用法(六)——快速搭建网络
- 机器学习-python的工作目录
- MATLAB修改默认工作路径
- 包含min函数的栈
- Java 从网页指定url获取图片并压缩到本地
- spring
- PyTorch基本用法(七)——模型的保存与加载
- Java JVM,JDK,JRE简介
- 《android日常bug系列》ViewPager不能左右滑动,原来竟是因为它...
- hdu 5883 The Best Path
- Mac 使用记录
- hdu 1715 大菲波数 (斐波那契数列 大数问题)
- java课程学习一:hello world
- ajax异步前后端
- Java设计模式之——简单工厂模式(静态工厂模式)