Source code for easycv.hooks.collate_hook

# Copyright (c) Alibaba, Inc. and its affiliates.
from timm.data import Mixup
from timm.data.mixup import mixup_target

from .registry import HOOKS


class BaseCollateHook(object):
    """Collate fn hook when build dataloader.
    Used when you need to process before or after merges a list of samples to form a mini-batch of Tensor(s).
    """

    def __init__(self) -> None:
        pass

    def before_collate(self, batch):
        return batch

    def after_collate(self, batch):
        return batch


[docs]@HOOKS.register_module() class MixupCollateHook(BaseCollateHook): """Mixedup data batch, should be used after merges a list of samples to form a mini-batch of Tensor(s). """
[docs] def __init__(self, **kwargs): self.mixup = Mixup(**kwargs)
[docs] def after_collate(self, results): batch_size = results['img'].size()[0] assert batch_size % 2 == 0, 'Batch size should be even when using this, but get {}'.format( batch_size) samples = results['img'] targets = results['gt_labels'] if self.mixup.mode == 'elem': lam = self.mixup._mix_elem(samples) elif self.mixup.mode == 'pair': lam = self.mixup._mix_pair(samples) else: lam = self.mixup._mix_batch(samples) device = samples.device targets = mixup_target( target=targets, num_classes=self.mixup.num_classes, lam=lam, smoothing=self.mixup.label_smoothing, device=device) results['img'] = samples results['gt_labels'] = targets return results