Source code for easycv.models.selfsup.necks

# Copyright (c) Alibaba, Inc. and its affiliates.
from functools import partial

import torch
import torch.nn as nn
from packaging import version
from timm.models.vision_transformer import Block

from easycv.models.utils import get_2d_sincos_pos_embed
from ..registry import NECKS
from ..utils import _init_weights, build_norm_layer, trunc_normal_

[docs]@NECKS.register_module class DINONeck(nn.Module):
[docs] def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): super().__init__() nlayers = max(nlayers, 1) if nlayers == 1: self.mlp = nn.Linear(in_dim, bottleneck_dim) else: layers = [nn.Linear(in_dim, hidden_dim)] if use_bn: layers.append(nn.BatchNorm1d(hidden_dim)) # layers.append(build_norm_layer(dict(type='SyncBN'), hidden_dim)[1]) layers.append(nn.GELU()) for _ in range(nlayers - 2): layers.append(nn.Linear(hidden_dim, hidden_dim)) if use_bn: layers.append(nn.BatchNorm1d(hidden_dim)) # layers.append(build_norm_layer(dict(type='SyncBN'), hidden_dim)[1]) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, bottleneck_dim)) self.mlp = nn.Sequential(*layers) self.apply(self._init_weights) self.last_layer = nn.utils.weight_norm( nn.Linear(bottleneck_dim, out_dim, bias=False)) if norm_last_layer: self.last_layer.weight_g.requires_grad = False
def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0)
[docs] def forward(self, x): x = self.mlp(x) x = nn.functional.normalize(x, dim=-1, p=2) x = self.last_layer(x) return x
[docs]@NECKS.register_module class MoBYMLP(nn.Module):
[docs] def __init__(self, in_channels=256, hid_channels=4096, out_channels=256, num_layers=2, with_avg_pool=True): super(MoBYMLP, self).__init__() # hidden layers linear_hidden = [nn.Identity()] for i in range(num_layers - 1): linear_hidden.append( nn.Linear(in_channels if i == 0 else hid_channels, hid_channels)) linear_hidden.append(nn.BatchNorm1d(hid_channels)) linear_hidden.append(nn.ReLU(inplace=True)) self.linear_hidden = nn.Sequential(*linear_hidden) self.linear_out = nn.Linear( in_channels if num_layers == 1 else hid_channels, out_channels) if num_layers >= 1 else nn.Identity() self.with_avg_pool = True self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
[docs] def forward(self, x): x = x[0] if self.with_avg_pool and len(x.shape) == 4: bs = x.shape[0] x = self.avg_pool(x).view([bs, -1]) # print(x.shape) # exit() x = self.linear_hidden(x) x = self.linear_out(x) return [x]
[docs] def init_weights(self, init_linear='normal'): _init_weights(self, init_linear)
[docs]@NECKS.register_module class NonLinearNeckSwav(nn.Module): '''The non-linear neck in byol: fc-syncbn-relu-fc '''
[docs] def __init__(self, in_channels, hid_channels, out_channels, with_avg_pool=True, export=False): super(NonLinearNeckSwav, self).__init__() if version.parse(torch.__version__) < version.parse('1.4.0'): self.expand_for_syncbn = True else: self.expand_for_syncbn = False self.with_avg_pool = with_avg_pool if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.export = export if not self.export: _, self.bn0 = build_norm_layer(dict(type='SyncBN'), hid_channels) else: _, self.bn0 = build_norm_layer(dict(type='BN'), hid_channels) self.fc0 = nn.Linear(in_channels, hid_channels) self.relu = nn.ReLU(inplace=True) self.fc1 = nn.Linear(hid_channels, out_channels)
def _forward_syncbn(self, module, x): assert x.dim() == 2 # syncbn < torch1.4.0 or bn while export need unsqueeze 4D dims if self.expand_for_syncbn or self.export: x = module(x.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1) else: x = module(x) return x
[docs] def init_weights(self, init_linear='normal'): _init_weights(self, init_linear)
[docs] def forward(self, x): assert len(x) == 1 or len(x) == 2, 'Got: {}'.format( len(x)) # fit for vit model x = x[0] if self.with_avg_pool: x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc0(x) x = self._forward_syncbn(self.bn0, x) x = self.relu(x) x = self.fc1(x) return [x]
[docs]@NECKS.register_module class NonLinearNeckV0(nn.Module): '''The non-linear neck in ODC, fc-bn-relu-dropout-fc-relu '''
[docs] def __init__(self, in_channels, hid_channels, out_channels, sync_bn=False, with_avg_pool=True): super(NonLinearNeckV0, self).__init__() self.with_avg_pool = with_avg_pool if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) if version.parse(torch.__version__) < version.parse('1.4.0'): self.expand_for_syncbn = True else: self.expand_for_syncbn = False self.fc0 = nn.Linear(in_channels, hid_channels) if sync_bn: _, self.bn0 = build_norm_layer( dict(type='SyncBN', momentum=0.001, affine=False), hid_channels) else: self.bn0 = nn.BatchNorm1d( hid_channels, momentum=0.001, affine=False) self.fc1 = nn.Linear(hid_channels, out_channels) self.relu = nn.ReLU(inplace=True) self.drop = nn.Dropout() self.sync_bn = sync_bn
[docs] def init_weights(self, init_linear='normal'): _init_weights(self, init_linear)
def _forward_syncbn(self, module, x): assert x.dim() == 2 if self.expand_for_syncbn: x = module(x.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1) else: x = module(x) return x
[docs] def forward(self, x): assert len(x) == 1 or len(x) == 2 # to fit vit model x = x[0] if self.with_avg_pool: x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc0(x) if self.sync_bn: x = self._forward_syncbn(self.bn0, x) else: x = self.bn0(x) x = self.relu(x) x = self.drop(x) x = self.fc1(x) x = self.relu(x) return [x]
[docs]@NECKS.register_module class NonLinearNeckV1(nn.Module): '''The non-linear neck in MoCO v2: fc-relu-fc '''
[docs] def __init__(self, in_channels, hid_channels, out_channels, with_avg_pool=True): super(NonLinearNeckV1, self).__init__() self.with_avg_pool = with_avg_pool if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.mlp = nn.Sequential( nn.Linear(in_channels, hid_channels), nn.ReLU(inplace=True), nn.Linear(hid_channels, out_channels))
[docs] def init_weights(self, init_linear='normal'): _init_weights(self, init_linear)
[docs] def forward(self, x): # assert len(x) == 1 or len(x)==2 # to fit vit model, vit model extract 2 features, we use first x = x[0] if self.with_avg_pool: x = self.avgpool(x) return [self.mlp(x.view(x.size(0), -1))]
[docs]@NECKS.register_module class NonLinearNeckV2(nn.Module): '''The non-linear neck in byol: fc-bn-relu-fc '''
[docs] def __init__(self, in_channels, hid_channels, out_channels, with_avg_pool=True): super(NonLinearNeckV2, self).__init__() self.with_avg_pool = with_avg_pool if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.mlp = nn.Sequential( nn.Linear(in_channels, hid_channels), nn.BatchNorm1d(hid_channels), nn.ReLU(inplace=True), nn.Linear(hid_channels, out_channels))
[docs] def init_weights(self, init_linear='normal'): _init_weights(self, init_linear)
[docs] def forward(self, x): assert len(x) == 1 or len(x) == 2, 'Got: {}'.format( len(x)) # to fit vit model x = x[0] if self.with_avg_pool: x = self.avgpool(x) return [self.mlp(x.view(x.size(0), -1))]
[docs]@NECKS.register_module class NonLinearNeckSimCLR(nn.Module): '''SimCLR non-linear neck. Structure: fc(no_bias)-bn(has_bias)-[relu-fc(no_bias)-bn(no_bias)]. The substructures in [] can be repeated. For the SimCLR default setting, the repeat time is 1. However, PyTorch does not support to specify (weight=True, bias=False). It only support \"affine\" including the weight and bias. Hence, the second BatchNorm has bias in this implementation. This is different from the offical implementation of SimCLR. Since SyncBatchNorm in pytorch<1.4.0 does not support 2D input, the input is expanded to 4D with shape: (N,C,1,1). I am not sure if this workaround has no bugs. See the pull request here: Args: in_channels: input channel number hid_channels: hidden channels out_channels: output channel number num_layers (int): number of fc layers, it is 2 in the SimCLR default setting. with_avg_pool: output with average pooling '''
[docs] def __init__(self, in_channels, hid_channels, out_channels, num_layers=2, with_avg_pool=True): super(NonLinearNeckSimCLR, self).__init__() self.with_avg_pool = with_avg_pool if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) if version.parse(torch.__version__) < version.parse('1.4.0'): self.expand_for_syncbn = True else: self.expand_for_syncbn = False self.relu = nn.ReLU(inplace=True) self.fc0 = nn.Linear(in_channels, hid_channels, bias=False) _, self.bn0 = build_norm_layer(dict(type='SyncBN'), hid_channels) self.fc_names = [] self.bn_names = [] for i in range(1, num_layers): this_channels = out_channels if i == num_layers - 1 \ else hid_channels self.add_module('fc{}'.format(i), nn.Linear(hid_channels, this_channels, bias=False)) self.add_module( 'bn{}'.format(i), build_norm_layer(dict(type='SyncBN'), this_channels)[1]) self.fc_names.append('fc{}'.format(i)) self.bn_names.append('bn{}'.format(i))
[docs] def init_weights(self, init_linear='normal'): _init_weights(self, init_linear)
def _forward_syncbn(self, module, x): assert x.dim() == 2 if self.expand_for_syncbn: x = module(x.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1) else: x = module(x) return x
[docs] def forward(self, x): assert len(x) == 1 or len(x) == 2 # to fit vit model x = x[0] if self.with_avg_pool: x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc0(x) x = self._forward_syncbn(self.bn0, x) for fc_name, bn_name in zip(self.fc_names, self.bn_names): fc = getattr(self, fc_name) bn = getattr(self, bn_name) x = self.relu(x) x = fc(x) x = self._forward_syncbn(bn, x) return [x]
[docs]@NECKS.register_module class RelativeLocNeck(nn.Module): '''Relative patch location neck: fc-bn-relu-dropout '''
[docs] def __init__(self, in_channels, out_channels, sync_bn=False, with_avg_pool=True): super(RelativeLocNeck, self).__init__() self.with_avg_pool = with_avg_pool if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) if version.parse(torch.__version__) < version.parse('1.4.0'): self.expand_for_syncbn = True else: self.expand_for_syncbn = False self.fc = nn.Linear(in_channels * 2, out_channels) if sync_bn: _, = build_norm_layer( dict(type='SyncBN', momentum=0.003), out_channels) else: = nn.BatchNorm1d(out_channels, momentum=0.003) self.relu = nn.ReLU(inplace=True) self.drop = nn.Dropout() self.sync_bn = sync_bn
[docs] def init_weights(self, init_linear='normal'): _init_weights(self, init_linear, std=0.005, bias=0.1)
def _forward_syncbn(self, module, x): assert x.dim() == 2 if self.expand_for_syncbn: x = module(x.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1) else: x = module(x) return x
[docs] def forward(self, x): assert len(x) == 1 or len(x) == 2 # to fit vit model x = x[0] if self.with_avg_pool: x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) if self.sync_bn: x = self._forward_syncbn(, x) else: x = x = self.relu(x) x = self.drop(x) return [x]
[docs]@NECKS.register_module class MAENeck(nn.Module): """MAE decoder Args: num_patches(int): number of patches from encoder embed_dim(int): encoder embedding dimension patch_size(int): encoder patch size in_chans(int): input image channels decoder_embed_dim(int): decoder embedding dimension decoder_depth(int): number of decoder layers decoder_num_heads(int): Parallel attention heads mlp_ratio(float): mlp ratio norm_layer: type of normalization layer """
[docs] def __init__(self, num_patches, embed_dim=768, patch_size=16, in_chans=3, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6)): super().__init__() self.num_patches = num_patches self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) self.decoder_pos_embed = nn.Parameter( torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) self.decoder_blocks = nn.ModuleList([ Block( decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for _ in range(decoder_depth) ]) self.decoder_norm = norm_layer(decoder_embed_dim) self.decoder_pred = nn.Linear( decoder_embed_dim, patch_size**2 * in_chans, bias=True)
[docs] def init_weights(self): torch.nn.init.normal_(self.mask_token, std=.02) decoder_pos_embed = get_2d_sincos_pos_embed( self.decoder_pos_embed.shape[-1], int(self.num_patches**.5), cls_token=True) torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights)
def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)
[docs] def forward(self, x, ids_restore): # embed tokens x = self.decoder_embed(x) # append mask tokens to sequence mask_tokens = self.mask_token.repeat( x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) x_ =[x[:, 1:, :], mask_tokens], dim=1) x_ = torch.gather( x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) x =[x[:, :1, :], x_], dim=1) # add pos embed x = x + self.decoder_pos_embed # apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x = self.decoder_norm(x) # predictor projection x = self.decoder_pred(x) # remove cls token x = x[:, 1:, :] return x
[docs]@NECKS.register_module class FastConvMAENeck(MAENeck): """Fast ConvMAE decoder, refer to: Args: num_patches (int): number of patches from encoder embed_dim (int): encoder embedding dimension patch_size (int): encoder patch size in_channels (int): input image channels decoder_embed_dim (int): decoder embedding dimension decoder_depth (int): number of decoder layers decoder_num_heads (int): Parallel attention heads mlp_ratio (float): mlp ratio norm_layer: type of normalization layer """
[docs] def __init__(self, num_patches, embed_dim=768, patch_size=16, in_channels=3, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6)): super().__init__( num_patches=num_patches, embed_dim=embed_dim, patch_size=patch_size, in_chans=in_channels, decoder_embed_dim=decoder_embed_dim, decoder_depth=decoder_depth, decoder_num_heads=decoder_num_heads, mlp_ratio=mlp_ratio, norm_layer=norm_layer) self.decoder_pos_embed = nn.Parameter( torch.zeros(1, num_patches, decoder_embed_dim), requires_grad=False)
[docs] def init_weights(self): decoder_pos_embed = get_2d_sincos_pos_embed( self.decoder_pos_embed.shape[-1], int(self.num_patches**.5), cls_token=False) torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) torch.nn.init.normal_(self.mask_token, std=.02) # initialize nn.Linear and nn.LayerNorm self.apply(super()._init_weights)
[docs] def forward(self, x, ids_restore): # embed tokens x = self.decoder_embed(x) # append mask tokens to sequence mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) x_ =[x, mask_tokens], dim=1) # no cls token B, L, C = x_.shape x_split1 = x_[:B // 4, :, :] x_split2 = torch.roll(x_[B // 4:B // 4 * 2, :, :], 49, 1) x_split3 = torch.roll(x_[B // 4 * 2:B // 4 * 3, :, :], 49 * 2, 1) x_split4 = torch.roll(x_[B // 4 * 3:, :, :], 49 * 3, 1) x_ =[x_split1, x_split2, x_split3, x_split4]) ids_restore = [ids_restore, ids_restore, ids_restore, ids_restore]) x = torch.gather( x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle # add pos embed x = x + self.decoder_pos_embed # apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x = self.decoder_norm(x) # predictor projection x = self.decoder_pred(x) return x