Source code for easycv.hooks.throughput_hook

# Copyright (c) Alibaba, Inc. and its affiliates.
import time

from mmcv.runner.hooks import Hook
from torch import distributed as dist

from easycv.hooks.registry import HOOKS
from easycv.utils.dist_utils import get_dist_info


[docs]@HOOKS.register_module() class ThroughputHook(Hook): """Count the throughput per second of all steps in the history. `warmup_iters` can be set to skip the calculation of the first few steps, if the initialization of the first few steps is slow. """
[docs] def __init__(self, warmup_iters=0, **kwargs) -> None: self.warmup_iters = warmup_iters self._iter_count = 0 self._start = False
def _reset(self): self._start_time = time.time() self._iter_count = 0 self._start = False
[docs] def before_train_epoch(self, runner): """reset per epoch """ self._reset()
[docs] def before_train_iter(self, runner): if not self._start and self._iter_count == self.warmup_iters: self._start_time = time.time() self._start = True
[docs] def after_train_iter(self, runner): self._iter_count += 1 key = 'avg throughput' batch_size = runner.data_loader.batch_size _, world_size = get_dist_info() total_batch_size = batch_size * world_size # The LoggerHook will average the log_buffer of the latest interval, # but we want to use the total time to calculate the throughput, # so we delete the historical buffers of the key to ensure that # the value printed each time is the total historical average if key in runner.log_buffer.val_history: runner.log_buffer.val_history[key] = [] runner.log_buffer.n_history[key] = [] total_time = time.time() - self._start_time throughput = max(0, (self._iter_count - self.warmup_iters)) * total_batch_size / total_time runner.log_buffer.update({key: throughput}, count=self._iter_count)