Source code for easycv.core.optimizer.lars

# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
from torch.optim import *  # noqa: F401,F403
from torch.optim.optimizer import Optimizer, required

from easycv.framework.errors import ValueError


[docs]class LARS(Optimizer): r"""Implements layer-wise adaptive rate scaling for SGD. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float): base learning rate (\gamma_0) momentum (float, optional): momentum factor (default: 0) ("m") weight_decay (float, optional): weight decay (L2 penalty) (default: 0) ("\beta") dampening (float, optional): dampening for momentum (default: 0) eta (float, optional): LARS coefficient nesterov (bool, optional): enables Nesterov momentum (default: False) Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. Large Batch Training of Convolutional Networks: https://arxiv.org/abs/1708.03888 Example: >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9, >>> weight_decay=1e-4, eta=1e-3) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() """
[docs] def __init__(self, params, lr=required, momentum=0, dampening=0, weight_decay=0, eta=0.001, nesterov=False): if lr is not required and lr < 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) if momentum < 0.0: raise ValueError('Invalid momentum value: {}'.format(momentum)) if weight_decay < 0.0: raise ValueError( 'Invalid weight_decay value: {}'.format(weight_decay)) if eta < 0.0: raise ValueError('Invalid LARS coefficient value: {}'.format(eta)) defaults = dict( lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov, eta=eta) if nesterov and (momentum <= 0 or dampening != 0): raise ValueError( 'Nesterov momentum requires a momentum and zero dampening') super(LARS, self).__init__(params, defaults)
def __setstate__(self, state): super(LARS, self).__setstate__(state) for group in self.param_groups: group.setdefault('nesterov', False)
[docs] @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: weight_decay = group['weight_decay'] momentum = group['momentum'] dampening = group['dampening'] eta = group['eta'] nesterov = group['nesterov'] lr = group['lr'] lars_exclude = group.get('lars_exclude', False) for p in group['params']: if p.grad is None: continue d_p = p.grad if lars_exclude: local_lr = 1. else: weight_norm = torch.norm(p).item() grad_norm = torch.norm(d_p).item() # Compute local learning rate for this layer local_lr = eta * weight_norm / \ (grad_norm + weight_decay * weight_norm) actual_lr = local_lr * lr d_p = d_p.add(p, alpha=weight_decay).mul(actual_lr) if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = \ torch.clone(d_p).detach() else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(d_p, alpha=1 - dampening) if nesterov: d_p = d_p.add(buf, alpha=momentum) else: d_p = buf p.add_(-d_p) return loss