init
This commit is contained in:
8
finetune/mmseg/structures/__init__.py
Normal file
8
finetune/mmseg/structures/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .sampler import BasePixelSampler, OHEMPixelSampler, build_pixel_sampler
|
||||
from .seg_data_sample import SegDataSample
|
||||
|
||||
__all__ = [
|
||||
'SegDataSample', 'BasePixelSampler', 'OHEMPixelSampler',
|
||||
'build_pixel_sampler'
|
||||
]
|
||||
6
finetune/mmseg/structures/sampler/__init__.py
Normal file
6
finetune/mmseg/structures/sampler/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_pixel_sampler import BasePixelSampler
|
||||
from .builder import build_pixel_sampler
|
||||
from .ohem_pixel_sampler import OHEMPixelSampler
|
||||
|
||||
__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
|
||||
13
finetune/mmseg/structures/sampler/base_pixel_sampler.py
Normal file
13
finetune/mmseg/structures/sampler/base_pixel_sampler.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class BasePixelSampler(metaclass=ABCMeta):
|
||||
"""Base class of pixel sampler."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sample(self, seg_logit, seg_label):
|
||||
"""Placeholder for sample function."""
|
||||
14
finetune/mmseg/structures/sampler/builder.py
Normal file
14
finetune/mmseg/structures/sampler/builder.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmseg.registry import TASK_UTILS
|
||||
|
||||
PIXEL_SAMPLERS = TASK_UTILS
|
||||
|
||||
|
||||
def build_pixel_sampler(cfg, **default_args):
|
||||
"""Build pixel sampler for segmentation map."""
|
||||
warnings.warn(
|
||||
'``build_pixel_sampler`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
85
finetune/mmseg/structures/sampler/ohem_pixel_sampler.py
Normal file
85
finetune/mmseg/structures/sampler/ohem_pixel_sampler.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base_pixel_sampler import BasePixelSampler
|
||||
from .builder import PIXEL_SAMPLERS
|
||||
|
||||
|
||||
@PIXEL_SAMPLERS.register_module()
|
||||
class OHEMPixelSampler(BasePixelSampler):
|
||||
"""Online Hard Example Mining Sampler for segmentation.
|
||||
|
||||
Args:
|
||||
context (nn.Module): The context of sampler, subclass of
|
||||
:obj:`BaseDecodeHead`.
|
||||
thresh (float, optional): The threshold for hard example selection.
|
||||
Below which, are prediction with low confidence. If not
|
||||
specified, the hard examples will be pixels of top ``min_kept``
|
||||
loss. Default: None.
|
||||
min_kept (int, optional): The minimum number of predictions to keep.
|
||||
Default: 100000.
|
||||
"""
|
||||
|
||||
def __init__(self, context, thresh=None, min_kept=100000):
|
||||
super().__init__()
|
||||
self.context = context
|
||||
assert min_kept > 1
|
||||
self.thresh = thresh
|
||||
self.min_kept = min_kept
|
||||
|
||||
def sample(self, seg_logit, seg_label):
|
||||
"""Sample pixels that have high loss or with low prediction confidence.
|
||||
|
||||
Args:
|
||||
seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
|
||||
seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: segmentation weight, shape (N, H, W)
|
||||
"""
|
||||
with torch.no_grad():
|
||||
assert seg_logit.shape[2:] == seg_label.shape[2:]
|
||||
assert seg_label.shape[1] == 1
|
||||
seg_label = seg_label.squeeze(1).long()
|
||||
batch_kept = self.min_kept * seg_label.size(0)
|
||||
valid_mask = seg_label != self.context.ignore_index
|
||||
seg_weight = seg_logit.new_zeros(size=seg_label.size())
|
||||
valid_seg_weight = seg_weight[valid_mask]
|
||||
if self.thresh is not None:
|
||||
seg_prob = F.softmax(seg_logit, dim=1)
|
||||
|
||||
tmp_seg_label = seg_label.clone().unsqueeze(1)
|
||||
tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
|
||||
seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
|
||||
sort_prob, sort_indices = seg_prob[valid_mask].sort()
|
||||
|
||||
if sort_prob.numel() > 0:
|
||||
min_threshold = sort_prob[min(batch_kept,
|
||||
sort_prob.numel() - 1)]
|
||||
else:
|
||||
min_threshold = 0.0
|
||||
threshold = max(min_threshold, self.thresh)
|
||||
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
|
||||
else:
|
||||
if not isinstance(self.context.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.context.loss_decode]
|
||||
else:
|
||||
losses_decode = self.context.loss_decode
|
||||
losses = 0.0
|
||||
for loss_module in losses_decode:
|
||||
losses += loss_module(
|
||||
seg_logit,
|
||||
seg_label,
|
||||
weight=None,
|
||||
ignore_index=self.context.ignore_index,
|
||||
reduction_override='none')
|
||||
|
||||
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
|
||||
_, sort_indices = losses[valid_mask].sort(descending=True)
|
||||
valid_seg_weight[sort_indices[:batch_kept]] = 1.
|
||||
|
||||
seg_weight[valid_mask] = valid_seg_weight
|
||||
|
||||
return seg_weight
|
||||
92
finetune/mmseg/structures/seg_data_sample.py
Normal file
92
finetune/mmseg/structures/seg_data_sample.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.structures import BaseDataElement, PixelData
|
||||
|
||||
|
||||
class SegDataSample(BaseDataElement):
|
||||
"""A data structure interface of MMSegmentation. They are used as
|
||||
interfaces between different components.
|
||||
|
||||
The attributes in ``SegDataSample`` are divided into several parts:
|
||||
|
||||
- ``gt_sem_seg``(PixelData): Ground truth of semantic segmentation.
|
||||
- ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
|
||||
- ``seg_logits``(PixelData): Predicted logits of semantic segmentation.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> import numpy as np
|
||||
>>> from mmengine.structures import PixelData
|
||||
>>> from mmseg.structures import SegDataSample
|
||||
|
||||
>>> data_sample = SegDataSample()
|
||||
>>> img_meta = dict(img_shape=(4, 4, 3),
|
||||
... pad_shape=(4, 4, 3))
|
||||
>>> gt_segmentations = PixelData(metainfo=img_meta)
|
||||
>>> gt_segmentations.data = torch.randint(0, 2, (1, 4, 4))
|
||||
>>> data_sample.gt_sem_seg = gt_segmentations
|
||||
>>> assert 'img_shape' in data_sample.gt_sem_seg.metainfo_keys()
|
||||
>>> data_sample.gt_sem_seg.shape
|
||||
(4, 4)
|
||||
>>> print(data_sample)
|
||||
<SegDataSample(
|
||||
|
||||
META INFORMATION
|
||||
|
||||
DATA FIELDS
|
||||
gt_sem_seg: <PixelData(
|
||||
|
||||
META INFORMATION
|
||||
img_shape: (4, 4, 3)
|
||||
pad_shape: (4, 4, 3)
|
||||
|
||||
DATA FIELDS
|
||||
data: tensor([[[1, 1, 1, 0],
|
||||
[1, 0, 1, 1],
|
||||
[1, 1, 1, 1],
|
||||
[0, 1, 0, 1]]])
|
||||
) at 0x1c2b4156460>
|
||||
) at 0x1c2aae44d60>
|
||||
|
||||
>>> data_sample = SegDataSample()
|
||||
>>> gt_sem_seg_data = dict(sem_seg=torch.rand(1, 4, 4))
|
||||
>>> gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
>>> data_sample.gt_sem_seg = gt_sem_seg
|
||||
>>> assert 'gt_sem_seg' in data_sample
|
||||
>>> assert 'sem_seg' in data_sample.gt_sem_seg
|
||||
"""
|
||||
|
||||
@property
|
||||
def gt_sem_seg(self) -> PixelData:
|
||||
return self._gt_sem_seg
|
||||
|
||||
@gt_sem_seg.setter
|
||||
def gt_sem_seg(self, value: PixelData) -> None:
|
||||
self.set_field(value, '_gt_sem_seg', dtype=PixelData)
|
||||
|
||||
@gt_sem_seg.deleter
|
||||
def gt_sem_seg(self) -> None:
|
||||
del self._gt_sem_seg
|
||||
|
||||
@property
|
||||
def pred_sem_seg(self) -> PixelData:
|
||||
return self._pred_sem_seg
|
||||
|
||||
@pred_sem_seg.setter
|
||||
def pred_sem_seg(self, value: PixelData) -> None:
|
||||
self.set_field(value, '_pred_sem_seg', dtype=PixelData)
|
||||
|
||||
@pred_sem_seg.deleter
|
||||
def pred_sem_seg(self) -> None:
|
||||
del self._pred_sem_seg
|
||||
|
||||
@property
|
||||
def seg_logits(self) -> PixelData:
|
||||
return self._seg_logits
|
||||
|
||||
@seg_logits.setter
|
||||
def seg_logits(self, value: PixelData) -> None:
|
||||
self.set_field(value, '_seg_logits', dtype=PixelData)
|
||||
|
||||
@seg_logits.deleter
|
||||
def seg_logits(self) -> None:
|
||||
del self._seg_logits
|
||||
Reference in New Issue
Block a user