# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import sys

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import get_dist_info

from easycv.framework.errors import KeyError, NotImplementedError, ValueError
from easycv.utils.preprocess_function import (gaussianBlurDynamic,
                                              randomGrayScale, solarize)
from .. import builder
from ..base import BaseModel
from ..registry import MODELS

[docs]class MultiCropWrapper(nn.Module): """ Perform forward pass separately on each resolution input. The inputs corresponding to a single resolution are clubbed and single forward is run on the same resolution inputs. Hence we do several forward passes = number of different resolutions used. We then concatenate all the output features and run the head forward on these concatenated features. """
[docs] def __init__(self, backbone, head): super(MultiCropWrapper, self).__init__() # disable layers dedicated to ImageNet labels classification backbone.fc, backbone.head = nn.Identity(), nn.Identity() self.backbone = backbone self.head = head self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
[docs] def forward(self, x): # convert to list if not isinstance(x, list): x = [x] idx_crops = torch.cumsum( torch.unique_consecutive( torch.tensor([inp.shape[-1] for inp in x]), return_counts=True, )[1], 0) start_idx, output = 0, torch.empty(0).to(x[0].device) for end_idx in idx_crops: _out = self.backbone([start_idx:end_idx])) # The output is a tuple with XCiT model. See: # if isinstance(_out, tuple) or isinstance(_out, list): _out = _out[0] # some backbone doesn't contains avgpool if len(_out.size()) > 2: bs = _out.size(0) _out = self.avgpool(_out).view(bs, -1) # accumulate outputs output =, _out)) start_idx = end_idx # Run the head forward on the concatenated features. # return self.head(output) tp = self.head(output) return tp
[docs]class DINOLoss(nn.Module):
[docs] def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs, nepochs, device, student_temp=0.1, center_momentum=0.9): super().__init__() self.student_temp = student_temp self.center_momentum = center_momentum self.ncrops = ncrops center = torch.zeros(1, out_dim).to(device) self.register_buffer('center', center) # we apply a warm up for the teacher temperature because # a too high temperature makes the training instable at the beginning self.teacher_temp_schedule = np.concatenate( (np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs), np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp))
[docs] def forward(self, student_output, teacher_output, epoch): """ Cross-entropy between softmax outputs of the teacher and student networks. """ student_out = student_output / self.student_temp student_out = student_out.chunk(self.ncrops) # teacher centering and sharpening temp = self.teacher_temp_schedule[epoch] teacher_out = F.softmax((teacher_output - / temp, dim=-1) teacher_out = teacher_out.detach().chunk(2) total_loss = 0 n_loss_terms = 0 for iq, q in enumerate(teacher_out): for v in range(len(student_out)): if v == iq: # we skip cases where student and teacher operate on the same view continue loss = torch.sum( -q * F.log_softmax(student_out[v], dim=-1), dim=-1) total_loss += loss.mean() n_loss_terms += 1 total_loss /= n_loss_terms self.update_center(teacher_output) return total_loss
[docs] @torch.no_grad() def update_center(self, teacher_output): """ Update center used for teacher output. """ batch_center = torch.sum(teacher_output, dim=0, keepdim=True) _, world_size = get_dist_info() if world_size > 1: dist.all_reduce(batch_center) batch_center = batch_center / (len(teacher_output) * world_size) # ema update = * self.center_momentum + batch_center * ( 1 - self.center_momentum)
[docs]def has_batchnorms(model): bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) for name, module in model.named_modules(): if isinstance(module, bn_types): return True return False
[docs]def get_params_groups(model): regularized = [] not_regularized = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # we do not regularize biases nor Norm parameters if name.endswith('.bias') or len(param.shape) == 1: not_regularized.append(param) else: regularized.append(param) return [{ 'params': regularized }, { 'params': not_regularized, 'weight_decay': 0. }]
[docs]@MODELS.register_module class DINO(BaseModel):
[docs] def __init__(self, backbone, train_preprocess=[], neck=None, config=None, pretrained=None): """ Init Moby Args: backbone: backbone config to build vision backbone train_preprocess: [gaussBlur, mixUp, solarize] neck: neck config to build Moby Neck config: DINO parameter config """ super(DINO, self).__init__() self.config = config self.train_preprocess = train_preprocess # we dont need it self.use_tfrecord_input = self.config.get('use_tfrecord_input', False) # dino has 3 augment pipeline, if use_tfrecord_input == True, use this self.train_preprocess_t1 = { 'randomGrayScale': { 'apply_prob': 0.2 }, 'gaussianBlurDynamic': { 'apply_prob': 1.0 } } self.train_preprocess_t2 = { 'randomGrayScale': { 'apply_prob': 0.2 }, 'gaussianBlurDynamic': { 'apply_prob': 0.1 }, 'solarize': { 'threshold': 0.5, 'apply_prob': 0.2 } } self.train_preprocess_s = { 'randomGrayScale': { 'apply_prob': 0.2 }, 'gaussianBlurDynamic': { 'apply_prob': 0.5 } } self.preprocess_key_map = { 'randomGrayScale': randomGrayScale, 'gaussianBlurDynamic': gaussianBlurDynamic, 'solarize': solarize } # build model backbone teacher = builder.build_backbone(backbone) # vit based model, students should assign drop_path_rate if backbone.get('drop_path_rate', None) is not None: print('drop_path_rate : ', self.config['drop_path_rate']) backbone['drop_path_rate'] = self.config['drop_path_rate'] student = builder.build_backbone(backbone) self.backbone = student # build model neck self.tneck = builder.build_neck(neck) neck['use_bn'] = config['use_bn_in_head'] neck['norm_last_layer'] = config['norm_last_layer'] sneck = builder.build_neck(neck) # combine with multi-crop infer wrapper self.teacher = MultiCropWrapper(teacher, sneck) self.student = MultiCropWrapper(student, self.tneck) if has_batchnorms(student): self.student = nn.SyncBatchNorm.convert_sync_batchnorm( self.student) self.teacher = nn.SyncBatchNorm.convert_sync_batchnorm( self.teacher) self.init_no_ddp_teacher = False self.cur_epoch = 0
# self.optim_param = get_params_groups(self.student)
[docs] def get_params_groups(self): model = self.student regularized = [] not_regularized = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # we do not regularize biases nor Norm parameters if name.endswith('.bias') or len(param.shape) == 1: not_regularized.append(param) else: regularized.append(param) return [{ 'params': regularized }, { 'params': not_regularized, 'weight_decay': 0. }]
[docs] def init_weights(self, pretrained=None): # TODO: unify the use of init_weight raise ValueError('Dino `init_weights` has done in backbone and neck')
# if pretrained is not None: # print_log('load model from: {}'.format(pretrained), logger='root') # self.backbone.init_weights(pretrained=pretrained) # self.neck.init_weights(init_linear='kaiming')
[docs] def init_before_train(self): # assign teacher model if hasattr(self.teacher, 'module'): self.teacher_without_ddp = self.teacher.module else: self.teacher_without_ddp = self.teacher self.teacher_without_ddp.load_state_dict(self.student.state_dict()) for p in self.teacher.parameters(): p.requires_grad = False self.dino_loss = DINOLoss( out_dim=self.config.get('out_dim', 65536), ncrops=int(self.config.get('local_crops_number', 8)) + 2, warmup_teacher_temp=self.config.get('warmup_teacher_temp', 0.04), teacher_temp=self.config.get('teacher_temp', 0.4), warmup_teacher_temp_epochs=self.config.get( 'warmup_teacher_temp_epochs', 10), nepochs=self.config.get('epochs', 100), device=p.device) return
[docs] def momentum_update_key_encoder(self, m=0.999): """ ema for dino """ with torch.no_grad(): # m = momentum_schedule[it] # momentum parameter for param_q, param_k in zip(self.student.parameters(), self.teacher_without_ddp.parameters()): - m) * param_q.detach().data)
[docs] def forward_train(self, inputs): # move this part to DINOHOOK # if not self.init_no_ddp_teacher: # self.init_before_train() # self.init_no_ddp_teacher = True # else: # with torch.no_grad(): # no gradient to keys # self.momentum_update_key_encoder() self.student.train() # tensor operate to support some data aug input_list = [] if self.use_tfrecord_input: # DINO DATAAUG SET # data aug for view 0(teacher1) img = inputs[0] for k in self.train_preprocess_t1.keys(): img = self.preprocess_key_map[k](img, **self.train_preprocess_t1[k]) input_list.append(img) # data aug for view 1(teacher2) img = inputs[1] for k in self.train_preprocess_t2.keys(): img = self.preprocess_key_map[k](img, **self.train_preprocess_t2[k]) input_list.append(img) # data aug for view 2~x(students) for img in inputs[2:]: for k in self.train_preprocess_s.keys(): img = self.preprocess_key_map[k]( img, **self.train_preprocess_s[k]) input_list.append(img) else: input_list = inputs # normalize the prototypes teacher_output = self.teacher( input_list[:2]) # only the 2 global views pass through the teacher student_output = self.student(input_list) loss = self.dino_loss(student_output, teacher_output, self.cur_epoch) if not math.isfinite(loss.item()): print( 'Loss is {}, stopping training'.format(loss.item()), force=True) sys.exit(1) if hasattr(self, 'this_loss'): self.this_loss = loss self.count = 1 losses = dict() losses['loss'] = loss return losses
[docs] def forward_test(self, img, **kwargs): pass
[docs] def forward_feature(self, img, **kwargs): """Forward backbone Returns: x (torch.Tensor): feature tensor """ # TODO: fix extract feature return_dict = {} x = self.student(img) return_dict['backbone'] = x if hasattr(self, 'tneck') and self.tneck is not None: feature = self.tneck([self.avg_pool(i) for i in x])[0] else: feature = self.avg_pool(x[-1]) return_dict['neck'] = feature return return_dict
[docs] def forward(self, img, gt_label=None, mode='train', extract_list=['neck'], **kwargs): if mode == 'train': return self.forward_train(img, **kwargs) elif mode == 'test': return self.forward_test(img, **kwargs) elif mode == 'extract': raise NotImplementedError() # rd = self.forward_feature(img) # rv = {} # for name in extract_list: # if name in rd.keys(): # rv[name] = rd[name] # else: # raise 'Extract %s is not support in classification models' % name # if gt_label is not None: # rv['gt_labels'] = gt_label.cpu() # return rv else: raise KeyError('No such mode: {}'.format(mode))