Source code for easycv.hooks.swav_hook

# Copyright (c) Alibaba, Inc. and its affiliates.
import os

import torch
from mmcv.runner import Hook, get_dist_info

from .registry import HOOKS


[docs]@HOOKS.register_module class SWAVHook(Hook): '''Hook in SWAV '''
[docs] def __init__(self, gpu_batch_size=32, dump_path='data/', **kwargs): self.dump_path = dump_path self.queue_length = None self.rank, self.world_size = get_dist_info() self.batch_size = gpu_batch_size if not os.path.exists(self.dump_path): os.makedirs(self.dump_path)
[docs] def before_run(self, runner): runner.model.module.queue = None runner.model.module.queue_path = os.path.join( self.dump_path, 'queue' + str(self.rank) + '.pth') if os.path.isfile(runner.model.module.queue_path): runner.model.module.queue = torch.load( runner.model.module.queue_path)['queue'] # the queue needs to be divisible by the batch size # print(type(runner.model.module)) self.queue_length = runner.model.module.config['queue_length'] self.queue_length -= self.queue_length % ( self.batch_size * self.world_size)
[docs] def before_train_epoch(self, runner): if self.queue_length > 0 and runner.epoch >= runner.model.module.config[ 'epoch_queue_starts'] and runner.model.module.queue is None: runner.model.module.queue = torch.zeros( len(runner.model.module.config['crops_for_assign']), self.queue_length // self.world_size, runner.model.module.feat_dim, ).cuda() return
[docs] def after_train_epoch(self, runner): if runner.model.module.queue is not None: torch.save({'queue': runner.model.module.queue}, runner.model.module.queue_path)