This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

View File

@@ -0,0 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ann_head import ANNHead
from .apc_head import APCHead
from .aspp_head import ASPPHead
from .cc_head import CCHead
from .da_head import DAHead
from .ddr_head import DDRHead
from .dm_head import DMHead
from .dnl_head import DNLHead
from .dpt_head import DPTHead
from .ema_head import EMAHead
from .enc_head import EncHead
from .fcn_head import FCNHead
from .fpn_head import FPNHead
from .gc_head import GCHead
from .ham_head import LightHamHead
from .isa_head import ISAHead
from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator
from .lraspp_head import LRASPPHead
from .mask2former_head import Mask2FormerHead
from .maskformer_head import MaskFormerHead
from .nl_head import NLHead
from .ocr_head import OCRHead
from .pid_head import PIDHead
from .point_head import PointHead
from .psa_head import PSAHead
from .psp_head import PSPHead
from .san_head import SideAdapterCLIPHead
from .segformer_head import SegformerHead
from .segmenter_mask_head import SegmenterMaskTransformerHead
from .sep_aspp_head import DepthwiseSeparableASPPHead
from .sep_fcn_head import DepthwiseSeparableFCNHead
from .setr_mla_head import SETRMLAHead
from .setr_up_head import SETRUPHead
from .stdc_head import STDCHead
from .uper_head import UPerHead
from .vpd_depth_head import VPDDepthHead
__all__ = [
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead',
'LightHamHead', 'PIDHead', 'DDRHead', 'VPDDepthHead', 'SideAdapterCLIPHead'
]

View File

@@ -0,0 +1,245 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.registry import MODELS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .decode_head import BaseDecodeHead
class PPMConcat(nn.ModuleList):
"""Pyramid Pooling Module that only concat the features of each layer.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
"""
def __init__(self, pool_scales=(1, 3, 6, 8)):
super().__init__(
[nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales])
def forward(self, feats):
"""Forward function."""
ppm_outs = []
for ppm in self:
ppm_out = ppm(feats)
ppm_outs.append(ppm_out.view(*feats.shape[:2], -1))
concat_outs = torch.cat(ppm_outs, dim=2)
return concat_outs
class SelfAttentionBlock(_SelfAttentionBlock):
"""Make a ANN used SelfAttentionBlock.
Args:
low_in_channels (int): Input channels of lower level feature,
which is the key feature for self-attention.
high_in_channels (int): Input channels of higher level feature,
which is the query feature for self-attention.
channels (int): Output channels of key/query transform.
out_channels (int): Output channels.
share_key_query (bool): Whether share projection weight between key
and query projection.
query_scale (int): The scale of query feature map.
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module of key feature.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict|None): Config of activation layers.
"""
def __init__(self, low_in_channels, high_in_channels, channels,
out_channels, share_key_query, query_scale, key_pool_scales,
conv_cfg, norm_cfg, act_cfg):
key_psp = PPMConcat(key_pool_scales)
if query_scale > 1:
query_downsample = nn.MaxPool2d(kernel_size=query_scale)
else:
query_downsample = None
super().__init__(
key_in_channels=low_in_channels,
query_in_channels=high_in_channels,
channels=channels,
out_channels=out_channels,
share_key_query=share_key_query,
query_downsample=query_downsample,
key_downsample=key_psp,
key_query_num_convs=1,
key_query_norm=True,
value_out_num_convs=1,
value_out_norm=False,
matmul_norm=True,
with_out=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
class AFNB(nn.Module):
"""Asymmetric Fusion Non-local Block(AFNB)
Args:
low_in_channels (int): Input channels of lower level feature,
which is the key feature for self-attention.
high_in_channels (int): Input channels of higher level feature,
which is the query feature for self-attention.
channels (int): Output channels of key/query transform.
out_channels (int): Output channels.
and query projection.
query_scales (tuple[int]): The scales of query feature map.
Default: (1,)
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module of key feature.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict|None): Config of activation layers.
"""
def __init__(self, low_in_channels, high_in_channels, channels,
out_channels, query_scales, key_pool_scales, conv_cfg,
norm_cfg, act_cfg):
super().__init__()
self.stages = nn.ModuleList()
for query_scale in query_scales:
self.stages.append(
SelfAttentionBlock(
low_in_channels=low_in_channels,
high_in_channels=high_in_channels,
channels=channels,
out_channels=out_channels,
share_key_query=False,
query_scale=query_scale,
key_pool_scales=key_pool_scales,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.bottleneck = ConvModule(
out_channels + high_in_channels,
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
def forward(self, low_feats, high_feats):
"""Forward function."""
priors = [stage(high_feats, low_feats) for stage in self.stages]
context = torch.stack(priors, dim=0).sum(dim=0)
output = self.bottleneck(torch.cat([context, high_feats], 1))
return output
class APNB(nn.Module):
"""Asymmetric Pyramid Non-local Block (APNB)
Args:
in_channels (int): Input channels of key/query feature,
which is the key feature for self-attention.
channels (int): Output channels of key/query transform.
out_channels (int): Output channels.
query_scales (tuple[int]): The scales of query feature map.
Default: (1,)
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module of key feature.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict|None): Config of activation layers.
"""
def __init__(self, in_channels, channels, out_channels, query_scales,
key_pool_scales, conv_cfg, norm_cfg, act_cfg):
super().__init__()
self.stages = nn.ModuleList()
for query_scale in query_scales:
self.stages.append(
SelfAttentionBlock(
low_in_channels=in_channels,
high_in_channels=in_channels,
channels=channels,
out_channels=out_channels,
share_key_query=True,
query_scale=query_scale,
key_pool_scales=key_pool_scales,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.bottleneck = ConvModule(
2 * in_channels,
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, feats):
"""Forward function."""
priors = [stage(feats, feats) for stage in self.stages]
context = torch.stack(priors, dim=0).sum(dim=0)
output = self.bottleneck(torch.cat([context, feats], 1))
return output
@MODELS.register_module()
class ANNHead(BaseDecodeHead):
"""Asymmetric Non-local Neural Networks for Semantic Segmentation.
This head is the implementation of `ANNNet
<https://arxiv.org/abs/1908.07678>`_.
Args:
project_channels (int): Projection channels for Nonlocal.
query_scales (tuple[int]): The scales of query feature map.
Default: (1,)
key_pool_scales (tuple[int]): The pooling scales of key feature map.
Default: (1, 3, 6, 8).
"""
def __init__(self,
project_channels,
query_scales=(1, ),
key_pool_scales=(1, 3, 6, 8),
**kwargs):
super().__init__(input_transform='multiple_select', **kwargs)
assert len(self.in_channels) == 2
low_in_channels, high_in_channels = self.in_channels
self.project_channels = project_channels
self.fusion = AFNB(
low_in_channels=low_in_channels,
high_in_channels=high_in_channels,
out_channels=high_in_channels,
channels=project_channels,
query_scales=query_scales,
key_pool_scales=key_pool_scales,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.bottleneck = ConvModule(
high_in_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.context = APNB(
in_channels=self.channels,
out_channels=self.channels,
channels=project_channels,
query_scales=query_scales,
key_pool_scales=key_pool_scales,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
low_feats, high_feats = self._transform_inputs(inputs)
output = self.fusion(low_feats, high_feats)
output = self.dropout(output)
output = self.bottleneck(output)
output = self.context(output)
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,159 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.registry import MODELS
from ..utils import resize
from .decode_head import BaseDecodeHead
class ACM(nn.Module):
"""Adaptive Context Module used in APCNet.
Args:
pool_scale (int): Pooling scale used in Adaptive Context
Module to extract region features.
fusion (bool): Add one conv to fuse residual feature.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict | None): Config of conv layers.
norm_cfg (dict | None): Config of norm layers.
act_cfg (dict): Config of activation layers.
"""
def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg,
norm_cfg, act_cfg):
super().__init__()
self.pool_scale = pool_scale
self.fusion = fusion
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.pooled_redu_conv = ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.input_redu_conv = ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.global_info = ConvModule(
self.channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)
self.residual_conv = ConvModule(
self.channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.fusion:
self.fusion_conv = ConvModule(
self.channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, x):
"""Forward function."""
pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale)
# [batch_size, channels, h, w]
x = self.input_redu_conv(x)
# [batch_size, channels, pool_scale, pool_scale]
pooled_x = self.pooled_redu_conv(pooled_x)
batch_size = x.size(0)
# [batch_size, pool_scale * pool_scale, channels]
pooled_x = pooled_x.view(batch_size, self.channels,
-1).permute(0, 2, 1).contiguous()
# [batch_size, h * w, pool_scale * pool_scale]
affinity_matrix = self.gla(x + resize(
self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:])
).permute(0, 2, 3, 1).reshape(
batch_size, -1, self.pool_scale**2)
affinity_matrix = F.sigmoid(affinity_matrix)
# [batch_size, h * w, channels]
z_out = torch.matmul(affinity_matrix, pooled_x)
# [batch_size, channels, h * w]
z_out = z_out.permute(0, 2, 1).contiguous()
# [batch_size, channels, h, w]
z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3))
z_out = self.residual_conv(z_out)
z_out = F.relu(z_out + x)
if self.fusion:
z_out = self.fusion_conv(z_out)
return z_out
@MODELS.register_module()
class APCHead(BaseDecodeHead):
"""Adaptive Pyramid Context Network for Semantic Segmentation.
This head is the implementation of
`APCNet <https://openaccess.thecvf.com/content_CVPR_2019/papers/\
He_Adaptive_Pyramid_Context_Network_for_Semantic_Segmentation_\
CVPR_2019_paper.pdf>`_.
Args:
pool_scales (tuple[int]): Pooling scales used in Adaptive Context
Module. Default: (1, 2, 3, 6).
fusion (bool): Add one conv to fuse residual feature.
"""
def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs):
super().__init__(**kwargs)
assert isinstance(pool_scales, (list, tuple))
self.pool_scales = pool_scales
self.fusion = fusion
acm_modules = []
for pool_scale in self.pool_scales:
acm_modules.append(
ACM(pool_scale,
self.fusion,
self.in_channels,
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.acm_modules = nn.ModuleList(acm_modules)
self.bottleneck = ConvModule(
self.in_channels + len(pool_scales) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
acm_outs = [x]
for acm_module in self.acm_modules:
acm_outs.append(acm_module(x))
acm_outs = torch.cat(acm_outs, dim=1)
output = self.bottleneck(acm_outs)
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,122 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.registry import MODELS
from ..utils import resize
from .decode_head import BaseDecodeHead
class ASPPModule(nn.ModuleList):
"""Atrous Spatial Pyramid Pooling (ASPP) Module.
Args:
dilations (tuple[int]): Dilation rate of each layer.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
"""
def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,
act_cfg):
super().__init__()
self.dilations = dilations
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
for dilation in dilations:
self.append(
ConvModule(
self.in_channels,
self.channels,
1 if dilation == 1 else 3,
dilation=dilation,
padding=0 if dilation == 1 else dilation,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
def forward(self, x):
"""Forward function."""
aspp_outs = []
for aspp_module in self:
aspp_outs.append(aspp_module(x))
return aspp_outs
@MODELS.register_module()
class ASPPHead(BaseDecodeHead):
"""Rethinking Atrous Convolution for Semantic Image Segmentation.
This head is the implementation of `DeepLabV3
<https://arxiv.org/abs/1706.05587>`_.
Args:
dilations (tuple[int]): Dilation rates for ASPP module.
Default: (1, 6, 12, 18).
"""
def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
super().__init__(**kwargs)
assert isinstance(dilations, (list, tuple))
self.dilations = dilations
self.image_pool = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.aspp_modules = ASPPModule(
dilations,
self.in_channels,
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.bottleneck = ConvModule(
(len(dilations) + 1) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def _forward_feature(self, inputs):
"""Forward function for feature maps before classifying each pixel with
``self.cls_seg`` fc.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
feats (Tensor): A tensor of shape (batch_size, self.channels,
H, W) which is feature map for last layer of decoder head.
"""
x = self._transform_inputs(inputs)
aspp_outs = [
resize(
self.image_pool(x),
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
]
aspp_outs.extend(self.aspp_modules(x))
aspp_outs = torch.cat(aspp_outs, dim=1)
feats = self.bottleneck(aspp_outs)
return feats
def forward(self, inputs):
"""Forward function."""
output = self._forward_feature(inputs)
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,62 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import List
from torch import Tensor
from mmseg.utils import ConfigType
from .decode_head import BaseDecodeHead
class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
"""Base class for cascade decode head used in
:class:`CascadeEncoderDecoder."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@abstractmethod
def forward(self, inputs, prev_output):
"""Placeholder of forward function."""
pass
def loss(self, inputs: List[Tensor], prev_output: Tensor,
batch_data_samples: List[dict], train_cfg: ConfigType) -> Tensor:
"""Forward function for training.
Args:
inputs (List[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
batch_data_samples (List[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_sem_seg`.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs, prev_output)
losses = self.loss_by_feat(seg_logits, batch_data_samples)
return losses
def predict(self, inputs: List[Tensor], prev_output: Tensor,
batch_img_metas: List[dict], tese_cfg: ConfigType):
"""Forward function for testing.
Args:
inputs (List[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
batch_img_metas (dict): List Image info where each dict may also
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', and 'pad_shape'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
seg_logits = self.forward(inputs, prev_output)
return self.predict_by_feat(seg_logits, batch_img_metas)

View File

@@ -0,0 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.registry import MODELS
from .fcn_head import FCNHead
try:
from mmcv.ops import CrissCrossAttention
except ModuleNotFoundError:
CrissCrossAttention = None
@MODELS.register_module()
class CCHead(FCNHead):
"""CCNet: Criss-Cross Attention for Semantic Segmentation.
This head is the implementation of `CCNet
<https://arxiv.org/abs/1811.11721>`_.
Args:
recurrence (int): Number of recurrence of Criss Cross Attention
module. Default: 2.
"""
def __init__(self, recurrence=2, **kwargs):
if CrissCrossAttention is None:
raise RuntimeError('Please install mmcv-full for '
'CrissCrossAttention ops')
super().__init__(num_convs=2, **kwargs)
self.recurrence = recurrence
self.cca = CrissCrossAttention(self.channels)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs[0](x)
for _ in range(self.recurrence):
output = self.cca(output)
output = self.convs[1](output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,184 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
import torch.nn.functional as F
from mmcv.cnn import ConvModule, Scale
from torch import Tensor, nn
from mmseg.registry import MODELS
from mmseg.utils import SampleList, add_prefix
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .decode_head import BaseDecodeHead
class PAM(_SelfAttentionBlock):
"""Position Attention Module (PAM)
Args:
in_channels (int): Input channels of key/query feature.
channels (int): Output channels of key/query transform.
"""
def __init__(self, in_channels, channels):
super().__init__(
key_in_channels=in_channels,
query_in_channels=in_channels,
channels=channels,
out_channels=in_channels,
share_key_query=False,
query_downsample=None,
key_downsample=None,
key_query_num_convs=1,
key_query_norm=False,
value_out_num_convs=1,
value_out_norm=False,
matmul_norm=False,
with_out=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=None)
self.gamma = Scale(0)
def forward(self, x):
"""Forward function."""
out = super().forward(x, x)
out = self.gamma(out) + x
return out
class CAM(nn.Module):
"""Channel Attention Module (CAM)"""
def __init__(self):
super().__init__()
self.gamma = Scale(0)
def forward(self, x):
"""Forward function."""
batch_size, channels, height, width = x.size()
proj_query = x.view(batch_size, channels, -1)
proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1)
energy = torch.bmm(proj_query, proj_key)
energy_new = torch.max(
energy, -1, keepdim=True)[0].expand_as(energy) - energy
attention = F.softmax(energy_new, dim=-1)
proj_value = x.view(batch_size, channels, -1)
out = torch.bmm(attention, proj_value)
out = out.view(batch_size, channels, height, width)
out = self.gamma(out) + x
return out
@MODELS.register_module()
class DAHead(BaseDecodeHead):
"""Dual Attention Network for Scene Segmentation.
This head is the implementation of `DANet
<https://arxiv.org/abs/1809.02983>`_.
Args:
pam_channels (int): The channels of Position Attention Module(PAM).
"""
def __init__(self, pam_channels, **kwargs):
super().__init__(**kwargs)
self.pam_channels = pam_channels
self.pam_in_conv = ConvModule(
self.in_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.pam = PAM(self.channels, pam_channels)
self.pam_out_conv = ConvModule(
self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.pam_conv_seg = nn.Conv2d(
self.channels, self.num_classes, kernel_size=1)
self.cam_in_conv = ConvModule(
self.in_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.cam = CAM()
self.cam_out_conv = ConvModule(
self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.cam_conv_seg = nn.Conv2d(
self.channels, self.num_classes, kernel_size=1)
def pam_cls_seg(self, feat):
"""PAM feature classification."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.pam_conv_seg(feat)
return output
def cam_cls_seg(self, feat):
"""CAM feature classification."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.cam_conv_seg(feat)
return output
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
pam_feat = self.pam_in_conv(x)
pam_feat = self.pam(pam_feat)
pam_feat = self.pam_out_conv(pam_feat)
pam_out = self.pam_cls_seg(pam_feat)
cam_feat = self.cam_in_conv(x)
cam_feat = self.cam(cam_feat)
cam_feat = self.cam_out_conv(cam_feat)
cam_out = self.cam_cls_seg(cam_feat)
feat_sum = pam_feat + cam_feat
pam_cam_out = self.cls_seg(feat_sum)
return pam_cam_out, pam_out, cam_out
def predict(self, inputs, batch_img_metas: List[dict], test_cfg,
**kwargs) -> List[Tensor]:
"""Forward function for testing, only ``pam_cam`` is used."""
seg_logits = self.forward(inputs)[0]
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
def loss_by_feat(self, seg_logit: Tuple[Tensor],
batch_data_samples: SampleList, **kwargs) -> dict:
"""Compute ``pam_cam``, ``pam``, ``cam`` loss."""
pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit
loss = dict()
loss.update(
add_prefix(
super().loss_by_feat(pam_cam_seg_logit, batch_data_samples),
'pam_cam'))
loss.update(
add_prefix(super().loss_by_feat(pam_seg_logit, batch_data_samples),
'pam'))
loss.update(
add_prefix(super().loss_by_feat(cam_seg_logit, batch_data_samples),
'cam'))
return loss

View File

@@ -0,0 +1,116 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union
import torch.nn as nn
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from torch import Tensor
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.models.losses import accuracy
from mmseg.models.utils import resize
from mmseg.registry import MODELS
from mmseg.utils import OptConfigType, SampleList
@MODELS.register_module()
class DDRHead(BaseDecodeHead):
"""Decode head for DDRNet.
Args:
in_channels (int): Number of input channels.
channels (int): Number of output channels.
num_classes (int): Number of classes.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict, optional): Config dict for activation layer.
Default: dict(type='ReLU', inplace=True).
"""
def __init__(self,
in_channels: int,
channels: int,
num_classes: int,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
**kwargs):
super().__init__(
in_channels,
channels,
num_classes=num_classes,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**kwargs)
self.head = self._make_base_head(self.in_channels, self.channels)
self.aux_head = self._make_base_head(self.in_channels // 2,
self.channels)
self.aux_cls_seg = nn.Conv2d(
self.channels, self.out_channels, kernel_size=1)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(
self,
inputs: Union[Tensor,
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
if self.training:
c3_feat, c5_feat = inputs
x_c = self.head(c5_feat)
x_c = self.cls_seg(x_c)
x_s = self.aux_head(c3_feat)
x_s = self.aux_cls_seg(x_s)
return x_c, x_s
else:
x_c = self.head(inputs)
x_c = self.cls_seg(x_c)
return x_c
def _make_base_head(self, in_channels: int,
channels: int) -> nn.Sequential:
layers = [
ConvModule(
in_channels,
channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
order=('norm', 'act', 'conv')),
build_norm_layer(self.norm_cfg, channels)[1],
build_activation_layer(self.act_cfg),
]
return nn.Sequential(*layers)
def loss_by_feat(self, seg_logits: Tuple[Tensor],
batch_data_samples: SampleList) -> dict:
loss = dict()
context_logit, spatial_logit = seg_logits
seg_label = self._stack_batch_gt(batch_data_samples)
context_logit = resize(
context_logit,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
spatial_logit = resize(
spatial_logit,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
seg_label = seg_label.squeeze(1)
loss['loss_context'] = self.loss_decode[0](context_logit, seg_label)
loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label)
loss['acc_seg'] = accuracy(
context_logit, seg_label, ignore_index=self.ignore_index)
return loss

View File

@@ -0,0 +1,366 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import ABCMeta, abstractmethod
from typing import List, Tuple
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from torch import Tensor
from mmseg.registry import MODELS
from mmseg.structures import build_pixel_sampler
from mmseg.utils import ConfigType, SampleList
from ..losses import accuracy
from ..utils import resize
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
"""Base class for BaseDecodeHead.
1. The ``init_weights`` method is used to initialize decode_head's
model parameters. After segmentor initialization, ``init_weights``
is triggered when ``segmentor.init_weights()`` is called externally.
2. The ``loss`` method is used to calculate the loss of decode_head,
which includes two steps: (1) the decode_head model performs forward
propagation to obtain the feature maps (2) The ``loss_by_feat`` method
is called based on the feature maps to calculate the loss.
.. code:: text
loss(): forward() -> loss_by_feat()
3. The ``predict`` method is used to predict segmentation results,
which includes two steps: (1) the decode_head model performs forward
propagation to obtain the feature maps (2) The ``predict_by_feat`` method
is called based on the feature maps to predict segmentation results
including post-processing.
.. code:: text
predict(): forward() -> predict_by_feat()
Args:
in_channels (int|Sequence[int]): Input channels.
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
out_channels (int): Output channels of conv_seg. Default: None.
threshold (float): Threshold for binary segmentation in the case of
`num_classes==1`. Default: None.
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
conv_cfg (dict|None): Config of conv layers. Default: None.
norm_cfg (dict|None): Config of norm layers. Default: None.
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU')
in_index (int|Sequence[int]): Input feature index. Default: -1
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
Default: None.
loss_decode (dict | Sequence[dict]): Config of decode loss.
The `loss_name` is property of corresponding loss function which
could be shown in training log. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
e.g. dict(type='CrossEntropyLoss'),
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='DiceLoss', loss_name='loss_dice')]
Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255.
sampler (dict|None): The config of segmentation map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
in_channels,
channels,
*,
num_classes,
out_channels=None,
threshold=None,
dropout_ratio=0.1,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
in_index=-1,
input_transform=None,
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
ignore_index=255,
sampler=None,
align_corners=False,
init_cfg=dict(
type='Normal', std=0.01, override=dict(name='conv_seg'))):
super().__init__(init_cfg)
self._init_inputs(in_channels, in_index, input_transform)
self.channels = channels
self.dropout_ratio = dropout_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.in_index = in_index
self.ignore_index = ignore_index
self.align_corners = align_corners
if out_channels is None:
if num_classes == 2:
warnings.warn('For binary segmentation, we suggest using'
'`out_channels = 1` to define the output'
'channels of segmentor, and use `threshold`'
'to convert `seg_logits` into a prediction'
'applying a threshold')
out_channels = num_classes
if out_channels != num_classes and out_channels != 1:
raise ValueError(
'out_channels should be equal to num_classes,'
'except binary segmentation set out_channels == 1 and'
f'num_classes == 2, but got out_channels={out_channels}'
f'and num_classes={num_classes}')
if out_channels == 1 and threshold is None:
threshold = 0.3
warnings.warn('threshold is not defined for binary, and defaults'
'to 0.3')
self.num_classes = num_classes
self.out_channels = out_channels
self.threshold = threshold
if isinstance(loss_decode, dict):
self.loss_decode = MODELS.build(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
self.loss_decode = nn.ModuleList()
for loss in loss_decode:
self.loss_decode.append(MODELS.build(loss))
else:
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
but got {type(loss_decode)}')
if sampler is not None:
self.sampler = build_pixel_sampler(sampler, context=self)
else:
self.sampler = None
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
if dropout_ratio > 0:
self.dropout = nn.Dropout2d(dropout_ratio)
else:
self.dropout = None
def extra_repr(self):
"""Extra repr."""
s = f'input_transform={self.input_transform}, ' \
f'ignore_index={self.ignore_index}, ' \
f'align_corners={self.align_corners}'
return s
def _init_inputs(self, in_channels, in_index, input_transform):
"""Check and initialize input transforms.
The in_channels, in_index and input_transform must match.
Specifically, when input_transform is None, only single feature map
will be selected. So in_channels and in_index must be of type int.
When input_transform
Args:
in_channels (int|Sequence[int]): Input channels.
in_index (int|Sequence[int]): Input feature index.
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
"""
if input_transform is not None:
assert input_transform in ['resize_concat', 'multiple_select']
self.input_transform = input_transform
self.in_index = in_index
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(in_index, (list, tuple))
assert len(in_channels) == len(in_index)
if input_transform == 'resize_concat':
self.in_channels = sum(in_channels)
else:
self.in_channels = in_channels
else:
assert isinstance(in_channels, int)
assert isinstance(in_index, int)
self.in_channels = in_channels
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(
input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
@abstractmethod
def forward(self, inputs):
"""Placeholder of forward function."""
pass
def cls_seg(self, feat):
"""Classify each pixel."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.conv_seg(feat)
return output
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList,
train_cfg: ConfigType) -> dict:
"""Forward function for training.
Args:
inputs (Tuple[Tensor]): List of multi-level img features.
batch_data_samples (list[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `img_metas` or `gt_semantic_seg`.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs)
losses = self.loss_by_feat(seg_logits, batch_data_samples)
return losses
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
test_cfg: ConfigType) -> Tensor:
"""Forward function for prediction.
Args:
inputs (Tuple[Tensor]): List of multi-level img features.
batch_img_metas (dict): List Image info where each dict may also
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', and 'pad_shape'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
test_cfg (dict): The testing config.
Returns:
Tensor: Outputs segmentation logits map.
"""
seg_logits = self.forward(inputs)
return self.predict_by_feat(seg_logits, batch_img_metas)
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
gt_semantic_segs = [
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
]
return torch.stack(gt_semantic_segs, dim=0)
def loss_by_feat(self, seg_logits: Tensor,
batch_data_samples: SampleList) -> dict:
"""Compute segmentation loss.
Args:
seg_logits (Tensor): The output from decode head forward function.
batch_data_samples (List[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_sem_seg`.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_label = self._stack_batch_gt(batch_data_samples)
loss = dict()
seg_logits = resize(
input=seg_logits,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logits, seg_label)
else:
seg_weight = None
seg_label = seg_label.squeeze(1)
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
seg_logits,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
else:
loss[loss_decode.loss_name] += loss_decode(
seg_logits,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
loss['acc_seg'] = accuracy(
seg_logits, seg_label, ignore_index=self.ignore_index)
return loss
def predict_by_feat(self, seg_logits: Tensor,
batch_img_metas: List[dict]) -> Tensor:
"""Transform a batch of output seg_logits to the input shape.
Args:
seg_logits (Tensor): The output from decode head forward function.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
Returns:
Tensor: Outputs segmentation logits map.
"""
if isinstance(batch_img_metas[0]['img_shape'], torch.Size):
# slide inference
size = batch_img_metas[0]['img_shape']
elif 'pad_shape' in batch_img_metas[0]:
size = batch_img_metas[0]['pad_shape'][:2]
else:
size = batch_img_metas[0]['img_shape']
seg_logits = resize(
input=seg_logits,
size=size,
mode='bilinear',
align_corners=self.align_corners)
return seg_logits

View File

@@ -0,0 +1,141 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
class DCM(nn.Module):
"""Dynamic Convolutional Module used in DMNet.
Args:
filter_size (int): The filter size of generated convolution kernel
used in Dynamic Convolutional Module.
fusion (bool): Add one conv to fuse DCM output feature.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict | None): Config of conv layers.
norm_cfg (dict | None): Config of norm layers.
act_cfg (dict): Config of activation layers.
"""
def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg,
norm_cfg, act_cfg):
super().__init__()
self.filter_size = filter_size
self.fusion = fusion
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1,
0)
self.input_redu_conv = ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.norm_cfg is not None:
self.norm = build_norm_layer(self.norm_cfg, self.channels)[1]
else:
self.norm = None
self.activate = build_activation_layer(self.act_cfg)
if self.fusion:
self.fusion_conv = ConvModule(
self.channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, x):
"""Forward function."""
generated_filter = self.filter_gen_conv(
F.adaptive_avg_pool2d(x, self.filter_size))
x = self.input_redu_conv(x)
b, c, h, w = x.shape
# [1, b * c, h, w], c = self.channels
x = x.view(1, b * c, h, w)
# [b * c, 1, filter_size, filter_size]
generated_filter = generated_filter.view(b * c, 1, self.filter_size,
self.filter_size)
pad = (self.filter_size - 1) // 2
if (self.filter_size - 1) % 2 == 0:
p2d = (pad, pad, pad, pad)
else:
p2d = (pad + 1, pad, pad + 1, pad)
x = F.pad(input=x, pad=p2d, mode='constant', value=0)
# [1, b * c, h, w]
output = F.conv2d(input=x, weight=generated_filter, groups=b * c)
# [b, c, h, w]
output = output.view(b, c, h, w)
if self.norm is not None:
output = self.norm(output)
output = self.activate(output)
if self.fusion:
output = self.fusion_conv(output)
return output
@MODELS.register_module()
class DMHead(BaseDecodeHead):
"""Dynamic Multi-scale Filters for Semantic Segmentation.
This head is the implementation of
`DMNet <https://openaccess.thecvf.com/content_ICCV_2019/papers/\
He_Dynamic_Multi-Scale_Filters_for_Semantic_Segmentation_\
ICCV_2019_paper.pdf>`_.
Args:
filter_sizes (tuple[int]): The size of generated convolutional filters
used in Dynamic Convolutional Module. Default: (1, 3, 5, 7).
fusion (bool): Add one conv to fuse DCM output feature.
"""
def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs):
super().__init__(**kwargs)
assert isinstance(filter_sizes, (list, tuple))
self.filter_sizes = filter_sizes
self.fusion = fusion
dcm_modules = []
for filter_size in self.filter_sizes:
dcm_modules.append(
DCM(filter_size,
self.fusion,
self.in_channels,
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.dcm_modules = nn.ModuleList(dcm_modules)
self.bottleneck = ConvModule(
self.in_channels + len(filter_sizes) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
dcm_outs = [x]
for dcm_module in self.dcm_modules:
dcm_outs.append(dcm_module(x))
dcm_outs = torch.cat(dcm_outs, dim=1)
output = self.bottleneck(dcm_outs)
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,137 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.cnn import NonLocal2d
from torch import nn
from mmseg.registry import MODELS
from .fcn_head import FCNHead
class DisentangledNonLocal2d(NonLocal2d):
"""Disentangled Non-Local Blocks.
Args:
temperature (float): Temperature to adjust attention. Default: 0.05
"""
def __init__(self, *arg, temperature, **kwargs):
super().__init__(*arg, **kwargs)
self.temperature = temperature
self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)
def embedded_gaussian(self, theta_x, phi_x):
"""Embedded gaussian with temperature."""
# NonLocal2d pairwise_weight: [N, HxW, HxW]
pairwise_weight = torch.matmul(theta_x, phi_x)
if self.use_scale:
# theta_x.shape[-1] is `self.inter_channels`
pairwise_weight /= torch.tensor(
theta_x.shape[-1],
dtype=torch.float,
device=pairwise_weight.device)**torch.tensor(
0.5, device=pairwise_weight.device)
pairwise_weight /= torch.tensor(
self.temperature, device=pairwise_weight.device)
pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight
def forward(self, x):
# x: [N, C, H, W]
n = x.size(0)
# g_x: [N, HxW, C]
g_x = self.g(x).view(n, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
# theta_x: [N, HxW, C], phi_x: [N, C, HxW]
if self.mode == 'gaussian':
theta_x = x.view(n, self.in_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
if self.sub_sample:
phi_x = self.phi(x).view(n, self.in_channels, -1)
else:
phi_x = x.view(n, self.in_channels, -1)
elif self.mode == 'concatenation':
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
else:
theta_x = self.theta(x).view(n, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(n, self.inter_channels, -1)
# subtract mean
theta_x -= theta_x.mean(dim=-2, keepdim=True)
phi_x -= phi_x.mean(dim=-1, keepdim=True)
pairwise_func = getattr(self, self.mode)
# pairwise_weight: [N, HxW, HxW]
pairwise_weight = pairwise_func(theta_x, phi_x)
# y: [N, HxW, C]
y = torch.matmul(pairwise_weight, g_x)
# y: [N, C, H, W]
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
*x.size()[2:])
# unary_mask: [N, 1, HxW]
unary_mask = self.conv_mask(x)
unary_mask = unary_mask.view(n, 1, -1)
unary_mask = unary_mask.softmax(dim=-1)
# unary_x: [N, 1, C]
unary_x = torch.matmul(unary_mask, g_x)
# unary_x: [N, C, 1, 1]
unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(
n, self.inter_channels, 1, 1)
output = x + self.conv_out(y + unary_x)
return output
@MODELS.register_module()
class DNLHead(FCNHead):
"""Disentangled Non-Local Neural Networks.
This head is the implementation of `DNLNet
<https://arxiv.org/abs/2006.06668>`_.
Args:
reduction (int): Reduction factor of projection transform. Default: 2.
use_scale (bool): Whether to scale pairwise_weight by
sqrt(1/inter_channels). Default: False.
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
'dot_product'. Default: 'embedded_gaussian.'.
temperature (float): Temperature to adjust attention. Default: 0.05
"""
def __init__(self,
reduction=2,
use_scale=True,
mode='embedded_gaussian',
temperature=0.05,
**kwargs):
super().__init__(num_convs=2, **kwargs)
self.reduction = reduction
self.use_scale = use_scale
self.mode = mode
self.temperature = temperature
self.dnl_block = DisentangledNonLocal2d(
in_channels=self.channels,
reduction=self.reduction,
use_scale=self.use_scale,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
mode=self.mode,
temperature=self.temperature)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs[0](x)
output = self.dnl_block(output)
output = self.convs[1](output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,294 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, Linear, build_activation_layer
from mmengine.model import BaseModule
from mmseg.registry import MODELS
from ..utils import resize
from .decode_head import BaseDecodeHead
class ReassembleBlocks(BaseModule):
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
rearrange the feature vector to feature map.
Args:
in_channels (int): ViT feature channels. Default: 768.
out_channels (List): output channels of each stage.
Default: [96, 192, 384, 768].
readout_type (str): Type of readout operation. Default: 'ignore'.
patch_size (int): The patch size. Default: 16.
init_cfg (dict, optional): Initialization config dict. Default: None.
"""
def __init__(self,
in_channels=768,
out_channels=[96, 192, 384, 768],
readout_type='ignore',
patch_size=16,
init_cfg=None):
super().__init__(init_cfg)
assert readout_type in ['ignore', 'add', 'project']
self.readout_type = readout_type
self.patch_size = patch_size
self.projects = nn.ModuleList([
ConvModule(
in_channels=in_channels,
out_channels=out_channel,
kernel_size=1,
act_cfg=None,
) for out_channel in out_channels
])
self.resize_layers = nn.ModuleList([
nn.ConvTranspose2d(
in_channels=out_channels[0],
out_channels=out_channels[0],
kernel_size=4,
stride=4,
padding=0),
nn.ConvTranspose2d(
in_channels=out_channels[1],
out_channels=out_channels[1],
kernel_size=2,
stride=2,
padding=0),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3],
out_channels=out_channels[3],
kernel_size=3,
stride=2,
padding=1)
])
if self.readout_type == 'project':
self.readout_projects = nn.ModuleList()
for _ in range(len(self.projects)):
self.readout_projects.append(
nn.Sequential(
Linear(2 * in_channels, in_channels),
build_activation_layer(dict(type='GELU'))))
def forward(self, inputs):
assert isinstance(inputs, list)
out = []
for i, x in enumerate(inputs):
assert len(x) == 2
x, cls_token = x[0], x[1]
feature_shape = x.shape
if self.readout_type == 'project':
x = x.flatten(2).permute((0, 2, 1))
readout = cls_token.unsqueeze(1).expand_as(x)
x = self.readout_projects[i](torch.cat((x, readout), -1))
x = x.permute(0, 2, 1).reshape(feature_shape)
elif self.readout_type == 'add':
x = x.flatten(2) + cls_token.unsqueeze(-1)
x = x.reshape(feature_shape)
else:
pass
x = self.projects[i](x)
x = self.resize_layers[i](x)
out.append(x)
return out
class PreActResidualConvUnit(BaseModule):
"""ResidualConvUnit, pre-activate residual unit.
Args:
in_channels (int): number of channels in the input feature map.
act_cfg (dict): dictionary to construct and config activation layer.
norm_cfg (dict): dictionary to construct and config norm layer.
stride (int): stride of the first block. Default: 1
dilation (int): dilation rate for convs layers. Default: 1.
init_cfg (dict, optional): Initialization config dict. Default: None.
"""
def __init__(self,
in_channels,
act_cfg,
norm_cfg,
stride=1,
dilation=1,
init_cfg=None):
super().__init__(init_cfg)
self.conv1 = ConvModule(
in_channels,
in_channels,
3,
stride=stride,
padding=dilation,
dilation=dilation,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
bias=False,
order=('act', 'conv', 'norm'))
self.conv2 = ConvModule(
in_channels,
in_channels,
3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
bias=False,
order=('act', 'conv', 'norm'))
def forward(self, inputs):
inputs_ = inputs.clone()
x = self.conv1(inputs)
x = self.conv2(x)
return x + inputs_
class FeatureFusionBlock(BaseModule):
"""FeatureFusionBlock, merge feature map from different stages.
Args:
in_channels (int): Input channels.
act_cfg (dict): The activation config for ResidualConvUnit.
norm_cfg (dict): Config dict for normalization layer.
expand (bool): Whether expand the channels in post process block.
Default: False.
align_corners (bool): align_corner setting for bilinear upsample.
Default: True.
init_cfg (dict, optional): Initialization config dict. Default: None.
"""
def __init__(self,
in_channels,
act_cfg,
norm_cfg,
expand=False,
align_corners=True,
init_cfg=None):
super().__init__(init_cfg)
self.in_channels = in_channels
self.expand = expand
self.align_corners = align_corners
self.out_channels = in_channels
if self.expand:
self.out_channels = in_channels // 2
self.project = ConvModule(
self.in_channels,
self.out_channels,
kernel_size=1,
act_cfg=None,
bias=True)
self.res_conv_unit1 = PreActResidualConvUnit(
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
self.res_conv_unit2 = PreActResidualConvUnit(
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
def forward(self, *inputs):
x = inputs[0]
if len(inputs) == 2:
if x.shape != inputs[1].shape:
res = resize(
inputs[1],
size=(x.shape[2], x.shape[3]),
mode='bilinear',
align_corners=False)
else:
res = inputs[1]
x = x + self.res_conv_unit1(res)
x = self.res_conv_unit2(x)
x = resize(
x,
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners)
x = self.project(x)
return x
@MODELS.register_module()
class DPTHead(BaseDecodeHead):
"""Vision Transformers for Dense Prediction.
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
Args:
embed_dims (int): The embed dimension of the ViT backbone.
Default: 768.
post_process_channels (List): Out channels of post process conv
layers. Default: [96, 192, 384, 768].
readout_type (str): Type of readout operation. Default: 'ignore'.
patch_size (int): The patch size. Default: 16.
expand_channels (bool): Whether expand the channels in post process
block. Default: False.
act_cfg (dict): The activation config for residual conv unit.
Default dict(type='ReLU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
"""
def __init__(self,
embed_dims=768,
post_process_channels=[96, 192, 384, 768],
readout_type='ignore',
patch_size=16,
expand_channels=False,
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN'),
**kwargs):
super().__init__(**kwargs)
self.in_channels = self.in_channels
self.expand_channels = expand_channels
self.reassemble_blocks = ReassembleBlocks(embed_dims,
post_process_channels,
readout_type, patch_size)
self.post_process_channels = [
channel * math.pow(2, i) if expand_channels else channel
for i, channel in enumerate(post_process_channels)
]
self.convs = nn.ModuleList()
for channel in self.post_process_channels:
self.convs.append(
ConvModule(
channel,
self.channels,
kernel_size=3,
padding=1,
act_cfg=None,
bias=False))
self.fusion_blocks = nn.ModuleList()
for _ in range(len(self.convs)):
self.fusion_blocks.append(
FeatureFusionBlock(self.channels, act_cfg, norm_cfg))
self.fusion_blocks[0].res_conv_unit1 = None
self.project = ConvModule(
self.channels,
self.channels,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg)
self.num_fusion_blocks = len(self.fusion_blocks)
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
self.num_post_process_channels = len(self.post_process_channels)
assert self.num_fusion_blocks == self.num_reassemble_blocks
assert self.num_reassemble_blocks == self.num_post_process_channels
def forward(self, inputs):
assert len(inputs) == self.num_reassemble_blocks
x = self._transform_inputs(inputs)
x = self.reassemble_blocks(x)
x = [self.convs[i](feature) for i, feature in enumerate(x)]
out = self.fusion_blocks[0](x[-1])
for i in range(1, len(self.fusion_blocks)):
out = self.fusion_blocks[i](out, x[-(i + 1)])
out = self.project(out)
out = self.cls_seg(out)
return out

View File

@@ -0,0 +1,169 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
def reduce_mean(tensor):
"""Reduce mean when distributed training."""
if not (dist.is_available() and dist.is_initialized()):
return tensor
tensor = tensor.clone()
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
return tensor
class EMAModule(nn.Module):
"""Expectation Maximization Attention Module used in EMANet.
Args:
channels (int): Channels of the whole module.
num_bases (int): Number of bases.
num_stages (int): Number of the EM iterations.
"""
def __init__(self, channels, num_bases, num_stages, momentum):
super().__init__()
assert num_stages >= 1, 'num_stages must be at least 1!'
self.num_bases = num_bases
self.num_stages = num_stages
self.momentum = momentum
bases = torch.zeros(1, channels, self.num_bases)
bases.normal_(0, math.sqrt(2. / self.num_bases))
# [1, channels, num_bases]
bases = F.normalize(bases, dim=1, p=2)
self.register_buffer('bases', bases)
def forward(self, feats):
"""Forward function."""
batch_size, channels, height, width = feats.size()
# [batch_size, channels, height*width]
feats = feats.view(batch_size, channels, height * width)
# [batch_size, channels, num_bases]
bases = self.bases.repeat(batch_size, 1, 1)
with torch.no_grad():
for i in range(self.num_stages):
# [batch_size, height*width, num_bases]
attention = torch.einsum('bcn,bck->bnk', feats, bases)
attention = F.softmax(attention, dim=2)
# l1 norm
attention_normed = F.normalize(attention, dim=1, p=1)
# [batch_size, channels, num_bases]
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
# l2 norm
bases = F.normalize(bases, dim=1, p=2)
feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
feats_recon = feats_recon.view(batch_size, channels, height, width)
if self.training:
bases = bases.mean(dim=0, keepdim=True)
bases = reduce_mean(bases)
# l2 norm
bases = F.normalize(bases, dim=1, p=2)
self.bases = (1 -
self.momentum) * self.bases + self.momentum * bases
return feats_recon
@MODELS.register_module()
class EMAHead(BaseDecodeHead):
"""Expectation Maximization Attention Networks for Semantic Segmentation.
This head is the implementation of `EMANet
<https://arxiv.org/abs/1907.13426>`_.
Args:
ema_channels (int): EMA module channels
num_bases (int): Number of bases.
num_stages (int): Number of the EM iterations.
concat_input (bool): Whether concat the input and output of convs
before classification layer. Default: True
momentum (float): Momentum to update the base. Default: 0.1.
"""
def __init__(self,
ema_channels,
num_bases,
num_stages,
concat_input=True,
momentum=0.1,
**kwargs):
super().__init__(**kwargs)
self.ema_channels = ema_channels
self.num_bases = num_bases
self.num_stages = num_stages
self.concat_input = concat_input
self.momentum = momentum
self.ema_module = EMAModule(self.ema_channels, self.num_bases,
self.num_stages, self.momentum)
self.ema_in_conv = ConvModule(
self.in_channels,
self.ema_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
# project (0, inf) -> (-inf, inf)
self.ema_mid_conv = ConvModule(
self.ema_channels,
self.ema_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=None,
act_cfg=None)
for param in self.ema_mid_conv.parameters():
param.requires_grad = False
self.ema_out_conv = ConvModule(
self.ema_channels,
self.ema_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.bottleneck = ConvModule(
self.ema_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.concat_input:
self.conv_cat = ConvModule(
self.in_channels + self.channels,
self.channels,
kernel_size=3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
feats = self.ema_in_conv(x)
identity = feats
feats = self.ema_mid_conv(feats)
recon = self.ema_module(feats)
recon = F.relu(recon, inplace=True)
recon = self.ema_out_conv(recon)
output = F.relu(identity + recon, inplace=True)
output = self.bottleneck(output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,196 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_norm_layer
from torch import Tensor
from mmseg.registry import MODELS
from mmseg.utils import ConfigType, SampleList
from ..utils import Encoding, resize
from .decode_head import BaseDecodeHead
class EncModule(nn.Module):
"""Encoding Module used in EncNet.
Args:
in_channels (int): Input channels.
num_codes (int): Number of code words.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
"""
def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):
super().__init__()
self.encoding_project = ConvModule(
in_channels,
in_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
# TODO: resolve this hack
# change to 1d
if norm_cfg is not None:
encoding_norm_cfg = norm_cfg.copy()
if encoding_norm_cfg['type'] in ['BN', 'IN']:
encoding_norm_cfg['type'] += '1d'
else:
encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(
'2d', '1d')
else:
# fallback to BN1d
encoding_norm_cfg = dict(type='BN1d')
self.encoding = nn.Sequential(
Encoding(channels=in_channels, num_codes=num_codes),
build_norm_layer(encoding_norm_cfg, num_codes)[1],
nn.ReLU(inplace=True))
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels), nn.Sigmoid())
def forward(self, x):
"""Forward function."""
encoding_projection = self.encoding_project(x)
encoding_feat = self.encoding(encoding_projection).mean(dim=1)
batch_size, channels, _, _ = x.size()
gamma = self.fc(encoding_feat)
y = gamma.view(batch_size, channels, 1, 1)
output = F.relu_(x + x * y)
return encoding_feat, output
@MODELS.register_module()
class EncHead(BaseDecodeHead):
"""Context Encoding for Semantic Segmentation.
This head is the implementation of `EncNet
<https://arxiv.org/abs/1803.08904>`_.
Args:
num_codes (int): Number of code words. Default: 32.
use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to
regularize the training. Default: True.
add_lateral (bool): Whether use lateral connection to fuse features.
Default: False.
loss_se_decode (dict): Config of decode loss.
Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
"""
def __init__(self,
num_codes=32,
use_se_loss=True,
add_lateral=False,
loss_se_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=0.2),
**kwargs):
super().__init__(input_transform='multiple_select', **kwargs)
self.use_se_loss = use_se_loss
self.add_lateral = add_lateral
self.num_codes = num_codes
self.bottleneck = ConvModule(
self.in_channels[-1],
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if add_lateral:
self.lateral_convs = nn.ModuleList()
for in_channels in self.in_channels[:-1]: # skip the last one
self.lateral_convs.append(
ConvModule(
in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.fusion = ConvModule(
len(self.in_channels) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.enc_module = EncModule(
self.channels,
num_codes=num_codes,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.use_se_loss:
self.loss_se_decode = MODELS.build(loss_se_decode)
self.se_layer = nn.Linear(self.channels, self.num_classes)
def forward(self, inputs):
"""Forward function."""
inputs = self._transform_inputs(inputs)
feat = self.bottleneck(inputs[-1])
if self.add_lateral:
laterals = [
resize(
lateral_conv(inputs[i]),
size=feat.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
for i, lateral_conv in enumerate(self.lateral_convs)
]
feat = self.fusion(torch.cat([feat, *laterals], 1))
encode_feat, output = self.enc_module(feat)
output = self.cls_seg(output)
if self.use_se_loss:
se_output = self.se_layer(encode_feat)
return output, se_output
else:
return output
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
test_cfg: ConfigType):
"""Forward function for testing, ignore se_loss."""
if self.use_se_loss:
seg_logits = self.forward(inputs)[0]
else:
seg_logits = self.forward(inputs)
return self.predict_by_feat(seg_logits, batch_img_metas)
@staticmethod
def _convert_to_onehot_labels(seg_label, num_classes):
"""Convert segmentation label to onehot.
Args:
seg_label (Tensor): Segmentation label of shape (N, H, W).
num_classes (int): Number of classes.
Returns:
Tensor: Onehot labels of shape (N, num_classes).
"""
batch_size = seg_label.size(0)
onehot_labels = seg_label.new_zeros((batch_size, num_classes))
for i in range(batch_size):
hist = seg_label[i].float().histc(
bins=num_classes, min=0, max=num_classes - 1)
onehot_labels[i] = hist > 0
return onehot_labels
def loss_by_feat(self, seg_logit: Tuple[Tensor],
batch_data_samples: SampleList, **kwargs) -> dict:
"""Compute segmentation and semantic encoding loss."""
seg_logit, se_seg_logit = seg_logit
loss = dict()
loss.update(super().loss_by_feat(seg_logit, batch_data_samples))
seg_label = self._stack_batch_gt(batch_data_samples)
se_loss = self.loss_se_decode(
se_seg_logit,
self._convert_to_onehot_labels(seg_label, self.num_classes))
loss['loss_se'] = se_loss
return loss

View File

@@ -0,0 +1,96 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
@MODELS.register_module()
class FCNHead(BaseDecodeHead):
"""Fully Convolution Networks for Semantic Segmentation.
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
Args:
num_convs (int): Number of convs in the head. Default: 2.
kernel_size (int): The kernel size for convs in the head. Default: 3.
concat_input (bool): Whether concat the input and output of convs
before classification layer.
dilation (int): The dilation rate for convs in the head. Default: 1.
"""
def __init__(self,
num_convs=2,
kernel_size=3,
concat_input=True,
dilation=1,
**kwargs):
assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int)
self.num_convs = num_convs
self.concat_input = concat_input
self.kernel_size = kernel_size
super().__init__(**kwargs)
if num_convs == 0:
assert self.in_channels == self.channels
conv_padding = (kernel_size // 2) * dilation
convs = []
convs.append(
ConvModule(
self.in_channels,
self.channels,
kernel_size=kernel_size,
padding=conv_padding,
dilation=dilation,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
for i in range(num_convs - 1):
convs.append(
ConvModule(
self.channels,
self.channels,
kernel_size=kernel_size,
padding=conv_padding,
dilation=dilation,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
if num_convs == 0:
self.convs = nn.Identity()
else:
self.convs = nn.Sequential(*convs)
if self.concat_input:
self.conv_cat = ConvModule(
self.in_channels + self.channels,
self.channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def _forward_feature(self, inputs):
"""Forward function for feature maps before classifying each pixel with
``self.cls_seg`` fc.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
feats (Tensor): A tensor of shape (batch_size, self.channels,
H, W) which is feature map for last layer of decoder head.
"""
x = self._transform_inputs(inputs)
feats = self.convs(x)
if self.concat_input:
feats = self.conv_cat(torch.cat([x, feats], dim=1))
return feats
def forward(self, inputs):
"""Forward function."""
output = self._forward_feature(inputs)
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,68 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.registry import MODELS
from ..utils import Upsample, resize
from .decode_head import BaseDecodeHead
@MODELS.register_module()
class FPNHead(BaseDecodeHead):
"""Panoptic Feature Pyramid Networks.
This head is the implementation of `Semantic FPN
<https://arxiv.org/abs/1901.02446>`_.
Args:
feature_strides (tuple[int]): The strides for input feature maps.
stack_lateral. All strides suppose to be power of 2. The first
one is of largest resolution.
"""
def __init__(self, feature_strides, **kwargs):
super().__init__(input_transform='multiple_select', **kwargs)
assert len(feature_strides) == len(self.in_channels)
assert min(feature_strides) == feature_strides[0]
self.feature_strides = feature_strides
self.scale_heads = nn.ModuleList()
for i in range(len(feature_strides)):
head_length = max(
1,
int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
scale_head = []
for k in range(head_length):
scale_head.append(
ConvModule(
self.in_channels[i] if k == 0 else self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
if feature_strides[i] != feature_strides[0]:
scale_head.append(
Upsample(
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners))
self.scale_heads.append(nn.Sequential(*scale_head))
def forward(self, inputs):
x = self._transform_inputs(inputs)
output = self.scale_heads[0](x[0])
for i in range(1, len(self.feature_strides)):
# non inplace
output = output + resize(
self.scale_heads[i](x[i]),
size=output.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.cnn import ContextBlock
from mmseg.registry import MODELS
from .fcn_head import FCNHead
@MODELS.register_module()
class GCHead(FCNHead):
"""GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.
This head is the implementation of `GCNet
<https://arxiv.org/abs/1904.11492>`_.
Args:
ratio (float): Multiplier of channels ratio. Default: 1/4.
pooling_type (str): The pooling type of context aggregation.
Options are 'att', 'avg'. Default: 'avg'.
fusion_types (tuple[str]): The fusion type for feature fusion.
Options are 'channel_add', 'channel_mul'. Default: ('channel_add',)
"""
def __init__(self,
ratio=1 / 4.,
pooling_type='att',
fusion_types=('channel_add', ),
**kwargs):
super().__init__(num_convs=2, **kwargs)
self.ratio = ratio
self.pooling_type = pooling_type
self.fusion_types = fusion_types
self.gc_block = ContextBlock(
in_channels=self.channels,
ratio=self.ratio,
pooling_type=self.pooling_type,
fusion_types=self.fusion_types)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs[0](x)
output = self.gc_block(output)
output = self.convs[1](output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,255 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Originally from https://github.com/visual-attention-network/segnext
# Licensed under the Apache License, Version 2.0 (the "License")
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.device import get_device
from mmseg.registry import MODELS
from ..utils import resize
from .decode_head import BaseDecodeHead
class Matrix_Decomposition_2D_Base(nn.Module):
"""Base class of 2D Matrix Decomposition.
Args:
MD_S (int): The number of spatial coefficient in
Matrix Decomposition, it may be used for calculation
of the number of latent dimension D in Matrix
Decomposition. Defaults: 1.
MD_R (int): The number of latent dimension R in
Matrix Decomposition. Defaults: 64.
train_steps (int): The number of iteration steps in
Multiplicative Update (MU) rule to solve Non-negative
Matrix Factorization (NMF) in training. Defaults: 6.
eval_steps (int): The number of iteration steps in
Multiplicative Update (MU) rule to solve Non-negative
Matrix Factorization (NMF) in evaluation. Defaults: 7.
inv_t (int): Inverted multiple number to make coefficient
smaller in softmax. Defaults: 100.
rand_init (bool): Whether to initialize randomly.
Defaults: True.
"""
def __init__(self,
MD_S=1,
MD_R=64,
train_steps=6,
eval_steps=7,
inv_t=100,
rand_init=True):
super().__init__()
self.S = MD_S
self.R = MD_R
self.train_steps = train_steps
self.eval_steps = eval_steps
self.inv_t = inv_t
self.rand_init = rand_init
def _build_bases(self, B, S, D, R, device=None):
raise NotImplementedError
def local_step(self, x, bases, coef):
raise NotImplementedError
def local_inference(self, x, bases):
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
coef = torch.bmm(x.transpose(1, 2), bases)
coef = F.softmax(self.inv_t * coef, dim=-1)
steps = self.train_steps if self.training else self.eval_steps
for _ in range(steps):
bases, coef = self.local_step(x, bases, coef)
return bases, coef
def compute_coef(self, x, bases, coef):
raise NotImplementedError
def forward(self, x, return_bases=False):
"""Forward Function."""
B, C, H, W = x.shape
# (B, C, H, W) -> (B * S, D, N)
D = C // self.S
N = H * W
x = x.view(B * self.S, D, N)
if not self.rand_init and not hasattr(self, 'bases'):
bases = self._build_bases(1, self.S, D, self.R, device=x.device)
self.register_buffer('bases', bases)
# (S, D, R) -> (B * S, D, R)
if self.rand_init:
bases = self._build_bases(B, self.S, D, self.R, device=x.device)
else:
bases = self.bases.repeat(B, 1, 1)
bases, coef = self.local_inference(x, bases)
# (B * S, N, R)
coef = self.compute_coef(x, bases, coef)
# (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
x = torch.bmm(bases, coef.transpose(1, 2))
# (B * S, D, N) -> (B, C, H, W)
x = x.view(B, C, H, W)
return x
class NMF2D(Matrix_Decomposition_2D_Base):
"""Non-negative Matrix Factorization (NMF) module.
It is inherited from ``Matrix_Decomposition_2D_Base`` module.
"""
def __init__(self, args=dict()):
super().__init__(**args)
self.inv_t = 1
def _build_bases(self, B, S, D, R, device=None):
"""Build bases in initialization."""
if device is None:
device = get_device()
bases = torch.rand((B * S, D, R)).to(device)
bases = F.normalize(bases, dim=1)
return bases
def local_step(self, x, bases, coef):
"""Local step in iteration to renew bases and coefficient."""
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
numerator = torch.bmm(x.transpose(1, 2), bases)
# (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
# Multiplicative Update
coef = coef * numerator / (denominator + 1e-6)
# (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
numerator = torch.bmm(x, coef)
# (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
# Multiplicative Update
bases = bases * numerator / (denominator + 1e-6)
return bases, coef
def compute_coef(self, x, bases, coef):
"""Compute coefficient."""
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
numerator = torch.bmm(x.transpose(1, 2), bases)
# (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
# multiplication update
coef = coef * numerator / (denominator + 1e-6)
return coef
class Hamburger(nn.Module):
"""Hamburger Module. It consists of one slice of "ham" (matrix
decomposition) and two slices of "bread" (linear transformation).
Args:
ham_channels (int): Input and output channels of feature.
ham_kwargs (dict): Config of matrix decomposition module.
norm_cfg (dict | None): Config of norm layers.
"""
def __init__(self,
ham_channels=512,
ham_kwargs=dict(),
norm_cfg=None,
**kwargs):
super().__init__()
self.ham_in = ConvModule(
ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None)
self.ham = NMF2D(ham_kwargs)
self.ham_out = ConvModule(
ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
def forward(self, x):
enjoy = self.ham_in(x)
enjoy = F.relu(enjoy, inplace=True)
enjoy = self.ham(enjoy)
enjoy = self.ham_out(enjoy)
ham = F.relu(x + enjoy, inplace=True)
return ham
@MODELS.register_module()
class LightHamHead(BaseDecodeHead):
"""SegNeXt decode head.
This decode head is the implementation of `SegNeXt: Rethinking
Convolutional Attention Design for Semantic
Segmentation <https://arxiv.org/abs/2209.08575>`_.
Inspiration from https://github.com/visual-attention-network/segnext.
Specifically, LightHamHead is inspired by HamNet from
`Is Attention Better Than Matrix Decomposition?
<https://arxiv.org/abs/2109.04553>`.
Args:
ham_channels (int): input channels for Hamburger.
Defaults: 512.
ham_kwargs (int): kwagrs for Ham. Defaults: dict().
"""
def __init__(self, ham_channels=512, ham_kwargs=dict(), **kwargs):
super().__init__(input_transform='multiple_select', **kwargs)
self.ham_channels = ham_channels
self.squeeze = ConvModule(
sum(self.in_channels),
self.ham_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.hamburger = Hamburger(ham_channels, ham_kwargs, **kwargs)
self.align = ConvModule(
self.ham_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
inputs = self._transform_inputs(inputs)
inputs = [
resize(
level,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for level in inputs
]
inputs = torch.cat(inputs, dim=1)
# apply a conv block to squeeze feature map
x = self.squeeze(inputs)
# apply hamburger module
x = self.hamburger(x)
# apply a conv block to align feature map
output = self.align(x)
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,143 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.registry import MODELS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .decode_head import BaseDecodeHead
class SelfAttentionBlock(_SelfAttentionBlock):
"""Self-Attention Module.
Args:
in_channels (int): Input channels of key/query feature.
channels (int): Output channels of key/query transform.
conv_cfg (dict | None): Config of conv layers.
norm_cfg (dict | None): Config of norm layers.
act_cfg (dict | None): Config of activation layers.
"""
def __init__(self, in_channels, channels, conv_cfg, norm_cfg, act_cfg):
super().__init__(
key_in_channels=in_channels,
query_in_channels=in_channels,
channels=channels,
out_channels=in_channels,
share_key_query=False,
query_downsample=None,
key_downsample=None,
key_query_num_convs=2,
key_query_norm=True,
value_out_num_convs=1,
value_out_norm=False,
matmul_norm=True,
with_out=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.output_project = self.build_project(
in_channels,
in_channels,
num_convs=1,
use_conv_module=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x):
"""Forward function."""
context = super().forward(x, x)
return self.output_project(context)
@MODELS.register_module()
class ISAHead(BaseDecodeHead):
"""Interlaced Sparse Self-Attention for Semantic Segmentation.
This head is the implementation of `ISA
<https://arxiv.org/abs/1907.12273>`_.
Args:
isa_channels (int): The channels of ISA Module.
down_factor (tuple[int]): The local group size of ISA.
"""
def __init__(self, isa_channels, down_factor=(8, 8), **kwargs):
super().__init__(**kwargs)
self.down_factor = down_factor
self.in_conv = ConvModule(
self.in_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.global_relation = SelfAttentionBlock(
self.channels,
isa_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.local_relation = SelfAttentionBlock(
self.channels,
isa_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.out_conv = ConvModule(
self.channels * 2,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x_ = self._transform_inputs(inputs)
x = self.in_conv(x_)
residual = x
n, c, h, w = x.size()
loc_h, loc_w = self.down_factor # size of local group in H- and W-axes
glb_h, glb_w = math.ceil(h / loc_h), math.ceil(w / loc_w)
pad_h, pad_w = glb_h * loc_h - h, glb_w * loc_w - w
if pad_h > 0 or pad_w > 0: # pad if the size is not divisible
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
pad_h - pad_h // 2)
x = F.pad(x, padding)
# global relation
x = x.view(n, c, glb_h, loc_h, glb_w, loc_w)
# do permutation to gather global group
x = x.permute(0, 3, 5, 1, 2, 4) # (n, loc_h, loc_w, c, glb_h, glb_w)
x = x.reshape(-1, c, glb_h, glb_w)
# apply attention within each global group
x = self.global_relation(x) # (n * loc_h * loc_w, c, glb_h, glb_w)
# local relation
x = x.view(n, loc_h, loc_w, c, glb_h, glb_w)
# do permutation to gather local group
x = x.permute(0, 4, 5, 3, 1, 2) # (n, glb_h, glb_w, c, loc_h, loc_w)
x = x.reshape(-1, c, loc_h, loc_w)
# apply attention within each local group
x = self.local_relation(x) # (n * glb_h * glb_w, c, loc_h, loc_w)
# permute each pixel back to its original position
x = x.view(n, glb_h, glb_w, c, loc_h, loc_w)
x = x.permute(0, 3, 1, 4, 2, 5) # (n, c, glb_h, loc_h, glb_w, loc_w)
x = x.reshape(n, c, glb_h * loc_h, glb_w * loc_w)
if pad_h > 0 or pad_w > 0: # remove padding
x = x[:, :, pad_h // 2:pad_h // 2 + h, pad_w // 2:pad_w // 2 + w]
x = self.out_conv(torch.cat([x, residual], dim=1))
out = self.cls_seg(x)
return out

View File

@@ -0,0 +1,461 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from mmcv.cnn.bricks.transformer import (FFN, MultiheadAttention,
build_transformer_layer)
from mmengine.logging import print_log
from torch import Tensor
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.registry import MODELS
from mmseg.utils import SampleList
@MODELS.register_module()
class KernelUpdator(nn.Module):
"""Dynamic Kernel Updator in Kernel Update Head.
Args:
in_channels (int): The number of channels of input feature map.
Default: 256.
feat_channels (int): The number of middle-stage channels in
the kernel updator. Default: 64.
out_channels (int): The number of output channels.
gate_sigmoid (bool): Whether use sigmoid function in gate
mechanism. Default: True.
gate_norm_act (bool): Whether add normalization and activation
layer in gate mechanism. Default: False.
activate_out: Whether add activation after gate mechanism.
Default: False.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='LN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
"""
def __init__(
self,
in_channels=256,
feat_channels=64,
out_channels=None,
gate_sigmoid=True,
gate_norm_act=False,
activate_out=False,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='ReLU', inplace=True),
):
super().__init__()
self.in_channels = in_channels
self.feat_channels = feat_channels
self.out_channels_raw = out_channels
self.gate_sigmoid = gate_sigmoid
self.gate_norm_act = gate_norm_act
self.activate_out = activate_out
self.act_cfg = act_cfg
self.norm_cfg = norm_cfg
self.out_channels = out_channels if out_channels else in_channels
self.num_params_in = self.feat_channels
self.num_params_out = self.feat_channels
self.dynamic_layer = nn.Linear(
self.in_channels, self.num_params_in + self.num_params_out)
self.input_layer = nn.Linear(self.in_channels,
self.num_params_in + self.num_params_out,
1)
self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
if self.gate_norm_act:
self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1]
self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
self.activation = build_activation_layer(act_cfg)
self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1)
self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
def forward(self, update_feature, input_feature):
"""Forward function of KernelUpdator.
Args:
update_feature (torch.Tensor): Feature map assembled from
each group. It would be reshaped with last dimension
shape: `self.in_channels`.
input_feature (torch.Tensor): Intermediate feature
with shape: (N, num_classes, conv_kernel_size**2, channels).
Returns:
Tensor: The output tensor of shape (N*C1/C2, K*K, C2), where N is
the number of classes, C1 and C2 are the feature map channels of
KernelUpdateHead and KernelUpdator, respectively.
"""
update_feature = update_feature.reshape(-1, self.in_channels)
num_proposals = update_feature.size(0)
# dynamic_layer works for
# phi_1 and psi_3 in Eq.(4) and (5) of K-Net paper
parameters = self.dynamic_layer(update_feature)
param_in = parameters[:, :self.num_params_in].view(
-1, self.feat_channels)
param_out = parameters[:, -self.num_params_out:].view(
-1, self.feat_channels)
# input_layer works for
# phi_2 and psi_4 in Eq.(4) and (5) of K-Net paper
input_feats = self.input_layer(
input_feature.reshape(num_proposals, -1, self.feat_channels))
input_in = input_feats[..., :self.num_params_in]
input_out = input_feats[..., -self.num_params_out:]
# `gate_feats` is F^G in K-Net paper
gate_feats = input_in * param_in.unsqueeze(-2)
if self.gate_norm_act:
gate_feats = self.activation(self.gate_norm(gate_feats))
input_gate = self.input_norm_in(self.input_gate(gate_feats))
update_gate = self.norm_in(self.update_gate(gate_feats))
if self.gate_sigmoid:
input_gate = input_gate.sigmoid()
update_gate = update_gate.sigmoid()
param_out = self.norm_out(param_out)
input_out = self.input_norm_out(input_out)
if self.activate_out:
param_out = self.activation(param_out)
input_out = self.activation(input_out)
# Gate mechanism. Eq.(5) in original paper.
# param_out has shape (batch_size, feat_channels, out_channels)
features = update_gate * param_out.unsqueeze(
-2) + input_gate * input_out
features = self.fc_layer(features)
features = self.fc_norm(features)
features = self.activation(features)
return features
@MODELS.register_module()
class KernelUpdateHead(nn.Module):
"""Kernel Update Head in K-Net.
Args:
num_classes (int): Number of classes. Default: 150.
num_ffn_fcs (int): The number of fully-connected layers in
FFNs. Default: 2.
num_heads (int): The number of parallel attention heads.
Default: 8.
num_mask_fcs (int): The number of fully connected layers for
mask prediction. Default: 3.
feedforward_channels (int): The hidden dimension of FFNs.
Defaults: 2048.
in_channels (int): The number of channels of input feature map.
Default: 256.
out_channels (int): The number of output channels.
Default: 256.
dropout (float): The Probability of an element to be
zeroed in MultiheadAttention and FFN. Default 0.0.
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
ffn_act_cfg (dict): Config of activation layers in FFN.
Default: dict(type='ReLU').
conv_kernel_size (int): The kernel size of convolution in
Kernel Update Head for dynamic kernel updation.
Default: 1.
feat_transform_cfg (dict | None): Config of feature transform.
Default: None.
kernel_init (bool): Whether initiate mask kernel in mask head.
Default: False.
with_ffn (bool): Whether add FFN in kernel update head.
Default: True.
feat_gather_stride (int): Stride of convolution in feature transform.
Default: 1.
mask_transform_stride (int): Stride of mask transform.
Default: 1.
kernel_updator_cfg (dict): Config of kernel updator.
Default: dict(
type='DynamicConv',
in_channels=256,
feat_channels=64,
out_channels=256,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN')).
"""
def __init__(self,
num_classes=150,
num_ffn_fcs=2,
num_heads=8,
num_mask_fcs=3,
feedforward_channels=2048,
in_channels=256,
out_channels=256,
dropout=0.0,
act_cfg=dict(type='ReLU', inplace=True),
ffn_act_cfg=dict(type='ReLU', inplace=True),
conv_kernel_size=1,
feat_transform_cfg=None,
kernel_init=False,
with_ffn=True,
feat_gather_stride=1,
mask_transform_stride=1,
kernel_updator_cfg=dict(
type='DynamicConv',
in_channels=256,
feat_channels=64,
out_channels=256,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'))):
super().__init__()
self.num_classes = num_classes
self.in_channels = in_channels
self.out_channels = out_channels
self.fp16_enabled = False
self.dropout = dropout
self.num_heads = num_heads
self.kernel_init = kernel_init
self.with_ffn = with_ffn
self.conv_kernel_size = conv_kernel_size
self.feat_gather_stride = feat_gather_stride
self.mask_transform_stride = mask_transform_stride
self.attention = MultiheadAttention(in_channels * conv_kernel_size**2,
num_heads, dropout)
self.attention_norm = build_norm_layer(
dict(type='LN'), in_channels * conv_kernel_size**2)[1]
self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg)
if feat_transform_cfg is not None:
kernel_size = feat_transform_cfg.pop('kernel_size', 1)
transform_channels = in_channels
self.feat_transform = ConvModule(
transform_channels,
in_channels,
kernel_size,
stride=feat_gather_stride,
padding=int(feat_gather_stride // 2),
**feat_transform_cfg)
else:
self.feat_transform = None
if self.with_ffn:
self.ffn = FFN(
in_channels,
feedforward_channels,
num_ffn_fcs,
act_cfg=ffn_act_cfg,
dropout=dropout)
self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
self.mask_fcs = nn.ModuleList()
for _ in range(num_mask_fcs):
self.mask_fcs.append(
nn.Linear(in_channels, in_channels, bias=False))
self.mask_fcs.append(
build_norm_layer(dict(type='LN'), in_channels)[1])
self.mask_fcs.append(build_activation_layer(act_cfg))
self.fc_mask = nn.Linear(in_channels, out_channels)
def init_weights(self):
"""Use xavier initialization for all weight parameter and set
classification head bias as a specific value when use focal loss."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
else:
# adopt the default initialization for
# the weight and bias of the layer norm
pass
if self.kernel_init:
print_log(
'mask kernel in mask head is normal initialized by std 0.01')
nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01)
def forward(self, x, proposal_feat, mask_preds, mask_shape=None):
"""Forward function of Dynamic Instance Interactive Head.
Args:
x (Tensor): Feature map from FPN with shape
(batch_size, feature_dimensions, H , W).
proposal_feat (Tensor): Intermediate feature get from
diihead in last stage, has shape
(batch_size, num_proposals, feature_dimensions)
mask_preds (Tensor): mask prediction from the former stage in shape
(batch_size, num_proposals, H, W).
Returns:
Tuple: The first tensor is predicted mask with shape
(N, num_classes, H, W), the second tensor is dynamic kernel
with shape (N, num_classes, channels, K, K).
"""
N, num_proposals = proposal_feat.shape[:2]
if self.feat_transform is not None:
x = self.feat_transform(x)
C, H, W = x.shape[-3:]
mask_h, mask_w = mask_preds.shape[-2:]
if mask_h != H or mask_w != W:
gather_mask = F.interpolate(
mask_preds, (H, W), align_corners=False, mode='bilinear')
else:
gather_mask = mask_preds
sigmoid_masks = gather_mask.softmax(dim=1)
# Group Feature Assembling. Eq.(3) in original paper.
# einsum is faster than bmm by 30%
x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x)
# obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C]
proposal_feat = proposal_feat.reshape(N, num_proposals,
self.in_channels,
-1).permute(0, 1, 3, 2)
obj_feat = self.kernel_update_conv(x_feat, proposal_feat)
# [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C]
obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2)
obj_feat = self.attention_norm(self.attention(obj_feat))
# [N, B, K*K*C] -> [B, N, K*K*C]
obj_feat = obj_feat.permute(1, 0, 2)
# obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]
obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels)
# FFN
if self.with_ffn:
obj_feat = self.ffn_norm(self.ffn(obj_feat))
mask_feat = obj_feat
for reg_layer in self.mask_fcs:
mask_feat = reg_layer(mask_feat)
# [B, N, K*K, C] -> [B, N, C, K*K]
mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2)
if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1):
mask_x = F.interpolate(
x, scale_factor=0.5, mode='bilinear', align_corners=False)
H, W = mask_x.shape[-2:]
else:
mask_x = x
# group conv is 5x faster than unfold and uses about 1/5 memory
# Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms
# Group conv vs. unfold vs. concat batch, 278 : 1420 : 369
# but in real training group conv is slower than concat batch
# so we keep using concat batch.
# fold_x = F.unfold(
# mask_x,
# self.conv_kernel_size,
# padding=int(self.conv_kernel_size // 2))
# mask_feat = mask_feat.reshape(N, num_proposals, -1)
# new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x)
# [B, N, C, K*K] -> [B*N, C, K, K]
mask_feat = mask_feat.reshape(N, num_proposals, C,
self.conv_kernel_size,
self.conv_kernel_size)
# [B, C, H, W] -> [1, B*C, H, W]
new_mask_preds = []
for i in range(N):
new_mask_preds.append(
F.conv2d(
mask_x[i:i + 1],
mask_feat[i],
padding=int(self.conv_kernel_size // 2)))
new_mask_preds = torch.cat(new_mask_preds, dim=0)
new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W)
if self.mask_transform_stride == 2:
new_mask_preds = F.interpolate(
new_mask_preds,
scale_factor=2,
mode='bilinear',
align_corners=False)
if mask_shape is not None and mask_shape[0] != H:
new_mask_preds = F.interpolate(
new_mask_preds,
mask_shape,
align_corners=False,
mode='bilinear')
return new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape(
N, num_proposals, self.in_channels, self.conv_kernel_size,
self.conv_kernel_size)
@MODELS.register_module()
class IterativeDecodeHead(BaseDecodeHead):
"""K-Net: Towards Unified Image Segmentation.
This head is the implementation of
`K-Net: <https://arxiv.org/abs/2106.14855>`_.
Args:
num_stages (int): The number of stages (kernel update heads)
in IterativeDecodeHead. Default: 3.
kernel_generate_head:(dict): Config of kernel generate head which
generate mask predictions, dynamic kernels and class predictions
for next kernel update heads.
kernel_update_head (dict): Config of kernel update head which refine
dynamic kernels and class predictions iteratively.
"""
def __init__(self, num_stages, kernel_generate_head, kernel_update_head,
**kwargs):
# ``IterativeDecodeHead`` would skip initialization of
# ``BaseDecodeHead`` which would be called when building
# ``self.kernel_generate_head``.
super(BaseDecodeHead, self).__init__(**kwargs)
assert num_stages == len(kernel_update_head)
self.num_stages = num_stages
self.kernel_generate_head = MODELS.build(kernel_generate_head)
self.kernel_update_head = nn.ModuleList()
self.align_corners = self.kernel_generate_head.align_corners
self.num_classes = self.kernel_generate_head.num_classes
self.input_transform = self.kernel_generate_head.input_transform
self.ignore_index = self.kernel_generate_head.ignore_index
self.out_channels = self.num_classes
for head_cfg in kernel_update_head:
self.kernel_update_head.append(MODELS.build(head_cfg))
def forward(self, inputs):
"""Forward function."""
feats = self.kernel_generate_head._forward_feature(inputs)
sem_seg = self.kernel_generate_head.cls_seg(feats)
seg_kernels = self.kernel_generate_head.conv_seg.weight.clone()
seg_kernels = seg_kernels[None].expand(
feats.size(0), *seg_kernels.size())
stage_segs = [sem_seg]
for i in range(self.num_stages):
sem_seg, seg_kernels = self.kernel_update_head[i](feats,
seg_kernels,
sem_seg)
stage_segs.append(sem_seg)
if self.training:
return stage_segs
# only return the prediction of the last stage during testing
return stage_segs[-1]
def loss_by_feat(self, seg_logits: List[Tensor],
batch_data_samples: SampleList, **kwargs) -> dict:
losses = dict()
for i, logit in enumerate(seg_logits):
loss = self.kernel_generate_head.loss_by_feat(
logit, batch_data_samples)
for k, v in loss.items():
losses[f'{k}.s{i}'] = v
return losses

View File

@@ -0,0 +1,91 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.utils import is_tuple_of
from mmseg.registry import MODELS
from ..utils import resize
from .decode_head import BaseDecodeHead
@MODELS.register_module()
class LRASPPHead(BaseDecodeHead):
"""Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.
This head is the improved implementation of `Searching for MobileNetV3
<https://ieeexplore.ieee.org/document/9008835>`_.
Args:
branch_channels (tuple[int]): The number of output channels in every
each branch. Default: (32, 64).
"""
def __init__(self, branch_channels=(32, 64), **kwargs):
super().__init__(**kwargs)
if self.input_transform != 'multiple_select':
raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
f'must be \'multiple_select\'. But received '
f'\'{self.input_transform}\'')
assert is_tuple_of(branch_channels, int)
assert len(branch_channels) == len(self.in_channels) - 1
self.branch_channels = branch_channels
self.convs = nn.Sequential()
self.conv_ups = nn.Sequential()
for i in range(len(branch_channels)):
self.convs.add_module(
f'conv{i}',
nn.Conv2d(
self.in_channels[i], branch_channels[i], 1, bias=False))
self.conv_ups.add_module(
f'conv_up{i}',
ConvModule(
self.channels + branch_channels[i],
self.channels,
1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
bias=False))
self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)
self.aspp_conv = ConvModule(
self.in_channels[-1],
self.channels,
1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
bias=False)
self.image_pool = nn.Sequential(
nn.AvgPool2d(kernel_size=49, stride=(16, 20)),
ConvModule(
self.in_channels[2],
self.channels,
1,
act_cfg=dict(type='Sigmoid'),
bias=False))
def forward(self, inputs):
"""Forward function."""
inputs = self._transform_inputs(inputs)
x = inputs[-1]
x = self.aspp_conv(x) * resize(
self.image_pool(x),
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
x = self.conv_up_input(x)
for i in range(len(self.branch_channels) - 1, -1, -1):
x = resize(
x,
size=inputs[i].size()[2:],
mode='bilinear',
align_corners=self.align_corners)
x = torch.cat([x, self.convs[i](inputs[i])], 1)
x = self.conv_ups[i](x)
return self.cls_seg(x)

View File

@@ -0,0 +1,163 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
try:
from mmdet.models.dense_heads import \
Mask2FormerHead as MMDET_Mask2FormerHead
except ModuleNotFoundError:
MMDET_Mask2FormerHead = BaseModule
from mmengine.structures import InstanceData
from torch import Tensor
from mmseg.registry import MODELS
from mmseg.structures.seg_data_sample import SegDataSample
from mmseg.utils import ConfigType, SampleList
@MODELS.register_module()
class Mask2FormerHead(MMDET_Mask2FormerHead):
"""Implements the Mask2Former head.
See `Mask2Former: Masked-attention Mask Transformer for Universal Image
Segmentation <https://arxiv.org/abs/2112.01527>`_ for details.
Args:
num_classes (int): Number of classes. Default: 150.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
ignore_index (int): The label index to be ignored. Default: 255.
"""
def __init__(self,
num_classes,
align_corners=False,
ignore_index=255,
**kwargs):
super().__init__(**kwargs)
self.num_classes = num_classes
self.align_corners = align_corners
self.out_channels = num_classes
self.ignore_index = ignore_index
feat_channels = kwargs['feat_channels']
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
def _seg_data_to_instance_data(self, batch_data_samples: SampleList):
"""Perform forward propagation to convert paradigm from MMSegmentation
to MMDetection to ensure ``MMDET_Mask2FormerHead`` could be called
normally. Specifically, ``batch_gt_instances`` would be added.
Args:
batch_data_samples (List[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as
`gt_sem_seg`.
Returns:
tuple[Tensor]: A tuple contains two lists.
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``labels``, each is
unique ground truth label id of images, with
shape (num_gt, ) and ``masks``, each is ground truth
masks of each instances of a image, shape (num_gt, h, w).
- batch_img_metas (list[dict]): List of image meta information.
"""
batch_img_metas = []
batch_gt_instances = []
for data_sample in batch_data_samples:
batch_img_metas.append(data_sample.metainfo)
gt_sem_seg = data_sample.gt_sem_seg.data
classes = torch.unique(
gt_sem_seg,
sorted=False,
return_inverse=False,
return_counts=False)
# remove ignored region
gt_labels = classes[classes != self.ignore_index]
masks = []
for class_id in gt_labels:
masks.append(gt_sem_seg == class_id)
if len(masks) == 0:
gt_masks = torch.zeros(
(0, gt_sem_seg.shape[-2],
gt_sem_seg.shape[-1])).to(gt_sem_seg).long()
else:
gt_masks = torch.stack(masks).squeeze(1).long()
instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
batch_gt_instances.append(instance_data)
return batch_gt_instances, batch_img_metas
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
train_cfg: ConfigType) -> dict:
"""Perform forward propagation and loss calculation of the decoder head
on the features of the upstream network.
Args:
x (tuple[Tensor]): Multi-level features from the upstream
network, each is a 4D-tensor.
batch_data_samples (List[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as
`gt_sem_seg`.
train_cfg (ConfigType): Training config.
Returns:
dict[str, Tensor]: a dictionary of loss components.
"""
# batch SegDataSample to InstanceDataSample
batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data(
batch_data_samples)
# forward
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
# loss
losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
batch_gt_instances, batch_img_metas)
return losses
def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict],
test_cfg: ConfigType) -> Tuple[Tensor]:
"""Test without augmentaton.
Args:
x (tuple[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
batch_img_metas (List[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as
`gt_sem_seg`.
test_cfg (ConfigType): Test config.
Returns:
Tensor: A tensor of segmentation mask.
"""
batch_data_samples = [
SegDataSample(metainfo=metainfo) for metainfo in batch_img_metas
]
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
mask_cls_results = all_cls_scores[-1]
mask_pred_results = all_mask_preds[-1]
if 'pad_shape' in batch_img_metas[0]:
size = batch_img_metas[0]['pad_shape']
else:
size = batch_img_metas[0]['img_shape']
# upsample mask
mask_pred_results = F.interpolate(
mask_pred_results, size=size, mode='bilinear', align_corners=False)
cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1]
mask_pred = mask_pred_results.sigmoid()
seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred)
return seg_logits

View File

@@ -0,0 +1,174 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
try:
from mmdet.models.dense_heads import MaskFormerHead as MMDET_MaskFormerHead
except ModuleNotFoundError:
MMDET_MaskFormerHead = BaseModule
from mmengine.structures import InstanceData
from torch import Tensor
from mmseg.registry import MODELS
from mmseg.structures.seg_data_sample import SegDataSample
from mmseg.utils import ConfigType, SampleList
@MODELS.register_module()
class MaskFormerHead(MMDET_MaskFormerHead):
"""Implements the MaskFormer head.
See `Per-Pixel Classification is Not All You Need for Semantic Segmentation
<https://arxiv.org/pdf/2107.06278>`_ for details.
Args:
num_classes (int): Number of classes. Default: 150.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
ignore_index (int): The label index to be ignored. Default: 255.
"""
def __init__(self,
num_classes: int = 150,
align_corners: bool = False,
ignore_index: int = 255,
**kwargs) -> None:
super().__init__(**kwargs)
self.out_channels = kwargs['out_channels']
self.align_corners = True
self.num_classes = num_classes
self.align_corners = align_corners
self.out_channels = num_classes
self.ignore_index = ignore_index
feat_channels = kwargs['feat_channels']
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
def _seg_data_to_instance_data(self, batch_data_samples: SampleList):
"""Perform forward propagation to convert paradigm from MMSegmentation
to MMDetection to ensure ``MMDET_MaskFormerHead`` could be called
normally. Specifically, ``batch_gt_instances`` would be added.
Args:
batch_data_samples (List[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as
`gt_sem_seg`.
Returns:
tuple[Tensor]: A tuple contains two lists.
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``labels``, each is
unique ground truth label id of images, with
shape (num_gt, ) and ``masks``, each is ground truth
masks of each instances of a image, shape (num_gt, h, w).
- batch_img_metas (list[dict]): List of image meta information.
"""
batch_img_metas = []
batch_gt_instances = []
for data_sample in batch_data_samples:
# Add `batch_input_shape` in metainfo of data_sample, which would
# be used in MaskFormerHead of MMDetection.
metainfo = data_sample.metainfo
metainfo['batch_input_shape'] = metainfo['img_shape']
data_sample.set_metainfo(metainfo)
batch_img_metas.append(data_sample.metainfo)
gt_sem_seg = data_sample.gt_sem_seg.data
classes = torch.unique(
gt_sem_seg,
sorted=False,
return_inverse=False,
return_counts=False)
# remove ignored region
gt_labels = classes[classes != self.ignore_index]
masks = []
for class_id in gt_labels:
masks.append(gt_sem_seg == class_id)
if len(masks) == 0:
gt_masks = torch.zeros((0, gt_sem_seg.shape[-2],
gt_sem_seg.shape[-1])).to(gt_sem_seg)
else:
gt_masks = torch.stack(masks).squeeze(1)
instance_data = InstanceData(
labels=gt_labels, masks=gt_masks.long())
batch_gt_instances.append(instance_data)
return batch_gt_instances, batch_img_metas
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
train_cfg: ConfigType) -> dict:
"""Perform forward propagation and loss calculation of the decoder head
on the features of the upstream network.
Args:
x (tuple[Tensor]): Multi-level features from the upstream
network, each is a 4D-tensor.
batch_data_samples (List[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as
`gt_sem_seg`.
train_cfg (ConfigType): Training config.
Returns:
dict[str, Tensor]: a dictionary of loss components.
"""
# batch SegDataSample to InstanceDataSample
batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data(
batch_data_samples)
# forward
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
# loss
losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
batch_gt_instances, batch_img_metas)
return losses
def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict],
test_cfg: ConfigType) -> Tuple[Tensor]:
"""Test without augmentaton.
Args:
x (tuple[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
batch_img_metas (List[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as
`gt_sem_seg`.
test_cfg (ConfigType): Test config.
Returns:
Tensor: A tensor of segmentation mask.
"""
batch_data_samples = []
for metainfo in batch_img_metas:
metainfo['batch_input_shape'] = metainfo['img_shape']
batch_data_samples.append(SegDataSample(metainfo=metainfo))
# Forward function of MaskFormerHead from MMDetection needs
# 'batch_data_samples' as inputs, which is image shape actually.
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
mask_cls_results = all_cls_scores[-1]
mask_pred_results = all_mask_preds[-1]
# upsample masks
img_shape = batch_img_metas[0]['batch_input_shape']
mask_pred_results = F.interpolate(
mask_pred_results,
size=img_shape,
mode='bilinear',
align_corners=False)
# semantic inference
cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1]
mask_pred = mask_pred_results.sigmoid()
seg_logits = torch.einsum('bqc,bqhw->bchw', cls_score, mask_pred)
return seg_logits

View File

@@ -0,0 +1,50 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.cnn import NonLocal2d
from mmseg.registry import MODELS
from .fcn_head import FCNHead
@MODELS.register_module()
class NLHead(FCNHead):
"""Non-local Neural Networks.
This head is the implementation of `NLNet
<https://arxiv.org/abs/1711.07971>`_.
Args:
reduction (int): Reduction factor of projection transform. Default: 2.
use_scale (bool): Whether to scale pairwise_weight by
sqrt(1/inter_channels). Default: True.
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
'dot_product'. Default: 'embedded_gaussian.'.
"""
def __init__(self,
reduction=2,
use_scale=True,
mode='embedded_gaussian',
**kwargs):
super().__init__(num_convs=2, **kwargs)
self.reduction = reduction
self.use_scale = use_scale
self.mode = mode
self.nl_block = NonLocal2d(
in_channels=self.channels,
reduction=self.reduction,
use_scale=self.use_scale,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
mode=self.mode)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs[0](x)
output = self.nl_block(output)
output = self.convs[1](output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,127 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.registry import MODELS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from ..utils import resize
from .cascade_decode_head import BaseCascadeDecodeHead
class SpatialGatherModule(nn.Module):
"""Aggregate the context features according to the initial predicted
probability distribution.
Employ the soft-weighted method to aggregate the context.
"""
def __init__(self, scale):
super().__init__()
self.scale = scale
def forward(self, feats, probs):
"""Forward function."""
batch_size, num_classes, height, width = probs.size()
channels = feats.size(1)
probs = probs.view(batch_size, num_classes, -1)
feats = feats.view(batch_size, channels, -1)
# [batch_size, height*width, num_classes]
feats = feats.permute(0, 2, 1)
# [batch_size, channels, height*width]
probs = F.softmax(self.scale * probs, dim=2)
# [batch_size, channels, num_classes]
ocr_context = torch.matmul(probs, feats)
ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3)
return ocr_context
class ObjectAttentionBlock(_SelfAttentionBlock):
"""Make a OCR used SelfAttentionBlock."""
def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg,
act_cfg):
if scale > 1:
query_downsample = nn.MaxPool2d(kernel_size=scale)
else:
query_downsample = None
super().__init__(
key_in_channels=in_channels,
query_in_channels=in_channels,
channels=channels,
out_channels=in_channels,
share_key_query=False,
query_downsample=query_downsample,
key_downsample=None,
key_query_num_convs=2,
key_query_norm=True,
value_out_num_convs=1,
value_out_norm=True,
matmul_norm=True,
with_out=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.bottleneck = ConvModule(
in_channels * 2,
in_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, query_feats, key_feats):
"""Forward function."""
context = super().forward(query_feats, key_feats)
output = self.bottleneck(torch.cat([context, query_feats], dim=1))
if self.query_downsample is not None:
output = resize(query_feats)
return output
@MODELS.register_module()
class OCRHead(BaseCascadeDecodeHead):
"""Object-Contextual Representations for Semantic Segmentation.
This head is the implementation of `OCRNet
<https://arxiv.org/abs/1909.11065>`_.
Args:
ocr_channels (int): The intermediate channels of OCR block.
scale (int): The scale of probability map in SpatialGatherModule in
Default: 1.
"""
def __init__(self, ocr_channels, scale=1, **kwargs):
super().__init__(**kwargs)
self.ocr_channels = ocr_channels
self.scale = scale
self.object_context_block = ObjectAttentionBlock(
self.channels,
self.ocr_channels,
self.scale,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.spatial_gather_module = SpatialGatherModule(self.scale)
self.bottleneck = ConvModule(
self.in_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs, prev_output):
"""Forward function."""
x = self._transform_inputs(inputs)
feats = self.bottleneck(x)
context = self.spatial_gather_module(feats, prev_output)
object_context = self.object_context_block(feats, context)
output = self.cls_seg(object_context)
return output

View File

@@ -0,0 +1,183 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from mmengine.model import BaseModule
from torch import Tensor
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.models.losses import accuracy
from mmseg.models.utils import resize
from mmseg.registry import MODELS
from mmseg.utils import OptConfigType, SampleList
class BasePIDHead(BaseModule):
"""Base class for PID head.
Args:
in_channels (int): Number of input channels.
channels (int): Number of output channels.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU', inplace=True).
init_cfg (dict or list[dict], optional): Init config dict.
Default: None.
"""
def __init__(self,
in_channels: int,
channels: int,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
init_cfg: OptConfigType = None):
super().__init__(init_cfg)
self.conv = ConvModule(
in_channels,
channels,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
order=('norm', 'act', 'conv'))
_, self.norm = build_norm_layer(norm_cfg, num_features=channels)
self.act = build_activation_layer(act_cfg)
def forward(self, x: Tensor, cls_seg: Optional[nn.Module]) -> Tensor:
"""Forward function.
Args:
x (Tensor): Input tensor.
cls_seg (nn.Module, optional): The classification head.
Returns:
Tensor: Output tensor.
"""
x = self.conv(x)
x = self.norm(x)
x = self.act(x)
if cls_seg is not None:
x = cls_seg(x)
return x
@MODELS.register_module()
class PIDHead(BaseDecodeHead):
"""Decode head for PIDNet.
Args:
in_channels (int): Number of input channels.
channels (int): Number of output channels.
num_classes (int): Number of classes.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU', inplace=True).
"""
def __init__(self,
in_channels: int,
channels: int,
num_classes: int,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
**kwargs):
super().__init__(
in_channels,
channels,
num_classes=num_classes,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**kwargs)
self.i_head = BasePIDHead(in_channels, channels, norm_cfg, act_cfg)
self.p_head = BasePIDHead(in_channels // 2, channels, norm_cfg,
act_cfg)
self.d_head = BasePIDHead(
in_channels // 2,
in_channels // 4,
norm_cfg,
)
self.p_cls_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
self.d_cls_seg = nn.Conv2d(in_channels // 4, 1, kernel_size=1)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(
self,
inputs: Union[Tensor,
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
"""Forward function.
Args:
inputs (Tensor | tuple[Tensor]): Input tensor or tuple of
Tensor. When training, the input is a tuple of three tensors,
(p_feat, i_feat, d_feat), and the output is a tuple of three
tensors, (p_seg_logit, i_seg_logit, d_seg_logit).
When inference, only the head of integral branch is used, and
input is a tensor of integral feature map, and the output is
the segmentation logit.
Returns:
Tensor | tuple[Tensor]: Output tensor or tuple of tensors.
"""
if self.training:
x_p, x_i, x_d = inputs
x_p = self.p_head(x_p, self.p_cls_seg)
x_i = self.i_head(x_i, self.cls_seg)
x_d = self.d_head(x_d, self.d_cls_seg)
return x_p, x_i, x_d
else:
return self.i_head(inputs, self.cls_seg)
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tuple[Tensor]:
gt_semantic_segs = [
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
]
gt_edge_segs = [
data_sample.gt_edge_map.data for data_sample in batch_data_samples
]
gt_sem_segs = torch.stack(gt_semantic_segs, dim=0)
gt_edge_segs = torch.stack(gt_edge_segs, dim=0)
return gt_sem_segs, gt_edge_segs
def loss_by_feat(self, seg_logits: Tuple[Tensor],
batch_data_samples: SampleList) -> dict:
loss = dict()
p_logit, i_logit, d_logit = seg_logits
sem_label, bd_label = self._stack_batch_gt(batch_data_samples)
p_logit = resize(
input=p_logit,
size=sem_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
i_logit = resize(
input=i_logit,
size=sem_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
d_logit = resize(
input=d_logit,
size=bd_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
sem_label = sem_label.squeeze(1)
bd_label = bd_label.squeeze(1)
loss['loss_sem_p'] = self.loss_decode[0](
p_logit, sem_label, ignore_index=self.ignore_index)
loss['loss_sem_i'] = self.loss_decode[1](i_logit, sem_label)
loss['loss_bd'] = self.loss_decode[2](d_logit, bd_label)
filler = torch.ones_like(sem_label) * self.ignore_index
sem_bd_label = torch.where(
torch.sigmoid(d_logit[:, 0, :, :]) > 0.8, sem_label, filler)
loss['loss_sem_bd'] = self.loss_decode[3](i_logit, sem_bd_label)
loss['acc_seg'] = accuracy(
i_logit, sem_label, ignore_index=self.ignore_index)
return loss

View File

@@ -0,0 +1,367 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
try:
from mmcv.ops import point_sample
except ModuleNotFoundError:
point_sample = None
from typing import List
from mmseg.registry import MODELS
from mmseg.utils import SampleList
from ..losses import accuracy
from ..utils import resize
from .cascade_decode_head import BaseCascadeDecodeHead
def calculate_uncertainty(seg_logits):
"""Estimate uncertainty based on seg logits.
For each location of the prediction ``seg_logits`` we estimate
uncertainty as the difference between top first and top second
predicted logits.
Args:
seg_logits (Tensor): Semantic segmentation logits,
shape (batch_size, num_classes, height, width).
Returns:
scores (Tensor): T uncertainty scores with the most uncertain
locations having the highest uncertainty score, shape (
batch_size, 1, height, width)
"""
top2_scores = torch.topk(seg_logits, k=2, dim=1)[0]
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
@MODELS.register_module()
class PointHead(BaseCascadeDecodeHead):
"""A mask point head use in PointRend.
This head is implemented of `PointRend: Image Segmentation as
Rendering <https://arxiv.org/abs/1912.08193>`_.
``PointHead`` use shared multi-layer perceptron (equivalent to
nn.Conv1d) to predict the logit of input points. The fine-grained feature
and coarse feature will be concatenate together for predication.
Args:
num_fcs (int): Number of fc layers in the head. Default: 3.
in_channels (int): Number of input channels. Default: 256.
fc_channels (int): Number of fc channels. Default: 256.
num_classes (int): Number of classes for logits. Default: 80.
class_agnostic (bool): Whether use class agnostic classification.
If so, the output channels of logits will be 1. Default: False.
coarse_pred_each_layer (bool): Whether concatenate coarse feature with
the output of each fc layer. Default: True.
conv_cfg (dict|None): Dictionary to construct and config conv layer.
Default: dict(type='Conv1d'))
norm_cfg (dict|None): Dictionary to construct and config norm layer.
Default: None.
loss_point (dict): Dictionary to construct and config loss layer of
point head. Default: dict(type='CrossEntropyLoss', use_mask=True,
loss_weight=1.0).
"""
def __init__(self,
num_fcs=3,
coarse_pred_each_layer=True,
conv_cfg=dict(type='Conv1d'),
norm_cfg=None,
act_cfg=dict(type='ReLU', inplace=False),
**kwargs):
super().__init__(
input_transform='multiple_select',
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
init_cfg=dict(
type='Normal', std=0.01, override=dict(name='fc_seg')),
**kwargs)
if point_sample is None:
raise RuntimeError('Please install mmcv-full for '
'point_sample ops')
self.num_fcs = num_fcs
self.coarse_pred_each_layer = coarse_pred_each_layer
fc_in_channels = sum(self.in_channels) + self.num_classes
fc_channels = self.channels
self.fcs = nn.ModuleList()
for k in range(num_fcs):
fc = ConvModule(
fc_in_channels,
fc_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.fcs.append(fc)
fc_in_channels = fc_channels
fc_in_channels += self.num_classes if self.coarse_pred_each_layer \
else 0
self.fc_seg = nn.Conv1d(
fc_in_channels,
self.num_classes,
kernel_size=1,
stride=1,
padding=0)
if self.dropout_ratio > 0:
self.dropout = nn.Dropout(self.dropout_ratio)
delattr(self, 'conv_seg')
def cls_seg(self, feat):
"""Classify each pixel with fc."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.fc_seg(feat)
return output
def forward(self, fine_grained_point_feats, coarse_point_feats):
x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1)
for fc in self.fcs:
x = fc(x)
if self.coarse_pred_each_layer:
x = torch.cat((x, coarse_point_feats), dim=1)
return self.cls_seg(x)
def _get_fine_grained_point_feats(self, x, points):
"""Sample from fine grained features.
Args:
x (list[Tensor]): Feature pyramid from by neck or backbone.
points (Tensor): Point coordinates, shape (batch_size,
num_points, 2).
Returns:
fine_grained_feats (Tensor): Sampled fine grained feature,
shape (batch_size, sum(channels of x), num_points).
"""
fine_grained_feats_list = [
point_sample(_, points, align_corners=self.align_corners)
for _ in x
]
if len(fine_grained_feats_list) > 1:
fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1)
else:
fine_grained_feats = fine_grained_feats_list[0]
return fine_grained_feats
def _get_coarse_point_feats(self, prev_output, points):
"""Sample from fine grained features.
Args:
prev_output (list[Tensor]): Prediction of previous decode head.
points (Tensor): Point coordinates, shape (batch_size,
num_points, 2).
Returns:
coarse_feats (Tensor): Sampled coarse feature, shape (batch_size,
num_classes, num_points).
"""
coarse_feats = point_sample(
prev_output, points, align_corners=self.align_corners)
return coarse_feats
def loss(self, inputs, prev_output, batch_data_samples: SampleList,
train_cfg, **kwargs):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
batch_data_samples (list[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `img_metas` or `gt_semantic_seg`.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self._transform_inputs(inputs)
with torch.no_grad():
points = self.get_points_train(
prev_output, calculate_uncertainty, cfg=train_cfg)
fine_grained_point_feats = self._get_fine_grained_point_feats(
x, points)
coarse_point_feats = self._get_coarse_point_feats(prev_output, points)
point_logits = self.forward(fine_grained_point_feats,
coarse_point_feats)
losses = self.loss_by_feat(point_logits, points, batch_data_samples)
return losses
def predict(self, inputs, prev_output, batch_img_metas: List[dict],
test_cfg, **kwargs):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
x = self._transform_inputs(inputs)
refined_seg_logits = prev_output.clone()
for _ in range(test_cfg.subdivision_steps):
refined_seg_logits = resize(
refined_seg_logits,
scale_factor=test_cfg.scale_factor,
mode='bilinear',
align_corners=self.align_corners)
batch_size, channels, height, width = refined_seg_logits.shape
point_indices, points = self.get_points_test(
refined_seg_logits, calculate_uncertainty, cfg=test_cfg)
fine_grained_point_feats = self._get_fine_grained_point_feats(
x, points)
coarse_point_feats = self._get_coarse_point_feats(
prev_output, points)
point_logits = self.forward(fine_grained_point_feats,
coarse_point_feats)
point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
refined_seg_logits = refined_seg_logits.reshape(
batch_size, channels, height * width)
refined_seg_logits = refined_seg_logits.scatter_(
2, point_indices, point_logits)
refined_seg_logits = refined_seg_logits.view(
batch_size, channels, height, width)
return self.predict_by_feat(refined_seg_logits, batch_img_metas,
**kwargs)
def loss_by_feat(self, point_logits, points, batch_data_samples, **kwargs):
"""Compute segmentation loss."""
gt_semantic_seg = self._stack_batch_gt(batch_data_samples)
point_label = point_sample(
gt_semantic_seg.float(),
points,
mode='nearest',
align_corners=self.align_corners)
point_label = point_label.squeeze(1).long()
loss = dict()
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_module in losses_decode:
loss['point' + loss_module.loss_name] = loss_module(
point_logits, point_label, ignore_index=self.ignore_index)
loss['acc_point'] = accuracy(
point_logits, point_label, ignore_index=self.ignore_index)
return loss
def get_points_train(self, seg_logits, uncertainty_func, cfg):
"""Sample points for training.
Sample points in [0, 1] x [0, 1] coordinate space based on their
uncertainty. The uncertainties are calculated for each point using
'uncertainty_func' function that takes point's logit prediction as
input.
Args:
seg_logits (Tensor): Semantic segmentation logits, shape (
batch_size, num_classes, height, width).
uncertainty_func (func): uncertainty calculation function.
cfg (dict): Training config of point head.
Returns:
point_coords (Tensor): A tensor of shape (batch_size, num_points,
2) that contains the coordinates of ``num_points`` sampled
points.
"""
num_points = cfg.num_points
oversample_ratio = cfg.oversample_ratio
importance_sample_ratio = cfg.importance_sample_ratio
assert oversample_ratio >= 1
assert 0 <= importance_sample_ratio <= 1
batch_size = seg_logits.shape[0]
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(
batch_size, num_sampled, 2, device=seg_logits.device)
point_logits = point_sample(seg_logits, point_coords)
# It is crucial to calculate uncertainty based on the sampled
# prediction value for the points. Calculating uncertainties of the
# coarse predictions first and sampling them for points leads to
# incorrect results. To illustrate this: assume uncertainty func(
# logits)=-abs(logits), a sampled point between two coarse
# predictions with -1 and 1 logits has 0 logits, and therefore 0
# uncertainty value. However, if we calculate uncertainties for the
# coarse predictions first, both will have -1 uncertainty,
# and sampled point will get -1 uncertainty.
point_uncertainties = uncertainty_func(point_logits)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(
point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_sampled * torch.arange(
batch_size, dtype=torch.long, device=seg_logits.device)
idx += shift[:, None]
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
batch_size, num_uncertain_points, 2)
if num_random_points > 0:
rand_point_coords = torch.rand(
batch_size, num_random_points, 2, device=seg_logits.device)
point_coords = torch.cat((point_coords, rand_point_coords), dim=1)
return point_coords
def get_points_test(self, seg_logits, uncertainty_func, cfg):
"""Sample points for testing.
Find ``num_points`` most uncertain points from ``uncertainty_map``.
Args:
seg_logits (Tensor): A tensor of shape (batch_size, num_classes,
height, width) for class-specific or class-agnostic prediction.
uncertainty_func (func): uncertainty calculation function.
cfg (dict): Testing config of point head.
Returns:
point_indices (Tensor): A tensor of shape (batch_size, num_points)
that contains indices from [0, height x width) of the most
uncertain points.
point_coords (Tensor): A tensor of shape (batch_size, num_points,
2) that contains [0, 1] x [0, 1] normalized coordinates of the
most uncertain points from the ``height x width`` grid .
"""
num_points = cfg.subdivision_num_points
uncertainty_map = uncertainty_func(seg_logits)
batch_size, _, height, width = uncertainty_map.shape
h_step = 1.0 / height
w_step = 1.0 / width
uncertainty_map = uncertainty_map.view(batch_size, height * width)
num_points = min(height * width, num_points)
point_indices = uncertainty_map.topk(num_points, dim=1)[1]
point_coords = torch.zeros(
batch_size,
num_points,
2,
dtype=torch.float,
device=seg_logits.device)
point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
width).float() * w_step
point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
width).float() * h_step
return point_indices, point_coords

View File

@@ -0,0 +1,197 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.registry import MODELS
from ..utils import resize
from .decode_head import BaseDecodeHead
try:
from mmcv.ops import PSAMask
except ModuleNotFoundError:
PSAMask = None
@MODELS.register_module()
class PSAHead(BaseDecodeHead):
"""Point-wise Spatial Attention Network for Scene Parsing.
This head is the implementation of `PSANet
<https://hszhao.github.io/papers/eccv18_psanet.pdf>`_.
Args:
mask_size (tuple[int]): The PSA mask size. It usually equals input
size.
psa_type (str): The type of psa module. Options are 'collect',
'distribute', 'bi-direction'. Default: 'bi-direction'
compact (bool): Whether use compact map for 'collect' mode.
Default: True.
shrink_factor (int): The downsample factors of psa mask. Default: 2.
normalization_factor (float): The normalize factor of attention.
psa_softmax (bool): Whether use softmax for attention.
"""
def __init__(self,
mask_size,
psa_type='bi-direction',
compact=False,
shrink_factor=2,
normalization_factor=1.0,
psa_softmax=True,
**kwargs):
if PSAMask is None:
raise RuntimeError('Please install mmcv-full for PSAMask ops')
super().__init__(**kwargs)
assert psa_type in ['collect', 'distribute', 'bi-direction']
self.psa_type = psa_type
self.compact = compact
self.shrink_factor = shrink_factor
self.mask_size = mask_size
mask_h, mask_w = mask_size
self.psa_softmax = psa_softmax
if normalization_factor is None:
normalization_factor = mask_h * mask_w
self.normalization_factor = normalization_factor
self.reduce = ConvModule(
self.in_channels,
self.channels,
kernel_size=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.attention = nn.Sequential(
ConvModule(
self.channels,
self.channels,
kernel_size=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
if psa_type == 'bi-direction':
self.reduce_p = ConvModule(
self.in_channels,
self.channels,
kernel_size=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.attention_p = nn.Sequential(
ConvModule(
self.channels,
self.channels,
kernel_size=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
self.psamask_collect = PSAMask('collect', mask_size)
self.psamask_distribute = PSAMask('distribute', mask_size)
else:
self.psamask = PSAMask(psa_type, mask_size)
self.proj = ConvModule(
self.channels * (2 if psa_type == 'bi-direction' else 1),
self.in_channels,
kernel_size=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.bottleneck = ConvModule(
self.in_channels * 2,
self.channels,
kernel_size=3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
identity = x
align_corners = self.align_corners
if self.psa_type in ['collect', 'distribute']:
out = self.reduce(x)
n, c, h, w = out.size()
if self.shrink_factor != 1:
if h % self.shrink_factor and w % self.shrink_factor:
h = (h - 1) // self.shrink_factor + 1
w = (w - 1) // self.shrink_factor + 1
align_corners = True
else:
h = h // self.shrink_factor
w = w // self.shrink_factor
align_corners = False
out = resize(
out,
size=(h, w),
mode='bilinear',
align_corners=align_corners)
y = self.attention(out)
if self.compact:
if self.psa_type == 'collect':
y = y.view(n, h * w,
h * w).transpose(1, 2).view(n, h * w, h, w)
else:
y = self.psamask(y)
if self.psa_softmax:
y = F.softmax(y, dim=1)
out = torch.bmm(
out.view(n, c, h * w), y.view(n, h * w, h * w)).view(
n, c, h, w) * (1.0 / self.normalization_factor)
else:
x_col = self.reduce(x)
x_dis = self.reduce_p(x)
n, c, h, w = x_col.size()
if self.shrink_factor != 1:
if h % self.shrink_factor and w % self.shrink_factor:
h = (h - 1) // self.shrink_factor + 1
w = (w - 1) // self.shrink_factor + 1
align_corners = True
else:
h = h // self.shrink_factor
w = w // self.shrink_factor
align_corners = False
x_col = resize(
x_col,
size=(h, w),
mode='bilinear',
align_corners=align_corners)
x_dis = resize(
x_dis,
size=(h, w),
mode='bilinear',
align_corners=align_corners)
y_col = self.attention(x_col)
y_dis = self.attention_p(x_dis)
if self.compact:
y_dis = y_dis.view(n, h * w,
h * w).transpose(1, 2).view(n, h * w, h, w)
else:
y_col = self.psamask_collect(y_col)
y_dis = self.psamask_distribute(y_dis)
if self.psa_softmax:
y_col = F.softmax(y_col, dim=1)
y_dis = F.softmax(y_dis, dim=1)
x_col = torch.bmm(
x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(
n, c, h, w) * (1.0 / self.normalization_factor)
x_dis = torch.bmm(
x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(
n, c, h, w) * (1.0 / self.normalization_factor)
out = torch.cat([x_col, x_dis], 1)
out = self.proj(out)
out = resize(
out,
size=identity.shape[2:],
mode='bilinear',
align_corners=align_corners)
out = self.bottleneck(torch.cat((identity, out), dim=1))
out = self.cls_seg(out)
return out

View File

@@ -0,0 +1,117 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.registry import MODELS
from ..utils import resize
from .decode_head import BaseDecodeHead
class PPM(nn.ModuleList):
"""Pooling Pyramid Module used in PSPNet.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
align_corners (bool): align_corners argument of F.interpolate.
"""
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
act_cfg, align_corners, **kwargs):
super().__init__()
self.pool_scales = pool_scales
self.align_corners = align_corners
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
for pool_scale in pool_scales:
self.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
**kwargs)))
def forward(self, x):
"""Forward function."""
ppm_outs = []
for ppm in self:
ppm_out = ppm(x)
upsampled_ppm_out = resize(
ppm_out,
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
ppm_outs.append(upsampled_ppm_out)
return ppm_outs
@MODELS.register_module()
class PSPHead(BaseDecodeHead):
"""Pyramid Scene Parsing Network.
This head is the implementation of
`PSPNet <https://arxiv.org/abs/1612.01105>`_.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module. Default: (1, 2, 3, 6).
"""
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super().__init__(**kwargs)
assert isinstance(pool_scales, (list, tuple))
self.pool_scales = pool_scales
self.psp_modules = PPM(
self.pool_scales,
self.in_channels,
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.bottleneck = ConvModule(
self.in_channels + len(pool_scales) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def _forward_feature(self, inputs):
"""Forward function for feature maps before classifying each pixel with
``self.cls_seg`` fc.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
feats (Tensor): A tensor of shape (batch_size, self.channels,
H, W) which is feature map for last layer of decoder head.
"""
x = self._transform_inputs(inputs)
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
psp_outs = torch.cat(psp_outs, dim=1)
feats = self.bottleneck(psp_outs)
return feats
def forward(self, inputs):
"""Forward function."""
output = self._forward_feature(inputs)
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,736 @@
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from mmcv.ops import point_sample
from mmengine.dist import all_reduce
from mmengine.model.weight_init import (caffe2_xavier_init, normal_init,
trunc_normal_)
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
from mmengine.structures import InstanceData
from torch import Tensor
from torch.nn import functional as F
from mmseg.models.backbones.vit import TransformerEncoderLayer
from mmseg.registry import MODELS
from mmseg.utils import (ConfigType, MatchMasks, SampleList,
seg_data_to_instance_data)
from ..utils import (MLP, LayerNorm2d, PatchEmbed, cross_attn_layer,
get_uncertain_point_coords_with_randomness, resize)
from .decode_head import BaseDecodeHead
class MLPMaskDecoder(nn.Module):
"""Module for decoding query and visual features with MLP layers to
generate the attention biases and the mask proposals."""
def __init__(
self,
*,
in_channels: int,
total_heads: int = 1,
total_layers: int = 1,
embed_channels: int = 256,
mlp_channels: int = 256,
mlp_num_layers: int = 3,
rescale_attn_bias: bool = False,
):
super().__init__()
self.total_heads = total_heads
self.total_layers = total_layers
dense_affine_func = partial(nn.Conv2d, kernel_size=1)
# Query Branch
self.query_mlp = MLP(in_channels, mlp_channels, embed_channels,
mlp_num_layers)
# Pixel Branch
self.pix_mlp = MLP(
in_channels,
mlp_channels,
embed_channels,
mlp_num_layers,
affine_func=dense_affine_func,
)
# Attention Bias Branch
self.attn_mlp = MLP(
in_channels,
mlp_channels,
embed_channels * self.total_heads * self.total_layers,
mlp_num_layers,
affine_func=dense_affine_func,
)
if rescale_attn_bias:
self.bias_scaling = nn.Linear(1, 1)
else:
self.bias_scaling = nn.Identity()
def forward(self, query: torch.Tensor,
x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward function.
Args:
query (Tensor): Query Tokens [B,N,C].
x (Tensor): Visual features [B,C,H,W]
Return:
mask_preds (Tensor): Mask proposals.
attn_bias (List[Tensor]): List of attention bias.
"""
query = self.query_mlp(query)
pix = self.pix_mlp(x)
b, c, h, w = pix.shape
# preidict mask
mask_preds = torch.einsum('bqc,bchw->bqhw', query, pix)
# generate attn bias
attn = self.attn_mlp(x)
attn = attn.reshape(b, self.total_layers, self.total_heads, c, h, w)
attn_bias = torch.einsum('bqc,blnchw->blnqhw', query, attn)
attn_bias = self.bias_scaling(attn_bias[..., None]).squeeze(-1)
attn_bias = attn_bias.chunk(self.total_layers, dim=1)
attn_bias = [attn.squeeze(1) for attn in attn_bias]
return mask_preds, attn_bias
class SideAdapterNetwork(nn.Module):
"""Side Adapter Network for predicting mask proposals and attention bias.
Args:
in_channels (int): Number of input channels. Default: 3.
clip_channels (int): Number of channels of visual features.
Default: 768.
embed_dims (int): embedding dimension. Default: 240.
patch_size (int): The patch size. Default: 16.
patch_bias (bool): Whether use bias in patch embedding.
Default: True.
num_queries (int): Number of queries for mask proposals.
Default: 100.
fusion_index (List[int]): The layer number of the encode
transformer to fuse with the CLIP feature.
Default: [0, 1, 2, 3].
cfg_encoder (ConfigType): Configs for the encode layers.
cfg_decoder (ConfigType): Configs for the decode layers.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
"""
def __init__(
self,
in_channels: int = 3,
clip_channels: int = 768,
embed_dims: int = 240,
patch_size: int = 16,
patch_bias: bool = True,
num_queries: int = 100,
fusion_index: list = [0, 1, 2, 3],
cfg_encoder: ConfigType = ...,
cfg_decoder: ConfigType = ...,
norm_cfg: dict = dict(type='LN'),
):
super().__init__()
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
padding=0,
input_size=(640, 640),
bias=patch_bias,
norm_cfg=None,
init_cfg=None,
)
ori_h, ori_w = self.patch_embed.init_out_size
num_patches = ori_h * ori_w
self.pos_embed = nn.Parameter(
torch.randn(1, num_patches, embed_dims) * .02)
self.query_pos_embed = nn.Parameter(
torch.zeros(1, num_queries, embed_dims))
self.query_embed = nn.Parameter(
torch.zeros(1, num_queries, embed_dims))
encode_layers = []
for i in range(cfg_encoder.num_encode_layer):
encode_layers.append(
TransformerEncoderLayer(
embed_dims=embed_dims,
num_heads=cfg_encoder.num_heads,
feedforward_channels=cfg_encoder.mlp_ratio * embed_dims,
norm_cfg=norm_cfg))
self.encode_layers = nn.ModuleList(encode_layers)
conv_clips = []
for i in range(len(fusion_index)):
conv_clips.append(
nn.Sequential(
LayerNorm2d(clip_channels),
ConvModule(
clip_channels,
embed_dims,
kernel_size=1,
norm_cfg=None,
act_cfg=None)))
self.conv_clips = nn.ModuleList(conv_clips)
self.fusion_index = fusion_index
self.mask_decoder = MLPMaskDecoder(
in_channels=embed_dims,
total_heads=cfg_decoder.num_heads,
total_layers=cfg_decoder.num_layers,
embed_channels=cfg_decoder.embed_channels,
mlp_channels=cfg_decoder.mlp_channels,
mlp_num_layers=cfg_decoder.num_mlp,
rescale_attn_bias=cfg_decoder.rescale)
def init_weights(self):
trunc_normal_(self.pos_embed, std=0.02)
nn.init.normal_(self.query_embed, std=0.02)
nn.init.normal_(self.query_pos_embed, std=0.02)
for i in range(len(self.conv_clips)):
caffe2_xavier_init(self.conv_clips[i][1].conv)
def fuse_clip(self, fused_index: int, x: torch.Tensor,
clip_feature: torch.Tensor, hwshape: Tuple[int,
int], L: int):
"""Fuse CLIP feature and visual tokens."""
fused_clip = (resize(
self.conv_clips[fused_index](clip_feature.contiguous()),
size=hwshape,
mode='bilinear',
align_corners=False)).permute(0, 2, 3, 1).reshape(x[:, -L:,
...].shape)
x = torch.cat([x[:, :-L, ...], x[:, -L:, ...] + fused_clip], dim=1)
return x
def encode_feature(self, image: torch.Tensor,
clip_features: List[torch.Tensor],
deep_supervision_idxs: List[int]) -> List[List]:
"""Encode images by a lightweight vision transformer."""
assert len(self.fusion_index) == len(clip_features)
x, hwshape = self.patch_embed(image)
ori_h, ori_w = self.patch_embed.init_out_size
pos_embed = self.pos_embed
if self.pos_embed.shape[1] != x.shape[1]:
# resize the position embedding
pos_embed = (
resize(
self.pos_embed.reshape(1, ori_h, ori_w,
-1).permute(0, 3, 1, 2),
size=hwshape,
mode='bicubic',
align_corners=False,
).flatten(2).permute(0, 2, 1))
pos_embed = torch.cat([
self.query_pos_embed.expand(pos_embed.shape[0], -1, -1), pos_embed
],
dim=1)
x = torch.cat([self.query_embed.expand(x.shape[0], -1, -1), x], dim=1)
x = x + pos_embed
L = hwshape[0] * hwshape[1]
fused_index = 0
if self.fusion_index[fused_index] == 0:
x = self.fuse_clip(fused_index, x, clip_features[0][0], hwshape, L)
fused_index += 1
outs = []
for index, block in enumerate(self.encode_layers, start=1):
x = block(x)
if index < len(self.fusion_index
) and index == self.fusion_index[fused_index]:
x = self.fuse_clip(fused_index, x,
clip_features[fused_index][0], hwshape, L)
fused_index += 1
x_query = x[:, :-L, ...]
x_feat = x[:, -L:, ...].permute(0, 2, 1)\
.reshape(x.shape[0], x.shape[-1], hwshape[0], hwshape[1])
if index in deep_supervision_idxs or index == len(
self.encode_layers):
outs.append({'query': x_query, 'x': x_feat})
if index < len(self.encode_layers):
x = x + pos_embed
return outs
def decode_feature(self, features):
mask_embeds = []
attn_biases = []
for feature in features:
mask_embed, attn_bias = self.mask_decoder(**feature)
mask_embeds.append(mask_embed)
attn_biases.append(attn_bias)
return mask_embeds, attn_biases
def forward(
self, image: torch.Tensor, clip_features: List[torch.Tensor],
deep_supervision_idxs: List[int]
) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
"""Forward function."""
features = self.encode_feature(image, clip_features,
deep_supervision_idxs)
mask_embeds, attn_biases = self.decode_feature(features)
return mask_embeds, attn_biases
class RecWithAttnbias(nn.Module):
"""Mask recognition module by applying the attention biases to rest deeper
CLIP layers.
Args:
sos_token_format (str): The format of sos token. It should be
chosen from ["cls_token", "learnable_token", "pos_embedding"].
Default: 'cls_token'.
sos_token_num (int): Number of sos token. It should be equal to
the number of quries. Default: 100.
num_layers (int): Number of rest CLIP layers for mask recognition.
Default: 3.
cross_attn (bool): Whether use cross attention to update sos token.
Default: False.
embed_dims (int): The feature dimension of CLIP layers.
Default: 768.
num_heads (int): Parallel attention heads of CLIP layers.
Default: 768.
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
Default: 4.
qkv_bias (bool): Whether to use bias in multihead-attention.
Default: True.
out_dims (int): Number of channels of the output mask proposals.
It should be equal to the out_dims of text_encoder.
Default: 512.
final_norm (True): Whether use norm layer for sos token.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
frozen_exclude (List): List of parameters that are not to be frozen.
"""
def __init__(self,
sos_token_format: str = 'cls_token',
sos_token_num: int = 100,
num_layers: int = 3,
cross_attn: bool = False,
embed_dims: int = 768,
num_heads: int = 12,
mlp_ratio: int = 4,
num_fcs: int = 2,
qkv_bias: bool = True,
out_dims: int = 512,
final_norm: bool = True,
act_cfg: dict = dict(type='GELU'),
norm_cfg: dict = dict(type='LN'),
frozen_exclude: List = []):
super().__init__()
assert sos_token_format in [
'cls_token', 'learnable_token', 'pos_embedding'
]
self.sos_token_format = sos_token_format
self.sos_token_num = sos_token_num
self.frozen_exclude = frozen_exclude
self.cross_attn = cross_attn
self.num_layers = num_layers
self.num_heads = num_heads
if sos_token_format in ['learnable_token', 'pos_embedding']:
self.sos_token = nn.Parameter(
torch.randn(sos_token_num, 1, self.proj.shape[0]))
self.frozen.append('sos_token')
layers = []
for i in range(num_layers):
layers.append(
BaseTransformerLayer(
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=embed_dims,
num_heads=num_heads,
batch_first=False,
bias=qkv_bias),
ffn_cfgs=dict(
type='FFN',
embed_dims=embed_dims,
feedforward_channels=mlp_ratio * embed_dims,
act_cfg=act_cfg),
operation_order=('norm', 'self_attn', 'norm', 'ffn')))
self.layers = nn.ModuleList(layers)
self.ln_post = build_norm_layer(norm_cfg, embed_dims)[1]
self.proj = nn.Linear(embed_dims, out_dims, bias=False)
self.final_norm = final_norm
self._freeze()
def init_weights(self, rec_state_dict):
if hasattr(self, 'sos_token'):
normal_init(self.sos_token, std=0.02)
if rec_state_dict is not None:
load_state_dict(self, rec_state_dict, strict=False, logger=None)
else:
super().init_weights()
def _freeze(self):
if 'all' in self.frozen_exclude:
return
for name, param in self.named_parameters():
if not any([exclude in name for exclude in self.frozen_exclude]):
param.requires_grad = False
def _build_attn_biases(self, attn_biases, target_shape):
formatted_attn_biases = []
for attn_bias in attn_biases:
# convert it to proper format: N*num_head,L,L
# attn_bias: [N, num_head/1, num_sos,H,W]
n, num_head, num_sos, h, w = attn_bias.shape
# reshape and downsample
attn_bias = F.adaptive_max_pool2d(
attn_bias.reshape(n, num_head * num_sos, h, w),
output_size=target_shape)
attn_bias = attn_bias.reshape(n, num_head, num_sos, *target_shape)
true_num_head = self.num_heads
assert (num_head == 1 or num_head
== true_num_head), f'num_head={num_head} is not supported.'
if num_head == 1:
attn_bias = attn_bias.repeat(1, true_num_head, 1, 1, 1)
attn_bias = attn_bias.reshape(n * true_num_head, num_sos, -1)
L = attn_bias.shape[-1]
if self.cross_attn:
# [n*num_head, num_sos, L]
formatted_attn_biases.append(attn_bias)
else:
# [n*num_head, num_sos+1+L, num_sos+1+L]
new_attn_bias = attn_bias.new_zeros(num_sos + 1 + L,
num_sos + 1 + L)
new_attn_bias[:, :num_sos] = -100
new_attn_bias[torch.arange(num_sos), torch.arange(num_sos)] = 0
new_attn_bias[:num_sos, num_sos] = -100
new_attn_bias = (
new_attn_bias[None, ...].expand(n * true_num_head, -1,
-1).clone())
new_attn_bias[..., :num_sos, -L:] = attn_bias
formatted_attn_biases.append(new_attn_bias)
if len(formatted_attn_biases) == 1:
formatted_attn_biases = [
formatted_attn_biases[0] for _ in range(self.num_layers)
]
return formatted_attn_biases
def forward(self, bias: List[Tensor], feature: List[Tensor]):
"""Forward function to recognize the category of masks
Args:
bias (List[Tensor]): Attention bias for transformer layers
feature (List[Tensor]): Output of the image encoder,
including cls_token and img_feature.
"""
cls_token = feature[1].unsqueeze(0)
img_feature = feature[0]
b, c, h, w = img_feature.shape
# construct clip shadow features
x = torch.cat(
[cls_token,
img_feature.reshape(b, c, -1).permute(2, 0, 1)])
# construct sos token
if self.sos_token_format == 'cls_token':
sos_token = cls_token.repeat(self.sos_token_num, 1, 1)
elif self.sos_token_format == 'learnable_token':
sos_token = self.sos_token.expand(-1, b, -1)
elif self.sos_token_format == 'pos_embedding':
sos_token = self.sos_token.expand(-1, b, -1) + cls_token
# construct attn bias
attn_biases = self._build_attn_biases(bias, target_shape=(h, w))
if self.cross_attn:
for i, block in enumerate(self.layers):
if self.cross_attn:
sos_token = cross_attn_layer(
block,
sos_token,
x[1:, ],
attn_biases[i],
)
if i < len(self.layers) - 1:
x = block(x)
else:
x = torch.cat([sos_token, x], dim=0)
for i, block in enumerate(self.layers):
x = block(x, attn_masks=[attn_biases[i]])
sos_token = x[:self.sos_token_num]
sos_token = sos_token.permute(1, 0, 2) # LND -> NLD
sos_token = self.ln_post(sos_token)
sos_token = self.proj(sos_token)
if self.final_norm:
sos_token = F.normalize(sos_token, dim=-1)
return sos_token
@MODELS.register_module()
class SideAdapterCLIPHead(BaseDecodeHead):
"""Side Adapter Network (SAN) for open-vocabulary semantic segmentation
with pre-trained vision-language model.
This decode head is the implementation of `Side Adapter Network
for Open-Vocabulary Semantic Segmentation`
<https://arxiv.org/abs/2302.12242>.
Modified from https://github.com/MendelXu/SAN/blob/main/san/model/side_adapter/side_adapter.py # noqa:E501
Copyright (c) 2023 MendelXu.
Licensed under the MIT License
Args:
num_classes (int): the number of classes.
san_cfg (ConfigType): Configs for SideAdapterNetwork module
maskgen_cfg (ConfigType): Configs for RecWithAttnbias module
"""
def __init__(self, num_classes: int, san_cfg: ConfigType,
maskgen_cfg: ConfigType, deep_supervision_idxs: List[int],
train_cfg: ConfigType, **kwargs):
super().__init__(
in_channels=san_cfg.in_channels,
channels=san_cfg.embed_dims,
num_classes=num_classes,
**kwargs)
assert san_cfg.num_queries == maskgen_cfg.sos_token_num, \
'num_queries in san_cfg should be equal to sos_token_num ' \
'in maskgen_cfg'
del self.conv_seg
self.side_adapter_network = SideAdapterNetwork(**san_cfg)
self.rec_with_attnbias = RecWithAttnbias(**maskgen_cfg)
self.deep_supervision_idxs = deep_supervision_idxs
self.train_cfg = train_cfg
if train_cfg:
self.match_masks = MatchMasks(
num_points=train_cfg.num_points,
num_queries=san_cfg.num_queries,
num_classes=num_classes,
assigner=train_cfg.assigner)
def init_weights(self):
rec_state_dict = None
if isinstance(self.init_cfg, dict) and \
self.init_cfg.get('type') == 'Pretrained_Part':
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
rec_state_dict = checkpoint.copy()
para_prefix = 'decode_head.rec_with_attnbias'
prefix_len = len(para_prefix) + 1
for k, v in checkpoint.items():
rec_state_dict.pop(k)
if para_prefix in k:
rec_state_dict[k[prefix_len:]] = v
self.side_adapter_network.init_weights()
self.rec_with_attnbias.init_weights(rec_state_dict)
def forward(self, inputs: Tuple[Tensor],
deep_supervision_idxs) -> Tuple[List]:
"""Forward function.
Args:
inputs (Tuple[Tensor]): A triplet including images,
list of multi-level visual features from image encoder and
class embeddings from text_encoder.
Returns:
mask_props (List[Tensor]): Mask proposals predicted by SAN.
mask_logits (List[Tensor]): Class logits of mask proposals.
"""
imgs, clip_feature, class_embeds = inputs
# predict mask proposals and attention bias
mask_props, attn_biases = self.side_adapter_network(
imgs, clip_feature, deep_supervision_idxs)
# mask recognition with attention bias
mask_embeds = [
self.rec_with_attnbias(att_bias, clip_feature[-1])
for att_bias in attn_biases
]
# Obtain class prediction of masks by comparing the similarity
# between the image token and the text embedding of class names.
mask_logits = [
torch.einsum('bqc,nc->bqn', mask_embed, class_embeds)
for mask_embed in mask_embeds
]
return mask_props, mask_logits
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
test_cfg: ConfigType) -> Tensor:
"""Forward function for prediction.
Args:
inputs (Tuple[Tensor]): Images, visual features from image encoder
and class embedding from text encoder.
batch_img_metas (dict): List Image info where each dict may also
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', and 'pad_shape'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
test_cfg (dict): The testing config.
Returns:
Tensor: Outputs segmentation logits map.
"""
mask_props, mask_logits = self.forward(inputs, [])
return self.predict_by_feat([mask_props[-1], mask_logits[-1]],
batch_img_metas)
def predict_by_feat(self, seg_logits: List[Tensor],
batch_img_metas: List[dict]) -> Tensor:
"""1. Transform a batch of mask proposals to the input shape.
2. Generate segmentation map with mask proposals and class logits.
"""
mask_pred = seg_logits[0]
cls_score = seg_logits[1]
if isinstance(batch_img_metas[0]['img_shape'], torch.Size):
# slide inference
size = batch_img_metas[0]['img_shape']
elif 'pad_shape' in batch_img_metas[0]:
size = batch_img_metas[0]['pad_shape'][:2]
else:
size = batch_img_metas[0]['img_shape']
# upsample mask
mask_pred = F.interpolate(
mask_pred, size=size, mode='bilinear', align_corners=False)
mask_cls = F.softmax(cls_score, dim=-1)[..., :-1]
mask_pred = mask_pred.sigmoid()
seg_logits = torch.einsum('bqc,bqhw->bchw', mask_cls, mask_pred)
return seg_logits
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
train_cfg: ConfigType) -> dict:
"""Perform forward propagation and loss calculation of the decoder head
on the features of the upstream network.
Args:
x (tuple[Tensor]): Multi-level features from the upstream
network, each is a 4D-tensor.
batch_data_samples (List[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as
`gt_sem_seg`.
train_cfg (ConfigType): Training config.
Returns:
dict[str, Tensor]: a dictionary of loss components.
"""
# batch SegDataSample to InstanceDataSample
batch_gt_instances = seg_data_to_instance_data(self.ignore_index,
batch_data_samples)
# forward
all_mask_props, all_mask_logits = self.forward(
x, self.deep_supervision_idxs)
# loss
losses = self.loss_by_feat(all_mask_logits, all_mask_props,
batch_gt_instances)
return losses
def loss_by_feat(
self, all_cls_scores: Tensor, all_mask_preds: Tensor,
batch_gt_instances: List[InstanceData]) -> Dict[str, Tensor]:
"""Loss function.
Args:
all_cls_scores (Tensor): Classification scores for all decoder
layers with shape (num_decoder, batch_size, num_queries,
cls_out_channels). Note `cls_out_channels` should includes
background.
all_mask_preds (Tensor): Mask scores for all decoder layers with
shape (num_decoder, batch_size, num_queries, h, w).
batch_gt_instances (list[obj:`InstanceData`]): each contains
``labels`` and ``masks``.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
num_dec_layers = len(all_cls_scores)
batch_gt_instances_list = [
batch_gt_instances for _ in range(num_dec_layers)
]
losses = []
for i in range(num_dec_layers):
cls_scores = all_cls_scores[i]
mask_preds = all_mask_preds[i]
# matching N mask predictions to K category labels
(labels, mask_targets, mask_weights,
avg_factor) = self.match_masks.get_targets(
cls_scores, mask_preds, batch_gt_instances_list[i])
cls_scores = cls_scores.flatten(0, 1)
labels = labels.flatten(0, 1)
num_total_masks = cls_scores.new_tensor([avg_factor],
dtype=torch.float)
all_reduce(num_total_masks, op='mean')
num_total_masks = max(num_total_masks, 1)
# extract positive ones
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
mask_preds = mask_preds[mask_weights > 0]
if mask_targets.shape[0] != 0:
with torch.no_grad():
points_coords = get_uncertain_point_coords_with_randomness(
mask_preds.unsqueeze(1), None,
self.train_cfg.num_points,
self.train_cfg.oversample_ratio,
self.train_cfg.importance_sample_ratio)
# shape (num_total_gts, h, w)
# -> (num_total_gts, num_points)
mask_point_targets = point_sample(
mask_targets.unsqueeze(1).float(),
points_coords).squeeze(1)
# shape (num_queries, h, w) -> (num_queries, num_points)
mask_point_preds = point_sample(
mask_preds.unsqueeze(1), points_coords).squeeze(1)
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
loss = dict()
for loss_decode in losses_decode:
if 'loss_cls' in loss_decode.loss_name:
if loss_decode.loss_name == 'loss_cls_ce':
loss[loss_decode.loss_name] = loss_decode(
cls_scores, labels)
else:
assert False, "Only support 'CrossEntropyLoss' in" \
' classification loss'
elif 'loss_mask' in loss_decode.loss_name:
if mask_targets.shape[0] == 0:
loss[loss_decode.loss_name] = mask_preds.sum()
elif loss_decode.loss_name == 'loss_mask_ce':
loss[loss_decode.loss_name] = loss_decode(
mask_point_preds,
mask_point_targets,
avg_factor=num_total_masks *
self.train_cfg.num_points)
elif loss_decode.loss_name == 'loss_mask_dice':
loss[loss_decode.loss_name] = loss_decode(
mask_point_preds,
mask_point_targets,
avg_factor=num_total_masks)
else:
assert False, "Only support 'CrossEntropyLoss' and" \
" 'DiceLoss' in mask loss"
else:
assert False, "Only support for 'loss_cls' and 'loss_mask'"
losses.append(loss)
loss_dict = dict()
# loss from the last decoder layer
loss_dict.update(losses[-1])
# loss from other decoder layers
for i, loss in enumerate(losses[:-1]):
for k, v in loss.items():
loss_dict[f'd{self.deep_supervision_idxs[i]}.{k}'] = v
return loss_dict

View File

@@ -0,0 +1,66 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.registry import MODELS
from ..utils import resize
@MODELS.register_module()
class SegformerHead(BaseDecodeHead):
"""The all mlp Head of segformer.
This head is the implementation of
`Segformer <https://arxiv.org/abs/2105.15203>` _.
Args:
interpolate_mode: The interpolate mode of MLP head upsample operation.
Default: 'bilinear'.
"""
def __init__(self, interpolate_mode='bilinear', **kwargs):
super().__init__(input_transform='multiple_select', **kwargs)
self.interpolate_mode = interpolate_mode
num_inputs = len(self.in_channels)
assert num_inputs == len(self.in_index)
self.convs = nn.ModuleList()
for i in range(num_inputs):
self.convs.append(
ConvModule(
in_channels=self.in_channels[i],
out_channels=self.channels,
kernel_size=1,
stride=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.fusion_conv = ConvModule(
in_channels=self.channels * num_inputs,
out_channels=self.channels,
kernel_size=1,
norm_cfg=self.norm_cfg)
def forward(self, inputs):
# Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
inputs = self._transform_inputs(inputs)
outs = []
for idx in range(len(inputs)):
x = inputs[idx]
conv = self.convs[idx]
outs.append(
resize(
input=conv(x),
size=inputs[0].shape[2:],
mode=self.interpolate_mode,
align_corners=self.align_corners))
out = self.fusion_conv(torch.cat(outs, dim=1))
out = self.cls_seg(out)
return out

View File

@@ -0,0 +1,132 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmengine.model import ModuleList
from mmengine.model.weight_init import (constant_init, trunc_normal_,
trunc_normal_init)
from mmseg.models.backbones.vit import TransformerEncoderLayer
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
@MODELS.register_module()
class SegmenterMaskTransformerHead(BaseDecodeHead):
"""Segmenter: Transformer for Semantic Segmentation.
This head is the implementation of
`Segmenter: <https://arxiv.org/abs/2105.05633>`_.
Args:
backbone_cfg:(dict): Config of backbone of
Context Path.
in_channels (int): The number of channels of input image.
num_layers (int): The depth of transformer.
num_heads (int): The number of attention heads.
embed_dims (int): The number of embedding dimension.
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
Default: 4.
drop_path_rate (float): stochastic depth rate. Default 0.1.
drop_rate (float): Probability of an element to be zeroed.
Default 0.0
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): Enable bias for qkv if True. Default: True.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
init_std (float): The value of std in weight initialization.
Default: 0.02.
"""
def __init__(
self,
in_channels,
num_layers,
num_heads,
embed_dims,
mlp_ratio=4,
drop_path_rate=0.1,
drop_rate=0.0,
attn_drop_rate=0.0,
num_fcs=2,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_std=0.02,
**kwargs,
):
super().__init__(in_channels=in_channels, **kwargs)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
self.layers = ModuleList()
for i in range(num_layers):
self.layers.append(
TransformerEncoderLayer(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=mlp_ratio * embed_dims,
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
num_fcs=num_fcs,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
batch_first=True,
))
self.dec_proj = nn.Linear(in_channels, embed_dims)
self.cls_emb = nn.Parameter(
torch.randn(1, self.num_classes, embed_dims))
self.patch_proj = nn.Linear(embed_dims, embed_dims, bias=False)
self.classes_proj = nn.Linear(embed_dims, embed_dims, bias=False)
self.decoder_norm = build_norm_layer(
norm_cfg, embed_dims, postfix=1)[1]
self.mask_norm = build_norm_layer(
norm_cfg, self.num_classes, postfix=2)[1]
self.init_std = init_std
delattr(self, 'conv_seg')
def init_weights(self):
trunc_normal_(self.cls_emb, std=self.init_std)
trunc_normal_init(self.patch_proj, std=self.init_std)
trunc_normal_init(self.classes_proj, std=self.init_std)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=self.init_std, bias=0)
elif isinstance(m, nn.LayerNorm):
constant_init(m, val=1.0, bias=0.0)
def forward(self, inputs):
x = self._transform_inputs(inputs)
b, c, h, w = x.shape
x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c)
x = self.dec_proj(x)
cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
x = torch.cat((x, cls_emb), 1)
for layer in self.layers:
x = layer(x)
x = self.decoder_norm(x)
patches = self.patch_proj(x[:, :-self.num_classes])
cls_seg_feat = self.classes_proj(x[:, -self.num_classes:])
patches = F.normalize(patches, dim=2, p=2)
cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2)
masks = patches @ cls_seg_feat.transpose(1, 2)
masks = self.mask_norm(masks)
masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w)
return masks

View File

@@ -0,0 +1,102 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmseg.registry import MODELS
from ..utils import resize
from .aspp_head import ASPPHead, ASPPModule
class DepthwiseSeparableASPPModule(ASPPModule):
"""Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable
conv."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
for i, dilation in enumerate(self.dilations):
if dilation > 1:
self[i] = DepthwiseSeparableConvModule(
self.in_channels,
self.channels,
3,
dilation=dilation,
padding=dilation,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
@MODELS.register_module()
class DepthwiseSeparableASPPHead(ASPPHead):
"""Encoder-Decoder with Atrous Separable Convolution for Semantic Image
Segmentation.
This head is the implementation of `DeepLabV3+
<https://arxiv.org/abs/1802.02611>`_.
Args:
c1_in_channels (int): The input channels of c1 decoder. If is 0,
the no decoder will be used.
c1_channels (int): The intermediate channels of c1 decoder.
"""
def __init__(self, c1_in_channels, c1_channels, **kwargs):
super().__init__(**kwargs)
assert c1_in_channels >= 0
self.aspp_modules = DepthwiseSeparableASPPModule(
dilations=self.dilations,
in_channels=self.in_channels,
channels=self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if c1_in_channels > 0:
self.c1_bottleneck = ConvModule(
c1_in_channels,
c1_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
else:
self.c1_bottleneck = None
self.sep_bottleneck = nn.Sequential(
DepthwiseSeparableConvModule(
self.channels + c1_channels,
self.channels,
3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
DepthwiseSeparableConvModule(
self.channels,
self.channels,
3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
aspp_outs = [
resize(
self.image_pool(x),
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
]
aspp_outs.extend(self.aspp_modules(x))
aspp_outs = torch.cat(aspp_outs, dim=1)
output = self.bottleneck(aspp_outs)
if self.c1_bottleneck is not None:
c1_output = self.c1_bottleneck(inputs[0])
output = resize(
input=output,
size=c1_output.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
output = torch.cat([output, c1_output], dim=1)
output = self.sep_bottleneck(output)
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,60 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import DepthwiseSeparableConvModule
from mmseg.registry import MODELS
from .fcn_head import FCNHead
@MODELS.register_module()
class DepthwiseSeparableFCNHead(FCNHead):
"""Depthwise-Separable Fully Convolutional Network for Semantic
Segmentation.
This head is implemented according to `Fast-SCNN: Fast Semantic
Segmentation Network <https://arxiv.org/abs/1902.04502>`_.
Args:
in_channels(int): Number of output channels of FFM.
channels(int): Number of middle-stage channels in the decode head.
concat_input(bool): Whether to concatenate original decode input into
the result of several consecutive convolution layers.
Default: True.
num_classes(int): Used to determine the dimension of
final prediction tensor.
in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
norm_cfg (dict | None): Config of norm layers.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
loss_decode(dict): Config of loss type and some
relevant additional options.
dw_act_cfg (dict):Activation config of depthwise ConvModule. If it is
'default', it will be the same as `act_cfg`. Default: None.
"""
def __init__(self, dw_act_cfg=None, **kwargs):
super().__init__(**kwargs)
self.convs[0] = DepthwiseSeparableConvModule(
self.in_channels,
self.channels,
kernel_size=self.kernel_size,
padding=self.kernel_size // 2,
norm_cfg=self.norm_cfg,
dw_act_cfg=dw_act_cfg)
for i in range(1, self.num_convs):
self.convs[i] = DepthwiseSeparableConvModule(
self.channels,
self.channels,
kernel_size=self.kernel_size,
padding=self.kernel_size // 2,
norm_cfg=self.norm_cfg,
dw_act_cfg=dw_act_cfg)
if self.concat_input:
self.conv_cat = DepthwiseSeparableConvModule(
self.in_channels + self.channels,
self.channels,
kernel_size=self.kernel_size,
padding=self.kernel_size // 2,
norm_cfg=self.norm_cfg,
dw_act_cfg=dw_act_cfg)

View File

@@ -0,0 +1,62 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.registry import MODELS
from ..utils import Upsample
from .decode_head import BaseDecodeHead
@MODELS.register_module()
class SETRMLAHead(BaseDecodeHead):
"""Multi level feature aggretation head of SETR.
MLA head of `SETR <https://arxiv.org/pdf/2012.15840.pdf>`_.
Args:
mlahead_channels (int): Channels of conv-conv-4x of multi-level feature
aggregation. Default: 128.
up_scale (int): The scale factor of interpolate. Default:4.
"""
def __init__(self, mla_channels=128, up_scale=4, **kwargs):
super().__init__(input_transform='multiple_select', **kwargs)
self.mla_channels = mla_channels
num_inputs = len(self.in_channels)
# Refer to self.cls_seg settings of BaseDecodeHead
assert self.channels == num_inputs * mla_channels
self.up_convs = nn.ModuleList()
for i in range(num_inputs):
self.up_convs.append(
nn.Sequential(
ConvModule(
in_channels=self.in_channels[i],
out_channels=mla_channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
in_channels=mla_channels,
out_channels=mla_channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
Upsample(
scale_factor=up_scale,
mode='bilinear',
align_corners=self.align_corners)))
def forward(self, inputs):
inputs = self._transform_inputs(inputs)
outs = []
for x, up_conv in zip(inputs, self.up_convs):
outs.append(up_conv(x))
out = torch.cat(outs, dim=1)
out = self.cls_seg(out)
return out

View File

@@ -0,0 +1,81 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer
from mmseg.registry import MODELS
from ..utils import Upsample
from .decode_head import BaseDecodeHead
@MODELS.register_module()
class SETRUPHead(BaseDecodeHead):
"""Naive upsampling head and Progressive upsampling head of SETR.
Naive or PUP head of `SETR <https://arxiv.org/pdf/2012.15840.pdf>`_.
Args:
norm_layer (dict): Config dict for input normalization.
Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True).
num_convs (int): Number of decoder convolutions. Default: 1.
up_scale (int): The scale factor of interpolate. Default:4.
kernel_size (int): The kernel size of convolution when decoding
feature information from backbone. Default: 3.
init_cfg (dict | list[dict] | None): Initialization config dict.
Default: dict(
type='Constant', val=1.0, bias=0, layer='LayerNorm').
"""
def __init__(self,
norm_layer=dict(type='LN', eps=1e-6, requires_grad=True),
num_convs=1,
up_scale=4,
kernel_size=3,
init_cfg=[
dict(type='Constant', val=1.0, bias=0, layer='LayerNorm'),
dict(
type='Normal',
std=0.01,
override=dict(name='conv_seg'))
],
**kwargs):
assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.'
super().__init__(init_cfg=init_cfg, **kwargs)
assert isinstance(self.in_channels, int)
_, self.norm = build_norm_layer(norm_layer, self.in_channels)
self.up_convs = nn.ModuleList()
in_channels = self.in_channels
out_channels = self.channels
for _ in range(num_convs):
self.up_convs.append(
nn.Sequential(
ConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=int(kernel_size - 1) // 2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
Upsample(
scale_factor=up_scale,
mode='bilinear',
align_corners=self.align_corners)))
in_channels = out_channels
def forward(self, x):
x = self._transform_inputs(x)
n, c, h, w = x.shape
x = x.reshape(n, c, h * w).transpose(2, 1).contiguous()
x = self.norm(x)
x = x.transpose(1, 2).reshape(n, c, h, w).contiguous()
for up_conv in self.up_convs:
x = up_conv(x)
out = self.cls_seg(x)
return out

View File

@@ -0,0 +1,97 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmengine.structures import PixelData
from torch import Tensor
from mmseg.registry import MODELS
from mmseg.structures import SegDataSample
from mmseg.utils import SampleList
from .fcn_head import FCNHead
@MODELS.register_module()
class STDCHead(FCNHead):
"""This head is the implementation of `Rethinking BiSeNet For Real-time
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
Args:
boundary_threshold (float): The threshold of calculating boundary.
Default: 0.1.
"""
def __init__(self, boundary_threshold=0.1, **kwargs):
super().__init__(**kwargs)
self.boundary_threshold = boundary_threshold
# Using register buffer to make laplacian kernel on the same
# device of `seg_label`.
self.register_buffer(
'laplacian_kernel',
torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1],
dtype=torch.float32,
requires_grad=False).reshape((1, 1, 3, 3)))
self.fusion_kernel = torch.nn.Parameter(
torch.tensor([[6. / 10], [3. / 10], [1. / 10]],
dtype=torch.float32).reshape(1, 3, 1, 1),
requires_grad=False)
def loss_by_feat(self, seg_logits: Tensor,
batch_data_samples: SampleList) -> dict:
"""Compute Detail Aggregation Loss."""
# Note: The paper claims `fusion_kernel` is a trainable 1x1 conv
# parameters. However, it is a constant in original repo and other
# codebase because it would not be added into computation graph
# after threshold operation.
seg_label = self._stack_batch_gt(batch_data_samples).to(
self.laplacian_kernel)
boundary_targets = F.conv2d(
seg_label, self.laplacian_kernel, padding=1)
boundary_targets = boundary_targets.clamp(min=0)
boundary_targets[boundary_targets > self.boundary_threshold] = 1
boundary_targets[boundary_targets <= self.boundary_threshold] = 0
boundary_targets_x2 = F.conv2d(
seg_label, self.laplacian_kernel, stride=2, padding=1)
boundary_targets_x2 = boundary_targets_x2.clamp(min=0)
boundary_targets_x4 = F.conv2d(
seg_label, self.laplacian_kernel, stride=4, padding=1)
boundary_targets_x4 = boundary_targets_x4.clamp(min=0)
boundary_targets_x4_up = F.interpolate(
boundary_targets_x4, boundary_targets.shape[2:], mode='nearest')
boundary_targets_x2_up = F.interpolate(
boundary_targets_x2, boundary_targets.shape[2:], mode='nearest')
boundary_targets_x2_up[
boundary_targets_x2_up > self.boundary_threshold] = 1
boundary_targets_x2_up[
boundary_targets_x2_up <= self.boundary_threshold] = 0
boundary_targets_x4_up[
boundary_targets_x4_up > self.boundary_threshold] = 1
boundary_targets_x4_up[
boundary_targets_x4_up <= self.boundary_threshold] = 0
boundary_targets_pyramids = torch.stack(
(boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up),
dim=1)
boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2)
boudary_targets_pyramid = F.conv2d(boundary_targets_pyramids,
self.fusion_kernel)
boudary_targets_pyramid[
boudary_targets_pyramid > self.boundary_threshold] = 1
boudary_targets_pyramid[
boudary_targets_pyramid <= self.boundary_threshold] = 0
seg_labels = boudary_targets_pyramid.long()
batch_sample_list = []
for label in seg_labels:
seg_data_sample = SegDataSample()
seg_data_sample.gt_sem_seg = PixelData(data=label)
batch_sample_list.append(seg_data_sample)
loss = super().loss_by_feat(seg_logits, batch_sample_list)
return loss

View File

@@ -0,0 +1,139 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.registry import MODELS
from ..utils import resize
from .decode_head import BaseDecodeHead
from .psp_head import PPM
@MODELS.register_module()
class UPerHead(BaseDecodeHead):
"""Unified Perceptual Parsing for Scene Understanding.
This head is the implementation of `UPerNet
<https://arxiv.org/abs/1807.10221>`_.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module applied on the last feature. Default: (1, 2, 3, 6).
"""
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super().__init__(input_transform='multiple_select', **kwargs)
# PSP Module
self.psp_modules = PPM(
pool_scales,
self.in_channels[-1],
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.bottleneck = ConvModule(
self.in_channels[-1] + len(pool_scales) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
# FPN Module
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
for in_channels in self.in_channels[:-1]: # skip the top layer
l_conv = ConvModule(
in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False)
fpn_conv = ConvModule(
self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
self.fpn_bottleneck = ConvModule(
len(self.in_channels) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def psp_forward(self, inputs):
"""Forward function of PSP module."""
x = inputs[-1]
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
psp_outs = torch.cat(psp_outs, dim=1)
output = self.bottleneck(psp_outs)
return output
def _forward_feature(self, inputs):
"""Forward function for feature maps before classifying each pixel with
``self.cls_seg`` fc.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
feats (Tensor): A tensor of shape (batch_size, self.channels,
H, W) which is feature map for last layer of decoder head.
"""
inputs = self._transform_inputs(inputs)
# build laterals
laterals = [
lateral_conv(inputs[i])
for i, lateral_conv in enumerate(self.lateral_convs)
]
laterals.append(self.psp_forward(inputs))
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] = laterals[i - 1] + resize(
laterals[i],
size=prev_shape,
mode='bilinear',
align_corners=self.align_corners)
# build outputs
fpn_outs = [
self.fpn_convs[i](laterals[i])
for i in range(used_backbone_levels - 1)
]
# append psp feature
fpn_outs.append(laterals[-1])
for i in range(used_backbone_levels - 1, 0, -1):
fpn_outs[i] = resize(
fpn_outs[i],
size=fpn_outs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners)
fpn_outs = torch.cat(fpn_outs, dim=1)
feats = self.fpn_bottleneck(fpn_outs)
return feats
def forward(self, inputs):
"""Forward function."""
output = self._forward_feature(inputs)
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,253 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer
from mmengine.model import BaseModule
from torch import Tensor
from mmseg.registry import MODELS
from mmseg.utils import SampleList
from ..utils import resize
from .decode_head import BaseDecodeHead
class VPDDepthDecoder(BaseModule):
"""VPD Depth Decoder class.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
num_deconv_layers (int): Number of deconvolution layers.
num_deconv_filters (List[int]): List of output channels for
deconvolution layers.
init_cfg (Optional[Union[Dict, List[Dict]]], optional): Configuration
for weight initialization. Defaults to Normal for Conv2d and
ConvTranspose2d layers.
"""
def __init__(self,
in_channels: int,
out_channels: int,
num_deconv_layers: int,
num_deconv_filters: List[int],
init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
type='Normal',
std=0.001,
layer=['Conv2d', 'ConvTranspose2d'])):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.deconv_layers = self._make_deconv_layer(
num_deconv_layers,
num_deconv_filters,
)
conv_layers = []
conv_layers.append(
build_conv_layer(
dict(type='Conv2d'),
in_channels=num_deconv_filters[-1],
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1))
conv_layers.append(build_norm_layer(dict(type='BN'), out_channels)[1])
conv_layers.append(nn.ReLU(inplace=True))
self.conv_layers = nn.Sequential(*conv_layers)
self.up_sample = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False)
def forward(self, x):
"""Forward pass through the decoder network."""
out = self.deconv_layers(x)
out = self.conv_layers(out)
out = self.up_sample(out)
out = self.up_sample(out)
return out
def _make_deconv_layer(self, num_layers, num_deconv_filters):
"""Make deconv layers."""
layers = []
in_channels = self.in_channels
for i in range(num_layers):
num_channels = num_deconv_filters[i]
layers.append(
build_upsample_layer(
dict(type='deconv'),
in_channels=in_channels,
out_channels=num_channels,
kernel_size=2,
stride=2,
padding=0,
output_padding=0,
bias=False))
layers.append(nn.BatchNorm2d(num_channels))
layers.append(nn.ReLU(inplace=True))
in_channels = num_channels
return nn.Sequential(*layers)
@MODELS.register_module()
class VPDDepthHead(BaseDecodeHead):
"""Depth Prediction Head for VPD.
.. _`VPD`: https://arxiv.org/abs/2303.02153
Args:
max_depth (float): Maximum depth value. Defaults to 10.0.
in_channels (Sequence[int]): Number of input channels for each
convolutional layer.
embed_dim (int): Dimension of embedding. Defaults to 192.
feature_dim (int): Dimension of aggregated feature. Defaults to 1536.
num_deconv_layers (int): Number of deconvolution layers in the
decoder. Defaults to 3.
num_deconv_filters (Sequence[int]): Number of filters for each deconv
layer. Defaults to (32, 32, 32).
fmap_border (Union[int, Sequence[int]]): Feature map border for
cropping. Defaults to 0.
align_corners (bool): Flag for align_corners in interpolation.
Defaults to False.
loss_decode (dict): Configurations for the loss function. Defaults to
dict(type='SiLogLoss').
init_cfg (dict): Initialization configurations. Defaults to
dict(type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']).
"""
num_classes = 1
out_channels = 1
input_transform = None
def __init__(
self,
max_depth: float = 10.0,
in_channels: Sequence[int] = [320, 640, 1280, 1280],
embed_dim: int = 192,
feature_dim: int = 1536,
num_deconv_layers: int = 3,
num_deconv_filters: Sequence[int] = (32, 32, 32),
fmap_border: Union[int, Sequence[int]] = 0,
align_corners: bool = False,
loss_decode: dict = dict(type='SiLogLoss'),
init_cfg=dict(
type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']),
):
super(BaseDecodeHead, self).__init__(init_cfg=init_cfg)
# initialize parameters
self.in_channels = in_channels
self.max_depth = max_depth
self.align_corners = align_corners
# feature map border
if isinstance(fmap_border, int):
fmap_border = (fmap_border, fmap_border)
self.fmap_border = fmap_border
# define network layers
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1),
nn.GroupNorm(16, in_channels[0]),
nn.ReLU(),
nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1),
)
self.conv2 = nn.Conv2d(
in_channels[1], in_channels[1], 3, stride=2, padding=1)
self.conv_aggregation = nn.Sequential(
nn.Conv2d(sum(in_channels), feature_dim, 1),
nn.GroupNorm(16, feature_dim),
nn.ReLU(),
)
self.decoder = VPDDepthDecoder(
in_channels=embed_dim * 8,
out_channels=embed_dim,
num_deconv_layers=num_deconv_layers,
num_deconv_filters=num_deconv_filters)
self.depth_pred_layer = nn.Sequential(
nn.Conv2d(
embed_dim, embed_dim, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(embed_dim, 1, kernel_size=3, stride=1, padding=1))
# build loss
if isinstance(loss_decode, dict):
self.loss_decode = MODELS.build(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
self.loss_decode = nn.ModuleList()
for loss in loss_decode:
self.loss_decode.append(MODELS.build(loss))
else:
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
but got {type(loss_decode)}')
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
gt_depth_maps = [
data_sample.gt_depth_map.data for data_sample in batch_data_samples
]
return torch.stack(gt_depth_maps, dim=0)
def forward(self, x):
x = [
x[0], x[1],
torch.cat([x[2], F.interpolate(x[3], scale_factor=2)], dim=1)
]
x = torch.cat([self.conv1(x[0]), self.conv2(x[1]), x[2]], dim=1)
x = self.conv_aggregation(x)
x = x[:, :, :x.size(2) - self.fmap_border[0], :x.size(3) -
self.fmap_border[1]].contiguous()
x = self.decoder(x)
out = self.depth_pred_layer(x)
depth = torch.sigmoid(out) * self.max_depth
return depth
def loss_by_feat(self, pred_depth_map: Tensor,
batch_data_samples: SampleList) -> dict:
"""Compute depth estimation loss.
Args:
pred_depth_map (Tensor): The output from decode head forward
function.
batch_data_samples (List[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_dpeth_map`.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
gt_depth_map = self._stack_batch_gt(batch_data_samples)
loss = dict()
pred_depth_map = resize(
input=pred_depth_map,
size=gt_depth_map.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
pred_depth_map, gt_depth_map)
else:
loss[loss_decode.loss_name] += loss_decode(
pred_depth_map, gt_depth_map)
return loss