Source code for easycv.models.heads.latent_pred_head

# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
from mmcv.cnn import normal_init

from .. import builder
from ..registry import HEADS


[docs]@HEADS.register_module class LatentPredictHead(nn.Module): '''Head for contrastive learning. '''
[docs] def __init__(self, predictor, size_average=True): super(LatentPredictHead, self).__init__() self.predictor = builder.build_neck(predictor) self.size_average = size_average
[docs] def init_weights(self, init_linear='normal'): self.predictor.init_weights(init_linear=init_linear)
[docs] def forward(self, input, target): ''' Args: input (Tensor): NxC input features. target (Tensor): NxC target features. ''' pred = self.predictor([input])[0] pred_norm = nn.functional.normalize(pred, dim=1) target_norm = nn.functional.normalize(target, dim=1) loss = -2 * (pred_norm * target_norm).sum() if self.size_average: loss /= input.size(0) return dict(loss=loss)
[docs]@HEADS.register_module class LatentClsHead(nn.Module): '''Head for contrastive learning. '''
[docs] def __init__(self, predictor): super(LatentClsHead, self).__init__() self.predictor = nn.Linear(predictor.in_channels, predictor.num_classes) self.criterion = nn.CrossEntropyLoss()
[docs] def init_weights(self, init_linear='normal'): normal_init(self.predictor, std=0.01)
[docs] def forward(self, input, target): ''' Args: input (Tensor): NxC input features. target (Tensor): NxC target features. ''' pred = self.predictor(input) with torch.no_grad(): label = torch.argmax(self.predictor(target), dim=1).detach() loss = self.criterion(pred, label) return dict(loss=loss)