This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

View File

@@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_assigner import BaseAssigner
from .hungarian_assigner import HungarianAssigner
from .match_cost import ClassificationCost, CrossEntropyLossCost, DiceCost
__all__ = [
'BaseAssigner',
'HungarianAssigner',
'ClassificationCost',
'CrossEntropyLossCost',
'DiceCost',
]

View File

@@ -0,0 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Optional
from mmengine.structures import InstanceData
class BaseAssigner(metaclass=ABCMeta):
"""Base assigner that assigns masks to ground truth class labels."""
@abstractmethod
def assign(self,
pred_instances: InstanceData,
gt_instances: InstanceData,
gt_instances_ignore: Optional[InstanceData] = None,
**kwargs):
"""Assign masks to either a ground truth class label or a negative
label."""

View File

@@ -0,0 +1,86 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union
import torch
from mmengine import ConfigDict
from mmengine.structures import InstanceData
from scipy.optimize import linear_sum_assignment
from torch.cuda.amp import autocast
from mmseg.registry import TASK_UTILS
from .base_assigner import BaseAssigner
@TASK_UTILS.register_module()
class HungarianAssigner(BaseAssigner):
"""Computes one-to-one matching between prediction masks and ground truth.
This class uses bipartite matching-based assignment to computes an
assignment between the prediction masks and the ground truth. The
assignment result is based on the weighted sum of match costs. The
Hungarian algorithm is used to calculate the best matching with the
minimum cost. The prediction masks that are not matched are classified
as background.
Args:
match_costs (ConfigDict|List[ConfigDict]): Match cost configs.
"""
def __init__(
self, match_costs: Union[List[Union[dict, ConfigDict]], dict,
ConfigDict]
) -> None:
if isinstance(match_costs, dict):
match_costs = [match_costs]
elif isinstance(match_costs, list):
assert len(match_costs) > 0, \
'match_costs must not be a empty list.'
self.match_costs = [
TASK_UTILS.build(match_cost) for match_cost in match_costs
]
def assign(self, pred_instances: InstanceData, gt_instances: InstanceData,
**kwargs):
"""Computes one-to-one matching based on the weighted costs.
This method assign each query prediction to a ground truth or
background. The assignment first calculates the cost for each
category assigned to each query mask, and then uses the
Hungarian algorithm to calculate the minimum cost as the best
match.
Args:
pred_instances (InstanceData): Instances of model
predictions. It includes "masks", with shape
(n, h, w) or (n, l), and "cls", with shape (n, num_classes+1)
gt_instances (InstanceData): Ground truth of instance
annotations. It includes "labels", with shape (k, ),
and "masks", with shape (k, h, w) or (k, l).
Returns:
matched_quiery_inds (Tensor): The indexes of matched quieres.
matched_label_inds (Tensor): The indexes of matched labels.
"""
# compute weighted cost
cost_list = []
with autocast(enabled=False):
for match_cost in self.match_costs:
cost = match_cost(
pred_instances=pred_instances, gt_instances=gt_instances)
cost_list.append(cost)
cost = torch.stack(cost_list).sum(dim=0)
device = cost.device
# do Hungarian matching on CPU using linear_sum_assignment
cost = cost.detach().cpu()
if linear_sum_assignment is None:
raise ImportError('Please run "pip install scipy" '
'to install scipy first.')
matched_quiery_inds, matched_label_inds = linear_sum_assignment(cost)
matched_quiery_inds = torch.from_numpy(matched_quiery_inds).to(device)
matched_label_inds = torch.from_numpy(matched_label_inds).to(device)
return matched_quiery_inds, matched_label_inds

View File

@@ -0,0 +1,231 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from typing import Union
import torch
import torch.nn.functional as F
from mmengine.structures import InstanceData
from torch import Tensor
from mmseg.registry import TASK_UTILS
class BaseMatchCost:
"""Base match cost class.
Args:
weight (Union[float, int]): Cost weight. Defaults to 1.
"""
def __init__(self, weight: Union[float, int] = 1.) -> None:
self.weight = weight
@abstractmethod
def __call__(self, pred_instances: InstanceData,
gt_instances: InstanceData, **kwargs) -> Tensor:
"""Compute match cost.
Args:
pred_instances (InstanceData): Instances of model predictions.
It often includes "labels" and "scores".
gt_instances (InstanceData): Ground truth of instance
annotations. It usually includes "labels".
Returns:
Tensor: Match Cost matrix of shape (num_preds, num_gts).
"""
pass
@TASK_UTILS.register_module()
class ClassificationCost(BaseMatchCost):
"""ClsSoftmaxCost.
Args:
weight (Union[float, int]): Cost weight. Defaults to 1.
Examples:
>>> from mmseg.models.assigners import ClassificationCost
>>> import torch
>>> self = ClassificationCost()
>>> cls_pred = torch.rand(4, 3)
>>> gt_labels = torch.tensor([0, 1, 2])
>>> factor = torch.tensor([10, 8, 10, 8])
>>> self(cls_pred, gt_labels)
tensor([[-0.3430, -0.3525, -0.3045],
[-0.3077, -0.2931, -0.3992],
[-0.3664, -0.3455, -0.2881],
[-0.3343, -0.2701, -0.3956]])
"""
def __init__(self, weight: Union[float, int] = 1) -> None:
super().__init__(weight=weight)
def __call__(self, pred_instances: InstanceData,
gt_instances: InstanceData, **kwargs) -> Tensor:
"""Compute match cost.
Args:
pred_instances (InstanceData): "scores" inside is
predicted classification logits, of shape
(num_queries, num_class).
gt_instances (InstanceData): "labels" inside should have
shape (num_gt, ).
Returns:
Tensor: Match Cost matrix of shape (num_preds, num_gts).
"""
assert hasattr(pred_instances, 'scores'), \
"pred_instances must contain 'scores'"
assert hasattr(gt_instances, 'labels'), \
"gt_instances must contain 'labels'"
pred_scores = pred_instances.scores
gt_labels = gt_instances.labels
pred_scores = pred_scores.softmax(-1)
cls_cost = -pred_scores[:, gt_labels]
return cls_cost * self.weight
@TASK_UTILS.register_module()
class DiceCost(BaseMatchCost):
"""Cost of mask assignments based on dice losses.
Args:
pred_act (bool): Whether to apply sigmoid to mask_pred.
Defaults to False.
eps (float): Defaults to 1e-3.
naive_dice (bool): If True, use the naive dice loss
in which the power of the number in the denominator is
the first power. If False, use the second power that
is adopted by K-Net and SOLO. Defaults to True.
weight (Union[float, int]): Cost weight. Defaults to 1.
"""
def __init__(self,
pred_act: bool = False,
eps: float = 1e-3,
naive_dice: bool = True,
weight: Union[float, int] = 1.) -> None:
super().__init__(weight=weight)
self.pred_act = pred_act
self.eps = eps
self.naive_dice = naive_dice
def _binary_mask_dice_loss(self, mask_preds: Tensor,
gt_masks: Tensor) -> Tensor:
"""
Args:
mask_preds (Tensor): Mask prediction in shape (num_queries, *).
gt_masks (Tensor): Ground truth in shape (num_gt, *)
store 0 or 1, 0 for negative class and 1 for
positive class.
Returns:
Tensor: Dice cost matrix in shape (num_queries, num_gt).
"""
mask_preds = mask_preds.flatten(1)
gt_masks = gt_masks.flatten(1).float()
numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks)
if self.naive_dice:
denominator = mask_preds.sum(-1)[:, None] + \
gt_masks.sum(-1)[None, :]
else:
denominator = mask_preds.pow(2).sum(1)[:, None] + \
gt_masks.pow(2).sum(1)[None, :]
loss = 1 - (numerator + self.eps) / (denominator + self.eps)
return loss
def __call__(self, pred_instances: InstanceData,
gt_instances: InstanceData, **kwargs) -> Tensor:
"""Compute match cost.
Args:
pred_instances (InstanceData): Predicted instances which
must contain "masks".
gt_instances (InstanceData): Ground truth which must contain
"mask".
Returns:
Tensor: Match Cost matrix of shape (num_preds, num_gts).
"""
assert hasattr(pred_instances, 'masks'), \
"pred_instances must contain 'masks'"
assert hasattr(gt_instances, 'masks'), \
"gt_instances must contain 'masks'"
pred_masks = pred_instances.masks
gt_masks = gt_instances.masks
if self.pred_act:
pred_masks = pred_masks.sigmoid()
dice_cost = self._binary_mask_dice_loss(pred_masks, gt_masks)
return dice_cost * self.weight
@TASK_UTILS.register_module()
class CrossEntropyLossCost(BaseMatchCost):
"""CrossEntropyLossCost.
Args:
use_sigmoid (bool): Whether the prediction uses sigmoid
of softmax. Defaults to True.
weight (Union[float, int]): Cost weight. Defaults to 1.
"""
def __init__(self,
use_sigmoid: bool = True,
weight: Union[float, int] = 1.) -> None:
super().__init__(weight=weight)
self.use_sigmoid = use_sigmoid
def _binary_cross_entropy(self, cls_pred: Tensor,
gt_labels: Tensor) -> Tensor:
"""
Args:
cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or
(num_queries, *).
gt_labels (Tensor): The learning label of prediction with
shape (num_gt, *).
Returns:
Tensor: Cross entropy cost matrix in shape (num_queries, num_gt).
"""
cls_pred = cls_pred.flatten(1).float()
gt_labels = gt_labels.flatten(1).float()
n = cls_pred.shape[1]
pos = F.binary_cross_entropy_with_logits(
cls_pred, torch.ones_like(cls_pred), reduction='none')
neg = F.binary_cross_entropy_with_logits(
cls_pred, torch.zeros_like(cls_pred), reduction='none')
cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \
torch.einsum('nc,mc->nm', neg, 1 - gt_labels)
cls_cost = cls_cost / n
return cls_cost
def __call__(self, pred_instances: InstanceData,
gt_instances: InstanceData, **kwargs) -> Tensor:
"""Compute match cost.
Args:
pred_instances (:obj:`InstanceData`): Predicted instances which
must contain ``masks``.
gt_instances (:obj:`InstanceData`): Ground truth which must contain
``masks``.
Returns:
Tensor: Match Cost matrix of shape (num_preds, num_gts).
"""
assert hasattr(pred_instances, 'masks'), \
"pred_instances must contain 'masks'"
assert hasattr(gt_instances, 'masks'), \
"gt_instances must contain 'masks'"
pred_masks = pred_instances.masks
gt_masks = gt_instances.masks
if self.use_sigmoid:
cls_cost = self._binary_cross_entropy(pred_masks, gt_masks)
else:
raise NotImplementedError
return cls_cost * self.weight