This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

View File

@@ -0,0 +1,15 @@
from .uper_head import UPerHead
from .up_head import UPHead
__all__ = [
'UPerHead', 'UPHead'
]
type_mapping = {
'UPerHead': UPerHead,
'UPHead': UPHead
}
def build_head(type, **kwargs):
return type_mapping[type](**kwargs)

View File

@@ -0,0 +1,201 @@
import warnings
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
from mmcv.runner import BaseModule, auto_fp16
from mmseg.core import build_pixel_sampler
from mmseg.ops import resize
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
"""Base class for BaseDecodeHead.
Args:
in_channels (int|Sequence[int]): Input channels.
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
out_channels (int): Output channels of conv_seg.
threshold (float): Threshold for binary segmentation in the case of
`out_channels==1`. Default: None.
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
conv_cfg (dict|None): Config of conv layers. Default: None.
norm_cfg (dict|None): Config of norm layers. Default: None.
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU')
in_index (int|Sequence[int]): Input feature index. Default: -1
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
Default: None.
loss_decode (dict | Sequence[dict]): Config of decode loss.
The `loss_name` is property of corresponding loss function which
could be shown in training log. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
e.g. dict(type='CrossEntropyLoss'),
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='DiceLoss', loss_name='loss_dice')]
Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255.
sampler (dict|None): The config of segmentation map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
in_channels,
channels,
*,
num_classes,
out_channels=None,
threshold=None,
dropout_ratio=0.1,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
in_index=-1,
input_transform=None,
sampler=None,
align_corners=False,
init_cfg=dict(type='Normal',
std=0.01,
override=dict(name='conv_seg'))):
super(BaseDecodeHead, self).__init__(init_cfg)
self._init_inputs(in_channels, in_index, input_transform)
self.channels = channels
self.dropout_ratio = dropout_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.in_index = in_index
self.align_corners = align_corners
if out_channels is None:
if num_classes == 2:
warnings.warn('For binary segmentation, we suggest using'
'`out_channels = 1` to define the output'
'channels of segmentor, and use `threshold`'
'to convert seg_logist into a prediction'
'applying a threshold')
out_channels = num_classes
if out_channels != num_classes and out_channels != 1:
raise ValueError(
'out_channels should be equal to num_classes,'
'except binary segmentation set out_channels == 1 and'
f'num_classes == 2, but got out_channels={out_channels}'
f'and num_classes={num_classes}')
if out_channels == 1 and threshold is None:
threshold = 0.3
warnings.warn('threshold is not defined for binary, and defaults'
'to 0.3')
self.num_classes = num_classes
self.out_channels = out_channels
self.threshold = threshold
if sampler is not None:
self.sampler = build_pixel_sampler(sampler, context=self)
else:
self.sampler = None
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
if dropout_ratio > 0:
self.dropout = nn.Dropout2d(dropout_ratio)
else:
self.dropout = None
self.fp16_enabled = False
def extra_repr(self):
"""Extra repr."""
s = f'input_transform={self.input_transform}, ' \
f'align_corners={self.align_corners}'
return s
def _init_inputs(self, in_channels, in_index, input_transform):
"""Check and initialize input transforms.
The in_channels, in_index and input_transform must match.
Specifically, when input_transform is None, only single feature map
will be selected. So in_channels and in_index must be of type int.
When input_transform
Args:
in_channels (int|Sequence[int]): Input channels.
in_index (int|Sequence[int]): Input feature index.
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
"""
if input_transform is not None:
assert input_transform in ['resize_concat', 'multiple_select']
self.input_transform = input_transform
self.in_index = in_index
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(in_index, (list, tuple))
assert len(in_channels) == len(in_index)
if input_transform == 'resize_concat':
self.in_channels = sum(in_channels)
else:
self.in_channels = in_channels
else:
assert isinstance(in_channels, int)
assert isinstance(in_index, int)
self.in_channels = in_channels
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
@auto_fp16()
@abstractmethod
def forward(self, inputs):
"""Placeholder of forward function."""
pass
def cls_seg(self, feat):
"""Classify each pixel."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.conv_seg(feat)
return output

View File

@@ -0,0 +1,60 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
class PPM(nn.ModuleList):
"""Pooling Pyramid Module used in PSPNet.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
align_corners (bool): align_corners argument of F.interpolate.
"""
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
act_cfg, align_corners, **kwargs):
super(PPM, self).__init__()
self.pool_scales = pool_scales
self.align_corners = align_corners
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
for pool_scale in pool_scales:
self.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
**kwargs)))
def forward(self, x):
"""Forward function."""
ppm_outs = []
for ppm in self:
ppm_out = ppm(x)
ppm_out = ppm_out.to(torch.float32)
upsampled_ppm_out = resize(
ppm_out,
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
upsampled_ppm_out = upsampled_ppm_out.to(torch.bfloat16)
ppm_outs.append(upsampled_ppm_out)
return ppm_outs

View File

@@ -0,0 +1,52 @@
import torch.nn as nn
from collections import OrderedDict
from mmcv.cnn.utils.weight_init import (kaiming_init, trunc_normal_)
from mmcv.runner import (CheckpointLoader, load_state_dict)
from mmseg.utils import get_root_logger
class UPHead(nn.Module):
def __init__(self, in_dim, out_dim, up_scale, init_cfg=None):
super().__init__()
self.decoder = nn.Sequential(
nn.Conv2d(in_channels=in_dim,
out_channels=up_scale**2 * out_dim,
kernel_size=1),
nn.PixelShuffle(up_scale),
)
self.init_cfg = init_cfg
self.apply(self._init_weights)
def _init_weights(self, m):
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
logger = get_root_logger()
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
_state_dict = checkpoint['state_dict']
else:
_state_dict = checkpoint
state_dict = OrderedDict()
for k, v in _state_dict.items():
if k.startswith('backbone.'):
state_dict[k[9:]] = v
else:
state_dict[k] = v
print(f'loading weight: {self.init_cfg["checkpoint"]}')
load_state_dict(self, state_dict, strict=False, logger=logger)
else:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
kaiming_init(m, mode='fan_in', bias=0.)
def forward(self, x):
x = self.decoder(x)
return x

View File

@@ -0,0 +1,130 @@
# coding: utf-8
# Copyright (c) Ant Group. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from .psp_head import PPM
class UPerHead(BaseDecodeHead):
"""Unified Perceptual Parsing for Scene Understanding.
This head is the implementation of `UPerNet
<https://arxiv.org/abs/1807.10221>`_.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module applied on the last feature. Default: (1, 2, 3, 6).
"""
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super(UPerHead, self).__init__(
input_transform='multiple_select', **kwargs)
# PSP Module
self.psp_modules = PPM(
pool_scales,
self.in_channels[-1],
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.bottleneck = ConvModule(
self.in_channels[-1] + len(pool_scales) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
# FPN Module
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
for in_channels in self.in_channels[:-1]: # skip the top layer
l_conv = ConvModule(
in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False)
fpn_conv = ConvModule(
self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
self.fpn_bottleneck = ConvModule(
len(self.in_channels) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def psp_forward(self, inputs):
"""Forward function of PSP module."""
x = inputs[-1]
psp_outs = [x]
# breakpoint()
psp_outs.extend(self.psp_modules(x))
psp_outs = torch.cat(psp_outs, dim=1)
output = self.bottleneck(psp_outs)
return output
def forward(self, inputs):
"""Forward function."""
# breakpoint()
inputs = self._transform_inputs(inputs)
# build laterals
laterals = [
lateral_conv(inputs[i])
for i, lateral_conv in enumerate(self.lateral_convs)
]
laterals.append(self.psp_forward(inputs))
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
prev_shape = laterals[i - 1].shape[2:]
laterals[i] = laterals[i].type(torch.float32)
laterals[i - 1] = laterals[i - 1] + resize(
laterals[i],
size=prev_shape,
mode='bilinear',
align_corners=self.align_corners)
# build outputs
fpn_outs = [
self.fpn_convs[i](laterals[i])
for i in range(used_backbone_levels - 1)
]
# append psp feature
fpn_outs.append(laterals[-1])
for i in range(used_backbone_levels - 1, 0, -1):
fpn_outs[i] = fpn_outs[i].type(torch.float32)
fpn_outs[i] = resize(
fpn_outs[i],
size=fpn_outs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners)
fpn_outs = torch.cat(fpn_outs, dim=1)
output = self.fpn_bottleneck(fpn_outs)
output = self.cls_seg(output)
return output