init
This commit is contained in:
366
finetune/mmseg/models/decode_heads/decode_head.py
Normal file
366
finetune/mmseg/models/decode_heads/decode_head.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import BaseModule
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import build_pixel_sampler
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
from ..losses import accuracy
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
"""Base class for BaseDecodeHead.
|
||||
|
||||
1. The ``init_weights`` method is used to initialize decode_head's
|
||||
model parameters. After segmentor initialization, ``init_weights``
|
||||
is triggered when ``segmentor.init_weights()`` is called externally.
|
||||
|
||||
2. The ``loss`` method is used to calculate the loss of decode_head,
|
||||
which includes two steps: (1) the decode_head model performs forward
|
||||
propagation to obtain the feature maps (2) The ``loss_by_feat`` method
|
||||
is called based on the feature maps to calculate the loss.
|
||||
|
||||
.. code:: text
|
||||
|
||||
loss(): forward() -> loss_by_feat()
|
||||
|
||||
3. The ``predict`` method is used to predict segmentation results,
|
||||
which includes two steps: (1) the decode_head model performs forward
|
||||
propagation to obtain the feature maps (2) The ``predict_by_feat`` method
|
||||
is called based on the feature maps to predict segmentation results
|
||||
including post-processing.
|
||||
|
||||
.. code:: text
|
||||
|
||||
predict(): forward() -> predict_by_feat()
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
num_classes (int): Number of classes.
|
||||
out_channels (int): Output channels of conv_seg. Default: None.
|
||||
threshold (float): Threshold for binary segmentation in the case of
|
||||
`num_classes==1`. Default: None.
|
||||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
||||
conv_cfg (dict|None): Config of conv layers. Default: None.
|
||||
norm_cfg (dict|None): Config of norm layers. Default: None.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU')
|
||||
in_index (int|Sequence[int]): Input feature index. Default: -1
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
Default: None.
|
||||
loss_decode (dict | Sequence[dict]): Config of decode loss.
|
||||
The `loss_name` is property of corresponding loss function which
|
||||
could be shown in training log. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_ce'.
|
||||
e.g. dict(type='CrossEntropyLoss'),
|
||||
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
|
||||
dict(type='DiceLoss', loss_name='loss_dice')]
|
||||
Default: dict(type='CrossEntropyLoss').
|
||||
ignore_index (int | None): The label index to be ignored. When using
|
||||
masked BCE loss, ignore_index should be set to None. Default: 255.
|
||||
sampler (dict|None): The config of segmentation map sampler.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channels,
|
||||
*,
|
||||
num_classes,
|
||||
out_channels=None,
|
||||
threshold=None,
|
||||
dropout_ratio=0.1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
in_index=-1,
|
||||
input_transform=None,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0),
|
||||
ignore_index=255,
|
||||
sampler=None,
|
||||
align_corners=False,
|
||||
init_cfg=dict(
|
||||
type='Normal', std=0.01, override=dict(name='conv_seg'))):
|
||||
super().__init__(init_cfg)
|
||||
self._init_inputs(in_channels, in_index, input_transform)
|
||||
self.channels = channels
|
||||
self.dropout_ratio = dropout_ratio
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.in_index = in_index
|
||||
|
||||
self.ignore_index = ignore_index
|
||||
self.align_corners = align_corners
|
||||
|
||||
if out_channels is None:
|
||||
if num_classes == 2:
|
||||
warnings.warn('For binary segmentation, we suggest using'
|
||||
'`out_channels = 1` to define the output'
|
||||
'channels of segmentor, and use `threshold`'
|
||||
'to convert `seg_logits` into a prediction'
|
||||
'applying a threshold')
|
||||
out_channels = num_classes
|
||||
|
||||
if out_channels != num_classes and out_channels != 1:
|
||||
raise ValueError(
|
||||
'out_channels should be equal to num_classes,'
|
||||
'except binary segmentation set out_channels == 1 and'
|
||||
f'num_classes == 2, but got out_channels={out_channels}'
|
||||
f'and num_classes={num_classes}')
|
||||
|
||||
if out_channels == 1 and threshold is None:
|
||||
threshold = 0.3
|
||||
warnings.warn('threshold is not defined for binary, and defaults'
|
||||
'to 0.3')
|
||||
self.num_classes = num_classes
|
||||
self.out_channels = out_channels
|
||||
self.threshold = threshold
|
||||
|
||||
if isinstance(loss_decode, dict):
|
||||
self.loss_decode = MODELS.build(loss_decode)
|
||||
elif isinstance(loss_decode, (list, tuple)):
|
||||
self.loss_decode = nn.ModuleList()
|
||||
for loss in loss_decode:
|
||||
self.loss_decode.append(MODELS.build(loss))
|
||||
else:
|
||||
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
|
||||
but got {type(loss_decode)}')
|
||||
|
||||
if sampler is not None:
|
||||
self.sampler = build_pixel_sampler(sampler, context=self)
|
||||
else:
|
||||
self.sampler = None
|
||||
|
||||
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
|
||||
if dropout_ratio > 0:
|
||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
def extra_repr(self):
|
||||
"""Extra repr."""
|
||||
s = f'input_transform={self.input_transform}, ' \
|
||||
f'ignore_index={self.ignore_index}, ' \
|
||||
f'align_corners={self.align_corners}'
|
||||
return s
|
||||
|
||||
def _init_inputs(self, in_channels, in_index, input_transform):
|
||||
"""Check and initialize input transforms.
|
||||
|
||||
The in_channels, in_index and input_transform must match.
|
||||
Specifically, when input_transform is None, only single feature map
|
||||
will be selected. So in_channels and in_index must be of type int.
|
||||
When input_transform
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
in_index (int|Sequence[int]): Input feature index.
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
"""
|
||||
|
||||
if input_transform is not None:
|
||||
assert input_transform in ['resize_concat', 'multiple_select']
|
||||
self.input_transform = input_transform
|
||||
self.in_index = in_index
|
||||
if input_transform is not None:
|
||||
assert isinstance(in_channels, (list, tuple))
|
||||
assert isinstance(in_index, (list, tuple))
|
||||
assert len(in_channels) == len(in_index)
|
||||
if input_transform == 'resize_concat':
|
||||
self.in_channels = sum(in_channels)
|
||||
else:
|
||||
self.in_channels = in_channels
|
||||
else:
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(in_index, int)
|
||||
self.in_channels = in_channels
|
||||
|
||||
def _transform_inputs(self, inputs):
|
||||
"""Transform inputs for decoder.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
Tensor: The transformed inputs
|
||||
"""
|
||||
|
||||
if self.input_transform == 'resize_concat':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
upsampled_inputs = [
|
||||
resize(
|
||||
input=x,
|
||||
size=inputs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners) for x in inputs
|
||||
]
|
||||
inputs = torch.cat(upsampled_inputs, dim=1)
|
||||
elif self.input_transform == 'multiple_select':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
else:
|
||||
inputs = inputs[self.in_index]
|
||||
|
||||
return inputs
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, inputs):
|
||||
"""Placeholder of forward function."""
|
||||
pass
|
||||
|
||||
def cls_seg(self, feat):
|
||||
"""Classify each pixel."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.conv_seg(feat)
|
||||
return output
|
||||
|
||||
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList,
|
||||
train_cfg: ConfigType) -> dict:
|
||||
"""Forward function for training.
|
||||
|
||||
Args:
|
||||
inputs (Tuple[Tensor]): List of multi-level img features.
|
||||
batch_data_samples (list[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `img_metas` or `gt_semantic_seg`.
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
seg_logits = self.forward(inputs)
|
||||
losses = self.loss_by_feat(seg_logits, batch_data_samples)
|
||||
return losses
|
||||
|
||||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType) -> Tensor:
|
||||
"""Forward function for prediction.
|
||||
|
||||
Args:
|
||||
inputs (Tuple[Tensor]): List of multi-level img features.
|
||||
batch_img_metas (dict): List Image info where each dict may also
|
||||
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Outputs segmentation logits map.
|
||||
"""
|
||||
seg_logits = self.forward(inputs)
|
||||
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas)
|
||||
|
||||
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
|
||||
gt_semantic_segs = [
|
||||
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
|
||||
]
|
||||
return torch.stack(gt_semantic_segs, dim=0)
|
||||
|
||||
def loss_by_feat(self, seg_logits: Tensor,
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
"""Compute segmentation loss.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): The output from decode head forward function.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
|
||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||
loss = dict()
|
||||
seg_logits = resize(
|
||||
input=seg_logits,
|
||||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.sampler is not None:
|
||||
seg_weight = self.sampler.sample(seg_logits, seg_label)
|
||||
else:
|
||||
seg_weight = None
|
||||
seg_label = seg_label.squeeze(1)
|
||||
|
||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.loss_decode]
|
||||
else:
|
||||
losses_decode = self.loss_decode
|
||||
for loss_decode in losses_decode:
|
||||
if loss_decode.loss_name not in loss:
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
seg_logits,
|
||||
seg_label,
|
||||
weight=seg_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
else:
|
||||
loss[loss_decode.loss_name] += loss_decode(
|
||||
seg_logits,
|
||||
seg_label,
|
||||
weight=seg_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
|
||||
loss['acc_seg'] = accuracy(
|
||||
seg_logits, seg_label, ignore_index=self.ignore_index)
|
||||
return loss
|
||||
|
||||
def predict_by_feat(self, seg_logits: Tensor,
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Transform a batch of output seg_logits to the input shape.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): The output from decode head forward function.
|
||||
batch_img_metas (list[dict]): Meta information of each image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
|
||||
Returns:
|
||||
Tensor: Outputs segmentation logits map.
|
||||
"""
|
||||
|
||||
if isinstance(batch_img_metas[0]['img_shape'], torch.Size):
|
||||
# slide inference
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
elif 'pad_shape' in batch_img_metas[0]:
|
||||
size = batch_img_metas[0]['pad_shape'][:2]
|
||||
else:
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
|
||||
seg_logits = resize(
|
||||
input=seg_logits,
|
||||
size=size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
return seg_logits
|
||||
Reference in New Issue
Block a user