Source code for easycv.hooks.ema_hook

# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import math

import torch
from mmcv.runner import Hook

from easycv.utils import dist_utils, py_util
from .registry import HOOKS


[docs]class ModelEMA: """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models Keep a moving average of everything in the model state_dict (parameters and buffers). This is intended to allow functionality like https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage A smoothed version of the weights is necessary for some training schemes to perform well. This class is sensitive where it is initialized in the sequence of model init, GPU assignment and distributed training wrappers. In Yolo5s, ema help increase mAP from 0.27 to 0.353 """
[docs] def __init__(self, model, decay=0.9999, updates=0): # Create EMA self.model = copy.deepcopy( model.module if dist_utils.is_parallel(model) else model).eval( ) # FP32 EMA # if next(model.parameters()).device.type != 'cpu': # self.model.half() # FP16 EMA self.updates = updates # number of EMA updates self.decay = lambda x: decay * (1 - math.exp( -x / 2000)) # decay exponential ramp (to help early epochs) for p in self.model.parameters(): p.requires_grad_(False)
[docs] def update(self, model): # Update EMA parameters with torch.no_grad(): self.updates += 1 d = self.decay(self.updates) msd = model.module.state_dict() if dist_utils.is_parallel( model) else model.state_dict() # model state_dict for k, v in self.model.state_dict().items(): if v.dtype.is_floating_point: v *= d v += (1. - d) * msd[k].detach()
[docs] def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): # Update EMA attributes py_util.copy_attr(self.model, model, include, exclude)
[docs]@HOOKS.register_module class EMAHook(Hook): """ Hook to carry out Exponential Moving Average """
[docs] def __init__(self, decay=0.9999, copy_model_attr=()): """ Args: decay: decay rate for exponetial moving average copy_model_attr: attribute to copy from origin model to ema model """ self.decay = decay self._copy_model_attr = copy_model_attr self._init_updates = False
[docs] def before_run(self, runner): runner.ema = ModelEMA(runner.model, decay=self.decay)
[docs] def before_train_epoch(self, runner): if not self._init_updates: runner.ema.updates = runner.iter self._init_updates = True
[docs] def after_train_iter(self, runner): runner.ema.update(runner.model)