init
This commit is contained in:
200
finetune/mmseg/models/segmentors/base.py
Normal file
200
finetune/mmseg/models/segmentors/base.py
Normal file
@@ -0,0 +1,200 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.structures import PixelData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import (ForwardResults, OptConfigType, OptMultiConfig,
|
||||
OptSampleList, SampleList)
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
||||
"""Base class for segmentors.
|
||||
|
||||
Args:
|
||||
data_preprocessor (dict, optional): Model preprocessing config
|
||||
for processing the input data. it usually includes
|
||||
``to_rgb``, ``pad_size_divisor``, ``pad_val``,
|
||||
``mean`` and ``std``. Default to None.
|
||||
init_cfg (dict, optional): the config to control the
|
||||
initialization. Default to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_preprocessor: OptConfigType = None,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
super().__init__(
|
||||
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
||||
|
||||
@property
|
||||
def with_neck(self) -> bool:
|
||||
"""bool: whether the segmentor has neck"""
|
||||
return hasattr(self, 'neck') and self.neck is not None
|
||||
|
||||
@property
|
||||
def with_auxiliary_head(self) -> bool:
|
||||
"""bool: whether the segmentor has auxiliary head"""
|
||||
return hasattr(self,
|
||||
'auxiliary_head') and self.auxiliary_head is not None
|
||||
|
||||
@property
|
||||
def with_decode_head(self) -> bool:
|
||||
"""bool: whether the segmentor has decode head"""
|
||||
return hasattr(self, 'decode_head') and self.decode_head is not None
|
||||
|
||||
@abstractmethod
|
||||
def extract_feat(self, inputs: Tensor) -> bool:
|
||||
"""Placeholder for extract features from images."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def encode_decode(self, inputs: Tensor, batch_data_samples: SampleList):
|
||||
"""Placeholder for encode images with backbone and decode into a
|
||||
semantic segmentation map of the same size as input."""
|
||||
pass
|
||||
|
||||
def forward(self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None,
|
||||
mode: str = 'tensor') -> ForwardResults:
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
|
||||
The method should accept three modes: "tensor", "predict" and "loss":
|
||||
|
||||
- "tensor": Forward the whole network and return tensor or tuple of
|
||||
tensor without any post-processing, same as a common nn.Module.
|
||||
- "predict": Forward and return the predictions, which are fully
|
||||
processed to a list of :obj:`SegDataSample`.
|
||||
- "loss": Forward and return a dict of losses according to the given
|
||||
inputs and data samples.
|
||||
|
||||
Note that this method doesn't handle neither back propagation nor
|
||||
optimizer updating, which are done in the :meth:`train_step`.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor with shape (N, C, ...) in
|
||||
general.
|
||||
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
|
||||
It usually includes information such as `metainfo` and
|
||||
`gt_sem_seg`. Default to None.
|
||||
mode (str): Return what kind of value. Defaults to 'tensor'.
|
||||
|
||||
Returns:
|
||||
The return type depends on ``mode``.
|
||||
|
||||
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
|
||||
- If ``mode="predict"``, return a list of :obj:`DetDataSample`.
|
||||
- If ``mode="loss"``, return a dict of tensor.
|
||||
"""
|
||||
if mode == 'loss':
|
||||
return self.loss(inputs, data_samples)
|
||||
elif mode == 'predict':
|
||||
return self.predict(inputs, data_samples)
|
||||
elif mode == 'tensor':
|
||||
return self._forward(inputs, data_samples)
|
||||
else:
|
||||
raise RuntimeError(f'Invalid mode "{mode}". '
|
||||
'Only supports loss, predict and tensor mode')
|
||||
|
||||
@abstractmethod
|
||||
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
|
||||
"""Calculate losses from a batch of inputs and data samples."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict(self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None) -> SampleList:
|
||||
"""Predict results from a batch of inputs and data samples with post-
|
||||
processing."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _forward(self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
|
||||
"""Network forward process.
|
||||
|
||||
Usually includes backbone, neck and head forward without any post-
|
||||
processing.
|
||||
"""
|
||||
pass
|
||||
|
||||
def postprocess_result(self,
|
||||
seg_logits: Tensor,
|
||||
data_samples: OptSampleList = None) -> SampleList:
|
||||
""" Convert results list to `SegDataSample`.
|
||||
Args:
|
||||
seg_logits (Tensor): The segmentation results, seg_logits from
|
||||
model of each input image.
|
||||
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
|
||||
It usually includes information such as `metainfo` and
|
||||
`gt_sem_seg`. Default to None.
|
||||
Returns:
|
||||
list[:obj:`SegDataSample`]: Segmentation results of the
|
||||
input images. Each SegDataSample usually contain:
|
||||
|
||||
- ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
|
||||
- ``seg_logits``(PixelData): Predicted logits of semantic
|
||||
segmentation before normalization.
|
||||
"""
|
||||
batch_size, C, H, W = seg_logits.shape
|
||||
|
||||
if data_samples is None:
|
||||
data_samples = [SegDataSample() for _ in range(batch_size)]
|
||||
only_prediction = True
|
||||
else:
|
||||
only_prediction = False
|
||||
|
||||
for i in range(batch_size):
|
||||
if not only_prediction:
|
||||
img_meta = data_samples[i].metainfo
|
||||
# remove padding area
|
||||
if 'img_padding_size' not in img_meta:
|
||||
padding_size = img_meta.get('padding_size', [0] * 4)
|
||||
else:
|
||||
padding_size = img_meta['img_padding_size']
|
||||
padding_left, padding_right, padding_top, padding_bottom =\
|
||||
padding_size
|
||||
# i_seg_logits shape is 1, C, H, W after remove padding
|
||||
i_seg_logits = seg_logits[i:i + 1, :,
|
||||
padding_top:H - padding_bottom,
|
||||
padding_left:W - padding_right]
|
||||
|
||||
flip = img_meta.get('flip', None)
|
||||
if flip:
|
||||
flip_direction = img_meta.get('flip_direction', None)
|
||||
assert flip_direction in ['horizontal', 'vertical']
|
||||
if flip_direction == 'horizontal':
|
||||
i_seg_logits = i_seg_logits.flip(dims=(3, ))
|
||||
else:
|
||||
i_seg_logits = i_seg_logits.flip(dims=(2, ))
|
||||
|
||||
# resize as original shape
|
||||
i_seg_logits = resize(
|
||||
i_seg_logits,
|
||||
size=img_meta['ori_shape'],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners,
|
||||
warning=False).squeeze(0)
|
||||
else:
|
||||
i_seg_logits = seg_logits[i]
|
||||
|
||||
if C > 1:
|
||||
i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True)
|
||||
else:
|
||||
i_seg_logits = i_seg_logits.sigmoid()
|
||||
i_seg_pred = (i_seg_logits >
|
||||
self.decode_head.threshold).to(i_seg_logits)
|
||||
data_samples[i].set_data({
|
||||
'seg_logits':
|
||||
PixelData(**{'data': i_seg_logits}),
|
||||
'pred_sem_seg':
|
||||
PixelData(**{'data': i_seg_pred})
|
||||
})
|
||||
|
||||
return data_samples
|
||||
Reference in New Issue
Block a user