pytorch使用(四)训练网络
来源:互联网 发布:程序员过关 编辑:程序博客网 时间:2024/06/05 15:15
pytorch使用:目录
pytorch使用(四)训练网络
1. 加载数据
# Dataprint('==> Preparing data..')trainset = MyDataset(path_file=pathFile,list_file=trainList,numJoints = 6,type=False)trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=8)testset = MyDataset(path_file=pathFile,list_file=testList,numJoints = 6,type=False)testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=8)
2. 加载网络,设置GPU
use_cuda = torch.cuda.is_available()best_loss = float('inf') # best test lossstart_epoch = 0 # start from epoch 0 or last epoch# Modelnet = MyNet()if use_cuda: net = torch.nn.DataParallel(net, device_ids=[0,1])#gpu id net.cuda() cudnn.benchmark = True
3. 设置学习率等超参数
optimizer = optim.SGD(net.parameters(), lr=0.000001, momentum=0.9, weight_decay=1e-4)
4. 设置loss
criterion = nn.MSELoss()
5.训练
def train(epoch): print('\nEpoch: %d' % epoch) net.train()#网络是训练模式 train_loss = 0 #batch_idx是序号,(input,output)是数据 for batch_idx,(input,output) in enumerate(trainloader): if use_cuda: input = input.cuda() output = output.cuda() #数据从tensor转为Variable input = Variable(input) output = Variable(output) #参数的梯度值初始化为0 optimizer.zero_grad() #前向传播 preds = net(input) #计算loss loss = criterion(preds,output) #反向传播 loss.backward() #更新参数 optimizer.step() train_loss += loss.data[0] print('%.3f %.3f' % (loss.data[0], train_loss / (batch_idx + 1)))
6. 学习率策略设置
optimizer
通过param_group
来管理参数组。param_group
中保存了参数组及其对应的学习率,动量等等.所以我们可以通过更改param_group['lr']
的值来更改对应参数组的学习率.
下面的例子设置学习率每训练10个epoch乘以0.5
def adjust_learning_rate(optimizer,epoch): ''' the learning rate multiply 0.5 every 50 epoch ''' if epoch%10 ==0: for param_group in optimizer.param_groups: param_group['lr'] = param_group['lr'] * 0.5
整体代码:
import osimport os.path as ospimport torchimport torch.nn as nnimport torch.optim as optimimport torch.backends.cudnn as cudnnfrom torch.autograd import Variablefrom MyNet2 import MyNetfrom draw_graph import make_dotfrom datagen import MyDataset#set pathrootPath = '/home/ggy/disk1/ggy/code/Adjacent-frames/YouTube_Pose'txtPath = osp.join(rootPath,'TemporalNet','data','txt')pathFile = osp.join(txtPath,'path.txt')trainList = osp.join(txtPath,'train.txt')testList = osp.join(txtPath,'test.txt')use_cuda = torch.cuda.is_available()best_loss = float('inf') # best test lossstart_epoch = 0 # start from epoch 0 or last epoch# Dataprint('==> Preparing data..')trainset = MyDataset(path_file=pathFile,list_file=trainList,numJoints = 6,type=False)trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=8)testset = MyDataset(path_file=pathFile,list_file=testList,numJoints = 6,type=False)testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=8)# Modelnet = MyNet()if use_cuda: net = torch.nn.DataParallel(net, device_ids=[0,1]) net.cuda() cudnn.benchmark = Truebase_lr = 0.000001optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=1e-4)#losscriterion = nn.MSELoss()# Trainingdef train(epoch): print('\nEpoch: %d' % epoch) net.train() train_loss = 0 for batch_idx,(input,output) in enumerate(trainloader): if use_cuda: input = input.cuda() output = output.cuda() input = Variable(input) output = Variable(output) optimizer.zero_grad() preds = net(input) loss = criterion(preds,output) loss.backward() optimizer.step() train_loss += loss.data[0] print('%.3f %.3f' % (loss.data[0], train_loss / (batch_idx + 1)))def test(epoch): print('\nTest') net.eval() test_loss = 0 for batch_idx,(input,output) in enumerate(trainloader): if use_cuda: input = input.cuda() output = output.cuda() input = Variable(input) output = Variable(output) preds = net(input) loss = criterion(preds, output) test_loss += loss.data[0] print('%.3f %.3f' % (loss.data[0], test_loss / (batch_idx + 1))) # Save checkpoint. global best_loss test_loss /= len(testloader) if test_loss < best_loss: print('Saving..') state = { 'net': net.module.state_dict(), 'loss': test_loss, 'epoch': epoch, } if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') torch.save(state, './checkpoint/ckpt.pth') best_loss = test_lossdef adjust_learning_rate(optimizer,epoch): ''' the learning rate multiply 0.5 every 50 epoch ''' if epoch%50 ==0: for param_group in optimizer.param_groups: param_group['lr'] = param_group['lr'] * 0.5for epoch in range(start_epoch, start_epoch+200): train(epoch) adjust_learning_rate(optimizer,epoch) if epoch % 10 == 0: test(epoch)
阅读全文
0 0
- pytorch使用(四)训练网络
- pytorch 使用预训练层
- pytorch使用(二)自定义网络
- pytorch使用(三)网络结构可视化
- Pytorch入门学习(四)---- 多GPU的使用
- Pytorch学习笔记(四)
- 使用pytorch预训练模型分类与特征提取
- PyTorch预训练
- PyTorch预训练
- pytorch训练imagenet分类
- 我在读pyTorch文档(四)
- pytorch pruning训练自己的数据库(流程+BUG调试)
- pytorch 使用
- 【pytorch】训练集的读取
- PyTorch代码学习-ImageNET训练
- 逻辑思维训练(四)
- pytorch学习笔记(四):输入流水线(input pipeline)
- pytorch学习笔记(四):输入流水线(input pipeline)
- 系统吞吐量、TPS(QPS)、用户并发量、性能测试概念和公式
- tensorflow.slice_input_producer
- FasterRCNN算法:RPN层的深入理解
- 军事理论课答案(西安交大版)
- sklearn CountVectorizer\TfidfVectorizer\TfidfTransformer函数详解
- pytorch使用(四)训练网络
- 搞定字体样式、背景的工具类(shape、selector、drawable)
- 物联网之无线网络技术(Cellular,LPWAN,LAN)
- Couldn't find a tree builder with the features you requested: lxml. Do you need to install a parser
- iOS MRC情况下重写setter getter方法
- 设计模式
- /var/lib/dpkg/info 文件夹作用以及补救方法
- LUOGU P2278 [HNOI2003]操作系统
- Windows10下VC6.0不能进行单步调试