This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

View File

@@ -0,0 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .accuracy import Accuracy, accuracy
from .boundary_loss import BoundaryLoss
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
cross_entropy, mask_cross_entropy)
from .dice_loss import DiceLoss
from .focal_loss import FocalLoss
from .huasdorff_distance_loss import HuasdorffDisstanceLoss
from .lovasz_loss import LovaszLoss
from .ohem_cross_entropy_loss import OhemCrossEntropy
from .silog_loss import SiLogLoss
from .tversky_loss import TverskyLoss
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
__all__ = [
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss',
'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss',
'HuasdorffDisstanceLoss', 'SiLogLoss'
]

View File

@@ -0,0 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
def accuracy(pred, target, topk=1, thresh=None, ignore_index=None):
"""Calculate accuracy according to the prediction and target.
Args:
pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
target (torch.Tensor): The target of each prediction, shape (N, , ...)
ignore_index (int | None): The label index to be ignored. Default: None
topk (int | tuple[int], optional): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
thresh (float, optional): If not None, predictions with scores under
this threshold are considered incorrect. Default to None.
Returns:
float | tuple[float]: If the input ``topk`` is a single integer,
the function will return a single float as accuracy. If
``topk`` is a tuple containing multiple integers, the
function will return a tuple containing accuracies of
each ``topk`` number.
"""
assert isinstance(topk, (int, tuple))
if isinstance(topk, int):
topk = (topk, )
return_single = True
else:
return_single = False
maxk = max(topk)
if pred.size(0) == 0:
accu = [pred.new_tensor(0.) for i in range(len(topk))]
return accu[0] if return_single else accu
assert pred.ndim == target.ndim + 1
assert pred.size(0) == target.size(0)
assert maxk <= pred.size(1), \
f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
pred_value, pred_label = pred.topk(maxk, dim=1)
# transpose to shape (maxk, N, ...)
pred_label = pred_label.transpose(0, 1)
correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
if thresh is not None:
# Only prediction values larger than thresh are counted as correct
correct = correct & (pred_value > thresh).t()
if ignore_index is not None:
correct = correct[:, target != ignore_index]
res = []
eps = torch.finfo(torch.float32).eps
for k in topk:
# Avoid causing ZeroDivisionError when all pixels
# of an image are ignored
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps
if ignore_index is not None:
total_num = target[target != ignore_index].numel() + eps
else:
total_num = target.numel() + eps
res.append(correct_k.mul_(100.0 / total_num))
return res[0] if return_single else res
class Accuracy(nn.Module):
"""Accuracy calculation module."""
def __init__(self, topk=(1, ), thresh=None, ignore_index=None):
"""Module to calculate the accuracy.
Args:
topk (tuple, optional): The criterion used to calculate the
accuracy. Defaults to (1,).
thresh (float, optional): If not None, predictions with scores
under this threshold are considered incorrect. Default to None.
"""
super().__init__()
self.topk = topk
self.thresh = thresh
self.ignore_index = ignore_index
def forward(self, pred, target):
"""Forward function to calculate accuracy.
Args:
pred (torch.Tensor): Prediction of models.
target (torch.Tensor): Target for each prediction.
Returns:
tuple[float]: The accuracies under different topk criterions.
"""
return accuracy(pred, target, self.topk, self.thresh,
self.ignore_index)

View File

@@ -0,0 +1,62 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from mmseg.registry import MODELS
@MODELS.register_module()
class BoundaryLoss(nn.Module):
"""Boundary loss.
This function is modified from
`PIDNet <https://github.com/XuJiacong/PIDNet/blob/main/utils/criterion.py#L122>`_. # noqa
Licensed under the MIT License.
Args:
loss_weight (float): Weight of the loss. Defaults to 1.0.
loss_name (str): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_boundary'.
"""
def __init__(self,
loss_weight: float = 1.0,
loss_name: str = 'loss_boundary'):
super().__init__()
self.loss_weight = loss_weight
self.loss_name_ = loss_name
def forward(self, bd_pre: Tensor, bd_gt: Tensor) -> Tensor:
"""Forward function.
Args:
bd_pre (Tensor): Predictions of the boundary head.
bd_gt (Tensor): Ground truth of the boundary.
Returns:
Tensor: Loss tensor.
"""
log_p = bd_pre.permute(0, 2, 3, 1).contiguous().view(1, -1)
target_t = bd_gt.view(1, -1).float()
pos_index = (target_t == 1)
neg_index = (target_t == 0)
weight = torch.zeros_like(log_p)
pos_num = pos_index.sum()
neg_num = neg_index.sum()
sum_num = pos_num + neg_num
weight[pos_index] = neg_num * 1.0 / sum_num
weight[neg_index] = pos_num * 1.0 / sum_num
loss = F.binary_cross_entropy_with_logits(
log_p, target_t, weight, reduction='mean')
return self.loss_weight * loss
@property
def loss_name(self):
return self.loss_name_

View File

@@ -0,0 +1,311 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmseg.registry import MODELS
from .utils import get_class_weight, weight_reduce_loss
def cross_entropy(pred,
label,
weight=None,
class_weight=None,
reduction='mean',
avg_factor=None,
ignore_index=-100,
avg_non_ignore=False):
"""cross_entropy. The wrapper function for :func:`F.cross_entropy`
Args:
pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
Default: None.
class_weight (list[float], optional): The weight for each class.
Default: None.
reduction (str, optional): The method used to reduce the loss.
Options are 'none', 'mean' and 'sum'. Default: 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Default: None.
ignore_index (int): Specifies a target value that is ignored and
does not contribute to the input gradients. When
``avg_non_ignore `` is ``True``, and the ``reduction`` is
``''mean''``, the loss is averaged over non-ignored targets.
Defaults: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
"""
# class_weight is a manual rescaling weight given to each class.
# If given, has to be a Tensor of size C element-wise losses
loss = F.cross_entropy(
pred,
label,
weight=class_weight,
reduction='none',
ignore_index=ignore_index)
# apply weights and do the reduction
# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if (avg_factor is None) and reduction == 'mean':
if class_weight is None:
if avg_non_ignore:
avg_factor = label.numel() - (label
== ignore_index).sum().item()
else:
avg_factor = label.numel()
else:
# the average factor should take the class weights into account
label_weights = torch.stack([class_weight[cls] for cls in label
]).to(device=class_weight.device)
if avg_non_ignore:
label_weights[label == ignore_index] = 0
avg_factor = label_weights.sum()
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
"""Expand onehot labels to match the size of prediction."""
bin_labels = labels.new_zeros(target_shape)
valid_mask = (labels >= 0) & (labels != ignore_index)
inds = torch.nonzero(valid_mask, as_tuple=True)
if inds[0].numel() > 0:
if labels.dim() == 3:
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
else:
bin_labels[inds[0], labels[valid_mask]] = 1
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
if label_weights is None:
bin_label_weights = valid_mask
else:
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
bin_label_weights = bin_label_weights * valid_mask
return bin_labels, bin_label_weights, valid_mask
def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=-100,
avg_non_ignore=False,
**kwargs):
"""Calculate the binary CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction.
Note: In bce loss, label < 0 is invalid.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int): The label index to be ignored. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
Returns:
torch.Tensor: The calculated loss
"""
if pred.size(1) == 1:
# For binary class segmentation, the shape of pred is
# [N, 1, H, W] and that of label is [N, H, W].
# As the ignore_index often set as 255, so the
# binary class label check should mask out
# ignore_index
assert label[label != ignore_index].max() <= 1, \
'For pred with shape [N, 1, H, W], its label must have at ' \
'most 2 classes'
pred = pred.squeeze(1)
if pred.dim() != label.dim():
assert (pred.dim() == 2 and label.dim() == 1) or (
pred.dim() == 4 and label.dim() == 3), \
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
'H, W], label shape [N, H, W] are supported'
# `weight` returned from `_expand_onehot_labels`
# has been treated for valid (non-ignore) pixels
label, weight, valid_mask = _expand_onehot_labels(
label, weight, pred.shape, ignore_index)
else:
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
weight = weight * valid_mask
else:
weight = valid_mask
# average loss over non-ignored and valid elements
if reduction == 'mean' and avg_factor is None and avg_non_ignore:
avg_factor = valid_mask.sum().item()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
# do the reduction for the weighted loss
loss = weight_reduce_loss(
loss, weight, reduction=reduction, avg_factor=avg_factor)
return loss
def mask_cross_entropy(pred,
target,
label,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=None,
**kwargs):
"""Calculate the CrossEntropy loss for masks.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
target (torch.Tensor): The learning label of the prediction.
label (torch.Tensor): ``label`` indicates the class label of the mask'
corresponding object. This will be used to select the mask in the
of the class which the object belongs to when the mask prediction
if not class-agnostic.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (None): Placeholder, to be consistent with other loss.
Default: None.
Returns:
torch.Tensor: The calculated loss
"""
assert ignore_index is None, 'BCE loss does not support ignore_index'
# TODO: handle these two reserved arguments
assert reduction == 'mean' and avg_factor is None
num_rois = pred.size()[0]
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, weight=class_weight, reduction='mean')[None]
@MODELS.register_module()
class CrossEntropyLoss(nn.Module):
"""CrossEntropyLoss.
Args:
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to False.
use_mask (bool, optional): Whether to use mask cross entropy loss.
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
"""
def __init__(self,
use_sigmoid=False,
use_mask=False,
reduction='mean',
class_weight=None,
loss_weight=1.0,
loss_name='loss_ce',
avg_non_ignore=False):
super().__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = get_class_weight(class_weight)
self.avg_non_ignore = avg_non_ignore
if not self.avg_non_ignore and self.reduction == 'mean':
warnings.warn(
'Default ``avg_non_ignore`` is False, if you would like to '
'ignore the certain label and average loss over non-ignore '
'labels, which is the same with PyTorch official '
'cross_entropy, set ``avg_non_ignore=True``.')
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
elif self.use_mask:
self.cls_criterion = mask_cross_entropy
else:
self.cls_criterion = cross_entropy
self._loss_name = loss_name
def extra_repr(self):
"""Extra repr."""
s = f'avg_non_ignore={self.avg_non_ignore}'
return s
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
ignore_index=-100,
**kwargs):
"""Forward function."""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = cls_score.new_tensor(self.class_weight)
else:
class_weight = None
# Note: for BCE loss, label < 0 is invalid.
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
avg_non_ignore=self.avg_non_ignore,
ignore_index=ignore_index,
**kwargs)
return loss_cls
@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name

View File

@@ -0,0 +1,202 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import torch
import torch.nn as nn
from mmseg.registry import MODELS
from .utils import weight_reduce_loss
def _expand_onehot_labels_dice(pred: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
"""Expand onehot labels to match the size of prediction.
Args:
pred (torch.Tensor): The prediction, has a shape (N, num_class, H, W).
target (torch.Tensor): The learning label of the prediction,
has a shape (N, H, W).
Returns:
torch.Tensor: The target after one-hot encoding,
has a shape (N, num_class, H, W).
"""
num_classes = pred.shape[1]
one_hot_target = torch.clamp(target, min=0, max=num_classes)
one_hot_target = torch.nn.functional.one_hot(one_hot_target,
num_classes + 1)
one_hot_target = one_hot_target[..., :num_classes].permute(0, 3, 1, 2)
return one_hot_target
def dice_loss(pred: torch.Tensor,
target: torch.Tensor,
weight: Union[torch.Tensor, None],
eps: float = 1e-3,
reduction: Union[str, None] = 'mean',
naive_dice: Union[bool, None] = False,
avg_factor: Union[int, None] = None,
ignore_index: Union[int, None] = 255) -> float:
"""Calculate dice loss, there are two forms of dice loss is supported:
- the one proposed in `V-Net: Fully Convolutional Neural
Networks for Volumetric Medical Image Segmentation
<https://arxiv.org/abs/1606.04797>`_.
- the dice loss in which the power of the number in the
denominator is the first power instead of the second
power.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *)
target (torch.Tensor): The learning label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
eps (float): Avoid dividing by zero. Default: 1e-3.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
Options are "none", "mean" and "sum".
naive_dice (bool, optional): If false, use the dice
loss defined in the V-Net paper, otherwise, use the
naive dice loss in which the power of the number in the
denominator is the first power instead of the second
power.Defaults to False.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
ignore_index (int, optional): The label index to be ignored.
Defaults to 255.
"""
if ignore_index is not None:
num_classes = pred.shape[1]
pred = pred[:, torch.arange(num_classes) != ignore_index, :, :]
target = target[:, torch.arange(num_classes) != ignore_index, :, :]
assert pred.shape[1] != 0 # if the ignored index is the only class
input = pred.flatten(1)
target = target.flatten(1).float()
a = torch.sum(input * target, 1)
if naive_dice:
b = torch.sum(input, 1)
c = torch.sum(target, 1)
d = (2 * a + eps) / (b + c + eps)
else:
b = torch.sum(input * input, 1) + eps
c = torch.sum(target * target, 1) + eps
d = (2 * a) / (b + c)
loss = 1 - d
if weight is not None:
assert weight.ndim == loss.ndim
assert len(weight) == len(pred)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@MODELS.register_module()
class DiceLoss(nn.Module):
def __init__(self,
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=False,
loss_weight=1.0,
ignore_index=255,
eps=1e-3,
loss_name='loss_dice'):
"""Compute dice loss.
Args:
use_sigmoid (bool, optional): Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
activate (bool): Whether to activate the predictions inside,
this will disable the inside sigmoid operation.
Defaults to True.
reduction (str, optional): The method used
to reduce the loss. Options are "none",
"mean" and "sum". Defaults to 'mean'.
naive_dice (bool, optional): If false, use the dice
loss defined in the V-Net paper, otherwise, use the
naive dice loss in which the power of the number in the
denominator is the first power instead of the second
power. Defaults to False.
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
ignore_index (int, optional): The label index to be ignored.
Default: 255.
eps (float): Avoid dividing by zero. Defaults to 1e-3.
loss_name (str, optional): Name of the loss item. If you want this
loss item to be included into the backward graph, `loss_` must
be the prefix of the name. Defaults to 'loss_dice'.
"""
super().__init__()
self.use_sigmoid = use_sigmoid
self.reduction = reduction
self.naive_dice = naive_dice
self.loss_weight = loss_weight
self.eps = eps
self.activate = activate
self.ignore_index = ignore_index
self._loss_name = loss_name
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
ignore_index=255,
**kwargs):
"""Forward function.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *).
target (torch.Tensor): The label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
one_hot_target = target
if (pred.shape != target.shape):
one_hot_target = _expand_onehot_labels_dice(pred, target)
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.activate:
if self.use_sigmoid:
pred = pred.sigmoid()
elif pred.shape[1] != 1:
# softmax does not work when there is only 1 class
pred = pred.softmax(dim=1)
loss = self.loss_weight * dice_loss(
pred,
one_hot_target,
weight,
eps=self.eps,
reduction=reduction,
naive_dice=self.naive_dice,
avg_factor=avg_factor,
ignore_index=self.ignore_index)
return loss
@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name

View File

@@ -0,0 +1,337 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/open-mmlab/mmdetection
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
from mmseg.registry import MODELS
from .utils import weight_reduce_loss
# This method is used when cuda is not available
def py_sigmoid_focal_loss(pred,
target,
one_hot_target=None,
weight=None,
gamma=2.0,
alpha=0.5,
class_weight=None,
valid_mask=None,
reduction='mean',
avg_factor=None):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the
number of classes
target (torch.Tensor): The learning label of the prediction with
shape (N, C)
one_hot_target (None): Placeholder. It should be None.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float | list[float], optional): A balanced form for Focal Loss.
Defaults to 0.5.
class_weight (list[float], optional): Weight of each class.
Defaults to None.
valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid
samples and uses 0 to mark the ignored samples. Default: None.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
if isinstance(alpha, list):
alpha = pred.new_tensor(alpha)
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * one_minus_pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
final_weight = torch.ones(1, pred.size(1)).type_as(loss)
if weight is not None:
if weight.shape != loss.shape and weight.size(0) == loss.size(0):
# For most cases, weight is of shape (N, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
assert weight.dim() == loss.dim()
final_weight = final_weight * weight
if class_weight is not None:
final_weight = final_weight * pred.new_tensor(class_weight)
if valid_mask is not None:
final_weight = final_weight * valid_mask
loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor)
return loss
def sigmoid_focal_loss(pred,
target,
one_hot_target,
weight=None,
gamma=2.0,
alpha=0.5,
class_weight=None,
valid_mask=None,
reduction='mean',
avg_factor=None):
r"""A wrapper of cuda version `Focal Loss
<https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
target (torch.Tensor): The learning label of the prediction. It's shape
should be (N, )
one_hot_target (torch.Tensor): The learning label with shape (N, C)
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float | list[float], optional): A balanced form for Focal Loss.
Defaults to 0.5.
class_weight (list[float], optional): Weight of each class.
Defaults to None.
valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid
samples and uses 0 to mark the ignored samples. Default: None.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
# Function.apply does not accept keyword arguments, so the decorator
# "weighted_loss" is not applicable
final_weight = torch.ones(1, pred.size(1)).type_as(pred)
if isinstance(alpha, list):
# _sigmoid_focal_loss doesn't accept alpha of list type. Therefore, if
# a list is given, we set the input alpha as 0.5. This means setting
# equal weight for foreground class and background class. By
# multiplying the loss by 2, the effect of setting alpha as 0.5 is
# undone. The alpha of type list is used to regulate the loss in the
# post-processing process.
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(),
gamma, 0.5, None, 'none') * 2
alpha = pred.new_tensor(alpha)
final_weight = final_weight * (
alpha * one_hot_target + (1 - alpha) * (1 - one_hot_target))
else:
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(),
gamma, alpha, None, 'none')
if weight is not None:
if weight.shape != loss.shape and weight.size(0) == loss.size(0):
# For most cases, weight is of shape (N, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
assert weight.dim() == loss.dim()
final_weight = final_weight * weight
if class_weight is not None:
final_weight = final_weight * pred.new_tensor(class_weight)
if valid_mask is not None:
final_weight = final_weight * valid_mask
loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor)
return loss
@MODELS.register_module()
class FocalLoss(nn.Module):
def __init__(self,
use_sigmoid=True,
gamma=2.0,
alpha=0.5,
reduction='mean',
class_weight=None,
loss_weight=1.0,
loss_name='loss_focal'):
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_
Args:
use_sigmoid (bool, optional): Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float | list[float], optional): A balanced form for Focal
Loss. Defaults to 0.5. When a list is provided, the length
of the list should be equal to the number of classes.
Please be careful that this parameter is not the
class-wise weight but the weight of a binary classification
problem. This binary classification problem regards the
pixels which belong to one class as the foreground
and the other pixels as the background, each element in
the list is the weight of the corresponding foreground class.
The value of alpha or each element of alpha should be a float
in the interval [0, 1]. If you want to specify the class-wise
weight, please use `class_weight` parameter.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and
"sum".
class_weight (list[float], optional): Weight of each class.
Defaults to None.
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
loss_name (str, optional): Name of the loss item. If you want this
loss item to be included into the backward graph, `loss_` must
be the prefix of the name. Defaults to 'loss_focal'.
"""
super().__init__()
assert use_sigmoid is True, \
'AssertionError: Only sigmoid focal loss supported now.'
assert reduction in ('none', 'mean', 'sum'), \
"AssertionError: reduction should be 'none', 'mean' or " \
"'sum'"
assert isinstance(alpha, (float, list)), \
'AssertionError: alpha should be of type float'
assert isinstance(gamma, float), \
'AssertionError: gamma should be of type float'
assert isinstance(loss_weight, float), \
'AssertionError: loss_weight should be of type float'
assert isinstance(loss_name, str), \
'AssertionError: loss_name should be of type str'
assert isinstance(class_weight, list) or class_weight is None, \
'AssertionError: class_weight must be None or of type list'
self.use_sigmoid = use_sigmoid
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.class_weight = class_weight
self.loss_weight = loss_weight
self._loss_name = loss_name
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
ignore_index=255,
**kwargs):
"""Forward function.
Args:
pred (torch.Tensor): The prediction with shape
(N, C) where C = number of classes, or
(N, C, d_1, d_2, ..., d_K) with K≥1 in the
case of K-dimensional loss.
target (torch.Tensor): The ground truth. If containing class
indices, shape (N) where each value is 0≤targets[i]≤C1,
or (N, d_1, d_2, ..., d_K) with K≥1 in the case of
K-dimensional loss. If containing class probabilities,
same shape as the input.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to
average the loss. Defaults to None.
reduction_override (str, optional): The reduction method used
to override the original reduction method of the loss.
Options are "none", "mean" and "sum".
ignore_index (int, optional): The label index to be ignored.
Default: 255
Returns:
torch.Tensor: The calculated loss
"""
assert isinstance(ignore_index, int), \
'ignore_index must be of type int'
assert reduction_override in (None, 'none', 'mean', 'sum'), \
"AssertionError: reduction should be 'none', 'mean' or " \
"'sum'"
assert pred.shape == target.shape or \
(pred.size(0) == target.size(0) and
pred.shape[2:] == target.shape[1:]), \
"The shape of pred doesn't match the shape of target"
original_shape = pred.shape
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
pred = pred.transpose(0, 1)
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
pred = pred.reshape(pred.size(0), -1)
# [C, N] -> [N, C]
pred = pred.transpose(0, 1).contiguous()
if original_shape == target.shape:
# target with shape [B, C, d_1, d_2, ...]
# transform it's shape into [N, C]
# [B, C, d_1, d_2, ...] -> [C, B, d_1, d_2, ..., d_k]
target = target.transpose(0, 1)
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
target = target.reshape(target.size(0), -1)
# [C, N] -> [N, C]
target = target.transpose(0, 1).contiguous()
else:
# target with shape [B, d_1, d_2, ...]
# transform it's shape into [N, ]
target = target.view(-1).contiguous()
valid_mask = (target != ignore_index).view(-1, 1)
# avoid raising error when using F.one_hot()
target = torch.where(target == ignore_index, target.new_tensor(0),
target)
reduction = (
reduction_override if reduction_override else self.reduction)
if self.use_sigmoid:
num_classes = pred.size(1)
if torch.cuda.is_available() and pred.is_cuda:
if target.dim() == 1:
one_hot_target = F.one_hot(
target, num_classes=num_classes + 1)
if num_classes == 1:
one_hot_target = one_hot_target[:, 1]
target = 1 - target
else:
one_hot_target = one_hot_target[:, :num_classes]
else:
one_hot_target = target
target = target.argmax(dim=1)
valid_mask = (target != ignore_index).view(-1, 1)
calculate_loss_func = sigmoid_focal_loss
else:
one_hot_target = None
if target.dim() == 1:
target = F.one_hot(target, num_classes=num_classes + 1)
if num_classes == 1:
target = target[:, 1]
else:
target = target[:, num_classes]
else:
valid_mask = (target.argmax(dim=1) != ignore_index).view(
-1, 1)
calculate_loss_func = py_sigmoid_focal_loss
loss_cls = self.loss_weight * calculate_loss_func(
pred,
target,
one_hot_target,
weight,
gamma=self.gamma,
alpha=self.alpha,
class_weight=self.class_weight,
valid_mask=valid_mask,
reduction=reduction,
avg_factor=avg_factor)
if reduction == 'none':
# [N, C] -> [C, N]
loss_cls = loss_cls.transpose(0, 1)
# [C, N] -> [C, B, d1, d2, ...]
# original_shape: [B, C, d1, d2, ...]
loss_cls = loss_cls.reshape(original_shape[1],
original_shape[0],
*original_shape[2:])
# [C, B, d1, d2, ...] -> [B, C, d1, d2, ...]
loss_cls = loss_cls.transpose(0, 1).contiguous()
else:
raise NotImplementedError
return loss_cls
@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name

View File

@@ -0,0 +1,160 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/JunMa11/SegWithDistMap/blob/
master/code/train_LA_HD.py (Apache-2.0 License)"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.ndimage import distance_transform_edt as distance
from torch import Tensor
from mmseg.registry import MODELS
from .utils import get_class_weight, weighted_loss
def compute_dtm(img_gt: Tensor, pred: Tensor) -> Tensor:
"""
compute the distance transform map of foreground in mask
Args:
img_gt: Ground truth of the image, (b, h, w)
pred: Predictions of the segmentation head after softmax, (b, c, h, w)
Returns:
output: the foreground Distance Map (SDM)
dtm(x) = 0; x in segmentation boundary
inf|x-y|; x in segmentation
"""
fg_dtm = torch.zeros_like(pred)
out_shape = pred.shape
for b in range(out_shape[0]): # batch size
for c in range(1, out_shape[1]): # default 0 channel is background
posmask = img_gt[b].byte()
if posmask.any():
posdis = distance(posmask)
fg_dtm[b][c] = torch.from_numpy(posdis)
return fg_dtm
@weighted_loss
def hd_loss(seg_soft: Tensor,
gt: Tensor,
seg_dtm: Tensor,
gt_dtm: Tensor,
class_weight=None,
ignore_index=255) -> Tensor:
"""
compute huasdorff distance loss for segmentation
Args:
seg_soft: softmax results, shape=(b,c,x,y)
gt: ground truth, shape=(b,x,y)
seg_dtm: segmentation distance transform map, shape=(b,c,x,y)
gt_dtm: ground truth distance transform map, shape=(b,c,x,y)
Returns:
output: hd_loss
"""
assert seg_soft.shape[0] == gt.shape[0]
total_loss = 0
num_class = seg_soft.shape[1]
if class_weight is not None:
assert class_weight.ndim == num_class
for i in range(1, num_class):
if i != ignore_index:
delta_s = (seg_soft[:, i, ...] - gt.float())**2
s_dtm = seg_dtm[:, i, ...]**2
g_dtm = gt_dtm[:, i, ...]**2
dtm = s_dtm + g_dtm
multiplied = torch.einsum('bxy, bxy->bxy', delta_s, dtm)
hd_loss = multiplied.mean()
if class_weight is not None:
hd_loss *= class_weight[i]
total_loss += hd_loss
return total_loss / num_class
@MODELS.register_module()
class HuasdorffDisstanceLoss(nn.Module):
"""HuasdorffDisstanceLoss. This loss is proposed in `How Distance Transform
Maps Boost Segmentation CNNs: An Empirical Study.
<http://proceedings.mlr.press/v121/ma20b.html>`_.
Args:
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float): Weight of the loss. Defaults to 1.0.
ignore_index (int | None): The label index to be ignored. Default: 255.
loss_name (str): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_boundary'.
"""
def __init__(self,
reduction='mean',
class_weight=None,
loss_weight=1.0,
ignore_index=255,
loss_name='loss_huasdorff_disstance',
**kwargs):
super().__init__()
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = get_class_weight(class_weight)
self._loss_name = loss_name
self.ignore_index = ignore_index
def forward(self,
pred: Tensor,
target: Tensor,
avg_factor=None,
reduction_override=None,
**kwargs) -> Tensor:
"""Forward function.
Args:
pred (Tensor): Predictions of the segmentation head. (B, C, H, W)
target (Tensor): Ground truth of the image. (B, H, W)
avg_factor (int, optional): Average factor that is used to
average the loss. Defaults to None.
reduction_override (str, optional): The reduction method used
to override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
Tensor: Loss tensor.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = pred.new_tensor(self.class_weight)
else:
class_weight = None
pred_soft = F.softmax(pred, dim=1)
valid_mask = (target != self.ignore_index).long()
target = target * valid_mask
with torch.no_grad():
gt_dtm = compute_dtm(target.cpu(), pred_soft)
gt_dtm = gt_dtm.float()
seg_dtm2 = compute_dtm(
pred_soft.argmax(dim=1, keepdim=False).cpu(), pred_soft)
seg_dtm2 = seg_dtm2.float()
loss_hd = self.loss_weight * hd_loss(
pred_soft,
target,
seg_dtm=seg_dtm2,
gt_dtm=gt_dtm,
reduction=reduction,
avg_factor=avg_factor,
class_weight=class_weight,
ignore_index=self.ignore_index)
return loss_hd
@property
def loss_name(self):
return self._loss_name

View File

@@ -0,0 +1,99 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmseg.registry import MODELS
@MODELS.register_module()
class KLDivLoss(nn.Module):
def __init__(self,
temperature: float = 1.0,
reduction: str = 'mean',
loss_name: str = 'loss_kld'):
"""Kullback-Leibler divergence Loss.
<https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>
Args:
temperature (float, optional): Temperature param
reduction (str, optional): The method to reduce the loss into a
scalar. Default is "mean". Options are "none", "sum",
and "mean"
"""
assert isinstance(temperature, (float, int)), \
'Expected temperature to be' \
f'float or int, but got {temperature.__class__.__name__} instead'
assert temperature != 0., 'Temperature must not be zero'
assert reduction in ['mean', 'none', 'sum'], \
'Reduction must be one of the options ("mean", ' \
f'"sum", "none"), but got {reduction}'
super().__init__()
self.temperature = temperature
self.reduction = reduction
self._loss_name = loss_name
def forward(self, input: torch.Tensor, target: torch.Tensor):
"""Forward function. Calculate KL divergence Loss.
Args:
input (Tensor): Logit tensor,
the data type is float32 or float64.
The shape is (N, C) where N is batchsize and C is number of
channels.
If there more than 2 dimensions, shape is (N, C, D1, D2, ...
Dk), k>= 1
target (Tensor): Logit tensor,
the data type is float32 or float64.
input and target must be with the same shape.
Returns:
(Tensor): Reduced loss.
"""
assert isinstance(input, torch.Tensor), 'Expected input to' \
f'be Tensor, but got {input.__class__.__name__} instead'
assert isinstance(target, torch.Tensor), 'Expected target to' \
f'be Tensor, but got {target.__class__.__name__} instead'
assert input.shape == target.shape, 'Input and target ' \
'must have same shape,' \
f'but got shapes {input.shape} and {target.shape}'
input = F.softmax(input / self.temperature, dim=1)
target = F.softmax(target / self.temperature, dim=1)
loss = F.kl_div(input, target, reduction='none', log_target=False)
loss = loss * self.temperature**2
batch_size = input.shape[0]
if self.reduction == 'sum':
# Change view to calculate instance-wise sum
loss = loss.view(batch_size, -1)
return torch.sum(loss, dim=1)
elif self.reduction == 'mean':
# Change view to calculate instance-wise mean
loss = loss.view(batch_size, -1)
return torch.mean(loss, dim=1)
return loss
@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name

View File

@@ -0,0 +1,323 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor
ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim
Berman 2018 ESAT-PSI KU Leuven (MIT License)"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.utils import is_list_of
from mmseg.registry import MODELS
from .utils import get_class_weight, weight_reduce_loss
def lovasz_grad(gt_sorted):
"""Computes gradient of the Lovasz extension w.r.t sorted errors.
See Alg. 1 in paper.
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts + (1 - gt_sorted).float().cumsum(0)
jaccard = 1. - intersection / union
if p > 1: # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
def flatten_binary_logits(logits, labels, ignore_index=None):
"""Flattens predictions in the batch (binary case) Remove labels equal to
'ignore_index'."""
logits = logits.view(-1)
labels = labels.view(-1)
if ignore_index is None:
return logits, labels
valid = (labels != ignore_index)
vlogits = logits[valid]
vlabels = labels[valid]
return vlogits, vlabels
def flatten_probs(probs, labels, ignore_index=None):
"""Flattens predictions in the batch."""
if probs.dim() == 3:
# assumes output of a sigmoid layer
B, H, W = probs.size()
probs = probs.view(B, 1, H, W)
B, C, H, W = probs.size()
probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C
labels = labels.view(-1)
if ignore_index is None:
return probs, labels
valid = (labels != ignore_index)
vprobs = probs[valid.nonzero().squeeze()]
vlabels = labels[valid]
return vprobs, vlabels
def lovasz_hinge_flat(logits, labels):
"""Binary Lovasz hinge loss.
Args:
logits (torch.Tensor): [P], logits at each prediction
(between -infty and +infty).
labels (torch.Tensor): [P], binary ground truth labels (0 or 1).
Returns:
torch.Tensor: The calculated loss.
"""
if len(labels) == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.
signs = 2. * labels.float() - 1.
errors = (1. - logits * signs)
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
perm = perm.data
gt_sorted = labels[perm]
grad = lovasz_grad(gt_sorted)
loss = torch.dot(F.relu(errors_sorted), grad)
return loss
def lovasz_hinge(logits,
labels,
classes='present',
per_image=False,
class_weight=None,
reduction='mean',
avg_factor=None,
ignore_index=255):
"""Binary Lovasz hinge loss.
Args:
logits (torch.Tensor): [B, H, W], logits at each pixel
(between -infty and +infty).
labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1).
classes (str | list[int], optional): Placeholder, to be consistent with
other loss. Default: None.
per_image (bool, optional): If per_image is True, compute the loss per
image instead of per batch. Default: False.
class_weight (list[float], optional): Placeholder, to be consistent
with other loss. Default: None.
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. This parameter only works when per_image is True.
Default: None.
ignore_index (int | None): The label index to be ignored. Default: 255.
Returns:
torch.Tensor: The calculated loss.
"""
if per_image:
loss = [
lovasz_hinge_flat(*flatten_binary_logits(
logit.unsqueeze(0), label.unsqueeze(0), ignore_index))
for logit, label in zip(logits, labels)
]
loss = weight_reduce_loss(
torch.stack(loss), None, reduction, avg_factor)
else:
loss = lovasz_hinge_flat(
*flatten_binary_logits(logits, labels, ignore_index))
return loss
def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None):
"""Multi-class Lovasz-Softmax loss.
Args:
probs (torch.Tensor): [P, C], class probabilities at each prediction
(between 0 and 1).
labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1).
classes (str | list[int], optional): Classes chosen to calculate loss.
'all' for all classes, 'present' for classes present in labels, or
a list of classes to average. Default: 'present'.
class_weight (list[float], optional): The weight for each class.
Default: None.
Returns:
torch.Tensor: The calculated loss.
"""
if probs.numel() == 0:
# only void pixels, the gradients should be 0
return probs * 0.
C = probs.size(1)
losses = []
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
for c in class_to_sum:
fg = (labels == c).float() # foreground for class c
if (classes == 'present' and fg.sum() == 0):
continue
if C == 1:
if len(classes) > 1:
raise ValueError('Sigmoid output possible only with 1 class')
class_pred = probs[:, 0]
else:
class_pred = probs[:, c]
errors = (fg - class_pred).abs()
errors_sorted, perm = torch.sort(errors, 0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted))
if class_weight is not None:
loss *= class_weight[c]
losses.append(loss)
return torch.stack(losses).mean()
def lovasz_softmax(probs,
labels,
classes='present',
per_image=False,
class_weight=None,
reduction='mean',
avg_factor=None,
ignore_index=255):
"""Multi-class Lovasz-Softmax loss.
Args:
probs (torch.Tensor): [B, C, H, W], class probabilities at each
prediction (between 0 and 1).
labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and
C - 1).
classes (str | list[int], optional): Classes chosen to calculate loss.
'all' for all classes, 'present' for classes present in labels, or
a list of classes to average. Default: 'present'.
per_image (bool, optional): If per_image is True, compute the loss per
image instead of per batch. Default: False.
class_weight (list[float], optional): The weight for each class.
Default: None.
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. This parameter only works when per_image is True.
Default: None.
ignore_index (int | None): The label index to be ignored. Default: 255.
Returns:
torch.Tensor: The calculated loss.
"""
if per_image:
loss = [
lovasz_softmax_flat(
*flatten_probs(
prob.unsqueeze(0), label.unsqueeze(0), ignore_index),
classes=classes,
class_weight=class_weight)
for prob, label in zip(probs, labels)
]
loss = weight_reduce_loss(
torch.stack(loss), None, reduction, avg_factor)
else:
loss = lovasz_softmax_flat(
*flatten_probs(probs, labels, ignore_index),
classes=classes,
class_weight=class_weight)
return loss
@MODELS.register_module()
class LovaszLoss(nn.Module):
"""LovaszLoss.
This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate
for the optimization of the intersection-over-union measure in neural
networks <https://arxiv.org/abs/1705.08790>`_.
Args:
loss_type (str, optional): Binary or multi-class loss.
Default: 'multi_class'. Options are "binary" and "multi_class".
classes (str | list[int], optional): Classes chosen to calculate loss.
'all' for all classes, 'present' for classes present in labels, or
a list of classes to average. Default: 'present'.
per_image (bool, optional): If per_image is True, compute the loss per
image instead of per batch. Default: False.
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_lovasz'.
"""
def __init__(self,
loss_type='multi_class',
classes='present',
per_image=False,
reduction='mean',
class_weight=None,
loss_weight=1.0,
loss_name='loss_lovasz'):
super().__init__()
assert loss_type in ('binary', 'multi_class'), "loss_type should be \
'binary' or 'multi_class'."
if loss_type == 'binary':
self.cls_criterion = lovasz_hinge
else:
self.cls_criterion = lovasz_softmax
assert classes in ('all', 'present') or is_list_of(classes, int)
if not per_image:
assert reduction == 'none', "reduction should be 'none' when \
per_image is False."
self.classes = classes
self.per_image = per_image
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = get_class_weight(class_weight)
self._loss_name = loss_name
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
"""Forward function."""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = cls_score.new_tensor(self.class_weight)
else:
class_weight = None
# if multi-class loss, transform logits to probs
if self.cls_criterion == lovasz_softmax:
cls_score = F.softmax(cls_score, dim=1)
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
self.classes,
self.per_image,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_cls
@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name

View File

@@ -0,0 +1,94 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from mmseg.registry import MODELS
@MODELS.register_module()
class OhemCrossEntropy(nn.Module):
"""OhemCrossEntropy loss.
This func is modified from
`PIDNet <https://github.com/XuJiacong/PIDNet/blob/main/utils/criterion.py#L43>`_. # noqa
Licensed under the MIT License.
Args:
ignore_label (int): Labels to ignore when computing the loss.
Default: 255
thresh (float, optional): The threshold for hard example selection.
Below which, are prediction with low confidence. If not
specified, the hard examples will be pixels of top ``min_kept``
loss. Default: 0.7.
min_kept (int, optional): The minimum number of predictions to keep.
Default: 100000.
loss_weight (float): Weight of the loss. Defaults to 1.0.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_name (str): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_boundary'.
"""
def __init__(self,
ignore_label: int = 255,
thres: float = 0.7,
min_kept: int = 100000,
loss_weight: float = 1.0,
class_weight: Optional[Union[List[float], str]] = None,
loss_name: str = 'loss_ohem'):
super().__init__()
self.thresh = thres
self.min_kept = max(1, min_kept)
self.ignore_label = ignore_label
self.loss_weight = loss_weight
self.loss_name_ = loss_name
self.class_weight = class_weight
def forward(self, score: Tensor, target: Tensor) -> Tensor:
"""Forward function.
Args:
score (Tensor): Predictions of the segmentation head.
target (Tensor): Ground truth of the image.
Returns:
Tensor: Loss tensor.
"""
# score: (N, C, H, W)
pred = F.softmax(score, dim=1)
if self.class_weight is not None:
class_weight = score.new_tensor(self.class_weight)
else:
class_weight = None
pixel_losses = F.cross_entropy(
score,
target,
weight=class_weight,
ignore_index=self.ignore_label,
reduction='none').contiguous().view(-1) # (N*H*W)
mask = target.contiguous().view(-1) != self.ignore_label # (N*H*W)
tmp_target = target.clone() # (N, H, W)
tmp_target[tmp_target == self.ignore_label] = 0
# pred: (N, C, H, W) -> (N*H*W, C)
pred = pred.gather(1, tmp_target.unsqueeze(1))
# pred: (N*H*W, C) -> (N*H*W), ind: (N*H*W)
pred, ind = pred.contiguous().view(-1, )[mask].contiguous().sort()
if pred.numel() > 0:
min_value = pred[min(self.min_kept, pred.numel() - 1)]
else:
return score.new_tensor(0.0)
threshold = max(min_value, self.thresh)
pixel_losses = pixel_losses[mask][ind]
pixel_losses = pixel_losses[pred < threshold]
return self.loss_weight * pixel_losses.mean()
@property
def loss_name(self):
return self.loss_name_

View File

@@ -0,0 +1,122 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union
import torch
import torch.nn as nn
from torch import Tensor
from mmseg.registry import MODELS
from .utils import weight_reduce_loss
def silog_loss(pred: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
eps: float = 1e-4,
reduction: Union[str, None] = 'mean',
avg_factor: Optional[int] = None) -> Tensor:
"""Computes the Scale-Invariant Logarithmic (SI-Log) loss between
prediction and target.
Args:
pred (Tensor): Predicted output.
target (Tensor): Ground truth.
weight (Optional[Tensor]): Optional weight to apply on the loss.
eps (float): Epsilon value to avoid division and log(0).
reduction (Union[str, None]): Specifies the reduction to apply to the
output: 'mean', 'sum' or None.
avg_factor (Optional[int]): Optional average factor for the loss.
Returns:
Tensor: The calculated SI-Log loss.
"""
pred, target = pred.flatten(1), target.flatten(1)
valid_mask = (target > eps).detach().float()
diff_log = torch.log(target.clamp(min=eps)) - torch.log(
pred.clamp(min=eps))
valid_mask = (target > eps).detach() & (~torch.isnan(diff_log))
diff_log[~valid_mask] = 0.0
valid_mask = valid_mask.float()
diff_log_sq_mean = (diff_log.pow(2) * valid_mask).sum(
dim=1) / valid_mask.sum(dim=1).clamp(min=eps)
diff_log_mean = (diff_log * valid_mask).sum(dim=1) / valid_mask.sum(
dim=1).clamp(min=eps)
loss = torch.sqrt(diff_log_sq_mean - 0.5 * diff_log_mean.pow(2))
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@MODELS.register_module()
class SiLogLoss(nn.Module):
"""Compute SiLog loss.
Args:
reduction (str, optional): The method used
to reduce the loss. Options are "none",
"mean" and "sum". Defaults to 'mean'.
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
eps (float): Avoid dividing by zero. Defaults to 1e-3.
loss_name (str, optional): Name of the loss item. If you want this
loss item to be included into the backward graph, `loss_` must
be the prefix of the name. Defaults to 'loss_silog'.
"""
def __init__(self,
reduction='mean',
loss_weight=1.0,
eps=1e-6,
loss_name='loss_silog'):
super().__init__()
self.reduction = reduction
self.loss_weight = loss_weight
self.eps = eps
self._loss_name = loss_name
def forward(
self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
):
assert pred.shape == target.shape, 'the shapes of pred ' \
f'({pred.shape}) and target ({target.shape}) are mismatch'
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = self.loss_weight * silog_loss(
pred,
target,
weight,
eps=self.eps,
reduction=reduction,
avg_factor=avg_factor,
)
return loss
@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name

View File

@@ -0,0 +1,137 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from
https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/dice_loss.py#L333
(Apache-2.0 License)"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from .utils import get_class_weight, weighted_loss
@weighted_loss
def tversky_loss(pred,
target,
valid_mask,
alpha=0.3,
beta=0.7,
smooth=1,
class_weight=None,
ignore_index=255):
assert pred.shape[0] == target.shape[0]
total_loss = 0
num_classes = pred.shape[1]
for i in range(num_classes):
if i != ignore_index:
tversky_loss = binary_tversky_loss(
pred[:, i],
target[..., i],
valid_mask=valid_mask,
alpha=alpha,
beta=beta,
smooth=smooth)
if class_weight is not None:
tversky_loss *= class_weight[i]
total_loss += tversky_loss
return total_loss / num_classes
@weighted_loss
def binary_tversky_loss(pred,
target,
valid_mask,
alpha=0.3,
beta=0.7,
smooth=1):
assert pred.shape[0] == target.shape[0]
pred = pred.reshape(pred.shape[0], -1)
target = target.reshape(target.shape[0], -1)
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
TP = torch.sum(torch.mul(pred, target) * valid_mask, dim=1)
FP = torch.sum(torch.mul(pred, 1 - target) * valid_mask, dim=1)
FN = torch.sum(torch.mul(1 - pred, target) * valid_mask, dim=1)
tversky = (TP + smooth) / (TP + alpha * FP + beta * FN + smooth)
return 1 - tversky
@LOSSES.register_module()
class TverskyLoss(nn.Module):
"""TverskyLoss. This loss is proposed in `Tversky loss function for image
segmentation using 3D fully convolutional deep networks.
<https://arxiv.org/abs/1706.05721>`_.
Args:
smooth (float): A float number to smooth loss, and avoid NaN error.
Default: 1.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Default to 1.0.
ignore_index (int | None): The label index to be ignored. Default: 255.
alpha(float, in [0, 1]):
The coefficient of false positives. Default: 0.3.
beta (float, in [0, 1]):
The coefficient of false negatives. Default: 0.7.
Note: alpha + beta = 1.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_tversky'.
"""
def __init__(self,
smooth=1,
class_weight=None,
loss_weight=1.0,
ignore_index=255,
alpha=0.3,
beta=0.7,
loss_name='loss_tversky'):
super().__init__()
self.smooth = smooth
self.class_weight = get_class_weight(class_weight)
self.loss_weight = loss_weight
self.ignore_index = ignore_index
assert (alpha + beta == 1.0), 'Sum of alpha and beta but be 1.0!'
self.alpha = alpha
self.beta = beta
self._loss_name = loss_name
def forward(self, pred, target, **kwargs):
if self.class_weight is not None:
class_weight = pred.new_tensor(self.class_weight)
else:
class_weight = None
pred = F.softmax(pred, dim=1)
num_classes = pred.shape[1]
one_hot_target = F.one_hot(
torch.clamp(target.long(), 0, num_classes - 1),
num_classes=num_classes)
valid_mask = (target != self.ignore_index).long()
loss = self.loss_weight * tversky_loss(
pred,
one_hot_target,
valid_mask=valid_mask,
alpha=self.alpha,
beta=self.beta,
smooth=self.smooth,
class_weight=class_weight,
ignore_index=self.ignore_index)
return loss
@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name

View File

@@ -0,0 +1,129 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.fileio import load
def get_class_weight(class_weight):
"""Get class weight for loss function.
Args:
class_weight (list[float] | str | None): If class_weight is a str,
take it as a file name and read from it.
"""
if isinstance(class_weight, str):
# take it as a file path
if class_weight.endswith('.npy'):
class_weight = np.load(class_weight)
else:
# pkl, json or yaml
class_weight = load(class_weight)
return class_weight
def reduce_loss(loss, reduction) -> torch.Tensor:
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean" and "sum".
Return:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
def weight_reduce_loss(loss,
weight=None,
reduction='mean',
avg_factor=None) -> torch.Tensor:
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Tensor): Element-wise weights.
reduction (str): Same as built-in losses of PyTorch.
avg_factor (float): Average factor when computing the mean of losses.
Returns:
Tensor: Processed loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
assert weight.dim() == loss.dim()
if weight.dim() > 1:
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
loss = reduce_loss(loss, reduction)
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = torch.finfo(torch.float32).eps
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
def weighted_loss(loss_func):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
avg_factor=None, **kwargs)`.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, avg_factor=2)
tensor(1.5000)
"""
@functools.wraps(loss_func)
def wrapper(pred,
target,
weight=None,
reduction='mean',
avg_factor=None,
**kwargs):
# get element-wise loss
loss = loss_func(pred, target, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
return wrapper