Source code for easycv.runner.ev_runner

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from distutils.version import LooseVersion

import torch
from mmcv.runner import EpochBasedRunner
from mmcv.runner.log_buffer import LogBuffer

from easycv.file import io
from easycv.framework.errors import RuntimeError, TypeError
from easycv.utils.checkpoint import load_checkpoint, save_checkpoint

if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'):
    from torch.cuda import amp


[docs]class EVRunner(EpochBasedRunner):
[docs] def __init__(self, model, batch_processor=None, optimizer=None, work_dir=None, logger=None, meta=None, fp16_enable=False): """ Epoch Runner for easycv, add support for oss IO and file sync. Args: model (:obj:`torch.nn.Module`): The model to be run. batch_processor (callable): A callable method that process a data batch. The interface of this method should be `batch_processor(model, data, train_mode) -> dict` optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an optimizer (in most cases) or a dict of optimizers (in models that requires more than one optimizer, e.g., GAN). work_dir (str, optional): The working directory to save checkpoints and logs. Defaults to None. logger (:obj:`logging.Logger`): Logger used during training. Defaults to None. (The default value is just for backward compatibility) meta (dict | None): A dict records some import information such as environment info and seed, which will be logged in logger hook. Defaults to None. fp16_enable (bool): if use fp16 """ super().__init__(model, batch_processor, optimizer, work_dir, logger, meta) self.data_loader = None self.fp16_enable = fp16_enable self.visualization_buffer = LogBuffer() if self.fp16_enable and LooseVersion( torch.__version__) < LooseVersion('1.6.0'): # convert model to fp16 self.model.half() # patch the normalization layers to make it work in fp32 mode from mmcv.runner.fp16_utils import patch_norm_fp32 patch_norm_fp32(self.model)
[docs] def run_iter(self, data_batch, train_mode, **kwargs): """ process for each iteration. Args: data_batch: Batch of dict of data. train_model (bool): If set True, run training step else validation step. """ if self.batch_processor is not None: outputs = self.batch_processor( self.model, data_batch, train_mode=train_mode, **kwargs) elif train_mode: outputs = self.model.train_step(data_batch, self.optimizer, **kwargs) else: outputs = self.model.val_step(data_batch, self.optimizer, **kwargs) if not isinstance(outputs, dict): raise TypeError('"batch_processor()" or "model.train_step()"' 'and "model.val_step()" must return a dict') if 'log_vars' in outputs: self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) self.outputs = outputs
[docs] def train(self, data_loader, **kwargs): """ Training process for one epoch which will iterate through all \ training data and call hooks at different stages. Args: data_loader: data loader object for training """ self.model.train() self.mode = 'train' self.data_loader = data_loader self._max_iters = self._max_epochs * len(self.data_loader) self.call_hook('before_train_epoch') for i, data_batch in enumerate(self.data_loader): self._inner_iter = i self.call_hook('before_train_iter') # use amp from pytorch 1.6 or later, we should use amp.autocast if self.fp16_enable and LooseVersion( torch.__version__) >= LooseVersion('1.6.0'): with amp.autocast(): self.run_iter(data_batch, train_mode=True) else: self.run_iter(data_batch, train_mode=True) self.call_hook('after_train_iter') self._iter += 1 self.call_hook('after_train_epoch') self._epoch += 1
[docs] def val(self, data_loader, **kwargs): """ Validation step which Deprecated, using evaluation hook instead. """ self.model.eval() self.mode = 'val' self.data_loader = data_loader self.call_hook('before_val_epoch') for i, data_batch in enumerate(self.data_loader): self._inner_iter = i self.call_hook('before_val_iter') with torch.no_grad(): self.run_iter(data_batch, train_mode=False) self.call_hook('after_val_iter') self.call_hook('after_val_epoch')
[docs] def save_checkpoint(self, out_dir, filename_tmpl='epoch_{}.pth', save_optimizer=True, meta=None, create_symlink=True): """Save checkpoint to file. Args: out_dir: Directory where checkpoint files are to be saved. filename_tmpl (str, optional): Checkpoint filename pattern. save_optimizer (bool, optional): save optimizer state. meta (dict, optional): Metadata to be saved in checkpoint. """ # implement save checkpoint to oss if meta is None: meta = dict(epoch=self.epoch + 1, iter=self.iter) else: meta.update(epoch=self.epoch + 1, iter=self.iter) filename = filename_tmpl.format(self.epoch + 1) filepath = os.path.join(out_dir, filename) optimizer = self.optimizer if save_optimizer else None if hasattr(self, 'ema'): model = self.ema.model else: model = self.model save_checkpoint(model, filepath, optimizer, meta)
[docs] def current_lr(self): """Get current learning rates. Returns: list[float] | dict[str, list[float]]: Current learning rates of all param groups. If the runner has a dict of optimizers, this method will return a dict. """ if isinstance(self.optimizer, torch.optim.Optimizer): lr = sorted([group['lr'] for group in self.optimizer.param_groups], reverse=True) # avoid lr display error elif isinstance(self.optimizer, dict): lr = dict() for name, optim in self.optimizer.items(): lr[name] = [group['lr'] for group in optim.param_groups] else: raise RuntimeError( 'lr is not applicable because optimizer does not exist.') return lr
# !! notice, map_location should be cpu, other wise may stack to some GPU ,which cause OOMß
[docs] def load_checkpoint(self, filename, map_location=torch.device('cpu'), strict=False, logger=None): """Load checkpoint from a file or URL. Args: filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``, ``oss://xxx``. Please refer to ``docs/source/model_zoo.md`` for details. map_location (str): Same as :func:`torch.load`. strict (bool): Whether to allow different params for the model and checkpoint. logger (:mod:`logging.Logger` or None): The logger for error message. Returns: dict or OrderedDict: The loaded checkpoint. """ if torch.cuda.is_available(): device_id = torch.cuda.current_device() map_location = lambda storage, loc: storage.cuda(device_id) return load_checkpoint( self.model, filename=filename, map_location=map_location, strict=strict, logger=logger)
[docs] def resume(self, checkpoint, resume_optimizer=True, map_location='default'): """ Resume state dict from checkpoint. Args: checkpoint: Checkpoint path resume_optimizer: Whether to resume optimizer state map_location (str): Same as :func:`torch.load`. """ if map_location == 'default': if torch.cuda.is_available(): device_id = torch.cuda.current_device() checkpoint = self.load_checkpoint( checkpoint, map_location=lambda storage, loc: storage.cuda(device_id)) else: checkpoint = self.load_checkpoint(checkpoint) else: checkpoint = self.load_checkpoint( checkpoint, map_location=map_location) self._epoch = checkpoint['meta']['epoch'] self._iter = checkpoint['meta']['iter'] if 'optimizer' in checkpoint and resume_optimizer: self.optimizer.load_state_dict(checkpoint['optimizer']) self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)