pytorch-fineturn the network and adjust learning rate

来源:互联网 发布:手机期货软件排行 编辑:程序博客网 时间:2024/06/05 08:24


1. 

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())


optimizer = torch.optim.SGD([
            {'params': base_params},
            {'params': model.fc.parameters(), 'lr': opt.lr}
        ], lr=opt.lr*0.1, momentum=0.9)



2. adjust learning rate


def adjust_lr(optimizer,  e, base_lr=base_lr, step_size=step_size):

    _lr = base_lr * 0.1 ** ( e // step_size)

    for param_group in optimizer.param_groups:

          param_group['lr'] = _ lr

     return optimizer

def adjust_lr(optimizer,  e, step_size=step_size, decay=0.1):

    if e==0 or e%step_size:

          return optimizer

    for param_group in optimizer.param_groups:

          param_group['lr']  *= decay

          return optimizer


3. lr_scheduler


scheduler = optim.lr_scheduler( optimizer, xxx)


scheduler.step()  # per epoch update onece


for I in range(epoch):

    scheduler.step()

    train(i)

    valid()





------------------------------------------reference-----------------------------

1. https://discuss.pytorch.org/t/how-to-perform-finetuning-in-pytorch/419/7

2. https://discuss.pytorch.org/t/adaptive-learning-rate/320/4

原创粉丝点击