init
This commit is contained in:
205
finetune/mmseg/utils/mask_classification.py
Normal file
205
finetune/mmseg/utils/mask_classification.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from mmcv.ops import point_sample
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import TASK_UTILS
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
|
||||
|
||||
def seg_data_to_instance_data(ignore_index: int,
|
||||
batch_data_samples: SampleList):
|
||||
"""Convert the paradigm of ground truth from semantic segmentation to
|
||||
instance segmentation.
|
||||
|
||||
Args:
|
||||
ignore_index (int): The label index to be ignored.
|
||||
batch_data_samples (List[SegDataSample]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: A tuple contains two lists.
|
||||
- batch_gt_instances (List[InstanceData]): Batch of
|
||||
gt_instance. It usually includes ``labels``, each is
|
||||
unique ground truth label id of images, with
|
||||
shape (num_gt, ) and ``masks``, each is ground truth
|
||||
masks of each instances of a image, shape (num_gt, h, w).
|
||||
- batch_img_metas (List[Dict]): List of image meta information.
|
||||
"""
|
||||
batch_gt_instances = []
|
||||
|
||||
for data_sample in batch_data_samples:
|
||||
gt_sem_seg = data_sample.gt_sem_seg.data
|
||||
classes = torch.unique(
|
||||
gt_sem_seg,
|
||||
sorted=False,
|
||||
return_inverse=False,
|
||||
return_counts=False)
|
||||
|
||||
# remove ignored region
|
||||
gt_labels = classes[classes != ignore_index]
|
||||
|
||||
masks = []
|
||||
for class_id in gt_labels:
|
||||
masks.append(gt_sem_seg == class_id)
|
||||
|
||||
if len(masks) == 0:
|
||||
gt_masks = torch.zeros(
|
||||
(0, gt_sem_seg.shape[-2],
|
||||
gt_sem_seg.shape[-1])).to(gt_sem_seg).long()
|
||||
else:
|
||||
gt_masks = torch.stack(masks).squeeze(1).long()
|
||||
|
||||
instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
|
||||
batch_gt_instances.append(instance_data)
|
||||
return batch_gt_instances
|
||||
|
||||
|
||||
class MatchMasks:
|
||||
"""Match the predictions to category labels.
|
||||
|
||||
Args:
|
||||
num_points (int): the number of sampled points to compute cost.
|
||||
num_queries (int): the number of prediction masks.
|
||||
num_classes (int): the number of classes.
|
||||
assigner (BaseAssigner): the assigner to compute matching.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_points: int,
|
||||
num_queries: int,
|
||||
num_classes: int,
|
||||
assigner: ConfigType = None):
|
||||
assert assigner is not None, "\'assigner\' in decode_head.train_cfg" \
|
||||
'cannot be None'
|
||||
assert num_points > 0, 'num_points should be a positive integer.'
|
||||
self.num_points = num_points
|
||||
self.num_queries = num_queries
|
||||
self.num_classes = num_classes
|
||||
self.assigner = TASK_UTILS.build(assigner)
|
||||
|
||||
def get_targets(self, cls_scores: List[Tensor], mask_preds: List[Tensor],
|
||||
batch_gt_instances: List[InstanceData]) -> Tuple:
|
||||
"""Compute best mask matches for all images for a decoder layer.
|
||||
|
||||
Args:
|
||||
cls_scores (List[Tensor]): Mask score logits from a single
|
||||
decoder layer for all images. Each with shape (num_queries,
|
||||
cls_out_channels).
|
||||
mask_preds (List[Tensor]): Mask logits from a single decoder
|
||||
layer for all images. Each with shape (num_queries, h, w).
|
||||
batch_gt_instances (List[InstanceData]): each contains
|
||||
``labels`` and ``masks``.
|
||||
|
||||
Returns:
|
||||
tuple: a tuple containing the following targets.
|
||||
|
||||
- labels (List[Tensor]): Labels of all images.\
|
||||
Each with shape (num_queries, ).
|
||||
- mask_targets (List[Tensor]): Mask targets of\
|
||||
all images. Each with shape (num_queries, h, w).
|
||||
- mask_weights (List[Tensor]): Mask weights of\
|
||||
all images. Each with shape (num_queries, ).
|
||||
- avg_factor (int): Average factor that is used to
|
||||
average the loss. `avg_factor` is usually equal
|
||||
to the number of positive priors.
|
||||
"""
|
||||
batch_size = cls_scores.shape[0]
|
||||
results = dict({
|
||||
'labels': [],
|
||||
'mask_targets': [],
|
||||
'mask_weights': [],
|
||||
})
|
||||
for i in range(batch_size):
|
||||
labels, mask_targets, mask_weights\
|
||||
= self._get_targets_single(cls_scores[i],
|
||||
mask_preds[i],
|
||||
batch_gt_instances[i])
|
||||
results['labels'].append(labels)
|
||||
results['mask_targets'].append(mask_targets)
|
||||
results['mask_weights'].append(mask_weights)
|
||||
|
||||
# shape (batch_size, num_queries)
|
||||
labels = torch.stack(results['labels'], dim=0)
|
||||
# shape (batch_size, num_gts, h, w)
|
||||
mask_targets = torch.cat(results['mask_targets'], dim=0)
|
||||
# shape (batch_size, num_queries)
|
||||
mask_weights = torch.stack(results['mask_weights'], dim=0)
|
||||
|
||||
avg_factor = sum(
|
||||
[len(gt_instances.labels) for gt_instances in batch_gt_instances])
|
||||
|
||||
res = (labels, mask_targets, mask_weights, avg_factor)
|
||||
|
||||
return res
|
||||
|
||||
def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
|
||||
gt_instances: InstanceData) \
|
||||
-> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""Compute a set of best mask matches for one image.
|
||||
|
||||
Args:
|
||||
cls_score (Tensor): Mask score logits from a single decoder layer
|
||||
for one image. Shape (num_queries, cls_out_channels).
|
||||
mask_pred (Tensor): Mask logits for a single decoder layer for one
|
||||
image. Shape (num_queries, h, w).
|
||||
gt_instances (:obj:`InstanceData`): It contains ``labels`` and
|
||||
``masks``.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: A tuple containing the following for one image.
|
||||
|
||||
- labels (Tensor): Labels of each image. \
|
||||
shape (num_queries, ).
|
||||
- mask_targets (Tensor): Mask targets of each image. \
|
||||
shape (num_queries, h, w).
|
||||
- mask_weights (Tensor): Mask weights of each image. \
|
||||
shape (num_queries, ).
|
||||
"""
|
||||
gt_labels = gt_instances.labels
|
||||
gt_masks = gt_instances.masks
|
||||
# when "gt_labels" is empty, classify all queries to background
|
||||
if len(gt_labels) == 0:
|
||||
labels = gt_labels.new_full((self.num_queries, ),
|
||||
self.num_classes,
|
||||
dtype=torch.long)
|
||||
mask_targets = gt_labels
|
||||
mask_weights = gt_labels.new_zeros((self.num_queries, ))
|
||||
return labels, mask_targets, mask_weights
|
||||
# sample points
|
||||
num_queries = cls_score.shape[0]
|
||||
num_gts = gt_labels.shape[0]
|
||||
|
||||
point_coords = torch.rand((1, self.num_points, 2),
|
||||
device=cls_score.device)
|
||||
# shape (num_queries, num_points)
|
||||
mask_points_pred = point_sample(
|
||||
mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1,
|
||||
1)).squeeze(1)
|
||||
# shape (num_gts, num_points)
|
||||
gt_points_masks = point_sample(
|
||||
gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1,
|
||||
1)).squeeze(1)
|
||||
|
||||
sampled_gt_instances = InstanceData(
|
||||
labels=gt_labels, masks=gt_points_masks)
|
||||
sampled_pred_instances = InstanceData(
|
||||
scores=cls_score, masks=mask_points_pred)
|
||||
# assign and sample
|
||||
matched_quiery_inds, matched_label_inds = self.assigner.assign(
|
||||
pred_instances=sampled_pred_instances,
|
||||
gt_instances=sampled_gt_instances)
|
||||
labels = gt_labels.new_full((self.num_queries, ),
|
||||
self.num_classes,
|
||||
dtype=torch.long)
|
||||
labels[matched_quiery_inds] = gt_labels[matched_label_inds]
|
||||
|
||||
mask_weights = gt_labels.new_zeros((self.num_queries, ))
|
||||
mask_weights[matched_quiery_inds] = 1
|
||||
mask_targets = gt_masks[matched_label_inds]
|
||||
|
||||
return labels, mask_targets, mask_weights
|
||||
Reference in New Issue
Block a user