Source code for easycv.models.loss.cross_entropy_loss

# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from easycv.framework.errors import ValueError
from easycv.models.builder import LOSSES
from easycv.models.loss.utils import weight_reduce_loss


def get_class_weight(class_weight):
    """Get class weight for loss function.

    Args:
        class_weight (list[float] | str | None): If class_weight is a str,
            take it as a file name and read from it.
    """
    if isinstance(class_weight, str):
        # take it as a file path
        if class_weight.endswith('.npy'):
            class_weight = np.load(class_weight)
        else:
            # pkl, json or yaml
            class_weight = mmcv.load(class_weight)

    return class_weight


def cross_entropy(pred,
                  label,
                  weight=None,
                  class_weight=None,
                  reduction='mean',
                  avg_factor=None,
                  ignore_index=-100,
                  avg_non_ignore=False):
    """cross_entropy. The wrapper function for :func:`F.cross_entropy`
    support sample-wise loss weight and the reduction average loss over non-ignored elements.

    Args:
        pred (torch.Tensor): The prediction with shape (N, 1).
        label (torch.Tensor): The learning label of the prediction.
        weight (torch.Tensor, optional): Sample-wise loss weight.
            Default: None.
        class_weight (list[float], optional): The weight for each class.
            Default: None.
        reduction (str, optional): The method used to reduce the loss.
            Options are 'none', 'mean' and 'sum'. Default: 'mean'.
        avg_factor (int, optional): Average factor that is used to average
            the loss. Default: None.
        ignore_index (int): Specifies a target value that is ignored and
            does not contribute to the input gradients. When
            ``avg_non_ignore `` is ``True``, and the ``reduction`` is
            ``''mean''``, the loss is averaged over non-ignored targets.
            Defaults: -100.
        avg_non_ignore (bool): The flag decides to whether the loss is
            only averaged over non-ignored targets. Default: False.
            `New in version 0.23.0.`
    """

    # class_weight is a manual rescaling weight given to each class.
    # If given, has to be a Tensor of size C element-wise losses
    loss = F.cross_entropy(
        pred,
        label,
        weight=class_weight,
        reduction='none',
        ignore_index=ignore_index)

    # apply weights and do the reduction
    # average loss over non-ignored elements
    # pytorch's official cross_entropy average loss over non-ignored elements
    # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660  # noqa
    if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
        avg_factor = label.numel() - (label == ignore_index).sum().item()
    if weight is not None:
        weight = weight.float()
    loss = weight_reduce_loss(
        loss, weight=weight, reduction=reduction, avg_factor=avg_factor)

    return loss


def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
    """Expand onehot labels to match the size of prediction."""
    bin_labels = labels.new_zeros(target_shape)
    valid_mask = (labels >= 0) & (labels != ignore_index)
    inds = torch.nonzero(valid_mask, as_tuple=True)

    if inds[0].numel() > 0:
        if labels.dim() == 3:
            bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
        else:
            bin_labels[inds[0], labels[valid_mask]] = 1

    valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()

    if label_weights is None:
        bin_label_weights = valid_mask
    else:
        bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
        bin_label_weights = bin_label_weights * valid_mask

    return bin_labels, bin_label_weights, valid_mask


def binary_cross_entropy(pred,
                         label,
                         weight=None,
                         reduction='mean',
                         avg_factor=None,
                         class_weight=None,
                         ignore_index=-100,
                         avg_non_ignore=False,
                         label_ceil=False,
                         **kwargs):
    """Calculate the binary CrossEntropy loss.

    Args:
        pred (torch.Tensor): The prediction with shape (N, 1).
        label (torch.Tensor): The learning label of the prediction.
            Note: In bce loss, label < 0 is invalid.
        weight (torch.Tensor, optional): Sample-wise loss weight.
        reduction (str, optional): The method used to reduce the loss.
            Options are "none", "mean" and "sum".
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
        class_weight (list[float], optional): The weight for each class.
        ignore_index (int): The label index to be ignored. Default: -100.
        avg_non_ignore (bool): The flag decides to whether the loss is
            only averaged over non-ignored targets. Default: False.
            `New in version 0.23.0.`
        label_ceil (bool): When use bce and set label_ceil=True,
            it will make elements belong to (0, 1] in label change to 1.
            Default: False.

    Returns:
        torch.Tensor: The calculated loss
    """
    if len(pred.shape) > 1 and pred.shape[1] == 1:
        # For binary class segmentation, the shape of pred is
        # [N, 1, H, W] and that of label is [N, H, W].
        # As the ignore_index often set as 255, so the
        # binary class label check should mask out
        # ignore_index
        assert label[label != ignore_index].max() <= 1, \
            'For pred with shape [N, 1, H, W], its label must have at ' \
            'most 2 classes'
        pred = pred.squeeze()
    if pred.dim() != label.dim():
        assert (pred.dim() == 2 and label.dim() == 1) or (
                pred.dim() == 4 and label.dim() == 3), \
            'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
            'H, W], label shape [N, H, W] are supported'
        # `weight` returned from `_expand_onehot_labels`
        # has been treated for valid (non-ignore) pixels
        label, weight, valid_mask = _expand_onehot_labels(
            label, weight, pred.shape, ignore_index)
    else:
        # should mask out the ignored elements
        valid_mask = ((label >= 0) & (label != ignore_index)).float()
        if weight is not None:
            weight = weight * valid_mask
        else:
            weight = valid_mask
    if label_ceil:
        label = label.gt(0.0).type(label.dtype)
    # average loss over non-ignored and valid elements
    if reduction == 'mean' and avg_factor is None and avg_non_ignore:
        avg_factor = valid_mask.sum().item()

    loss = F.binary_cross_entropy_with_logits(
        pred, label.float(), pos_weight=class_weight, reduction='none')
    # do the reduction for the weighted loss
    loss = weight_reduce_loss(
        loss, weight, reduction=reduction, avg_factor=avg_factor)

    return loss


def mask_cross_entropy(pred,
                       target,
                       label,
                       reduction='mean',
                       avg_factor=None,
                       class_weight=None,
                       ignore_index=None,
                       **kwargs):
    """Calculate the CrossEntropy loss for masks.

    Args:
        pred (torch.Tensor): The prediction with shape (N, C), C is the number
            of classes.
        target (torch.Tensor): The learning label of the prediction.
        label (torch.Tensor): ``label`` indicates the class label of the mask'
            corresponding object. This will be used to select the mask in the
            of the class which the object belongs to when the mask prediction
            if not class-agnostic.
        reduction (str, optional): The method used to reduce the loss.
            Options are "none", "mean" and "sum".
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
        class_weight (list[float], optional): The weight for each class.
        ignore_index (None): Placeholder, to be consistent with other loss.
            Default: None.

    Returns:
        torch.Tensor: The calculated loss
    """
    assert ignore_index is None, 'BCE loss does not support ignore_index'
    # TODO: handle these two reserved arguments
    assert reduction == 'mean' and avg_factor is None
    num_rois = pred.size()[0]
    inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
    pred_slice = pred[inds, label].squeeze(1)
    return F.binary_cross_entropy_with_logits(
        pred_slice, target, weight=class_weight, reduction='mean')[None]


[docs]@LOSSES.register_module() class CrossEntropyLoss(nn.Module): """CrossEntropyLoss. Args: use_sigmoid (bool, optional): Whether the prediction uses sigmoid of softmax. Defaults to False. use_mask (bool, optional): Whether to use mask cross entropy loss. Defaults to False. reduction (str, optional): . Defaults to 'mean'. Options are "none", "mean" and "sum". class_weight (list[float] | str, optional): Weight of each class. If in str format, read them from a file. Defaults to None. loss_weight (float, optional): Weight of the loss. Defaults to 1.0. loss_name (str, optional): Name of the loss item. If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name. Defaults to 'loss_ce'. avg_non_ignore (bool): The flag decides to whether the loss is only averaged over non-ignored targets. Default: False. `New in version 0.23.0.` label_ceil (bool): When use bce and set label_ceil=True, it will make elements belong to (0, 1] in label change to 1. Default: False. """
[docs] def __init__(self, use_sigmoid=False, use_mask=False, reduction='mean', class_weight=None, loss_weight=1.0, loss_name='loss_ce', avg_non_ignore=False, label_ceil=False): super(CrossEntropyLoss, self).__init__() assert (use_sigmoid is False) or (use_mask is False) self.use_sigmoid = use_sigmoid if label_ceil: if not use_sigmoid: raise ValueError( '‘label_ceil’ is supported only when ‘use_sigmoid’ is true. If not use bce, please set ‘label_ceil’=False' ) self.use_mask = use_mask self.reduction = reduction self.loss_weight = loss_weight self.class_weight = get_class_weight(class_weight) self.avg_non_ignore = avg_non_ignore if not self.avg_non_ignore and self.reduction == 'mean': warnings.warn( 'Default ``avg_non_ignore`` is False, if you would like to ' 'ignore the certain label and average loss over non-ignore ' 'labels, which is the same with PyTorch official ' 'cross_entropy, set ``avg_non_ignore=True``.') if self.use_sigmoid: self.cls_criterion = binary_cross_entropy elif self.use_mask: self.cls_criterion = mask_cross_entropy else: self.cls_criterion = cross_entropy self._loss_name = loss_name self.label_ceil = label_ceil
[docs] def extra_repr(self): """Extra repr.""" s = f'avg_non_ignore={self.avg_non_ignore}' return s
[docs] def forward(self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, ignore_index=-100, **kwargs): """Forward function.""" assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if self.class_weight is not None: class_weight = cls_score.new_tensor(self.class_weight) else: class_weight = None # Note: for BCE loss, label < 0 is invalid. if self.use_sigmoid: loss_cls = self.loss_weight * self.cls_criterion( cls_score, label, weight, class_weight=class_weight, reduction=reduction, avg_factor=avg_factor, avg_non_ignore=self.avg_non_ignore, ignore_index=ignore_index, label_ceil=self.label_ceil, **kwargs) else: loss_cls = self.loss_weight * self.cls_criterion( cls_score, label, weight, class_weight=class_weight, reduction=reduction, avg_factor=avg_factor, avg_non_ignore=self.avg_non_ignore, ignore_index=ignore_index, **kwargs) return loss_cls
@property def loss_name(self): """Loss Name. This function must be implemented and will return the name of this loss function. This name will be used to combine different loss items by simple sum operation. In addition, if you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name. Returns: str: The name of this loss item. """ return self._loss_name
@LOSSES.register_module() class AsymmetricLoss(nn.Module): def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True): super(AsymmetricLoss, self).__init__() self.gamma_neg = gamma_neg self.gamma_pos = gamma_pos self.clip = clip self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss self.eps = eps def forward(self, x, y): """" Parameters ---------- x: input logits y: targets (multi-label binarized vector) """ # Calculating Probabilities x_sigmoid = torch.sigmoid(x) xs_pos = x_sigmoid xs_neg = 1 - x_sigmoid # Asymmetric Clipping if self.clip is not None and self.clip > 0: xs_neg = (xs_neg + self.clip).clamp(max=1) # Basic CE calculation los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) loss = los_pos + los_neg # Asymmetric Focusing if self.gamma_neg > 0 or self.gamma_pos > 0: if self.disable_torch_grad_focal_loss: torch.set_grad_enabled(False) pt0 = xs_pos * y pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p pt = pt0 + pt1 one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) one_sided_w = torch.pow(1 - pt, one_sided_gamma) if self.disable_torch_grad_focal_loss: torch.set_grad_enabled(True) loss *= one_sided_w return -loss.sum()