Source code for easycv.hooks.yolox_mode_switch_hook

# Copyright (c) Alibaba, Inc. and its affiliates.
from mmcv.parallel import is_module_wrapper
from mmcv.runner.hooks import Hook

from .registry import HOOKS


[docs]@HOOKS.register_module() class YOLOXModeSwitchHook(Hook): """Switch the mode of YOLOX during training. This hook turns off the mosaic and mixup data augmentation and switches to use L1 loss in bbox_head. Args: no_aug_epochs (int): The number of latter epochs in the end of the training to close the data augmentation and switch to L1 loss. Default: 15. skip_type_keys (list[str], optional): Sequence of type string to be skip pipeline. Default: ('Mosaic', 'RandomAffine', 'MixUp') """
[docs] def __init__(self, no_aug_epochs=15, skip_type_keys=('MMMosaic', 'MMRandomAffine', 'MMMixUp'), **kwargs): super(YOLOXModeSwitchHook, self).__init__() self.no_aug_epochs = no_aug_epochs self.skip_type_keys = skip_type_keys
[docs] def before_train_epoch(self, runner): """Close mosaic and mixup augmentation and switches to use L1 loss.""" epoch = runner.epoch train_loader = runner.data_loader model = runner.model if is_module_wrapper(model): model = model.module if (epoch + 1) == runner.max_epochs - self.no_aug_epochs: runner.logger.info('No mosaic and mixup aug now!') train_loader.dataset.update_skip_type_keys(self.skip_type_keys) runner.logger.info('Add additional L1 loss now!') model.head.use_l1 = True