Source code for easycv.models.selfsup.mae

import torch

from easycv.framework.errors import KeyError
from .. import builder
from ..base import BaseModel
from ..registry import MODELS


[docs]@MODELS.register_module class MAE(BaseModel):
[docs] def __init__(self, backbone, neck, mask_ratio=0.75, norm_pix_loss=True, **kwargs): super(MAE, self).__init__() self.mask_ratio = mask_ratio self.norm_pix_loss = norm_pix_loss self.encoder = builder.build_backbone(backbone) self.patch_size = self.encoder.patch_size neck['num_patches'] = self.encoder.num_patches self.decoder = builder.build_neck(neck) self.init_weights()
[docs] def init_weights(self): self.encoder.init_weights() self.decoder.init_weights()
[docs] def patchify(self, imgs): """convert image to patch Args: imgs: (N, 3, H, W) Returns: x: (N, L, patch_size**2 *3) """ p = self.patch_size assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 h = w = imgs.shape[2] // p x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) x = torch.einsum('nchpwq->nhwpqc', x) x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) return x
[docs] def forward_loss(self, imgs, pred, mask): """compute loss Args: imgs: (N, 3, H, W) pred: (N, L, p*p*3) mask: (N, L), 0 is keep, 1 is remove, """ target = self.patchify(imgs) if self.norm_pix_loss: mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.e-6)**.5 # adapt to ConvMAE assert pred.shape[0] % target.shape[0] == 0 target = torch.cat([target] * (pred.shape[0] // target.shape[0])) loss = (pred - target)**2 loss = loss.mean(dim=-1) # [N, L], mean loss per patch loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches return loss
[docs] def forward_train(self, img, **kwargs): latent, mask, ids_restore = self.encoder( img, mask_ratio=self.mask_ratio) pred = self.decoder(latent, ids_restore) loss = self.forward_loss(img, pred, mask) return dict(loss=loss)
[docs] def forward_test(self, img, **kwargs): pass
[docs] def forward(self, img, mode='train', **kwargs): if mode == 'train': return self.forward_train(img, **kwargs) elif mode == 'test': return self.forward_test(img, **kwargs) else: raise KeyError('No such mode: {}'.format(mode))