PyTorch基本用法(九)——优化器
来源:互联网 发布:监测网络流量的软件 编辑:程序博客网 时间:2024/05/24 00:29
文章作者:Tyan
博客:noahsnail.com | CSDN | 简书
本文主要是关于PyTorch的一些用法。
import torchimport matplotlib.pyplot as pltimport torch.nn.functional as Fimport torch.utils.data as Datafrom torch.autograd import Variable# 定义超参数LR = 0.01BATCH_SIZE = 32EPOCH = 10# 生成数据x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim = 1)y = x.pow(2) + 0.1 * torch.normal(torch.zeros(x.size()))# 绘制数据图像plt.scatter(x.numpy(), y.numpy())plt.show()
# 定义数据库dataset = Data.TensorDataset(data_tensor = x, target_tensor = y)# 定义数据加载器loader = Data.DataLoader(dataset = dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 2)# 定义pytorch网络class Net(torch.nn.Module): def __init__(self, n_features, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_features, n_hidden) self.predict = torch.nn.Linear(n_hidden, n_output) def forward(self, x): x = F.relu(self.hidden(x)) y = self.predict(x) return y
# 定义不同的优化器网络net_SGD = Net(1, 10, 1)net_Momentum = Net(1, 10, 1)net_RMSprop = Net(1, 10, 1)net_Adam = Net(1, 10, 1)# 选择不同的优化方法opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr = LR)opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr = LR, momentum = 0.9)opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr = LR, alpha = 0.9)opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr = LR, betas= (0.9, 0.99))nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]# 选择损失函数loss_func = torch.nn.MSELoss()# 不同方法的lossloss_SGD = []loss_Momentum = []loss_RMSprop =[]loss_Adam = []# 保存所有losslosses = [loss_SGD, loss_Momentum, loss_RMSprop, loss_Adam]# 执行训练for epoch in xrange(EPOCH): for step, (batch_x, batch_y) in enumerate(loader): var_x = Variable(batch_x) var_y = Variable(batch_y) for net, optimizer, loss_history in zip(nets, optimizers, losses): # 对x进行预测 prediction = net(var_x) # 计算损失 loss = loss_func(prediction, var_y) # 每次迭代清空上一次的梯度 optimizer.zero_grad() # 反向传播 loss.backward() # 更新梯度 optimizer.step() # 保存loss记录 loss_history.append(loss.data[0])
# 画图labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']for i, loss_history in enumerate(losses): plt.plot(loss_history, label = labels[i])plt.legend(loc = 'best')plt.xlabel('Steps')plt.ylabel('Loss')plt.ylim((0, 0.2))plt.show()
阅读全文
0 0
- PyTorch基本用法(九)——优化器
- PyTorch基本用法(二)——Variable
- PyTorch基本用法(四)——回归
- PyTorch基本用法(五)——分类
- PyTorch基本用法(一)——Numpy,Torch对比
- PyTorch基本用法(三)——激活函数
- PyTorch基本用法(六)——快速搭建网络
- PyTorch基本用法(八)——批训练
- PyTorch基本用法(十)——卷积神经网络
- PyTorch基本用法(七)——模型的保存与加载
- PyTorch学习系列(九)——参数_定义
- PyTorch学习系列(九)——参数_初始化
- PyTorch学习—PyTorch是什么?
- PyTorch批训练及优化器比较
- 最优化学习笔记(九)——基本的共轭方向算法
- tensorflow的基本用法(九)——定义卷积神经网络训练MNIST
- matplotlib的基本用法(九)——绘制等高线图
- PyTorch(一)——数据处理
- 红帽6.7未注册使用yum源
- AngularJS 自定义过滤器
- Fragmentxml
- 面试题07:Count the smiley faces!
- 一步一步带你搭建后台管理系统之使用requirejs整合常用前端插件
- PyTorch基本用法(九)——优化器
- 洛谷Oj-奇怪的电梯-广度优先搜索
- LoginActivityQQ登陆
- Ai challenger 场景分类: train softmax using tfrecord
- spring之基于aspectj注解aop使用
- SQL Server中的STUFF函数的使用
- JAVA语言的三种技术架构
- HDU 1001 Sum Problem JAVA
- 批量提取文件名称