PyTorch基本用法(九)——优化器

来源:互联网 发布:监测网络流量的软件 编辑:程序博客网 时间:2024/05/24 00:29

文章作者:Tyan
博客:noahsnail.com  |  CSDN  |  简书

本文主要是关于PyTorch的一些用法。

import torchimport matplotlib.pyplot as pltimport torch.nn.functional as Fimport torch.utils.data as Datafrom torch.autograd import Variable# 定义超参数LR = 0.01BATCH_SIZE = 32EPOCH = 10# 生成数据x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim = 1)y = x.pow(2) + 0.1  * torch.normal(torch.zeros(x.size()))# 绘制数据图像plt.scatter(x.numpy(), y.numpy())plt.show()

png

# 定义数据库dataset = Data.TensorDataset(data_tensor = x, target_tensor = y)# 定义数据加载器loader = Data.DataLoader(dataset = dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 2)# 定义pytorch网络class Net(torch.nn.Module):    def __init__(self, n_features, n_hidden, n_output):        super(Net, self).__init__()        self.hidden = torch.nn.Linear(n_features, n_hidden)        self.predict = torch.nn.Linear(n_hidden, n_output)    def forward(self, x):        x = F.relu(self.hidden(x))        y = self.predict(x)        return y
# 定义不同的优化器网络net_SGD = Net(1, 10, 1)net_Momentum = Net(1, 10, 1)net_RMSprop = Net(1, 10, 1)net_Adam = Net(1, 10, 1)# 选择不同的优化方法opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr = LR)opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr = LR, momentum = 0.9)opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr = LR, alpha = 0.9)opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr = LR, betas= (0.9, 0.99))nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]# 选择损失函数loss_func = torch.nn.MSELoss()# 不同方法的lossloss_SGD = []loss_Momentum = []loss_RMSprop =[]loss_Adam = []# 保存所有losslosses = [loss_SGD, loss_Momentum, loss_RMSprop, loss_Adam]# 执行训练for epoch in xrange(EPOCH):    for step, (batch_x, batch_y) in enumerate(loader):        var_x = Variable(batch_x)        var_y = Variable(batch_y)        for net, optimizer, loss_history in zip(nets, optimizers, losses):            # 对x进行预测            prediction = net(var_x)            # 计算损失            loss = loss_func(prediction, var_y)            # 每次迭代清空上一次的梯度            optimizer.zero_grad()            # 反向传播            loss.backward()            # 更新梯度            optimizer.step()            # 保存loss记录            loss_history.append(loss.data[0])
# 画图labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']for i, loss_history in enumerate(losses):    plt.plot(loss_history, label = labels[i])plt.legend(loc = 'best')plt.xlabel('Steps')plt.ylabel('Loss')plt.ylim((0, 0.2))plt.show()

png