# Copyright (c) Alibaba, Inc. and its affiliates.
import time
from mmcv.runner import Hook, get_dist_info
from .registry import HOOKS
[docs]@HOOKS.register_module
class TIMEHook(Hook):
'''
This hook to show time for runner running process
'''
[docs] def __init__(self, end_momentum=1., **kwargs):
self.end_infer = 0
self.rank, self.num_replicas = get_dist_info()
self.now_time = lambda: int(round(time.time() * 1000))
[docs] def before_train_iter(self, runner):
self.end_load = self.now_time()
if self.rank == 0:
print(self.rank,
' load data need : %d ms' % (self.end_load - self.end_infer))
[docs] def after_train_iter(self, runner):
self.end_infer = self.now_time()
if self.rank == 0:
print(
self.rank,
' infer model need : %d ms' % (self.end_infer - self.end_load))