Source code for easycv.models.loss.mse_loss

# Copyright (c) OpenMMLab. All rights reserved.
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/models/losses/mse_loss.py

import torch.nn as nn

from ..registry import LOSSES


[docs]@LOSSES.register_module() class JointsMSELoss(nn.Module): """MSE loss for heatmaps. Args: use_target_weight (bool): Option to use weighted MSE loss. Different joint types may have different target weights. loss_weight (float): Weight of the loss. Default: 1.0. """
[docs] def __init__(self, use_target_weight=False, loss_weight=1.): super().__init__() self.criterion = nn.MSELoss() self.use_target_weight = use_target_weight self.loss_weight = loss_weight
[docs] def forward(self, output, target, target_weight): """Forward function.""" batch_size = output.size(0) num_joints = output.size(1) heatmaps_pred = output.reshape( (batch_size, num_joints, -1)).split(1, 1) heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1) loss = 0. for idx in range(num_joints): heatmap_pred = heatmaps_pred[idx].squeeze(1) heatmap_gt = heatmaps_gt[idx].squeeze(1) if self.use_target_weight: loss += self.criterion(heatmap_pred * target_weight[:, idx], heatmap_gt * target_weight[:, idx]) else: loss += self.criterion(heatmap_pred, heatmap_gt) return loss / num_joints * self.loss_weight