Source code for easycv.models.backbones.mae_vit_transformer

"""
Mostly copy-paste from
https://github.com/facebookresearch/mae/blob/main/models_mae.py
"""

from functools import partial

import torch
import torch.nn as nn
from timm.models.vision_transformer import Block, PatchEmbed

from easycv.models.utils import get_2d_sincos_pos_embed
from ..registry import BACKBONES


[docs]@BACKBONES.register_module class MaskedAutoencoderViT(nn.Module): """ Masked Autoencoder with VisionTransformer backbone. MaskedAutoencoderViT is mostly same as vit_tranformer_dynamic, but with a random_masking func. MaskedAutoencoderViT model can be loaded by vit_tranformer_dynamic. Args: img_size(int): input image size patch_size(int): patch size in_chans(int): input image channels embed_dim(int): feature dimensions depth(int): number of encoder layers num_heads(int): Parallel attention heads mlp_ratio(float): mlp ratio norm_layer: type of normalization layer """
[docs] def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), ): super().__init__() self.patch_size = patch_size self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) self.num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter( torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding self.blocks = nn.ModuleList([ Block( embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(depth) ]) self.norm = norm_layer(embed_dim)
[docs] def init_weights(self): w = self.patch_embed.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) torch.nn.init.normal_(self.cls_token, std=.02) pos_embed = get_2d_sincos_pos_embed( self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) self.pos_embed.data.copy_( torch.from_numpy(pos_embed).float().unsqueeze(0)) for m in self.modules(): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)
[docs] def random_masking(self, x, mask_ratio): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim len_keep = int(L * (1 - mask_ratio)) noise = torch.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort( noise, dim=1) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather( x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([N, L], device=x.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = torch.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore
[docs] def forward(self, x, mask_ratio): # embed patches x = self.patch_embed(x) # add pos embed w/o cls token x = x + self.pos_embed[:, 1:, :] # masking: length -> length * mask_ratio x, mask, ids_restore = self.random_masking(x, mask_ratio) # append cls token cls_token = self.cls_token + self.pos_embed[:, :1, :] cls_tokens = cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) # apply Transformer blocks for blk in self.blocks: x = blk(x) x = self.norm(x) return x, mask, ids_restore