Source code for easycv.hooks.dino_hook

# Copyright (c) Alibaba, Inc. and its affiliates.
import time

import numpy as np
import torch
from mmcv.runner import Hook, get_dist_info

from .registry import HOOKS


[docs]def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): warmup_schedule = np.array([]) warmup_iters = warmup_epochs * niter_per_ep if warmup_epochs > 0: warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) iters = np.arange(epochs * niter_per_ep - warmup_iters) schedule = final_value + 0.5 * (base_value - final_value) * ( 1 + np.cos(np.pi * iters / len(iters))) schedule = np.concatenate((warmup_schedule, schedule)) assert len(schedule) == epochs * niter_per_ep return schedule
[docs]@HOOKS.register_module class DINOHook(Hook): '''Hook in DINO '''
[docs] def __init__(self, momentum_teacher=0.996, weight_decay=0.04, weight_decay_end=0.4, **kwargs): self.momentum_teacher = momentum_teacher self.weight_decay = weight_decay self.weight_decay_end = weight_decay_end
[docs] def before_run(self, runner): # call model init runner.model.module.init_before_train() try: self.rank, self.world_size = get_dist_info() except: self.rank = 0 self.world_size = 1 max_progress = runner.max_epochs self.epoch_length = runner.data_loader[0].__len__() self.momentum_schedule = cosine_scheduler(self.momentum_teacher, 1, max_progress, self.epoch_length) self.wd_schedule = cosine_scheduler(self.weight_decay, self.weight_decay_end, max_progress, self.epoch_length) self.optimizer = runner.optimizer runner.model.module.this_loss = 0 runner.model.module.count = 0 self.epoch_total_loss = 0 self.count = 0
[docs] def before_train_iter(self, runner): cur_iter = runner.iter # setup weight decay for i, param_group in enumerate(self.optimizer.param_groups): if i == 0: # only the first group is regularized param_group['weight_decay'] = self.wd_schedule[cur_iter] # call model ema if cur_iter > 0: runner.model.module.momentum_update_key_encoder( self.momentum_schedule[cur_iter])
[docs] def after_train_iter(self, runner): if self.world_size > 1: t = torch.tensor( [runner.model.module.count, runner.model.module.this_loss], dtype=torch.float64, device='cuda') torch.distributed.barrier() torch.distributed.all_reduce(t) t = t.tolist() self.count += int(t[0]) self.epoch_total_loss += t[1] else: self.count += int(runner.model.module.count) self.epoch_total_loss += runner.model.module.this_loss if runner.iter % 10 == 0 and self.rank == 0: print(' wd : %.4f momentum : %.4f total_avg_loss : %.4f' % (self.wd_schedule[runner.iter], self.momentum_schedule[runner.iter], self.epoch_total_loss / self.count))
[docs] def before_train_epoch(self, runner): # reset epoch loss self.epoch_total_loss = 0 self.count = 0 torch.cuda.empty_cache() # Make sure `torch.cuda.empty_cache` is done and all cache is cleaned time.sleep(3) runner.model.module.cur_epoch = runner.epoch