init
This commit is contained in:
21
finetune/mmseg/models/losses/__init__.py
Normal file
21
finetune/mmseg/models/losses/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .accuracy import Accuracy, accuracy
|
||||
from .boundary_loss import BoundaryLoss
|
||||
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
|
||||
cross_entropy, mask_cross_entropy)
|
||||
from .dice_loss import DiceLoss
|
||||
from .focal_loss import FocalLoss
|
||||
from .huasdorff_distance_loss import HuasdorffDisstanceLoss
|
||||
from .lovasz_loss import LovaszLoss
|
||||
from .ohem_cross_entropy_loss import OhemCrossEntropy
|
||||
from .silog_loss import SiLogLoss
|
||||
from .tversky_loss import TverskyLoss
|
||||
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
|
||||
|
||||
__all__ = [
|
||||
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
|
||||
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
|
||||
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss',
|
||||
'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss',
|
||||
'HuasdorffDisstanceLoss', 'SiLogLoss'
|
||||
]
|
||||
92
finetune/mmseg/models/losses/accuracy.py
Normal file
92
finetune/mmseg/models/losses/accuracy.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def accuracy(pred, target, topk=1, thresh=None, ignore_index=None):
|
||||
"""Calculate accuracy according to the prediction and target.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
|
||||
target (torch.Tensor): The target of each prediction, shape (N, , ...)
|
||||
ignore_index (int | None): The label index to be ignored. Default: None
|
||||
topk (int | tuple[int], optional): If the predictions in ``topk``
|
||||
matches the target, the predictions will be regarded as
|
||||
correct ones. Defaults to 1.
|
||||
thresh (float, optional): If not None, predictions with scores under
|
||||
this threshold are considered incorrect. Default to None.
|
||||
|
||||
Returns:
|
||||
float | tuple[float]: If the input ``topk`` is a single integer,
|
||||
the function will return a single float as accuracy. If
|
||||
``topk`` is a tuple containing multiple integers, the
|
||||
function will return a tuple containing accuracies of
|
||||
each ``topk`` number.
|
||||
"""
|
||||
assert isinstance(topk, (int, tuple))
|
||||
if isinstance(topk, int):
|
||||
topk = (topk, )
|
||||
return_single = True
|
||||
else:
|
||||
return_single = False
|
||||
|
||||
maxk = max(topk)
|
||||
if pred.size(0) == 0:
|
||||
accu = [pred.new_tensor(0.) for i in range(len(topk))]
|
||||
return accu[0] if return_single else accu
|
||||
assert pred.ndim == target.ndim + 1
|
||||
assert pred.size(0) == target.size(0)
|
||||
assert maxk <= pred.size(1), \
|
||||
f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
|
||||
pred_value, pred_label = pred.topk(maxk, dim=1)
|
||||
# transpose to shape (maxk, N, ...)
|
||||
pred_label = pred_label.transpose(0, 1)
|
||||
correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
|
||||
if thresh is not None:
|
||||
# Only prediction values larger than thresh are counted as correct
|
||||
correct = correct & (pred_value > thresh).t()
|
||||
if ignore_index is not None:
|
||||
correct = correct[:, target != ignore_index]
|
||||
res = []
|
||||
eps = torch.finfo(torch.float32).eps
|
||||
for k in topk:
|
||||
# Avoid causing ZeroDivisionError when all pixels
|
||||
# of an image are ignored
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps
|
||||
if ignore_index is not None:
|
||||
total_num = target[target != ignore_index].numel() + eps
|
||||
else:
|
||||
total_num = target.numel() + eps
|
||||
res.append(correct_k.mul_(100.0 / total_num))
|
||||
return res[0] if return_single else res
|
||||
|
||||
|
||||
class Accuracy(nn.Module):
|
||||
"""Accuracy calculation module."""
|
||||
|
||||
def __init__(self, topk=(1, ), thresh=None, ignore_index=None):
|
||||
"""Module to calculate the accuracy.
|
||||
|
||||
Args:
|
||||
topk (tuple, optional): The criterion used to calculate the
|
||||
accuracy. Defaults to (1,).
|
||||
thresh (float, optional): If not None, predictions with scores
|
||||
under this threshold are considered incorrect. Default to None.
|
||||
"""
|
||||
super().__init__()
|
||||
self.topk = topk
|
||||
self.thresh = thresh
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def forward(self, pred, target):
|
||||
"""Forward function to calculate accuracy.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): Prediction of models.
|
||||
target (torch.Tensor): Target for each prediction.
|
||||
|
||||
Returns:
|
||||
tuple[float]: The accuracies under different topk criterions.
|
||||
"""
|
||||
return accuracy(pred, target, self.topk, self.thresh,
|
||||
self.ignore_index)
|
||||
62
finetune/mmseg/models/losses/boundary_loss.py
Normal file
62
finetune/mmseg/models/losses/boundary_loss.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BoundaryLoss(nn.Module):
|
||||
"""Boundary loss.
|
||||
|
||||
This function is modified from
|
||||
`PIDNet <https://github.com/XuJiacong/PIDNet/blob/main/utils/criterion.py#L122>`_. # noqa
|
||||
Licensed under the MIT License.
|
||||
|
||||
|
||||
Args:
|
||||
loss_weight (float): Weight of the loss. Defaults to 1.0.
|
||||
loss_name (str): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_boundary'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss_weight: float = 1.0,
|
||||
loss_name: str = 'loss_boundary'):
|
||||
super().__init__()
|
||||
self.loss_weight = loss_weight
|
||||
self.loss_name_ = loss_name
|
||||
|
||||
def forward(self, bd_pre: Tensor, bd_gt: Tensor) -> Tensor:
|
||||
"""Forward function.
|
||||
Args:
|
||||
bd_pre (Tensor): Predictions of the boundary head.
|
||||
bd_gt (Tensor): Ground truth of the boundary.
|
||||
|
||||
Returns:
|
||||
Tensor: Loss tensor.
|
||||
"""
|
||||
log_p = bd_pre.permute(0, 2, 3, 1).contiguous().view(1, -1)
|
||||
target_t = bd_gt.view(1, -1).float()
|
||||
|
||||
pos_index = (target_t == 1)
|
||||
neg_index = (target_t == 0)
|
||||
|
||||
weight = torch.zeros_like(log_p)
|
||||
pos_num = pos_index.sum()
|
||||
neg_num = neg_index.sum()
|
||||
sum_num = pos_num + neg_num
|
||||
weight[pos_index] = neg_num * 1.0 / sum_num
|
||||
weight[neg_index] = pos_num * 1.0 / sum_num
|
||||
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
log_p, target_t, weight, reduction='mean')
|
||||
|
||||
return self.loss_weight * loss
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
return self.loss_name_
|
||||
311
finetune/mmseg/models/losses/cross_entropy_loss.py
Normal file
311
finetune/mmseg/models/losses/cross_entropy_loss.py
Normal file
@@ -0,0 +1,311 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import get_class_weight, weight_reduce_loss
|
||||
|
||||
|
||||
def cross_entropy(pred,
|
||||
label,
|
||||
weight=None,
|
||||
class_weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
ignore_index=-100,
|
||||
avg_non_ignore=False):
|
||||
"""cross_entropy. The wrapper function for :func:`F.cross_entropy`
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, 1).
|
||||
label (torch.Tensor): The learning label of the prediction.
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||
Default: None.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
Default: None.
|
||||
reduction (str, optional): The method used to reduce the loss.
|
||||
Options are 'none', 'mean' and 'sum'. Default: 'mean'.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Default: None.
|
||||
ignore_index (int): Specifies a target value that is ignored and
|
||||
does not contribute to the input gradients. When
|
||||
``avg_non_ignore `` is ``True``, and the ``reduction`` is
|
||||
``''mean''``, the loss is averaged over non-ignored targets.
|
||||
Defaults: -100.
|
||||
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||
only averaged over non-ignored targets. Default: False.
|
||||
`New in version 0.23.0.`
|
||||
"""
|
||||
|
||||
# class_weight is a manual rescaling weight given to each class.
|
||||
# If given, has to be a Tensor of size C element-wise losses
|
||||
loss = F.cross_entropy(
|
||||
pred,
|
||||
label,
|
||||
weight=class_weight,
|
||||
reduction='none',
|
||||
ignore_index=ignore_index)
|
||||
|
||||
# apply weights and do the reduction
|
||||
# average loss over non-ignored elements
|
||||
# pytorch's official cross_entropy average loss over non-ignored elements
|
||||
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
|
||||
if (avg_factor is None) and reduction == 'mean':
|
||||
if class_weight is None:
|
||||
if avg_non_ignore:
|
||||
avg_factor = label.numel() - (label
|
||||
== ignore_index).sum().item()
|
||||
else:
|
||||
avg_factor = label.numel()
|
||||
|
||||
else:
|
||||
# the average factor should take the class weights into account
|
||||
label_weights = torch.stack([class_weight[cls] for cls in label
|
||||
]).to(device=class_weight.device)
|
||||
|
||||
if avg_non_ignore:
|
||||
label_weights[label == ignore_index] = 0
|
||||
avg_factor = label_weights.sum()
|
||||
|
||||
if weight is not None:
|
||||
weight = weight.float()
|
||||
loss = weight_reduce_loss(
|
||||
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
|
||||
"""Expand onehot labels to match the size of prediction."""
|
||||
bin_labels = labels.new_zeros(target_shape)
|
||||
valid_mask = (labels >= 0) & (labels != ignore_index)
|
||||
inds = torch.nonzero(valid_mask, as_tuple=True)
|
||||
|
||||
if inds[0].numel() > 0:
|
||||
if labels.dim() == 3:
|
||||
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
|
||||
else:
|
||||
bin_labels[inds[0], labels[valid_mask]] = 1
|
||||
|
||||
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
|
||||
|
||||
if label_weights is None:
|
||||
bin_label_weights = valid_mask
|
||||
else:
|
||||
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
|
||||
bin_label_weights = bin_label_weights * valid_mask
|
||||
|
||||
return bin_labels, bin_label_weights, valid_mask
|
||||
|
||||
|
||||
def binary_cross_entropy(pred,
|
||||
label,
|
||||
weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
class_weight=None,
|
||||
ignore_index=-100,
|
||||
avg_non_ignore=False,
|
||||
**kwargs):
|
||||
"""Calculate the binary CrossEntropy loss.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, 1).
|
||||
label (torch.Tensor): The learning label of the prediction.
|
||||
Note: In bce loss, label < 0 is invalid.
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||
reduction (str, optional): The method used to reduce the loss.
|
||||
Options are "none", "mean" and "sum".
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
ignore_index (int): The label index to be ignored. Default: -100.
|
||||
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||
only averaged over non-ignored targets. Default: False.
|
||||
`New in version 0.23.0.`
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
if pred.size(1) == 1:
|
||||
# For binary class segmentation, the shape of pred is
|
||||
# [N, 1, H, W] and that of label is [N, H, W].
|
||||
# As the ignore_index often set as 255, so the
|
||||
# binary class label check should mask out
|
||||
# ignore_index
|
||||
assert label[label != ignore_index].max() <= 1, \
|
||||
'For pred with shape [N, 1, H, W], its label must have at ' \
|
||||
'most 2 classes'
|
||||
pred = pred.squeeze(1)
|
||||
if pred.dim() != label.dim():
|
||||
assert (pred.dim() == 2 and label.dim() == 1) or (
|
||||
pred.dim() == 4 and label.dim() == 3), \
|
||||
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
|
||||
'H, W], label shape [N, H, W] are supported'
|
||||
# `weight` returned from `_expand_onehot_labels`
|
||||
# has been treated for valid (non-ignore) pixels
|
||||
label, weight, valid_mask = _expand_onehot_labels(
|
||||
label, weight, pred.shape, ignore_index)
|
||||
else:
|
||||
# should mask out the ignored elements
|
||||
valid_mask = ((label >= 0) & (label != ignore_index)).float()
|
||||
if weight is not None:
|
||||
weight = weight * valid_mask
|
||||
else:
|
||||
weight = valid_mask
|
||||
# average loss over non-ignored and valid elements
|
||||
if reduction == 'mean' and avg_factor is None and avg_non_ignore:
|
||||
avg_factor = valid_mask.sum().item()
|
||||
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
pred, label.float(), pos_weight=class_weight, reduction='none')
|
||||
# do the reduction for the weighted loss
|
||||
loss = weight_reduce_loss(
|
||||
loss, weight, reduction=reduction, avg_factor=avg_factor)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def mask_cross_entropy(pred,
|
||||
target,
|
||||
label,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
class_weight=None,
|
||||
ignore_index=None,
|
||||
**kwargs):
|
||||
"""Calculate the CrossEntropy loss for masks.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
||||
of classes.
|
||||
target (torch.Tensor): The learning label of the prediction.
|
||||
label (torch.Tensor): ``label`` indicates the class label of the mask'
|
||||
corresponding object. This will be used to select the mask in the
|
||||
of the class which the object belongs to when the mask prediction
|
||||
if not class-agnostic.
|
||||
reduction (str, optional): The method used to reduce the loss.
|
||||
Options are "none", "mean" and "sum".
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
ignore_index (None): Placeholder, to be consistent with other loss.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
assert ignore_index is None, 'BCE loss does not support ignore_index'
|
||||
# TODO: handle these two reserved arguments
|
||||
assert reduction == 'mean' and avg_factor is None
|
||||
num_rois = pred.size()[0]
|
||||
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
|
||||
pred_slice = pred[inds, label].squeeze(1)
|
||||
return F.binary_cross_entropy_with_logits(
|
||||
pred_slice, target, weight=class_weight, reduction='mean')[None]
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CrossEntropyLoss(nn.Module):
|
||||
"""CrossEntropyLoss.
|
||||
|
||||
Args:
|
||||
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
|
||||
of softmax. Defaults to False.
|
||||
use_mask (bool, optional): Whether to use mask cross entropy loss.
|
||||
Defaults to False.
|
||||
reduction (str, optional): . Defaults to 'mean'.
|
||||
Options are "none", "mean" and "sum".
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
||||
loss_name (str, optional): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_ce'.
|
||||
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||
only averaged over non-ignored targets. Default: False.
|
||||
`New in version 0.23.0.`
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
use_sigmoid=False,
|
||||
use_mask=False,
|
||||
reduction='mean',
|
||||
class_weight=None,
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_ce',
|
||||
avg_non_ignore=False):
|
||||
super().__init__()
|
||||
assert (use_sigmoid is False) or (use_mask is False)
|
||||
self.use_sigmoid = use_sigmoid
|
||||
self.use_mask = use_mask
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.class_weight = get_class_weight(class_weight)
|
||||
self.avg_non_ignore = avg_non_ignore
|
||||
if not self.avg_non_ignore and self.reduction == 'mean':
|
||||
warnings.warn(
|
||||
'Default ``avg_non_ignore`` is False, if you would like to '
|
||||
'ignore the certain label and average loss over non-ignore '
|
||||
'labels, which is the same with PyTorch official '
|
||||
'cross_entropy, set ``avg_non_ignore=True``.')
|
||||
|
||||
if self.use_sigmoid:
|
||||
self.cls_criterion = binary_cross_entropy
|
||||
elif self.use_mask:
|
||||
self.cls_criterion = mask_cross_entropy
|
||||
else:
|
||||
self.cls_criterion = cross_entropy
|
||||
self._loss_name = loss_name
|
||||
|
||||
def extra_repr(self):
|
||||
"""Extra repr."""
|
||||
s = f'avg_non_ignore={self.avg_non_ignore}'
|
||||
return s
|
||||
|
||||
def forward(self,
|
||||
cls_score,
|
||||
label,
|
||||
weight=None,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
ignore_index=-100,
|
||||
**kwargs):
|
||||
"""Forward function."""
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
if self.class_weight is not None:
|
||||
class_weight = cls_score.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
# Note: for BCE loss, label < 0 is invalid.
|
||||
loss_cls = self.loss_weight * self.cls_criterion(
|
||||
cls_score,
|
||||
label,
|
||||
weight,
|
||||
class_weight=class_weight,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor,
|
||||
avg_non_ignore=self.avg_non_ignore,
|
||||
ignore_index=ignore_index,
|
||||
**kwargs)
|
||||
return loss_cls
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
||||
202
finetune/mmseg/models/losses/dice_loss.py
Normal file
202
finetune/mmseg/models/losses/dice_loss.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import weight_reduce_loss
|
||||
|
||||
|
||||
def _expand_onehot_labels_dice(pred: torch.Tensor,
|
||||
target: torch.Tensor) -> torch.Tensor:
|
||||
"""Expand onehot labels to match the size of prediction.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction, has a shape (N, num_class, H, W).
|
||||
target (torch.Tensor): The learning label of the prediction,
|
||||
has a shape (N, H, W).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The target after one-hot encoding,
|
||||
has a shape (N, num_class, H, W).
|
||||
"""
|
||||
num_classes = pred.shape[1]
|
||||
one_hot_target = torch.clamp(target, min=0, max=num_classes)
|
||||
one_hot_target = torch.nn.functional.one_hot(one_hot_target,
|
||||
num_classes + 1)
|
||||
one_hot_target = one_hot_target[..., :num_classes].permute(0, 3, 1, 2)
|
||||
return one_hot_target
|
||||
|
||||
|
||||
def dice_loss(pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
weight: Union[torch.Tensor, None],
|
||||
eps: float = 1e-3,
|
||||
reduction: Union[str, None] = 'mean',
|
||||
naive_dice: Union[bool, None] = False,
|
||||
avg_factor: Union[int, None] = None,
|
||||
ignore_index: Union[int, None] = 255) -> float:
|
||||
"""Calculate dice loss, there are two forms of dice loss is supported:
|
||||
|
||||
- the one proposed in `V-Net: Fully Convolutional Neural
|
||||
Networks for Volumetric Medical Image Segmentation
|
||||
<https://arxiv.org/abs/1606.04797>`_.
|
||||
- the dice loss in which the power of the number in the
|
||||
denominator is the first power instead of the second
|
||||
power.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction, has a shape (n, *)
|
||||
target (torch.Tensor): The learning label of the prediction,
|
||||
shape (n, *), same shape of pred.
|
||||
weight (torch.Tensor, optional): The weight of loss for each
|
||||
prediction, has a shape (n,). Defaults to None.
|
||||
eps (float): Avoid dividing by zero. Default: 1e-3.
|
||||
reduction (str, optional): The method used to reduce the loss into
|
||||
a scalar. Defaults to 'mean'.
|
||||
Options are "none", "mean" and "sum".
|
||||
naive_dice (bool, optional): If false, use the dice
|
||||
loss defined in the V-Net paper, otherwise, use the
|
||||
naive dice loss in which the power of the number in the
|
||||
denominator is the first power instead of the second
|
||||
power.Defaults to False.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
ignore_index (int, optional): The label index to be ignored.
|
||||
Defaults to 255.
|
||||
"""
|
||||
if ignore_index is not None:
|
||||
num_classes = pred.shape[1]
|
||||
pred = pred[:, torch.arange(num_classes) != ignore_index, :, :]
|
||||
target = target[:, torch.arange(num_classes) != ignore_index, :, :]
|
||||
assert pred.shape[1] != 0 # if the ignored index is the only class
|
||||
input = pred.flatten(1)
|
||||
target = target.flatten(1).float()
|
||||
a = torch.sum(input * target, 1)
|
||||
if naive_dice:
|
||||
b = torch.sum(input, 1)
|
||||
c = torch.sum(target, 1)
|
||||
d = (2 * a + eps) / (b + c + eps)
|
||||
else:
|
||||
b = torch.sum(input * input, 1) + eps
|
||||
c = torch.sum(target * target, 1) + eps
|
||||
d = (2 * a) / (b + c)
|
||||
|
||||
loss = 1 - d
|
||||
if weight is not None:
|
||||
assert weight.ndim == loss.ndim
|
||||
assert len(weight) == len(pred)
|
||||
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
|
||||
return loss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DiceLoss(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
use_sigmoid=True,
|
||||
activate=True,
|
||||
reduction='mean',
|
||||
naive_dice=False,
|
||||
loss_weight=1.0,
|
||||
ignore_index=255,
|
||||
eps=1e-3,
|
||||
loss_name='loss_dice'):
|
||||
"""Compute dice loss.
|
||||
|
||||
Args:
|
||||
use_sigmoid (bool, optional): Whether to the prediction is
|
||||
used for sigmoid or softmax. Defaults to True.
|
||||
activate (bool): Whether to activate the predictions inside,
|
||||
this will disable the inside sigmoid operation.
|
||||
Defaults to True.
|
||||
reduction (str, optional): The method used
|
||||
to reduce the loss. Options are "none",
|
||||
"mean" and "sum". Defaults to 'mean'.
|
||||
naive_dice (bool, optional): If false, use the dice
|
||||
loss defined in the V-Net paper, otherwise, use the
|
||||
naive dice loss in which the power of the number in the
|
||||
denominator is the first power instead of the second
|
||||
power. Defaults to False.
|
||||
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
|
||||
ignore_index (int, optional): The label index to be ignored.
|
||||
Default: 255.
|
||||
eps (float): Avoid dividing by zero. Defaults to 1e-3.
|
||||
loss_name (str, optional): Name of the loss item. If you want this
|
||||
loss item to be included into the backward graph, `loss_` must
|
||||
be the prefix of the name. Defaults to 'loss_dice'.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.use_sigmoid = use_sigmoid
|
||||
self.reduction = reduction
|
||||
self.naive_dice = naive_dice
|
||||
self.loss_weight = loss_weight
|
||||
self.eps = eps
|
||||
self.activate = activate
|
||||
self.ignore_index = ignore_index
|
||||
self._loss_name = loss_name
|
||||
|
||||
def forward(self,
|
||||
pred,
|
||||
target,
|
||||
weight=None,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
ignore_index=255,
|
||||
**kwargs):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction, has a shape (n, *).
|
||||
target (torch.Tensor): The label of the prediction,
|
||||
shape (n, *), same shape of pred.
|
||||
weight (torch.Tensor, optional): The weight of loss for each
|
||||
prediction, has a shape (n,). Defaults to None.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
reduction_override (str, optional): The reduction method used to
|
||||
override the original reduction method of the loss.
|
||||
Options are "none", "mean" and "sum".
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
one_hot_target = target
|
||||
if (pred.shape != target.shape):
|
||||
one_hot_target = _expand_onehot_labels_dice(pred, target)
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
if self.activate:
|
||||
if self.use_sigmoid:
|
||||
pred = pred.sigmoid()
|
||||
elif pred.shape[1] != 1:
|
||||
# softmax does not work when there is only 1 class
|
||||
pred = pred.softmax(dim=1)
|
||||
loss = self.loss_weight * dice_loss(
|
||||
pred,
|
||||
one_hot_target,
|
||||
weight,
|
||||
eps=self.eps,
|
||||
reduction=reduction,
|
||||
naive_dice=self.naive_dice,
|
||||
avg_factor=avg_factor,
|
||||
ignore_index=self.ignore_index)
|
||||
|
||||
return loss
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
||||
337
finetune/mmseg/models/losses/focal_loss.py
Normal file
337
finetune/mmseg/models/losses/focal_loss.py
Normal file
@@ -0,0 +1,337 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from https://github.com/open-mmlab/mmdetection
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import weight_reduce_loss
|
||||
|
||||
|
||||
# This method is used when cuda is not available
|
||||
def py_sigmoid_focal_loss(pred,
|
||||
target,
|
||||
one_hot_target=None,
|
||||
weight=None,
|
||||
gamma=2.0,
|
||||
alpha=0.5,
|
||||
class_weight=None,
|
||||
valid_mask=None,
|
||||
reduction='mean',
|
||||
avg_factor=None):
|
||||
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, C), C is the
|
||||
number of classes
|
||||
target (torch.Tensor): The learning label of the prediction with
|
||||
shape (N, C)
|
||||
one_hot_target (None): Placeholder. It should be None.
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||
gamma (float, optional): The gamma for calculating the modulating
|
||||
factor. Defaults to 2.0.
|
||||
alpha (float | list[float], optional): A balanced form for Focal Loss.
|
||||
Defaults to 0.5.
|
||||
class_weight (list[float], optional): Weight of each class.
|
||||
Defaults to None.
|
||||
valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid
|
||||
samples and uses 0 to mark the ignored samples. Default: None.
|
||||
reduction (str, optional): The method used to reduce the loss into
|
||||
a scalar. Defaults to 'mean'.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
"""
|
||||
if isinstance(alpha, list):
|
||||
alpha = pred.new_tensor(alpha)
|
||||
pred_sigmoid = pred.sigmoid()
|
||||
target = target.type_as(pred)
|
||||
one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
|
||||
focal_weight = (alpha * target + (1 - alpha) *
|
||||
(1 - target)) * one_minus_pt.pow(gamma)
|
||||
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
pred, target, reduction='none') * focal_weight
|
||||
final_weight = torch.ones(1, pred.size(1)).type_as(loss)
|
||||
if weight is not None:
|
||||
if weight.shape != loss.shape and weight.size(0) == loss.size(0):
|
||||
# For most cases, weight is of shape (N, ),
|
||||
# which means it does not have the second axis num_class
|
||||
weight = weight.view(-1, 1)
|
||||
assert weight.dim() == loss.dim()
|
||||
final_weight = final_weight * weight
|
||||
if class_weight is not None:
|
||||
final_weight = final_weight * pred.new_tensor(class_weight)
|
||||
if valid_mask is not None:
|
||||
final_weight = final_weight * valid_mask
|
||||
loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor)
|
||||
return loss
|
||||
|
||||
|
||||
def sigmoid_focal_loss(pred,
|
||||
target,
|
||||
one_hot_target,
|
||||
weight=None,
|
||||
gamma=2.0,
|
||||
alpha=0.5,
|
||||
class_weight=None,
|
||||
valid_mask=None,
|
||||
reduction='mean',
|
||||
avg_factor=None):
|
||||
r"""A wrapper of cuda version `Focal Loss
|
||||
<https://arxiv.org/abs/1708.02002>`_.
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
||||
of classes.
|
||||
target (torch.Tensor): The learning label of the prediction. It's shape
|
||||
should be (N, )
|
||||
one_hot_target (torch.Tensor): The learning label with shape (N, C)
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||
gamma (float, optional): The gamma for calculating the modulating
|
||||
factor. Defaults to 2.0.
|
||||
alpha (float | list[float], optional): A balanced form for Focal Loss.
|
||||
Defaults to 0.5.
|
||||
class_weight (list[float], optional): Weight of each class.
|
||||
Defaults to None.
|
||||
valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid
|
||||
samples and uses 0 to mark the ignored samples. Default: None.
|
||||
reduction (str, optional): The method used to reduce the loss into
|
||||
a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
"""
|
||||
# Function.apply does not accept keyword arguments, so the decorator
|
||||
# "weighted_loss" is not applicable
|
||||
final_weight = torch.ones(1, pred.size(1)).type_as(pred)
|
||||
if isinstance(alpha, list):
|
||||
# _sigmoid_focal_loss doesn't accept alpha of list type. Therefore, if
|
||||
# a list is given, we set the input alpha as 0.5. This means setting
|
||||
# equal weight for foreground class and background class. By
|
||||
# multiplying the loss by 2, the effect of setting alpha as 0.5 is
|
||||
# undone. The alpha of type list is used to regulate the loss in the
|
||||
# post-processing process.
|
||||
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(),
|
||||
gamma, 0.5, None, 'none') * 2
|
||||
alpha = pred.new_tensor(alpha)
|
||||
final_weight = final_weight * (
|
||||
alpha * one_hot_target + (1 - alpha) * (1 - one_hot_target))
|
||||
else:
|
||||
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(),
|
||||
gamma, alpha, None, 'none')
|
||||
if weight is not None:
|
||||
if weight.shape != loss.shape and weight.size(0) == loss.size(0):
|
||||
# For most cases, weight is of shape (N, ),
|
||||
# which means it does not have the second axis num_class
|
||||
weight = weight.view(-1, 1)
|
||||
assert weight.dim() == loss.dim()
|
||||
final_weight = final_weight * weight
|
||||
if class_weight is not None:
|
||||
final_weight = final_weight * pred.new_tensor(class_weight)
|
||||
if valid_mask is not None:
|
||||
final_weight = final_weight * valid_mask
|
||||
loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor)
|
||||
return loss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FocalLoss(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
use_sigmoid=True,
|
||||
gamma=2.0,
|
||||
alpha=0.5,
|
||||
reduction='mean',
|
||||
class_weight=None,
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_focal'):
|
||||
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_
|
||||
Args:
|
||||
use_sigmoid (bool, optional): Whether to the prediction is
|
||||
used for sigmoid or softmax. Defaults to True.
|
||||
gamma (float, optional): The gamma for calculating the modulating
|
||||
factor. Defaults to 2.0.
|
||||
alpha (float | list[float], optional): A balanced form for Focal
|
||||
Loss. Defaults to 0.5. When a list is provided, the length
|
||||
of the list should be equal to the number of classes.
|
||||
Please be careful that this parameter is not the
|
||||
class-wise weight but the weight of a binary classification
|
||||
problem. This binary classification problem regards the
|
||||
pixels which belong to one class as the foreground
|
||||
and the other pixels as the background, each element in
|
||||
the list is the weight of the corresponding foreground class.
|
||||
The value of alpha or each element of alpha should be a float
|
||||
in the interval [0, 1]. If you want to specify the class-wise
|
||||
weight, please use `class_weight` parameter.
|
||||
reduction (str, optional): The method used to reduce the loss into
|
||||
a scalar. Defaults to 'mean'. Options are "none", "mean" and
|
||||
"sum".
|
||||
class_weight (list[float], optional): Weight of each class.
|
||||
Defaults to None.
|
||||
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
|
||||
loss_name (str, optional): Name of the loss item. If you want this
|
||||
loss item to be included into the backward graph, `loss_` must
|
||||
be the prefix of the name. Defaults to 'loss_focal'.
|
||||
"""
|
||||
super().__init__()
|
||||
assert use_sigmoid is True, \
|
||||
'AssertionError: Only sigmoid focal loss supported now.'
|
||||
assert reduction in ('none', 'mean', 'sum'), \
|
||||
"AssertionError: reduction should be 'none', 'mean' or " \
|
||||
"'sum'"
|
||||
assert isinstance(alpha, (float, list)), \
|
||||
'AssertionError: alpha should be of type float'
|
||||
assert isinstance(gamma, float), \
|
||||
'AssertionError: gamma should be of type float'
|
||||
assert isinstance(loss_weight, float), \
|
||||
'AssertionError: loss_weight should be of type float'
|
||||
assert isinstance(loss_name, str), \
|
||||
'AssertionError: loss_name should be of type str'
|
||||
assert isinstance(class_weight, list) or class_weight is None, \
|
||||
'AssertionError: class_weight must be None or of type list'
|
||||
self.use_sigmoid = use_sigmoid
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
self.reduction = reduction
|
||||
self.class_weight = class_weight
|
||||
self.loss_weight = loss_weight
|
||||
self._loss_name = loss_name
|
||||
|
||||
def forward(self,
|
||||
pred,
|
||||
target,
|
||||
weight=None,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
ignore_index=255,
|
||||
**kwargs):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape
|
||||
(N, C) where C = number of classes, or
|
||||
(N, C, d_1, d_2, ..., d_K) with K≥1 in the
|
||||
case of K-dimensional loss.
|
||||
target (torch.Tensor): The ground truth. If containing class
|
||||
indices, shape (N) where each value is 0≤targets[i]≤C−1,
|
||||
or (N, d_1, d_2, ..., d_K) with K≥1 in the case of
|
||||
K-dimensional loss. If containing class probabilities,
|
||||
same shape as the input.
|
||||
weight (torch.Tensor, optional): The weight of loss for each
|
||||
prediction. Defaults to None.
|
||||
avg_factor (int, optional): Average factor that is used to
|
||||
average the loss. Defaults to None.
|
||||
reduction_override (str, optional): The reduction method used
|
||||
to override the original reduction method of the loss.
|
||||
Options are "none", "mean" and "sum".
|
||||
ignore_index (int, optional): The label index to be ignored.
|
||||
Default: 255
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
assert isinstance(ignore_index, int), \
|
||||
'ignore_index must be of type int'
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum'), \
|
||||
"AssertionError: reduction should be 'none', 'mean' or " \
|
||||
"'sum'"
|
||||
assert pred.shape == target.shape or \
|
||||
(pred.size(0) == target.size(0) and
|
||||
pred.shape[2:] == target.shape[1:]), \
|
||||
"The shape of pred doesn't match the shape of target"
|
||||
|
||||
original_shape = pred.shape
|
||||
|
||||
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
|
||||
pred = pred.transpose(0, 1)
|
||||
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
|
||||
pred = pred.reshape(pred.size(0), -1)
|
||||
# [C, N] -> [N, C]
|
||||
pred = pred.transpose(0, 1).contiguous()
|
||||
|
||||
if original_shape == target.shape:
|
||||
# target with shape [B, C, d_1, d_2, ...]
|
||||
# transform it's shape into [N, C]
|
||||
# [B, C, d_1, d_2, ...] -> [C, B, d_1, d_2, ..., d_k]
|
||||
target = target.transpose(0, 1)
|
||||
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
|
||||
target = target.reshape(target.size(0), -1)
|
||||
# [C, N] -> [N, C]
|
||||
target = target.transpose(0, 1).contiguous()
|
||||
else:
|
||||
# target with shape [B, d_1, d_2, ...]
|
||||
# transform it's shape into [N, ]
|
||||
target = target.view(-1).contiguous()
|
||||
valid_mask = (target != ignore_index).view(-1, 1)
|
||||
# avoid raising error when using F.one_hot()
|
||||
target = torch.where(target == ignore_index, target.new_tensor(0),
|
||||
target)
|
||||
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
if self.use_sigmoid:
|
||||
num_classes = pred.size(1)
|
||||
if torch.cuda.is_available() and pred.is_cuda:
|
||||
if target.dim() == 1:
|
||||
one_hot_target = F.one_hot(
|
||||
target, num_classes=num_classes + 1)
|
||||
if num_classes == 1:
|
||||
one_hot_target = one_hot_target[:, 1]
|
||||
target = 1 - target
|
||||
else:
|
||||
one_hot_target = one_hot_target[:, :num_classes]
|
||||
else:
|
||||
one_hot_target = target
|
||||
target = target.argmax(dim=1)
|
||||
valid_mask = (target != ignore_index).view(-1, 1)
|
||||
calculate_loss_func = sigmoid_focal_loss
|
||||
else:
|
||||
one_hot_target = None
|
||||
if target.dim() == 1:
|
||||
target = F.one_hot(target, num_classes=num_classes + 1)
|
||||
if num_classes == 1:
|
||||
target = target[:, 1]
|
||||
else:
|
||||
target = target[:, num_classes]
|
||||
else:
|
||||
valid_mask = (target.argmax(dim=1) != ignore_index).view(
|
||||
-1, 1)
|
||||
calculate_loss_func = py_sigmoid_focal_loss
|
||||
|
||||
loss_cls = self.loss_weight * calculate_loss_func(
|
||||
pred,
|
||||
target,
|
||||
one_hot_target,
|
||||
weight,
|
||||
gamma=self.gamma,
|
||||
alpha=self.alpha,
|
||||
class_weight=self.class_weight,
|
||||
valid_mask=valid_mask,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor)
|
||||
|
||||
if reduction == 'none':
|
||||
# [N, C] -> [C, N]
|
||||
loss_cls = loss_cls.transpose(0, 1)
|
||||
# [C, N] -> [C, B, d1, d2, ...]
|
||||
# original_shape: [B, C, d1, d2, ...]
|
||||
loss_cls = loss_cls.reshape(original_shape[1],
|
||||
original_shape[0],
|
||||
*original_shape[2:])
|
||||
# [C, B, d1, d2, ...] -> [B, C, d1, d2, ...]
|
||||
loss_cls = loss_cls.transpose(0, 1).contiguous()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return loss_cls
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
||||
160
finetune/mmseg/models/losses/huasdorff_distance_loss.py
Normal file
160
finetune/mmseg/models/losses/huasdorff_distance_loss.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Modified from https://github.com/JunMa11/SegWithDistMap/blob/
|
||||
master/code/train_LA_HD.py (Apache-2.0 License)"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scipy.ndimage import distance_transform_edt as distance
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import get_class_weight, weighted_loss
|
||||
|
||||
|
||||
def compute_dtm(img_gt: Tensor, pred: Tensor) -> Tensor:
|
||||
"""
|
||||
compute the distance transform map of foreground in mask
|
||||
Args:
|
||||
img_gt: Ground truth of the image, (b, h, w)
|
||||
pred: Predictions of the segmentation head after softmax, (b, c, h, w)
|
||||
|
||||
Returns:
|
||||
output: the foreground Distance Map (SDM)
|
||||
dtm(x) = 0; x in segmentation boundary
|
||||
inf|x-y|; x in segmentation
|
||||
"""
|
||||
|
||||
fg_dtm = torch.zeros_like(pred)
|
||||
out_shape = pred.shape
|
||||
for b in range(out_shape[0]): # batch size
|
||||
for c in range(1, out_shape[1]): # default 0 channel is background
|
||||
posmask = img_gt[b].byte()
|
||||
if posmask.any():
|
||||
posdis = distance(posmask)
|
||||
fg_dtm[b][c] = torch.from_numpy(posdis)
|
||||
|
||||
return fg_dtm
|
||||
|
||||
|
||||
@weighted_loss
|
||||
def hd_loss(seg_soft: Tensor,
|
||||
gt: Tensor,
|
||||
seg_dtm: Tensor,
|
||||
gt_dtm: Tensor,
|
||||
class_weight=None,
|
||||
ignore_index=255) -> Tensor:
|
||||
"""
|
||||
compute huasdorff distance loss for segmentation
|
||||
Args:
|
||||
seg_soft: softmax results, shape=(b,c,x,y)
|
||||
gt: ground truth, shape=(b,x,y)
|
||||
seg_dtm: segmentation distance transform map, shape=(b,c,x,y)
|
||||
gt_dtm: ground truth distance transform map, shape=(b,c,x,y)
|
||||
|
||||
Returns:
|
||||
output: hd_loss
|
||||
"""
|
||||
assert seg_soft.shape[0] == gt.shape[0]
|
||||
total_loss = 0
|
||||
num_class = seg_soft.shape[1]
|
||||
if class_weight is not None:
|
||||
assert class_weight.ndim == num_class
|
||||
for i in range(1, num_class):
|
||||
if i != ignore_index:
|
||||
delta_s = (seg_soft[:, i, ...] - gt.float())**2
|
||||
s_dtm = seg_dtm[:, i, ...]**2
|
||||
g_dtm = gt_dtm[:, i, ...]**2
|
||||
dtm = s_dtm + g_dtm
|
||||
multiplied = torch.einsum('bxy, bxy->bxy', delta_s, dtm)
|
||||
hd_loss = multiplied.mean()
|
||||
if class_weight is not None:
|
||||
hd_loss *= class_weight[i]
|
||||
total_loss += hd_loss
|
||||
|
||||
return total_loss / num_class
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class HuasdorffDisstanceLoss(nn.Module):
|
||||
"""HuasdorffDisstanceLoss. This loss is proposed in `How Distance Transform
|
||||
Maps Boost Segmentation CNNs: An Empirical Study.
|
||||
|
||||
<http://proceedings.mlr.press/v121/ma20b.html>`_.
|
||||
Args:
|
||||
reduction (str, optional): The method used to reduce the loss into
|
||||
a scalar. Defaults to 'mean'.
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_weight (float): Weight of the loss. Defaults to 1.0.
|
||||
ignore_index (int | None): The label index to be ignored. Default: 255.
|
||||
loss_name (str): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_boundary'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reduction='mean',
|
||||
class_weight=None,
|
||||
loss_weight=1.0,
|
||||
ignore_index=255,
|
||||
loss_name='loss_huasdorff_disstance',
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.class_weight = get_class_weight(class_weight)
|
||||
self._loss_name = loss_name
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def forward(self,
|
||||
pred: Tensor,
|
||||
target: Tensor,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
**kwargs) -> Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
pred (Tensor): Predictions of the segmentation head. (B, C, H, W)
|
||||
target (Tensor): Ground truth of the image. (B, H, W)
|
||||
avg_factor (int, optional): Average factor that is used to
|
||||
average the loss. Defaults to None.
|
||||
reduction_override (str, optional): The reduction method used
|
||||
to override the original reduction method of the loss.
|
||||
Options are "none", "mean" and "sum".
|
||||
Returns:
|
||||
Tensor: Loss tensor.
|
||||
"""
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
if self.class_weight is not None:
|
||||
class_weight = pred.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
|
||||
pred_soft = F.softmax(pred, dim=1)
|
||||
valid_mask = (target != self.ignore_index).long()
|
||||
target = target * valid_mask
|
||||
|
||||
with torch.no_grad():
|
||||
gt_dtm = compute_dtm(target.cpu(), pred_soft)
|
||||
gt_dtm = gt_dtm.float()
|
||||
seg_dtm2 = compute_dtm(
|
||||
pred_soft.argmax(dim=1, keepdim=False).cpu(), pred_soft)
|
||||
seg_dtm2 = seg_dtm2.float()
|
||||
|
||||
loss_hd = self.loss_weight * hd_loss(
|
||||
pred_soft,
|
||||
target,
|
||||
seg_dtm=seg_dtm2,
|
||||
gt_dtm=gt_dtm,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor,
|
||||
class_weight=class_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
return loss_hd
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
return self._loss_name
|
||||
99
finetune/mmseg/models/losses/kldiv_loss.py
Normal file
99
finetune/mmseg/models/losses/kldiv_loss.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class KLDivLoss(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
temperature: float = 1.0,
|
||||
reduction: str = 'mean',
|
||||
loss_name: str = 'loss_kld'):
|
||||
"""Kullback-Leibler divergence Loss.
|
||||
|
||||
<https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Temperature param
|
||||
reduction (str, optional): The method to reduce the loss into a
|
||||
scalar. Default is "mean". Options are "none", "sum",
|
||||
and "mean"
|
||||
"""
|
||||
|
||||
assert isinstance(temperature, (float, int)), \
|
||||
'Expected temperature to be' \
|
||||
f'float or int, but got {temperature.__class__.__name__} instead'
|
||||
assert temperature != 0., 'Temperature must not be zero'
|
||||
|
||||
assert reduction in ['mean', 'none', 'sum'], \
|
||||
'Reduction must be one of the options ("mean", ' \
|
||||
f'"sum", "none"), but got {reduction}'
|
||||
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.reduction = reduction
|
||||
self._loss_name = loss_name
|
||||
|
||||
def forward(self, input: torch.Tensor, target: torch.Tensor):
|
||||
"""Forward function. Calculate KL divergence Loss.
|
||||
|
||||
Args:
|
||||
input (Tensor): Logit tensor,
|
||||
the data type is float32 or float64.
|
||||
The shape is (N, C) where N is batchsize and C is number of
|
||||
channels.
|
||||
If there more than 2 dimensions, shape is (N, C, D1, D2, ...
|
||||
Dk), k>= 1
|
||||
target (Tensor): Logit tensor,
|
||||
the data type is float32 or float64.
|
||||
input and target must be with the same shape.
|
||||
|
||||
Returns:
|
||||
(Tensor): Reduced loss.
|
||||
"""
|
||||
assert isinstance(input, torch.Tensor), 'Expected input to' \
|
||||
f'be Tensor, but got {input.__class__.__name__} instead'
|
||||
assert isinstance(target, torch.Tensor), 'Expected target to' \
|
||||
f'be Tensor, but got {target.__class__.__name__} instead'
|
||||
|
||||
assert input.shape == target.shape, 'Input and target ' \
|
||||
'must have same shape,' \
|
||||
f'but got shapes {input.shape} and {target.shape}'
|
||||
|
||||
input = F.softmax(input / self.temperature, dim=1)
|
||||
target = F.softmax(target / self.temperature, dim=1)
|
||||
|
||||
loss = F.kl_div(input, target, reduction='none', log_target=False)
|
||||
loss = loss * self.temperature**2
|
||||
|
||||
batch_size = input.shape[0]
|
||||
|
||||
if self.reduction == 'sum':
|
||||
# Change view to calculate instance-wise sum
|
||||
loss = loss.view(batch_size, -1)
|
||||
return torch.sum(loss, dim=1)
|
||||
|
||||
elif self.reduction == 'mean':
|
||||
# Change view to calculate instance-wise mean
|
||||
loss = loss.view(batch_size, -1)
|
||||
return torch.mean(loss, dim=1)
|
||||
|
||||
return loss
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
||||
323
finetune/mmseg/models/losses/lovasz_loss.py
Normal file
323
finetune/mmseg/models/losses/lovasz_loss.py
Normal file
@@ -0,0 +1,323 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor
|
||||
ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim
|
||||
Berman 2018 ESAT-PSI KU Leuven (MIT License)"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.utils import is_list_of
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import get_class_weight, weight_reduce_loss
|
||||
|
||||
|
||||
def lovasz_grad(gt_sorted):
|
||||
"""Computes gradient of the Lovasz extension w.r.t sorted errors.
|
||||
|
||||
See Alg. 1 in paper.
|
||||
"""
|
||||
p = len(gt_sorted)
|
||||
gts = gt_sorted.sum()
|
||||
intersection = gts - gt_sorted.float().cumsum(0)
|
||||
union = gts + (1 - gt_sorted).float().cumsum(0)
|
||||
jaccard = 1. - intersection / union
|
||||
if p > 1: # cover 1-pixel case
|
||||
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
|
||||
return jaccard
|
||||
|
||||
|
||||
def flatten_binary_logits(logits, labels, ignore_index=None):
|
||||
"""Flattens predictions in the batch (binary case) Remove labels equal to
|
||||
'ignore_index'."""
|
||||
logits = logits.view(-1)
|
||||
labels = labels.view(-1)
|
||||
if ignore_index is None:
|
||||
return logits, labels
|
||||
valid = (labels != ignore_index)
|
||||
vlogits = logits[valid]
|
||||
vlabels = labels[valid]
|
||||
return vlogits, vlabels
|
||||
|
||||
|
||||
def flatten_probs(probs, labels, ignore_index=None):
|
||||
"""Flattens predictions in the batch."""
|
||||
if probs.dim() == 3:
|
||||
# assumes output of a sigmoid layer
|
||||
B, H, W = probs.size()
|
||||
probs = probs.view(B, 1, H, W)
|
||||
B, C, H, W = probs.size()
|
||||
probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C
|
||||
labels = labels.view(-1)
|
||||
if ignore_index is None:
|
||||
return probs, labels
|
||||
valid = (labels != ignore_index)
|
||||
vprobs = probs[valid.nonzero().squeeze()]
|
||||
vlabels = labels[valid]
|
||||
return vprobs, vlabels
|
||||
|
||||
|
||||
def lovasz_hinge_flat(logits, labels):
|
||||
"""Binary Lovasz hinge loss.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): [P], logits at each prediction
|
||||
(between -infty and +infty).
|
||||
labels (torch.Tensor): [P], binary ground truth labels (0 or 1).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss.
|
||||
"""
|
||||
if len(labels) == 0:
|
||||
# only void pixels, the gradients should be 0
|
||||
return logits.sum() * 0.
|
||||
signs = 2. * labels.float() - 1.
|
||||
errors = (1. - logits * signs)
|
||||
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
|
||||
perm = perm.data
|
||||
gt_sorted = labels[perm]
|
||||
grad = lovasz_grad(gt_sorted)
|
||||
loss = torch.dot(F.relu(errors_sorted), grad)
|
||||
return loss
|
||||
|
||||
|
||||
def lovasz_hinge(logits,
|
||||
labels,
|
||||
classes='present',
|
||||
per_image=False,
|
||||
class_weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
ignore_index=255):
|
||||
"""Binary Lovasz hinge loss.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): [B, H, W], logits at each pixel
|
||||
(between -infty and +infty).
|
||||
labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1).
|
||||
classes (str | list[int], optional): Placeholder, to be consistent with
|
||||
other loss. Default: None.
|
||||
per_image (bool, optional): If per_image is True, compute the loss per
|
||||
image instead of per batch. Default: False.
|
||||
class_weight (list[float], optional): Placeholder, to be consistent
|
||||
with other loss. Default: None.
|
||||
reduction (str, optional): The method used to reduce the loss. Options
|
||||
are "none", "mean" and "sum". This parameter only works when
|
||||
per_image is True. Default: 'mean'.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. This parameter only works when per_image is True.
|
||||
Default: None.
|
||||
ignore_index (int | None): The label index to be ignored. Default: 255.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss.
|
||||
"""
|
||||
if per_image:
|
||||
loss = [
|
||||
lovasz_hinge_flat(*flatten_binary_logits(
|
||||
logit.unsqueeze(0), label.unsqueeze(0), ignore_index))
|
||||
for logit, label in zip(logits, labels)
|
||||
]
|
||||
loss = weight_reduce_loss(
|
||||
torch.stack(loss), None, reduction, avg_factor)
|
||||
else:
|
||||
loss = lovasz_hinge_flat(
|
||||
*flatten_binary_logits(logits, labels, ignore_index))
|
||||
return loss
|
||||
|
||||
|
||||
def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None):
|
||||
"""Multi-class Lovasz-Softmax loss.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): [P, C], class probabilities at each prediction
|
||||
(between 0 and 1).
|
||||
labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1).
|
||||
classes (str | list[int], optional): Classes chosen to calculate loss.
|
||||
'all' for all classes, 'present' for classes present in labels, or
|
||||
a list of classes to average. Default: 'present'.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss.
|
||||
"""
|
||||
if probs.numel() == 0:
|
||||
# only void pixels, the gradients should be 0
|
||||
return probs * 0.
|
||||
C = probs.size(1)
|
||||
losses = []
|
||||
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
|
||||
for c in class_to_sum:
|
||||
fg = (labels == c).float() # foreground for class c
|
||||
if (classes == 'present' and fg.sum() == 0):
|
||||
continue
|
||||
if C == 1:
|
||||
if len(classes) > 1:
|
||||
raise ValueError('Sigmoid output possible only with 1 class')
|
||||
class_pred = probs[:, 0]
|
||||
else:
|
||||
class_pred = probs[:, c]
|
||||
errors = (fg - class_pred).abs()
|
||||
errors_sorted, perm = torch.sort(errors, 0, descending=True)
|
||||
perm = perm.data
|
||||
fg_sorted = fg[perm]
|
||||
loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted))
|
||||
if class_weight is not None:
|
||||
loss *= class_weight[c]
|
||||
losses.append(loss)
|
||||
return torch.stack(losses).mean()
|
||||
|
||||
|
||||
def lovasz_softmax(probs,
|
||||
labels,
|
||||
classes='present',
|
||||
per_image=False,
|
||||
class_weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
ignore_index=255):
|
||||
"""Multi-class Lovasz-Softmax loss.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): [B, C, H, W], class probabilities at each
|
||||
prediction (between 0 and 1).
|
||||
labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and
|
||||
C - 1).
|
||||
classes (str | list[int], optional): Classes chosen to calculate loss.
|
||||
'all' for all classes, 'present' for classes present in labels, or
|
||||
a list of classes to average. Default: 'present'.
|
||||
per_image (bool, optional): If per_image is True, compute the loss per
|
||||
image instead of per batch. Default: False.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
Default: None.
|
||||
reduction (str, optional): The method used to reduce the loss. Options
|
||||
are "none", "mean" and "sum". This parameter only works when
|
||||
per_image is True. Default: 'mean'.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. This parameter only works when per_image is True.
|
||||
Default: None.
|
||||
ignore_index (int | None): The label index to be ignored. Default: 255.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss.
|
||||
"""
|
||||
|
||||
if per_image:
|
||||
loss = [
|
||||
lovasz_softmax_flat(
|
||||
*flatten_probs(
|
||||
prob.unsqueeze(0), label.unsqueeze(0), ignore_index),
|
||||
classes=classes,
|
||||
class_weight=class_weight)
|
||||
for prob, label in zip(probs, labels)
|
||||
]
|
||||
loss = weight_reduce_loss(
|
||||
torch.stack(loss), None, reduction, avg_factor)
|
||||
else:
|
||||
loss = lovasz_softmax_flat(
|
||||
*flatten_probs(probs, labels, ignore_index),
|
||||
classes=classes,
|
||||
class_weight=class_weight)
|
||||
return loss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class LovaszLoss(nn.Module):
|
||||
"""LovaszLoss.
|
||||
|
||||
This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate
|
||||
for the optimization of the intersection-over-union measure in neural
|
||||
networks <https://arxiv.org/abs/1705.08790>`_.
|
||||
|
||||
Args:
|
||||
loss_type (str, optional): Binary or multi-class loss.
|
||||
Default: 'multi_class'. Options are "binary" and "multi_class".
|
||||
classes (str | list[int], optional): Classes chosen to calculate loss.
|
||||
'all' for all classes, 'present' for classes present in labels, or
|
||||
a list of classes to average. Default: 'present'.
|
||||
per_image (bool, optional): If per_image is True, compute the loss per
|
||||
image instead of per batch. Default: False.
|
||||
reduction (str, optional): The method used to reduce the loss. Options
|
||||
are "none", "mean" and "sum". This parameter only works when
|
||||
per_image is True. Default: 'mean'.
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
||||
loss_name (str, optional): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_lovasz'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss_type='multi_class',
|
||||
classes='present',
|
||||
per_image=False,
|
||||
reduction='mean',
|
||||
class_weight=None,
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz'):
|
||||
super().__init__()
|
||||
assert loss_type in ('binary', 'multi_class'), "loss_type should be \
|
||||
'binary' or 'multi_class'."
|
||||
|
||||
if loss_type == 'binary':
|
||||
self.cls_criterion = lovasz_hinge
|
||||
else:
|
||||
self.cls_criterion = lovasz_softmax
|
||||
assert classes in ('all', 'present') or is_list_of(classes, int)
|
||||
if not per_image:
|
||||
assert reduction == 'none', "reduction should be 'none' when \
|
||||
per_image is False."
|
||||
|
||||
self.classes = classes
|
||||
self.per_image = per_image
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.class_weight = get_class_weight(class_weight)
|
||||
self._loss_name = loss_name
|
||||
|
||||
def forward(self,
|
||||
cls_score,
|
||||
label,
|
||||
weight=None,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
**kwargs):
|
||||
"""Forward function."""
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
if self.class_weight is not None:
|
||||
class_weight = cls_score.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
|
||||
# if multi-class loss, transform logits to probs
|
||||
if self.cls_criterion == lovasz_softmax:
|
||||
cls_score = F.softmax(cls_score, dim=1)
|
||||
|
||||
loss_cls = self.loss_weight * self.cls_criterion(
|
||||
cls_score,
|
||||
label,
|
||||
self.classes,
|
||||
self.per_image,
|
||||
class_weight=class_weight,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor,
|
||||
**kwargs)
|
||||
return loss_cls
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
||||
94
finetune/mmseg/models/losses/ohem_cross_entropy_loss.py
Normal file
94
finetune/mmseg/models/losses/ohem_cross_entropy_loss.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class OhemCrossEntropy(nn.Module):
|
||||
"""OhemCrossEntropy loss.
|
||||
|
||||
This func is modified from
|
||||
`PIDNet <https://github.com/XuJiacong/PIDNet/blob/main/utils/criterion.py#L43>`_. # noqa
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Args:
|
||||
ignore_label (int): Labels to ignore when computing the loss.
|
||||
Default: 255
|
||||
thresh (float, optional): The threshold for hard example selection.
|
||||
Below which, are prediction with low confidence. If not
|
||||
specified, the hard examples will be pixels of top ``min_kept``
|
||||
loss. Default: 0.7.
|
||||
min_kept (int, optional): The minimum number of predictions to keep.
|
||||
Default: 100000.
|
||||
loss_weight (float): Weight of the loss. Defaults to 1.0.
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_name (str): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_boundary'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ignore_label: int = 255,
|
||||
thres: float = 0.7,
|
||||
min_kept: int = 100000,
|
||||
loss_weight: float = 1.0,
|
||||
class_weight: Optional[Union[List[float], str]] = None,
|
||||
loss_name: str = 'loss_ohem'):
|
||||
super().__init__()
|
||||
self.thresh = thres
|
||||
self.min_kept = max(1, min_kept)
|
||||
self.ignore_label = ignore_label
|
||||
self.loss_weight = loss_weight
|
||||
self.loss_name_ = loss_name
|
||||
self.class_weight = class_weight
|
||||
|
||||
def forward(self, score: Tensor, target: Tensor) -> Tensor:
|
||||
"""Forward function.
|
||||
Args:
|
||||
score (Tensor): Predictions of the segmentation head.
|
||||
target (Tensor): Ground truth of the image.
|
||||
|
||||
Returns:
|
||||
Tensor: Loss tensor.
|
||||
"""
|
||||
# score: (N, C, H, W)
|
||||
pred = F.softmax(score, dim=1)
|
||||
if self.class_weight is not None:
|
||||
class_weight = score.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
|
||||
pixel_losses = F.cross_entropy(
|
||||
score,
|
||||
target,
|
||||
weight=class_weight,
|
||||
ignore_index=self.ignore_label,
|
||||
reduction='none').contiguous().view(-1) # (N*H*W)
|
||||
mask = target.contiguous().view(-1) != self.ignore_label # (N*H*W)
|
||||
|
||||
tmp_target = target.clone() # (N, H, W)
|
||||
tmp_target[tmp_target == self.ignore_label] = 0
|
||||
# pred: (N, C, H, W) -> (N*H*W, C)
|
||||
pred = pred.gather(1, tmp_target.unsqueeze(1))
|
||||
# pred: (N*H*W, C) -> (N*H*W), ind: (N*H*W)
|
||||
pred, ind = pred.contiguous().view(-1, )[mask].contiguous().sort()
|
||||
if pred.numel() > 0:
|
||||
min_value = pred[min(self.min_kept, pred.numel() - 1)]
|
||||
else:
|
||||
return score.new_tensor(0.0)
|
||||
threshold = max(min_value, self.thresh)
|
||||
|
||||
pixel_losses = pixel_losses[mask][ind]
|
||||
pixel_losses = pixel_losses[pred < threshold]
|
||||
return self.loss_weight * pixel_losses.mean()
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
return self.loss_name_
|
||||
122
finetune/mmseg/models/losses/silog_loss.py
Normal file
122
finetune/mmseg/models/losses/silog_loss.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import weight_reduce_loss
|
||||
|
||||
|
||||
def silog_loss(pred: Tensor,
|
||||
target: Tensor,
|
||||
weight: Optional[Tensor] = None,
|
||||
eps: float = 1e-4,
|
||||
reduction: Union[str, None] = 'mean',
|
||||
avg_factor: Optional[int] = None) -> Tensor:
|
||||
"""Computes the Scale-Invariant Logarithmic (SI-Log) loss between
|
||||
prediction and target.
|
||||
|
||||
Args:
|
||||
pred (Tensor): Predicted output.
|
||||
target (Tensor): Ground truth.
|
||||
weight (Optional[Tensor]): Optional weight to apply on the loss.
|
||||
eps (float): Epsilon value to avoid division and log(0).
|
||||
reduction (Union[str, None]): Specifies the reduction to apply to the
|
||||
output: 'mean', 'sum' or None.
|
||||
avg_factor (Optional[int]): Optional average factor for the loss.
|
||||
|
||||
Returns:
|
||||
Tensor: The calculated SI-Log loss.
|
||||
"""
|
||||
pred, target = pred.flatten(1), target.flatten(1)
|
||||
valid_mask = (target > eps).detach().float()
|
||||
|
||||
diff_log = torch.log(target.clamp(min=eps)) - torch.log(
|
||||
pred.clamp(min=eps))
|
||||
|
||||
valid_mask = (target > eps).detach() & (~torch.isnan(diff_log))
|
||||
diff_log[~valid_mask] = 0.0
|
||||
valid_mask = valid_mask.float()
|
||||
|
||||
diff_log_sq_mean = (diff_log.pow(2) * valid_mask).sum(
|
||||
dim=1) / valid_mask.sum(dim=1).clamp(min=eps)
|
||||
diff_log_mean = (diff_log * valid_mask).sum(dim=1) / valid_mask.sum(
|
||||
dim=1).clamp(min=eps)
|
||||
|
||||
loss = torch.sqrt(diff_log_sq_mean - 0.5 * diff_log_mean.pow(2))
|
||||
|
||||
if weight is not None:
|
||||
weight = weight.float()
|
||||
|
||||
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
|
||||
return loss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SiLogLoss(nn.Module):
|
||||
"""Compute SiLog loss.
|
||||
|
||||
Args:
|
||||
reduction (str, optional): The method used
|
||||
to reduce the loss. Options are "none",
|
||||
"mean" and "sum". Defaults to 'mean'.
|
||||
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
|
||||
eps (float): Avoid dividing by zero. Defaults to 1e-3.
|
||||
loss_name (str, optional): Name of the loss item. If you want this
|
||||
loss item to be included into the backward graph, `loss_` must
|
||||
be the prefix of the name. Defaults to 'loss_silog'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reduction='mean',
|
||||
loss_weight=1.0,
|
||||
eps=1e-6,
|
||||
loss_name='loss_silog'):
|
||||
super().__init__()
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.eps = eps
|
||||
self._loss_name = loss_name
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred,
|
||||
target,
|
||||
weight=None,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
):
|
||||
|
||||
assert pred.shape == target.shape, 'the shapes of pred ' \
|
||||
f'({pred.shape}) and target ({target.shape}) are mismatch'
|
||||
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
|
||||
loss = self.loss_weight * silog_loss(
|
||||
pred,
|
||||
target,
|
||||
weight,
|
||||
eps=self.eps,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor,
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
||||
137
finetune/mmseg/models/losses/tversky_loss.py
Normal file
137
finetune/mmseg/models/losses/tversky_loss.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Modified from
|
||||
https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/dice_loss.py#L333
|
||||
(Apache-2.0 License)"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import LOSSES
|
||||
from .utils import get_class_weight, weighted_loss
|
||||
|
||||
|
||||
@weighted_loss
|
||||
def tversky_loss(pred,
|
||||
target,
|
||||
valid_mask,
|
||||
alpha=0.3,
|
||||
beta=0.7,
|
||||
smooth=1,
|
||||
class_weight=None,
|
||||
ignore_index=255):
|
||||
assert pred.shape[0] == target.shape[0]
|
||||
total_loss = 0
|
||||
num_classes = pred.shape[1]
|
||||
for i in range(num_classes):
|
||||
if i != ignore_index:
|
||||
tversky_loss = binary_tversky_loss(
|
||||
pred[:, i],
|
||||
target[..., i],
|
||||
valid_mask=valid_mask,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
smooth=smooth)
|
||||
if class_weight is not None:
|
||||
tversky_loss *= class_weight[i]
|
||||
total_loss += tversky_loss
|
||||
return total_loss / num_classes
|
||||
|
||||
|
||||
@weighted_loss
|
||||
def binary_tversky_loss(pred,
|
||||
target,
|
||||
valid_mask,
|
||||
alpha=0.3,
|
||||
beta=0.7,
|
||||
smooth=1):
|
||||
assert pred.shape[0] == target.shape[0]
|
||||
pred = pred.reshape(pred.shape[0], -1)
|
||||
target = target.reshape(target.shape[0], -1)
|
||||
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
|
||||
|
||||
TP = torch.sum(torch.mul(pred, target) * valid_mask, dim=1)
|
||||
FP = torch.sum(torch.mul(pred, 1 - target) * valid_mask, dim=1)
|
||||
FN = torch.sum(torch.mul(1 - pred, target) * valid_mask, dim=1)
|
||||
tversky = (TP + smooth) / (TP + alpha * FP + beta * FN + smooth)
|
||||
|
||||
return 1 - tversky
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class TverskyLoss(nn.Module):
|
||||
"""TverskyLoss. This loss is proposed in `Tversky loss function for image
|
||||
segmentation using 3D fully convolutional deep networks.
|
||||
|
||||
<https://arxiv.org/abs/1706.05721>`_.
|
||||
Args:
|
||||
smooth (float): A float number to smooth loss, and avoid NaN error.
|
||||
Default: 1.
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_weight (float, optional): Weight of the loss. Default to 1.0.
|
||||
ignore_index (int | None): The label index to be ignored. Default: 255.
|
||||
alpha(float, in [0, 1]):
|
||||
The coefficient of false positives. Default: 0.3.
|
||||
beta (float, in [0, 1]):
|
||||
The coefficient of false negatives. Default: 0.7.
|
||||
Note: alpha + beta = 1.
|
||||
loss_name (str, optional): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_tversky'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
smooth=1,
|
||||
class_weight=None,
|
||||
loss_weight=1.0,
|
||||
ignore_index=255,
|
||||
alpha=0.3,
|
||||
beta=0.7,
|
||||
loss_name='loss_tversky'):
|
||||
super().__init__()
|
||||
self.smooth = smooth
|
||||
self.class_weight = get_class_weight(class_weight)
|
||||
self.loss_weight = loss_weight
|
||||
self.ignore_index = ignore_index
|
||||
assert (alpha + beta == 1.0), 'Sum of alpha and beta but be 1.0!'
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self._loss_name = loss_name
|
||||
|
||||
def forward(self, pred, target, **kwargs):
|
||||
if self.class_weight is not None:
|
||||
class_weight = pred.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
|
||||
pred = F.softmax(pred, dim=1)
|
||||
num_classes = pred.shape[1]
|
||||
one_hot_target = F.one_hot(
|
||||
torch.clamp(target.long(), 0, num_classes - 1),
|
||||
num_classes=num_classes)
|
||||
valid_mask = (target != self.ignore_index).long()
|
||||
|
||||
loss = self.loss_weight * tversky_loss(
|
||||
pred,
|
||||
one_hot_target,
|
||||
valid_mask=valid_mask,
|
||||
alpha=self.alpha,
|
||||
beta=self.beta,
|
||||
smooth=self.smooth,
|
||||
class_weight=class_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
return loss
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
||||
129
finetune/mmseg/models/losses/utils.py
Normal file
129
finetune/mmseg/models/losses/utils.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.fileio import load
|
||||
|
||||
|
||||
def get_class_weight(class_weight):
|
||||
"""Get class weight for loss function.
|
||||
|
||||
Args:
|
||||
class_weight (list[float] | str | None): If class_weight is a str,
|
||||
take it as a file name and read from it.
|
||||
"""
|
||||
if isinstance(class_weight, str):
|
||||
# take it as a file path
|
||||
if class_weight.endswith('.npy'):
|
||||
class_weight = np.load(class_weight)
|
||||
else:
|
||||
# pkl, json or yaml
|
||||
class_weight = load(class_weight)
|
||||
|
||||
return class_weight
|
||||
|
||||
|
||||
def reduce_loss(loss, reduction) -> torch.Tensor:
|
||||
"""Reduce loss as specified.
|
||||
|
||||
Args:
|
||||
loss (Tensor): Elementwise loss tensor.
|
||||
reduction (str): Options are "none", "mean" and "sum".
|
||||
|
||||
Return:
|
||||
Tensor: Reduced loss tensor.
|
||||
"""
|
||||
reduction_enum = F._Reduction.get_enum(reduction)
|
||||
# none: 0, elementwise_mean:1, sum: 2
|
||||
if reduction_enum == 0:
|
||||
return loss
|
||||
elif reduction_enum == 1:
|
||||
return loss.mean()
|
||||
elif reduction_enum == 2:
|
||||
return loss.sum()
|
||||
|
||||
|
||||
def weight_reduce_loss(loss,
|
||||
weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None) -> torch.Tensor:
|
||||
"""Apply element-wise weight and reduce loss.
|
||||
|
||||
Args:
|
||||
loss (Tensor): Element-wise loss.
|
||||
weight (Tensor): Element-wise weights.
|
||||
reduction (str): Same as built-in losses of PyTorch.
|
||||
avg_factor (float): Average factor when computing the mean of losses.
|
||||
|
||||
Returns:
|
||||
Tensor: Processed loss values.
|
||||
"""
|
||||
# if weight is specified, apply element-wise weight
|
||||
if weight is not None:
|
||||
assert weight.dim() == loss.dim()
|
||||
if weight.dim() > 1:
|
||||
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
|
||||
loss = loss * weight
|
||||
|
||||
# if avg_factor is not specified, just reduce the loss
|
||||
if avg_factor is None:
|
||||
loss = reduce_loss(loss, reduction)
|
||||
else:
|
||||
# if reduction is mean, then average the loss by avg_factor
|
||||
if reduction == 'mean':
|
||||
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
|
||||
# i.e., all labels of an image belong to ignore index.
|
||||
eps = torch.finfo(torch.float32).eps
|
||||
loss = loss.sum() / (avg_factor + eps)
|
||||
# if reduction is 'none', then do nothing, otherwise raise an error
|
||||
elif reduction != 'none':
|
||||
raise ValueError('avg_factor can not be used with reduction="sum"')
|
||||
return loss
|
||||
|
||||
|
||||
def weighted_loss(loss_func):
|
||||
"""Create a weighted version of a given loss function.
|
||||
|
||||
To use this decorator, the loss function must have the signature like
|
||||
`loss_func(pred, target, **kwargs)`. The function only needs to compute
|
||||
element-wise loss without any reduction. This decorator will add weight
|
||||
and reduction arguments to the function. The decorated function will have
|
||||
the signature like `loss_func(pred, target, weight=None, reduction='mean',
|
||||
avg_factor=None, **kwargs)`.
|
||||
|
||||
:Example:
|
||||
|
||||
>>> import torch
|
||||
>>> @weighted_loss
|
||||
>>> def l1_loss(pred, target):
|
||||
>>> return (pred - target).abs()
|
||||
|
||||
>>> pred = torch.Tensor([0, 2, 3])
|
||||
>>> target = torch.Tensor([1, 1, 1])
|
||||
>>> weight = torch.Tensor([1, 0, 1])
|
||||
|
||||
>>> l1_loss(pred, target)
|
||||
tensor(1.3333)
|
||||
>>> l1_loss(pred, target, weight)
|
||||
tensor(1.)
|
||||
>>> l1_loss(pred, target, reduction='none')
|
||||
tensor([1., 1., 2.])
|
||||
>>> l1_loss(pred, target, weight, avg_factor=2)
|
||||
tensor(1.5000)
|
||||
"""
|
||||
|
||||
@functools.wraps(loss_func)
|
||||
def wrapper(pred,
|
||||
target,
|
||||
weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
**kwargs):
|
||||
# get element-wise loss
|
||||
loss = loss_func(pred, target, **kwargs)
|
||||
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
|
||||
return loss
|
||||
|
||||
return wrapper
|
||||
Reference in New Issue
Block a user