init
This commit is contained in:
48
finetune/mmseg/models/decode_heads/__init__.py
Normal file
48
finetune/mmseg/models/decode_heads/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .ann_head import ANNHead
|
||||
from .apc_head import APCHead
|
||||
from .aspp_head import ASPPHead
|
||||
from .cc_head import CCHead
|
||||
from .da_head import DAHead
|
||||
from .ddr_head import DDRHead
|
||||
from .dm_head import DMHead
|
||||
from .dnl_head import DNLHead
|
||||
from .dpt_head import DPTHead
|
||||
from .ema_head import EMAHead
|
||||
from .enc_head import EncHead
|
||||
from .fcn_head import FCNHead
|
||||
from .fpn_head import FPNHead
|
||||
from .gc_head import GCHead
|
||||
from .ham_head import LightHamHead
|
||||
from .isa_head import ISAHead
|
||||
from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator
|
||||
from .lraspp_head import LRASPPHead
|
||||
from .mask2former_head import Mask2FormerHead
|
||||
from .maskformer_head import MaskFormerHead
|
||||
from .nl_head import NLHead
|
||||
from .ocr_head import OCRHead
|
||||
from .pid_head import PIDHead
|
||||
from .point_head import PointHead
|
||||
from .psa_head import PSAHead
|
||||
from .psp_head import PSPHead
|
||||
from .san_head import SideAdapterCLIPHead
|
||||
from .segformer_head import SegformerHead
|
||||
from .segmenter_mask_head import SegmenterMaskTransformerHead
|
||||
from .sep_aspp_head import DepthwiseSeparableASPPHead
|
||||
from .sep_fcn_head import DepthwiseSeparableFCNHead
|
||||
from .setr_mla_head import SETRMLAHead
|
||||
from .setr_up_head import SETRUPHead
|
||||
from .stdc_head import STDCHead
|
||||
from .uper_head import UPerHead
|
||||
from .vpd_depth_head import VPDDepthHead
|
||||
|
||||
__all__ = [
|
||||
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
|
||||
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
||||
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
|
||||
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
|
||||
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
|
||||
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
|
||||
'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead',
|
||||
'LightHamHead', 'PIDHead', 'DDRHead', 'VPDDepthHead', 'SideAdapterCLIPHead'
|
||||
]
|
||||
245
finetune/mmseg/models/decode_heads/ann_head.py
Normal file
245
finetune/mmseg/models/decode_heads/ann_head.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class PPMConcat(nn.ModuleList):
|
||||
"""Pyramid Pooling Module that only concat the features of each layer.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 3, 6, 8)):
|
||||
super().__init__(
|
||||
[nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales])
|
||||
|
||||
def forward(self, feats):
|
||||
"""Forward function."""
|
||||
ppm_outs = []
|
||||
for ppm in self:
|
||||
ppm_out = ppm(feats)
|
||||
ppm_outs.append(ppm_out.view(*feats.shape[:2], -1))
|
||||
concat_outs = torch.cat(ppm_outs, dim=2)
|
||||
return concat_outs
|
||||
|
||||
|
||||
class SelfAttentionBlock(_SelfAttentionBlock):
|
||||
"""Make a ANN used SelfAttentionBlock.
|
||||
|
||||
Args:
|
||||
low_in_channels (int): Input channels of lower level feature,
|
||||
which is the key feature for self-attention.
|
||||
high_in_channels (int): Input channels of higher level feature,
|
||||
which is the query feature for self-attention.
|
||||
channels (int): Output channels of key/query transform.
|
||||
out_channels (int): Output channels.
|
||||
share_key_query (bool): Whether share projection weight between key
|
||||
and query projection.
|
||||
query_scale (int): The scale of query feature map.
|
||||
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module of key feature.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict|None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, low_in_channels, high_in_channels, channels,
|
||||
out_channels, share_key_query, query_scale, key_pool_scales,
|
||||
conv_cfg, norm_cfg, act_cfg):
|
||||
key_psp = PPMConcat(key_pool_scales)
|
||||
if query_scale > 1:
|
||||
query_downsample = nn.MaxPool2d(kernel_size=query_scale)
|
||||
else:
|
||||
query_downsample = None
|
||||
super().__init__(
|
||||
key_in_channels=low_in_channels,
|
||||
query_in_channels=high_in_channels,
|
||||
channels=channels,
|
||||
out_channels=out_channels,
|
||||
share_key_query=share_key_query,
|
||||
query_downsample=query_downsample,
|
||||
key_downsample=key_psp,
|
||||
key_query_num_convs=1,
|
||||
key_query_norm=True,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=False,
|
||||
matmul_norm=True,
|
||||
with_out=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
|
||||
class AFNB(nn.Module):
|
||||
"""Asymmetric Fusion Non-local Block(AFNB)
|
||||
|
||||
Args:
|
||||
low_in_channels (int): Input channels of lower level feature,
|
||||
which is the key feature for self-attention.
|
||||
high_in_channels (int): Input channels of higher level feature,
|
||||
which is the query feature for self-attention.
|
||||
channels (int): Output channels of key/query transform.
|
||||
out_channels (int): Output channels.
|
||||
and query projection.
|
||||
query_scales (tuple[int]): The scales of query feature map.
|
||||
Default: (1,)
|
||||
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module of key feature.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict|None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, low_in_channels, high_in_channels, channels,
|
||||
out_channels, query_scales, key_pool_scales, conv_cfg,
|
||||
norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.stages = nn.ModuleList()
|
||||
for query_scale in query_scales:
|
||||
self.stages.append(
|
||||
SelfAttentionBlock(
|
||||
low_in_channels=low_in_channels,
|
||||
high_in_channels=high_in_channels,
|
||||
channels=channels,
|
||||
out_channels=out_channels,
|
||||
share_key_query=False,
|
||||
query_scale=query_scale,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.bottleneck = ConvModule(
|
||||
out_channels + high_in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
def forward(self, low_feats, high_feats):
|
||||
"""Forward function."""
|
||||
priors = [stage(high_feats, low_feats) for stage in self.stages]
|
||||
context = torch.stack(priors, dim=0).sum(dim=0)
|
||||
output = self.bottleneck(torch.cat([context, high_feats], 1))
|
||||
return output
|
||||
|
||||
|
||||
class APNB(nn.Module):
|
||||
"""Asymmetric Pyramid Non-local Block (APNB)
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of key/query feature,
|
||||
which is the key feature for self-attention.
|
||||
channels (int): Output channels of key/query transform.
|
||||
out_channels (int): Output channels.
|
||||
query_scales (tuple[int]): The scales of query feature map.
|
||||
Default: (1,)
|
||||
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module of key feature.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict|None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, channels, out_channels, query_scales,
|
||||
key_pool_scales, conv_cfg, norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.stages = nn.ModuleList()
|
||||
for query_scale in query_scales:
|
||||
self.stages.append(
|
||||
SelfAttentionBlock(
|
||||
low_in_channels=in_channels,
|
||||
high_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=out_channels,
|
||||
share_key_query=True,
|
||||
query_scale=query_scale,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.bottleneck = ConvModule(
|
||||
2 * in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, feats):
|
||||
"""Forward function."""
|
||||
priors = [stage(feats, feats) for stage in self.stages]
|
||||
context = torch.stack(priors, dim=0).sum(dim=0)
|
||||
output = self.bottleneck(torch.cat([context, feats], 1))
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ANNHead(BaseDecodeHead):
|
||||
"""Asymmetric Non-local Neural Networks for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `ANNNet
|
||||
<https://arxiv.org/abs/1908.07678>`_.
|
||||
|
||||
Args:
|
||||
project_channels (int): Projection channels for Nonlocal.
|
||||
query_scales (tuple[int]): The scales of query feature map.
|
||||
Default: (1,)
|
||||
key_pool_scales (tuple[int]): The pooling scales of key feature map.
|
||||
Default: (1, 3, 6, 8).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
project_channels,
|
||||
query_scales=(1, ),
|
||||
key_pool_scales=(1, 3, 6, 8),
|
||||
**kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
assert len(self.in_channels) == 2
|
||||
low_in_channels, high_in_channels = self.in_channels
|
||||
self.project_channels = project_channels
|
||||
self.fusion = AFNB(
|
||||
low_in_channels=low_in_channels,
|
||||
high_in_channels=high_in_channels,
|
||||
out_channels=high_in_channels,
|
||||
channels=project_channels,
|
||||
query_scales=query_scales,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.bottleneck = ConvModule(
|
||||
high_in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.context = APNB(
|
||||
in_channels=self.channels,
|
||||
out_channels=self.channels,
|
||||
channels=project_channels,
|
||||
query_scales=query_scales,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
low_feats, high_feats = self._transform_inputs(inputs)
|
||||
output = self.fusion(low_feats, high_feats)
|
||||
output = self.dropout(output)
|
||||
output = self.bottleneck(output)
|
||||
output = self.context(output)
|
||||
output = self.cls_seg(output)
|
||||
|
||||
return output
|
||||
159
finetune/mmseg/models/decode_heads/apc_head.py
Normal file
159
finetune/mmseg/models/decode_heads/apc_head.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class ACM(nn.Module):
|
||||
"""Adaptive Context Module used in APCNet.
|
||||
|
||||
Args:
|
||||
pool_scale (int): Pooling scale used in Adaptive Context
|
||||
Module to extract region features.
|
||||
fusion (bool): Add one conv to fuse residual feature.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg,
|
||||
norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.pool_scale = pool_scale
|
||||
self.fusion = fusion
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.pooled_redu_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.input_redu_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.global_info = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)
|
||||
|
||||
self.residual_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
if self.fusion:
|
||||
self.fusion_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale)
|
||||
# [batch_size, channels, h, w]
|
||||
x = self.input_redu_conv(x)
|
||||
# [batch_size, channels, pool_scale, pool_scale]
|
||||
pooled_x = self.pooled_redu_conv(pooled_x)
|
||||
batch_size = x.size(0)
|
||||
# [batch_size, pool_scale * pool_scale, channels]
|
||||
pooled_x = pooled_x.view(batch_size, self.channels,
|
||||
-1).permute(0, 2, 1).contiguous()
|
||||
# [batch_size, h * w, pool_scale * pool_scale]
|
||||
affinity_matrix = self.gla(x + resize(
|
||||
self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:])
|
||||
).permute(0, 2, 3, 1).reshape(
|
||||
batch_size, -1, self.pool_scale**2)
|
||||
affinity_matrix = F.sigmoid(affinity_matrix)
|
||||
# [batch_size, h * w, channels]
|
||||
z_out = torch.matmul(affinity_matrix, pooled_x)
|
||||
# [batch_size, channels, h * w]
|
||||
z_out = z_out.permute(0, 2, 1).contiguous()
|
||||
# [batch_size, channels, h, w]
|
||||
z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3))
|
||||
z_out = self.residual_conv(z_out)
|
||||
z_out = F.relu(z_out + x)
|
||||
if self.fusion:
|
||||
z_out = self.fusion_conv(z_out)
|
||||
|
||||
return z_out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class APCHead(BaseDecodeHead):
|
||||
"""Adaptive Pyramid Context Network for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of
|
||||
`APCNet <https://openaccess.thecvf.com/content_CVPR_2019/papers/\
|
||||
He_Adaptive_Pyramid_Context_Network_for_Semantic_Segmentation_\
|
||||
CVPR_2019_paper.pdf>`_.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Adaptive Context
|
||||
Module. Default: (1, 2, 3, 6).
|
||||
fusion (bool): Add one conv to fuse residual feature.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(pool_scales, (list, tuple))
|
||||
self.pool_scales = pool_scales
|
||||
self.fusion = fusion
|
||||
acm_modules = []
|
||||
for pool_scale in self.pool_scales:
|
||||
acm_modules.append(
|
||||
ACM(pool_scale,
|
||||
self.fusion,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.acm_modules = nn.ModuleList(acm_modules)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels + len(pool_scales) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
acm_outs = [x]
|
||||
for acm_module in self.acm_modules:
|
||||
acm_outs.append(acm_module(x))
|
||||
acm_outs = torch.cat(acm_outs, dim=1)
|
||||
output = self.bottleneck(acm_outs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
122
finetune/mmseg/models/decode_heads/aspp_head.py
Normal file
122
finetune/mmseg/models/decode_heads/aspp_head.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class ASPPModule(nn.ModuleList):
|
||||
"""Atrous Spatial Pyramid Pooling (ASPP) Module.
|
||||
|
||||
Args:
|
||||
dilations (tuple[int]): Dilation rate of each layer.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,
|
||||
act_cfg):
|
||||
super().__init__()
|
||||
self.dilations = dilations
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
for dilation in dilations:
|
||||
self.append(
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1 if dilation == 1 else 3,
|
||||
dilation=dilation,
|
||||
padding=0 if dilation == 1 else dilation,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
aspp_outs = []
|
||||
for aspp_module in self:
|
||||
aspp_outs.append(aspp_module(x))
|
||||
|
||||
return aspp_outs
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ASPPHead(BaseDecodeHead):
|
||||
"""Rethinking Atrous Convolution for Semantic Image Segmentation.
|
||||
|
||||
This head is the implementation of `DeepLabV3
|
||||
<https://arxiv.org/abs/1706.05587>`_.
|
||||
|
||||
Args:
|
||||
dilations (tuple[int]): Dilation rates for ASPP module.
|
||||
Default: (1, 6, 12, 18).
|
||||
"""
|
||||
|
||||
def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(dilations, (list, tuple))
|
||||
self.dilations = dilations
|
||||
self.image_pool = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.aspp_modules = ASPPModule(
|
||||
dilations,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.bottleneck = ConvModule(
|
||||
(len(dilations) + 1) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def _forward_feature(self, inputs):
|
||||
"""Forward function for feature maps before classifying each pixel with
|
||||
``self.cls_seg`` fc.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||
H, W) which is feature map for last layer of decoder head.
|
||||
"""
|
||||
x = self._transform_inputs(inputs)
|
||||
aspp_outs = [
|
||||
resize(
|
||||
self.image_pool(x),
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
]
|
||||
aspp_outs.extend(self.aspp_modules(x))
|
||||
aspp_outs = torch.cat(aspp_outs, dim=1)
|
||||
feats = self.bottleneck(aspp_outs)
|
||||
return feats
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
output = self._forward_feature(inputs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
62
finetune/mmseg/models/decode_heads/cascade_decode_head.py
Normal file
62
finetune/mmseg/models/decode_heads/cascade_decode_head.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.utils import ConfigType
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
|
||||
"""Base class for cascade decode head used in
|
||||
:class:`CascadeEncoderDecoder."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, inputs, prev_output):
|
||||
"""Placeholder of forward function."""
|
||||
pass
|
||||
|
||||
def loss(self, inputs: List[Tensor], prev_output: Tensor,
|
||||
batch_data_samples: List[dict], train_cfg: ConfigType) -> Tensor:
|
||||
"""Forward function for training.
|
||||
|
||||
Args:
|
||||
inputs (List[Tensor]): List of multi-level img features.
|
||||
prev_output (Tensor): The output of previous decode head.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_sem_seg`.
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
seg_logits = self.forward(inputs, prev_output)
|
||||
losses = self.loss_by_feat(seg_logits, batch_data_samples)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, inputs: List[Tensor], prev_output: Tensor,
|
||||
batch_img_metas: List[dict], tese_cfg: ConfigType):
|
||||
"""Forward function for testing.
|
||||
|
||||
Args:
|
||||
inputs (List[Tensor]): List of multi-level img features.
|
||||
prev_output (Tensor): The output of previous decode head.
|
||||
batch_img_metas (dict): List Image info where each dict may also
|
||||
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Output segmentation map.
|
||||
"""
|
||||
seg_logits = self.forward(inputs, prev_output)
|
||||
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas)
|
||||
43
finetune/mmseg/models/decode_heads/cc_head.py
Normal file
43
finetune/mmseg/models/decode_heads/cc_head.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
try:
|
||||
from mmcv.ops import CrissCrossAttention
|
||||
except ModuleNotFoundError:
|
||||
CrissCrossAttention = None
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CCHead(FCNHead):
|
||||
"""CCNet: Criss-Cross Attention for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `CCNet
|
||||
<https://arxiv.org/abs/1811.11721>`_.
|
||||
|
||||
Args:
|
||||
recurrence (int): Number of recurrence of Criss Cross Attention
|
||||
module. Default: 2.
|
||||
"""
|
||||
|
||||
def __init__(self, recurrence=2, **kwargs):
|
||||
if CrissCrossAttention is None:
|
||||
raise RuntimeError('Please install mmcv-full for '
|
||||
'CrissCrossAttention ops')
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.recurrence = recurrence
|
||||
self.cca = CrissCrossAttention(self.channels)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
for _ in range(self.recurrence):
|
||||
output = self.cca(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
184
finetune/mmseg/models/decode_heads/da_head.py
Normal file
184
finetune/mmseg/models/decode_heads/da_head.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, Scale
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList, add_prefix
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class PAM(_SelfAttentionBlock):
|
||||
"""Position Attention Module (PAM)
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of key/query feature.
|
||||
channels (int): Output channels of key/query transform.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, channels):
|
||||
super().__init__(
|
||||
key_in_channels=in_channels,
|
||||
query_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=in_channels,
|
||||
share_key_query=False,
|
||||
query_downsample=None,
|
||||
key_downsample=None,
|
||||
key_query_num_convs=1,
|
||||
key_query_norm=False,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=False,
|
||||
matmul_norm=False,
|
||||
with_out=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=None)
|
||||
|
||||
self.gamma = Scale(0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
out = super().forward(x, x)
|
||||
|
||||
out = self.gamma(out) + x
|
||||
return out
|
||||
|
||||
|
||||
class CAM(nn.Module):
|
||||
"""Channel Attention Module (CAM)"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gamma = Scale(0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
batch_size, channels, height, width = x.size()
|
||||
proj_query = x.view(batch_size, channels, -1)
|
||||
proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1)
|
||||
energy = torch.bmm(proj_query, proj_key)
|
||||
energy_new = torch.max(
|
||||
energy, -1, keepdim=True)[0].expand_as(energy) - energy
|
||||
attention = F.softmax(energy_new, dim=-1)
|
||||
proj_value = x.view(batch_size, channels, -1)
|
||||
|
||||
out = torch.bmm(attention, proj_value)
|
||||
out = out.view(batch_size, channels, height, width)
|
||||
|
||||
out = self.gamma(out) + x
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DAHead(BaseDecodeHead):
|
||||
"""Dual Attention Network for Scene Segmentation.
|
||||
|
||||
This head is the implementation of `DANet
|
||||
<https://arxiv.org/abs/1809.02983>`_.
|
||||
|
||||
Args:
|
||||
pam_channels (int): The channels of Position Attention Module(PAM).
|
||||
"""
|
||||
|
||||
def __init__(self, pam_channels, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.pam_channels = pam_channels
|
||||
self.pam_in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.pam = PAM(self.channels, pam_channels)
|
||||
self.pam_out_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.pam_conv_seg = nn.Conv2d(
|
||||
self.channels, self.num_classes, kernel_size=1)
|
||||
|
||||
self.cam_in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.cam = CAM()
|
||||
self.cam_out_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.cam_conv_seg = nn.Conv2d(
|
||||
self.channels, self.num_classes, kernel_size=1)
|
||||
|
||||
def pam_cls_seg(self, feat):
|
||||
"""PAM feature classification."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.pam_conv_seg(feat)
|
||||
return output
|
||||
|
||||
def cam_cls_seg(self, feat):
|
||||
"""CAM feature classification."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.cam_conv_seg(feat)
|
||||
return output
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
pam_feat = self.pam_in_conv(x)
|
||||
pam_feat = self.pam(pam_feat)
|
||||
pam_feat = self.pam_out_conv(pam_feat)
|
||||
pam_out = self.pam_cls_seg(pam_feat)
|
||||
|
||||
cam_feat = self.cam_in_conv(x)
|
||||
cam_feat = self.cam(cam_feat)
|
||||
cam_feat = self.cam_out_conv(cam_feat)
|
||||
cam_out = self.cam_cls_seg(cam_feat)
|
||||
|
||||
feat_sum = pam_feat + cam_feat
|
||||
pam_cam_out = self.cls_seg(feat_sum)
|
||||
|
||||
return pam_cam_out, pam_out, cam_out
|
||||
|
||||
def predict(self, inputs, batch_img_metas: List[dict], test_cfg,
|
||||
**kwargs) -> List[Tensor]:
|
||||
"""Forward function for testing, only ``pam_cam`` is used."""
|
||||
seg_logits = self.forward(inputs)[0]
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
|
||||
|
||||
def loss_by_feat(self, seg_logit: Tuple[Tensor],
|
||||
batch_data_samples: SampleList, **kwargs) -> dict:
|
||||
"""Compute ``pam_cam``, ``pam``, ``cam`` loss."""
|
||||
pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit
|
||||
loss = dict()
|
||||
loss.update(
|
||||
add_prefix(
|
||||
super().loss_by_feat(pam_cam_seg_logit, batch_data_samples),
|
||||
'pam_cam'))
|
||||
loss.update(
|
||||
add_prefix(super().loss_by_feat(pam_seg_logit, batch_data_samples),
|
||||
'pam'))
|
||||
loss.update(
|
||||
add_prefix(super().loss_by_feat(cam_seg_logit, batch_data_samples),
|
||||
'cam'))
|
||||
return loss
|
||||
116
finetune/mmseg/models/decode_heads/ddr_head.py
Normal file
116
finetune/mmseg/models/decode_heads/ddr_head.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.models.losses import accuracy
|
||||
from mmseg.models.utils import resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType, SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DDRHead(BaseDecodeHead):
|
||||
"""Decode head for DDRNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
num_classes (int): Number of classes.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict, optional): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
num_classes: int,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
channels,
|
||||
num_classes=num_classes,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**kwargs)
|
||||
|
||||
self.head = self._make_base_head(self.in_channels, self.channels)
|
||||
self.aux_head = self._make_base_head(self.in_channels // 2,
|
||||
self.channels)
|
||||
self.aux_cls_seg = nn.Conv2d(
|
||||
self.channels, self.out_channels, kernel_size=1)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: Union[Tensor,
|
||||
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
|
||||
if self.training:
|
||||
c3_feat, c5_feat = inputs
|
||||
x_c = self.head(c5_feat)
|
||||
x_c = self.cls_seg(x_c)
|
||||
x_s = self.aux_head(c3_feat)
|
||||
x_s = self.aux_cls_seg(x_s)
|
||||
|
||||
return x_c, x_s
|
||||
else:
|
||||
x_c = self.head(inputs)
|
||||
x_c = self.cls_seg(x_c)
|
||||
return x_c
|
||||
|
||||
def _make_base_head(self, in_channels: int,
|
||||
channels: int) -> nn.Sequential:
|
||||
layers = [
|
||||
ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
order=('norm', 'act', 'conv')),
|
||||
build_norm_layer(self.norm_cfg, channels)[1],
|
||||
build_activation_layer(self.act_cfg),
|
||||
]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def loss_by_feat(self, seg_logits: Tuple[Tensor],
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
loss = dict()
|
||||
context_logit, spatial_logit = seg_logits
|
||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||
|
||||
context_logit = resize(
|
||||
context_logit,
|
||||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
spatial_logit = resize(
|
||||
spatial_logit,
|
||||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
seg_label = seg_label.squeeze(1)
|
||||
|
||||
loss['loss_context'] = self.loss_decode[0](context_logit, seg_label)
|
||||
loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label)
|
||||
loss['acc_seg'] = accuracy(
|
||||
context_logit, seg_label, ignore_index=self.ignore_index)
|
||||
|
||||
return loss
|
||||
366
finetune/mmseg/models/decode_heads/decode_head.py
Normal file
366
finetune/mmseg/models/decode_heads/decode_head.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import BaseModule
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import build_pixel_sampler
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
from ..losses import accuracy
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
"""Base class for BaseDecodeHead.
|
||||
|
||||
1. The ``init_weights`` method is used to initialize decode_head's
|
||||
model parameters. After segmentor initialization, ``init_weights``
|
||||
is triggered when ``segmentor.init_weights()`` is called externally.
|
||||
|
||||
2. The ``loss`` method is used to calculate the loss of decode_head,
|
||||
which includes two steps: (1) the decode_head model performs forward
|
||||
propagation to obtain the feature maps (2) The ``loss_by_feat`` method
|
||||
is called based on the feature maps to calculate the loss.
|
||||
|
||||
.. code:: text
|
||||
|
||||
loss(): forward() -> loss_by_feat()
|
||||
|
||||
3. The ``predict`` method is used to predict segmentation results,
|
||||
which includes two steps: (1) the decode_head model performs forward
|
||||
propagation to obtain the feature maps (2) The ``predict_by_feat`` method
|
||||
is called based on the feature maps to predict segmentation results
|
||||
including post-processing.
|
||||
|
||||
.. code:: text
|
||||
|
||||
predict(): forward() -> predict_by_feat()
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
num_classes (int): Number of classes.
|
||||
out_channels (int): Output channels of conv_seg. Default: None.
|
||||
threshold (float): Threshold for binary segmentation in the case of
|
||||
`num_classes==1`. Default: None.
|
||||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
||||
conv_cfg (dict|None): Config of conv layers. Default: None.
|
||||
norm_cfg (dict|None): Config of norm layers. Default: None.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU')
|
||||
in_index (int|Sequence[int]): Input feature index. Default: -1
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
Default: None.
|
||||
loss_decode (dict | Sequence[dict]): Config of decode loss.
|
||||
The `loss_name` is property of corresponding loss function which
|
||||
could be shown in training log. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_ce'.
|
||||
e.g. dict(type='CrossEntropyLoss'),
|
||||
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
|
||||
dict(type='DiceLoss', loss_name='loss_dice')]
|
||||
Default: dict(type='CrossEntropyLoss').
|
||||
ignore_index (int | None): The label index to be ignored. When using
|
||||
masked BCE loss, ignore_index should be set to None. Default: 255.
|
||||
sampler (dict|None): The config of segmentation map sampler.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channels,
|
||||
*,
|
||||
num_classes,
|
||||
out_channels=None,
|
||||
threshold=None,
|
||||
dropout_ratio=0.1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
in_index=-1,
|
||||
input_transform=None,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0),
|
||||
ignore_index=255,
|
||||
sampler=None,
|
||||
align_corners=False,
|
||||
init_cfg=dict(
|
||||
type='Normal', std=0.01, override=dict(name='conv_seg'))):
|
||||
super().__init__(init_cfg)
|
||||
self._init_inputs(in_channels, in_index, input_transform)
|
||||
self.channels = channels
|
||||
self.dropout_ratio = dropout_ratio
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.in_index = in_index
|
||||
|
||||
self.ignore_index = ignore_index
|
||||
self.align_corners = align_corners
|
||||
|
||||
if out_channels is None:
|
||||
if num_classes == 2:
|
||||
warnings.warn('For binary segmentation, we suggest using'
|
||||
'`out_channels = 1` to define the output'
|
||||
'channels of segmentor, and use `threshold`'
|
||||
'to convert `seg_logits` into a prediction'
|
||||
'applying a threshold')
|
||||
out_channels = num_classes
|
||||
|
||||
if out_channels != num_classes and out_channels != 1:
|
||||
raise ValueError(
|
||||
'out_channels should be equal to num_classes,'
|
||||
'except binary segmentation set out_channels == 1 and'
|
||||
f'num_classes == 2, but got out_channels={out_channels}'
|
||||
f'and num_classes={num_classes}')
|
||||
|
||||
if out_channels == 1 and threshold is None:
|
||||
threshold = 0.3
|
||||
warnings.warn('threshold is not defined for binary, and defaults'
|
||||
'to 0.3')
|
||||
self.num_classes = num_classes
|
||||
self.out_channels = out_channels
|
||||
self.threshold = threshold
|
||||
|
||||
if isinstance(loss_decode, dict):
|
||||
self.loss_decode = MODELS.build(loss_decode)
|
||||
elif isinstance(loss_decode, (list, tuple)):
|
||||
self.loss_decode = nn.ModuleList()
|
||||
for loss in loss_decode:
|
||||
self.loss_decode.append(MODELS.build(loss))
|
||||
else:
|
||||
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
|
||||
but got {type(loss_decode)}')
|
||||
|
||||
if sampler is not None:
|
||||
self.sampler = build_pixel_sampler(sampler, context=self)
|
||||
else:
|
||||
self.sampler = None
|
||||
|
||||
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
|
||||
if dropout_ratio > 0:
|
||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
def extra_repr(self):
|
||||
"""Extra repr."""
|
||||
s = f'input_transform={self.input_transform}, ' \
|
||||
f'ignore_index={self.ignore_index}, ' \
|
||||
f'align_corners={self.align_corners}'
|
||||
return s
|
||||
|
||||
def _init_inputs(self, in_channels, in_index, input_transform):
|
||||
"""Check and initialize input transforms.
|
||||
|
||||
The in_channels, in_index and input_transform must match.
|
||||
Specifically, when input_transform is None, only single feature map
|
||||
will be selected. So in_channels and in_index must be of type int.
|
||||
When input_transform
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
in_index (int|Sequence[int]): Input feature index.
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
"""
|
||||
|
||||
if input_transform is not None:
|
||||
assert input_transform in ['resize_concat', 'multiple_select']
|
||||
self.input_transform = input_transform
|
||||
self.in_index = in_index
|
||||
if input_transform is not None:
|
||||
assert isinstance(in_channels, (list, tuple))
|
||||
assert isinstance(in_index, (list, tuple))
|
||||
assert len(in_channels) == len(in_index)
|
||||
if input_transform == 'resize_concat':
|
||||
self.in_channels = sum(in_channels)
|
||||
else:
|
||||
self.in_channels = in_channels
|
||||
else:
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(in_index, int)
|
||||
self.in_channels = in_channels
|
||||
|
||||
def _transform_inputs(self, inputs):
|
||||
"""Transform inputs for decoder.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
Tensor: The transformed inputs
|
||||
"""
|
||||
|
||||
if self.input_transform == 'resize_concat':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
upsampled_inputs = [
|
||||
resize(
|
||||
input=x,
|
||||
size=inputs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners) for x in inputs
|
||||
]
|
||||
inputs = torch.cat(upsampled_inputs, dim=1)
|
||||
elif self.input_transform == 'multiple_select':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
else:
|
||||
inputs = inputs[self.in_index]
|
||||
|
||||
return inputs
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, inputs):
|
||||
"""Placeholder of forward function."""
|
||||
pass
|
||||
|
||||
def cls_seg(self, feat):
|
||||
"""Classify each pixel."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.conv_seg(feat)
|
||||
return output
|
||||
|
||||
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList,
|
||||
train_cfg: ConfigType) -> dict:
|
||||
"""Forward function for training.
|
||||
|
||||
Args:
|
||||
inputs (Tuple[Tensor]): List of multi-level img features.
|
||||
batch_data_samples (list[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `img_metas` or `gt_semantic_seg`.
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
seg_logits = self.forward(inputs)
|
||||
losses = self.loss_by_feat(seg_logits, batch_data_samples)
|
||||
return losses
|
||||
|
||||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType) -> Tensor:
|
||||
"""Forward function for prediction.
|
||||
|
||||
Args:
|
||||
inputs (Tuple[Tensor]): List of multi-level img features.
|
||||
batch_img_metas (dict): List Image info where each dict may also
|
||||
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Outputs segmentation logits map.
|
||||
"""
|
||||
seg_logits = self.forward(inputs)
|
||||
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas)
|
||||
|
||||
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
|
||||
gt_semantic_segs = [
|
||||
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
|
||||
]
|
||||
return torch.stack(gt_semantic_segs, dim=0)
|
||||
|
||||
def loss_by_feat(self, seg_logits: Tensor,
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
"""Compute segmentation loss.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): The output from decode head forward function.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
|
||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||
loss = dict()
|
||||
seg_logits = resize(
|
||||
input=seg_logits,
|
||||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.sampler is not None:
|
||||
seg_weight = self.sampler.sample(seg_logits, seg_label)
|
||||
else:
|
||||
seg_weight = None
|
||||
seg_label = seg_label.squeeze(1)
|
||||
|
||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.loss_decode]
|
||||
else:
|
||||
losses_decode = self.loss_decode
|
||||
for loss_decode in losses_decode:
|
||||
if loss_decode.loss_name not in loss:
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
seg_logits,
|
||||
seg_label,
|
||||
weight=seg_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
else:
|
||||
loss[loss_decode.loss_name] += loss_decode(
|
||||
seg_logits,
|
||||
seg_label,
|
||||
weight=seg_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
|
||||
loss['acc_seg'] = accuracy(
|
||||
seg_logits, seg_label, ignore_index=self.ignore_index)
|
||||
return loss
|
||||
|
||||
def predict_by_feat(self, seg_logits: Tensor,
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Transform a batch of output seg_logits to the input shape.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): The output from decode head forward function.
|
||||
batch_img_metas (list[dict]): Meta information of each image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
|
||||
Returns:
|
||||
Tensor: Outputs segmentation logits map.
|
||||
"""
|
||||
|
||||
if isinstance(batch_img_metas[0]['img_shape'], torch.Size):
|
||||
# slide inference
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
elif 'pad_shape' in batch_img_metas[0]:
|
||||
size = batch_img_metas[0]['pad_shape'][:2]
|
||||
else:
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
|
||||
seg_logits = resize(
|
||||
input=seg_logits,
|
||||
size=size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
return seg_logits
|
||||
141
finetune/mmseg/models/decode_heads/dm_head.py
Normal file
141
finetune/mmseg/models/decode_heads/dm_head.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class DCM(nn.Module):
|
||||
"""Dynamic Convolutional Module used in DMNet.
|
||||
|
||||
Args:
|
||||
filter_size (int): The filter size of generated convolution kernel
|
||||
used in Dynamic Convolutional Module.
|
||||
fusion (bool): Add one conv to fuse DCM output feature.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg,
|
||||
norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.filter_size = filter_size
|
||||
self.fusion = fusion
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1,
|
||||
0)
|
||||
|
||||
self.input_redu_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
if self.norm_cfg is not None:
|
||||
self.norm = build_norm_layer(self.norm_cfg, self.channels)[1]
|
||||
else:
|
||||
self.norm = None
|
||||
self.activate = build_activation_layer(self.act_cfg)
|
||||
|
||||
if self.fusion:
|
||||
self.fusion_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
generated_filter = self.filter_gen_conv(
|
||||
F.adaptive_avg_pool2d(x, self.filter_size))
|
||||
x = self.input_redu_conv(x)
|
||||
b, c, h, w = x.shape
|
||||
# [1, b * c, h, w], c = self.channels
|
||||
x = x.view(1, b * c, h, w)
|
||||
# [b * c, 1, filter_size, filter_size]
|
||||
generated_filter = generated_filter.view(b * c, 1, self.filter_size,
|
||||
self.filter_size)
|
||||
pad = (self.filter_size - 1) // 2
|
||||
if (self.filter_size - 1) % 2 == 0:
|
||||
p2d = (pad, pad, pad, pad)
|
||||
else:
|
||||
p2d = (pad + 1, pad, pad + 1, pad)
|
||||
x = F.pad(input=x, pad=p2d, mode='constant', value=0)
|
||||
# [1, b * c, h, w]
|
||||
output = F.conv2d(input=x, weight=generated_filter, groups=b * c)
|
||||
# [b, c, h, w]
|
||||
output = output.view(b, c, h, w)
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
output = self.activate(output)
|
||||
|
||||
if self.fusion:
|
||||
output = self.fusion_conv(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DMHead(BaseDecodeHead):
|
||||
"""Dynamic Multi-scale Filters for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of
|
||||
`DMNet <https://openaccess.thecvf.com/content_ICCV_2019/papers/\
|
||||
He_Dynamic_Multi-Scale_Filters_for_Semantic_Segmentation_\
|
||||
ICCV_2019_paper.pdf>`_.
|
||||
|
||||
Args:
|
||||
filter_sizes (tuple[int]): The size of generated convolutional filters
|
||||
used in Dynamic Convolutional Module. Default: (1, 3, 5, 7).
|
||||
fusion (bool): Add one conv to fuse DCM output feature.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(filter_sizes, (list, tuple))
|
||||
self.filter_sizes = filter_sizes
|
||||
self.fusion = fusion
|
||||
dcm_modules = []
|
||||
for filter_size in self.filter_sizes:
|
||||
dcm_modules.append(
|
||||
DCM(filter_size,
|
||||
self.fusion,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.dcm_modules = nn.ModuleList(dcm_modules)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels + len(filter_sizes) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
dcm_outs = [x]
|
||||
for dcm_module in self.dcm_modules:
|
||||
dcm_outs.append(dcm_module(x))
|
||||
dcm_outs = torch.cat(dcm_outs, dim=1)
|
||||
output = self.bottleneck(dcm_outs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
137
finetune/mmseg/models/decode_heads/dnl_head.py
Normal file
137
finetune/mmseg/models/decode_heads/dnl_head.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.cnn import NonLocal2d
|
||||
from torch import nn
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
class DisentangledNonLocal2d(NonLocal2d):
|
||||
"""Disentangled Non-Local Blocks.
|
||||
|
||||
Args:
|
||||
temperature (float): Temperature to adjust attention. Default: 0.05
|
||||
"""
|
||||
|
||||
def __init__(self, *arg, temperature, **kwargs):
|
||||
super().__init__(*arg, **kwargs)
|
||||
self.temperature = temperature
|
||||
self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)
|
||||
|
||||
def embedded_gaussian(self, theta_x, phi_x):
|
||||
"""Embedded gaussian with temperature."""
|
||||
|
||||
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
||||
pairwise_weight = torch.matmul(theta_x, phi_x)
|
||||
if self.use_scale:
|
||||
# theta_x.shape[-1] is `self.inter_channels`
|
||||
pairwise_weight /= torch.tensor(
|
||||
theta_x.shape[-1],
|
||||
dtype=torch.float,
|
||||
device=pairwise_weight.device)**torch.tensor(
|
||||
0.5, device=pairwise_weight.device)
|
||||
pairwise_weight /= torch.tensor(
|
||||
self.temperature, device=pairwise_weight.device)
|
||||
pairwise_weight = pairwise_weight.softmax(dim=-1)
|
||||
return pairwise_weight
|
||||
|
||||
def forward(self, x):
|
||||
# x: [N, C, H, W]
|
||||
n = x.size(0)
|
||||
|
||||
# g_x: [N, HxW, C]
|
||||
g_x = self.g(x).view(n, self.inter_channels, -1)
|
||||
g_x = g_x.permute(0, 2, 1)
|
||||
|
||||
# theta_x: [N, HxW, C], phi_x: [N, C, HxW]
|
||||
if self.mode == 'gaussian':
|
||||
theta_x = x.view(n, self.in_channels, -1)
|
||||
theta_x = theta_x.permute(0, 2, 1)
|
||||
if self.sub_sample:
|
||||
phi_x = self.phi(x).view(n, self.in_channels, -1)
|
||||
else:
|
||||
phi_x = x.view(n, self.in_channels, -1)
|
||||
elif self.mode == 'concatenation':
|
||||
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
|
||||
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
|
||||
else:
|
||||
theta_x = self.theta(x).view(n, self.inter_channels, -1)
|
||||
theta_x = theta_x.permute(0, 2, 1)
|
||||
phi_x = self.phi(x).view(n, self.inter_channels, -1)
|
||||
|
||||
# subtract mean
|
||||
theta_x -= theta_x.mean(dim=-2, keepdim=True)
|
||||
phi_x -= phi_x.mean(dim=-1, keepdim=True)
|
||||
|
||||
pairwise_func = getattr(self, self.mode)
|
||||
# pairwise_weight: [N, HxW, HxW]
|
||||
pairwise_weight = pairwise_func(theta_x, phi_x)
|
||||
|
||||
# y: [N, HxW, C]
|
||||
y = torch.matmul(pairwise_weight, g_x)
|
||||
# y: [N, C, H, W]
|
||||
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
|
||||
*x.size()[2:])
|
||||
|
||||
# unary_mask: [N, 1, HxW]
|
||||
unary_mask = self.conv_mask(x)
|
||||
unary_mask = unary_mask.view(n, 1, -1)
|
||||
unary_mask = unary_mask.softmax(dim=-1)
|
||||
# unary_x: [N, 1, C]
|
||||
unary_x = torch.matmul(unary_mask, g_x)
|
||||
# unary_x: [N, C, 1, 1]
|
||||
unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(
|
||||
n, self.inter_channels, 1, 1)
|
||||
|
||||
output = x + self.conv_out(y + unary_x)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DNLHead(FCNHead):
|
||||
"""Disentangled Non-Local Neural Networks.
|
||||
|
||||
This head is the implementation of `DNLNet
|
||||
<https://arxiv.org/abs/2006.06668>`_.
|
||||
|
||||
Args:
|
||||
reduction (int): Reduction factor of projection transform. Default: 2.
|
||||
use_scale (bool): Whether to scale pairwise_weight by
|
||||
sqrt(1/inter_channels). Default: False.
|
||||
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
|
||||
'dot_product'. Default: 'embedded_gaussian.'.
|
||||
temperature (float): Temperature to adjust attention. Default: 0.05
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reduction=2,
|
||||
use_scale=True,
|
||||
mode='embedded_gaussian',
|
||||
temperature=0.05,
|
||||
**kwargs):
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.reduction = reduction
|
||||
self.use_scale = use_scale
|
||||
self.mode = mode
|
||||
self.temperature = temperature
|
||||
self.dnl_block = DisentangledNonLocal2d(
|
||||
in_channels=self.channels,
|
||||
reduction=self.reduction,
|
||||
use_scale=self.use_scale,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
mode=self.mode,
|
||||
temperature=self.temperature)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
output = self.dnl_block(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
294
finetune/mmseg/models/decode_heads/dpt_head.py
Normal file
294
finetune/mmseg/models/decode_heads/dpt_head.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, Linear, build_activation_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class ReassembleBlocks(BaseModule):
|
||||
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
|
||||
rearrange the feature vector to feature map.
|
||||
|
||||
Args:
|
||||
in_channels (int): ViT feature channels. Default: 768.
|
||||
out_channels (List): output channels of each stage.
|
||||
Default: [96, 192, 384, 768].
|
||||
readout_type (str): Type of readout operation. Default: 'ignore'.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=768,
|
||||
out_channels=[96, 192, 384, 768],
|
||||
readout_type='ignore',
|
||||
patch_size=16,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
assert readout_type in ['ignore', 'add', 'project']
|
||||
self.readout_type = readout_type
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.projects = nn.ModuleList([
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
act_cfg=None,
|
||||
) for out_channel in out_channels
|
||||
])
|
||||
|
||||
self.resize_layers = nn.ModuleList([
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[0],
|
||||
out_channels=out_channels[0],
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[1],
|
||||
out_channels=out_channels[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0),
|
||||
nn.Identity(),
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels[3],
|
||||
out_channels=out_channels[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1)
|
||||
])
|
||||
if self.readout_type == 'project':
|
||||
self.readout_projects = nn.ModuleList()
|
||||
for _ in range(len(self.projects)):
|
||||
self.readout_projects.append(
|
||||
nn.Sequential(
|
||||
Linear(2 * in_channels, in_channels),
|
||||
build_activation_layer(dict(type='GELU'))))
|
||||
|
||||
def forward(self, inputs):
|
||||
assert isinstance(inputs, list)
|
||||
out = []
|
||||
for i, x in enumerate(inputs):
|
||||
assert len(x) == 2
|
||||
x, cls_token = x[0], x[1]
|
||||
feature_shape = x.shape
|
||||
if self.readout_type == 'project':
|
||||
x = x.flatten(2).permute((0, 2, 1))
|
||||
readout = cls_token.unsqueeze(1).expand_as(x)
|
||||
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
||||
x = x.permute(0, 2, 1).reshape(feature_shape)
|
||||
elif self.readout_type == 'add':
|
||||
x = x.flatten(2) + cls_token.unsqueeze(-1)
|
||||
x = x.reshape(feature_shape)
|
||||
else:
|
||||
pass
|
||||
x = self.projects[i](x)
|
||||
x = self.resize_layers[i](x)
|
||||
out.append(x)
|
||||
return out
|
||||
|
||||
|
||||
class PreActResidualConvUnit(BaseModule):
|
||||
"""ResidualConvUnit, pre-activate residual unit.
|
||||
|
||||
Args:
|
||||
in_channels (int): number of channels in the input feature map.
|
||||
act_cfg (dict): dictionary to construct and config activation layer.
|
||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||
stride (int): stride of the first block. Default: 1
|
||||
dilation (int): dilation rate for convs layers. Default: 1.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
act_cfg,
|
||||
norm_cfg,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.conv1 = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
bias=False,
|
||||
order=('act', 'conv', 'norm'))
|
||||
|
||||
self.conv2 = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
bias=False,
|
||||
order=('act', 'conv', 'norm'))
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs_ = inputs.clone()
|
||||
x = self.conv1(inputs)
|
||||
x = self.conv2(x)
|
||||
return x + inputs_
|
||||
|
||||
|
||||
class FeatureFusionBlock(BaseModule):
|
||||
"""FeatureFusionBlock, merge feature map from different stages.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
act_cfg (dict): The activation config for ResidualConvUnit.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
expand (bool): Whether expand the channels in post process block.
|
||||
Default: False.
|
||||
align_corners (bool): align_corner setting for bilinear upsample.
|
||||
Default: True.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
act_cfg,
|
||||
norm_cfg,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.expand = expand
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.out_channels = in_channels
|
||||
if self.expand:
|
||||
self.out_channels = in_channels // 2
|
||||
|
||||
self.project = ConvModule(
|
||||
self.in_channels,
|
||||
self.out_channels,
|
||||
kernel_size=1,
|
||||
act_cfg=None,
|
||||
bias=True)
|
||||
|
||||
self.res_conv_unit1 = PreActResidualConvUnit(
|
||||
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
||||
self.res_conv_unit2 = PreActResidualConvUnit(
|
||||
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
||||
|
||||
def forward(self, *inputs):
|
||||
x = inputs[0]
|
||||
if len(inputs) == 2:
|
||||
if x.shape != inputs[1].shape:
|
||||
res = resize(
|
||||
inputs[1],
|
||||
size=(x.shape[2], x.shape[3]),
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
else:
|
||||
res = inputs[1]
|
||||
x = x + self.res_conv_unit1(res)
|
||||
x = self.res_conv_unit2(x)
|
||||
x = resize(
|
||||
x,
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.project(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DPTHead(BaseDecodeHead):
|
||||
"""Vision Transformers for Dense Prediction.
|
||||
|
||||
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The embed dimension of the ViT backbone.
|
||||
Default: 768.
|
||||
post_process_channels (List): Out channels of post process conv
|
||||
layers. Default: [96, 192, 384, 768].
|
||||
readout_type (str): Type of readout operation. Default: 'ignore'.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
expand_channels (bool): Whether expand the channels in post process
|
||||
block. Default: False.
|
||||
act_cfg (dict): The activation config for residual conv unit.
|
||||
Default dict(type='ReLU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims=768,
|
||||
post_process_channels=[96, 192, 384, 768],
|
||||
readout_type='ignore',
|
||||
patch_size=16,
|
||||
expand_channels=False,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
norm_cfg=dict(type='BN'),
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.in_channels = self.in_channels
|
||||
self.expand_channels = expand_channels
|
||||
self.reassemble_blocks = ReassembleBlocks(embed_dims,
|
||||
post_process_channels,
|
||||
readout_type, patch_size)
|
||||
|
||||
self.post_process_channels = [
|
||||
channel * math.pow(2, i) if expand_channels else channel
|
||||
for i, channel in enumerate(post_process_channels)
|
||||
]
|
||||
self.convs = nn.ModuleList()
|
||||
for channel in self.post_process_channels:
|
||||
self.convs.append(
|
||||
ConvModule(
|
||||
channel,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
act_cfg=None,
|
||||
bias=False))
|
||||
self.fusion_blocks = nn.ModuleList()
|
||||
for _ in range(len(self.convs)):
|
||||
self.fusion_blocks.append(
|
||||
FeatureFusionBlock(self.channels, act_cfg, norm_cfg))
|
||||
self.fusion_blocks[0].res_conv_unit1 = None
|
||||
self.project = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg)
|
||||
self.num_fusion_blocks = len(self.fusion_blocks)
|
||||
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
|
||||
self.num_post_process_channels = len(self.post_process_channels)
|
||||
assert self.num_fusion_blocks == self.num_reassemble_blocks
|
||||
assert self.num_reassemble_blocks == self.num_post_process_channels
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == self.num_reassemble_blocks
|
||||
x = self._transform_inputs(inputs)
|
||||
x = self.reassemble_blocks(x)
|
||||
x = [self.convs[i](feature) for i, feature in enumerate(x)]
|
||||
out = self.fusion_blocks[0](x[-1])
|
||||
for i in range(1, len(self.fusion_blocks)):
|
||||
out = self.fusion_blocks[i](out, x[-(i + 1)])
|
||||
out = self.project(out)
|
||||
out = self.cls_seg(out)
|
||||
return out
|
||||
169
finetune/mmseg/models/decode_heads/ema_head.py
Normal file
169
finetune/mmseg/models/decode_heads/ema_head.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
def reduce_mean(tensor):
|
||||
"""Reduce mean when distributed training."""
|
||||
if not (dist.is_available() and dist.is_initialized()):
|
||||
return tensor
|
||||
tensor = tensor.clone()
|
||||
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
|
||||
return tensor
|
||||
|
||||
|
||||
class EMAModule(nn.Module):
|
||||
"""Expectation Maximization Attention Module used in EMANet.
|
||||
|
||||
Args:
|
||||
channels (int): Channels of the whole module.
|
||||
num_bases (int): Number of bases.
|
||||
num_stages (int): Number of the EM iterations.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, num_bases, num_stages, momentum):
|
||||
super().__init__()
|
||||
assert num_stages >= 1, 'num_stages must be at least 1!'
|
||||
self.num_bases = num_bases
|
||||
self.num_stages = num_stages
|
||||
self.momentum = momentum
|
||||
|
||||
bases = torch.zeros(1, channels, self.num_bases)
|
||||
bases.normal_(0, math.sqrt(2. / self.num_bases))
|
||||
# [1, channels, num_bases]
|
||||
bases = F.normalize(bases, dim=1, p=2)
|
||||
self.register_buffer('bases', bases)
|
||||
|
||||
def forward(self, feats):
|
||||
"""Forward function."""
|
||||
batch_size, channels, height, width = feats.size()
|
||||
# [batch_size, channels, height*width]
|
||||
feats = feats.view(batch_size, channels, height * width)
|
||||
# [batch_size, channels, num_bases]
|
||||
bases = self.bases.repeat(batch_size, 1, 1)
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(self.num_stages):
|
||||
# [batch_size, height*width, num_bases]
|
||||
attention = torch.einsum('bcn,bck->bnk', feats, bases)
|
||||
attention = F.softmax(attention, dim=2)
|
||||
# l1 norm
|
||||
attention_normed = F.normalize(attention, dim=1, p=1)
|
||||
# [batch_size, channels, num_bases]
|
||||
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
|
||||
# l2 norm
|
||||
bases = F.normalize(bases, dim=1, p=2)
|
||||
|
||||
feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
|
||||
feats_recon = feats_recon.view(batch_size, channels, height, width)
|
||||
|
||||
if self.training:
|
||||
bases = bases.mean(dim=0, keepdim=True)
|
||||
bases = reduce_mean(bases)
|
||||
# l2 norm
|
||||
bases = F.normalize(bases, dim=1, p=2)
|
||||
self.bases = (1 -
|
||||
self.momentum) * self.bases + self.momentum * bases
|
||||
|
||||
return feats_recon
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class EMAHead(BaseDecodeHead):
|
||||
"""Expectation Maximization Attention Networks for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `EMANet
|
||||
<https://arxiv.org/abs/1907.13426>`_.
|
||||
|
||||
Args:
|
||||
ema_channels (int): EMA module channels
|
||||
num_bases (int): Number of bases.
|
||||
num_stages (int): Number of the EM iterations.
|
||||
concat_input (bool): Whether concat the input and output of convs
|
||||
before classification layer. Default: True
|
||||
momentum (float): Momentum to update the base. Default: 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ema_channels,
|
||||
num_bases,
|
||||
num_stages,
|
||||
concat_input=True,
|
||||
momentum=0.1,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.ema_channels = ema_channels
|
||||
self.num_bases = num_bases
|
||||
self.num_stages = num_stages
|
||||
self.concat_input = concat_input
|
||||
self.momentum = momentum
|
||||
self.ema_module = EMAModule(self.ema_channels, self.num_bases,
|
||||
self.num_stages, self.momentum)
|
||||
|
||||
self.ema_in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.ema_channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
# project (0, inf) -> (-inf, inf)
|
||||
self.ema_mid_conv = ConvModule(
|
||||
self.ema_channels,
|
||||
self.ema_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=None,
|
||||
act_cfg=None)
|
||||
for param in self.ema_mid_conv.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.ema_out_conv = ConvModule(
|
||||
self.ema_channels,
|
||||
self.ema_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
self.bottleneck = ConvModule(
|
||||
self.ema_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
if self.concat_input:
|
||||
self.conv_cat = ConvModule(
|
||||
self.in_channels + self.channels,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
feats = self.ema_in_conv(x)
|
||||
identity = feats
|
||||
feats = self.ema_mid_conv(feats)
|
||||
recon = self.ema_module(feats)
|
||||
recon = F.relu(recon, inplace=True)
|
||||
recon = self.ema_out_conv(recon)
|
||||
output = F.relu(identity + recon, inplace=True)
|
||||
output = self.bottleneck(output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
196
finetune/mmseg/models/decode_heads/enc_head.py
Normal file
196
finetune/mmseg/models/decode_heads/enc_head.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
from ..utils import Encoding, resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class EncModule(nn.Module):
|
||||
"""Encoding Module used in EncNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
num_codes (int): Number of code words.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.encoding_project = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
# TODO: resolve this hack
|
||||
# change to 1d
|
||||
if norm_cfg is not None:
|
||||
encoding_norm_cfg = norm_cfg.copy()
|
||||
if encoding_norm_cfg['type'] in ['BN', 'IN']:
|
||||
encoding_norm_cfg['type'] += '1d'
|
||||
else:
|
||||
encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(
|
||||
'2d', '1d')
|
||||
else:
|
||||
# fallback to BN1d
|
||||
encoding_norm_cfg = dict(type='BN1d')
|
||||
self.encoding = nn.Sequential(
|
||||
Encoding(channels=in_channels, num_codes=num_codes),
|
||||
build_norm_layer(encoding_norm_cfg, num_codes)[1],
|
||||
nn.ReLU(inplace=True))
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_channels, in_channels), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
encoding_projection = self.encoding_project(x)
|
||||
encoding_feat = self.encoding(encoding_projection).mean(dim=1)
|
||||
batch_size, channels, _, _ = x.size()
|
||||
gamma = self.fc(encoding_feat)
|
||||
y = gamma.view(batch_size, channels, 1, 1)
|
||||
output = F.relu_(x + x * y)
|
||||
return encoding_feat, output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class EncHead(BaseDecodeHead):
|
||||
"""Context Encoding for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `EncNet
|
||||
<https://arxiv.org/abs/1803.08904>`_.
|
||||
|
||||
Args:
|
||||
num_codes (int): Number of code words. Default: 32.
|
||||
use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to
|
||||
regularize the training. Default: True.
|
||||
add_lateral (bool): Whether use lateral connection to fuse features.
|
||||
Default: False.
|
||||
loss_se_decode (dict): Config of decode loss.
|
||||
Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_codes=32,
|
||||
use_se_loss=True,
|
||||
add_lateral=False,
|
||||
loss_se_decode=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
loss_weight=0.2),
|
||||
**kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
self.use_se_loss = use_se_loss
|
||||
self.add_lateral = add_lateral
|
||||
self.num_codes = num_codes
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels[-1],
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
if add_lateral:
|
||||
self.lateral_convs = nn.ModuleList()
|
||||
for in_channels in self.in_channels[:-1]: # skip the last one
|
||||
self.lateral_convs.append(
|
||||
ConvModule(
|
||||
in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.fusion = ConvModule(
|
||||
len(self.in_channels) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.enc_module = EncModule(
|
||||
self.channels,
|
||||
num_codes=num_codes,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
if self.use_se_loss:
|
||||
self.loss_se_decode = MODELS.build(loss_se_decode)
|
||||
self.se_layer = nn.Linear(self.channels, self.num_classes)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
inputs = self._transform_inputs(inputs)
|
||||
feat = self.bottleneck(inputs[-1])
|
||||
if self.add_lateral:
|
||||
laterals = [
|
||||
resize(
|
||||
lateral_conv(inputs[i]),
|
||||
size=feat.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||
]
|
||||
feat = self.fusion(torch.cat([feat, *laterals], 1))
|
||||
encode_feat, output = self.enc_module(feat)
|
||||
output = self.cls_seg(output)
|
||||
if self.use_se_loss:
|
||||
se_output = self.se_layer(encode_feat)
|
||||
return output, se_output
|
||||
else:
|
||||
return output
|
||||
|
||||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType):
|
||||
"""Forward function for testing, ignore se_loss."""
|
||||
if self.use_se_loss:
|
||||
seg_logits = self.forward(inputs)[0]
|
||||
else:
|
||||
seg_logits = self.forward(inputs)
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas)
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_onehot_labels(seg_label, num_classes):
|
||||
"""Convert segmentation label to onehot.
|
||||
|
||||
Args:
|
||||
seg_label (Tensor): Segmentation label of shape (N, H, W).
|
||||
num_classes (int): Number of classes.
|
||||
|
||||
Returns:
|
||||
Tensor: Onehot labels of shape (N, num_classes).
|
||||
"""
|
||||
|
||||
batch_size = seg_label.size(0)
|
||||
onehot_labels = seg_label.new_zeros((batch_size, num_classes))
|
||||
for i in range(batch_size):
|
||||
hist = seg_label[i].float().histc(
|
||||
bins=num_classes, min=0, max=num_classes - 1)
|
||||
onehot_labels[i] = hist > 0
|
||||
return onehot_labels
|
||||
|
||||
def loss_by_feat(self, seg_logit: Tuple[Tensor],
|
||||
batch_data_samples: SampleList, **kwargs) -> dict:
|
||||
"""Compute segmentation and semantic encoding loss."""
|
||||
seg_logit, se_seg_logit = seg_logit
|
||||
loss = dict()
|
||||
loss.update(super().loss_by_feat(seg_logit, batch_data_samples))
|
||||
|
||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||
se_loss = self.loss_se_decode(
|
||||
se_seg_logit,
|
||||
self._convert_to_onehot_labels(seg_label, self.num_classes))
|
||||
loss['loss_se'] = se_loss
|
||||
return loss
|
||||
96
finetune/mmseg/models/decode_heads/fcn_head.py
Normal file
96
finetune/mmseg/models/decode_heads/fcn_head.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FCNHead(BaseDecodeHead):
|
||||
"""Fully Convolution Networks for Semantic Segmentation.
|
||||
|
||||
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
|
||||
|
||||
Args:
|
||||
num_convs (int): Number of convs in the head. Default: 2.
|
||||
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
||||
concat_input (bool): Whether concat the input and output of convs
|
||||
before classification layer.
|
||||
dilation (int): The dilation rate for convs in the head. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_convs=2,
|
||||
kernel_size=3,
|
||||
concat_input=True,
|
||||
dilation=1,
|
||||
**kwargs):
|
||||
assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int)
|
||||
self.num_convs = num_convs
|
||||
self.concat_input = concat_input
|
||||
self.kernel_size = kernel_size
|
||||
super().__init__(**kwargs)
|
||||
if num_convs == 0:
|
||||
assert self.in_channels == self.channels
|
||||
|
||||
conv_padding = (kernel_size // 2) * dilation
|
||||
convs = []
|
||||
convs.append(
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=conv_padding,
|
||||
dilation=dilation,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
for i in range(num_convs - 1):
|
||||
convs.append(
|
||||
ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=conv_padding,
|
||||
dilation=dilation,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
if num_convs == 0:
|
||||
self.convs = nn.Identity()
|
||||
else:
|
||||
self.convs = nn.Sequential(*convs)
|
||||
if self.concat_input:
|
||||
self.conv_cat = ConvModule(
|
||||
self.in_channels + self.channels,
|
||||
self.channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def _forward_feature(self, inputs):
|
||||
"""Forward function for feature maps before classifying each pixel with
|
||||
``self.cls_seg`` fc.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||
H, W) which is feature map for last layer of decoder head.
|
||||
"""
|
||||
x = self._transform_inputs(inputs)
|
||||
feats = self.convs(x)
|
||||
if self.concat_input:
|
||||
feats = self.conv_cat(torch.cat([x, feats], dim=1))
|
||||
return feats
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
output = self._forward_feature(inputs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
68
finetune/mmseg/models/decode_heads/fpn_head.py
Normal file
68
finetune/mmseg/models/decode_heads/fpn_head.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample, resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FPNHead(BaseDecodeHead):
|
||||
"""Panoptic Feature Pyramid Networks.
|
||||
|
||||
This head is the implementation of `Semantic FPN
|
||||
<https://arxiv.org/abs/1901.02446>`_.
|
||||
|
||||
Args:
|
||||
feature_strides (tuple[int]): The strides for input feature maps.
|
||||
stack_lateral. All strides suppose to be power of 2. The first
|
||||
one is of largest resolution.
|
||||
"""
|
||||
|
||||
def __init__(self, feature_strides, **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
assert len(feature_strides) == len(self.in_channels)
|
||||
assert min(feature_strides) == feature_strides[0]
|
||||
self.feature_strides = feature_strides
|
||||
|
||||
self.scale_heads = nn.ModuleList()
|
||||
for i in range(len(feature_strides)):
|
||||
head_length = max(
|
||||
1,
|
||||
int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
|
||||
scale_head = []
|
||||
for k in range(head_length):
|
||||
scale_head.append(
|
||||
ConvModule(
|
||||
self.in_channels[i] if k == 0 else self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
if feature_strides[i] != feature_strides[0]:
|
||||
scale_head.append(
|
||||
Upsample(
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners))
|
||||
self.scale_heads.append(nn.Sequential(*scale_head))
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
x = self._transform_inputs(inputs)
|
||||
|
||||
output = self.scale_heads[0](x[0])
|
||||
for i in range(1, len(self.feature_strides)):
|
||||
# non inplace
|
||||
output = output + resize(
|
||||
self.scale_heads[i](x[i]),
|
||||
size=output.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
48
finetune/mmseg/models/decode_heads/gc_head.py
Normal file
48
finetune/mmseg/models/decode_heads/gc_head.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.cnn import ContextBlock
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class GCHead(FCNHead):
|
||||
"""GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.
|
||||
|
||||
This head is the implementation of `GCNet
|
||||
<https://arxiv.org/abs/1904.11492>`_.
|
||||
|
||||
Args:
|
||||
ratio (float): Multiplier of channels ratio. Default: 1/4.
|
||||
pooling_type (str): The pooling type of context aggregation.
|
||||
Options are 'att', 'avg'. Default: 'avg'.
|
||||
fusion_types (tuple[str]): The fusion type for feature fusion.
|
||||
Options are 'channel_add', 'channel_mul'. Default: ('channel_add',)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ratio=1 / 4.,
|
||||
pooling_type='att',
|
||||
fusion_types=('channel_add', ),
|
||||
**kwargs):
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.ratio = ratio
|
||||
self.pooling_type = pooling_type
|
||||
self.fusion_types = fusion_types
|
||||
self.gc_block = ContextBlock(
|
||||
in_channels=self.channels,
|
||||
ratio=self.ratio,
|
||||
pooling_type=self.pooling_type,
|
||||
fusion_types=self.fusion_types)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
output = self.gc_block(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
255
finetune/mmseg/models/decode_heads/ham_head.py
Normal file
255
finetune/mmseg/models/decode_heads/ham_head.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Originally from https://github.com/visual-attention-network/segnext
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.device import get_device
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class Matrix_Decomposition_2D_Base(nn.Module):
|
||||
"""Base class of 2D Matrix Decomposition.
|
||||
|
||||
Args:
|
||||
MD_S (int): The number of spatial coefficient in
|
||||
Matrix Decomposition, it may be used for calculation
|
||||
of the number of latent dimension D in Matrix
|
||||
Decomposition. Defaults: 1.
|
||||
MD_R (int): The number of latent dimension R in
|
||||
Matrix Decomposition. Defaults: 64.
|
||||
train_steps (int): The number of iteration steps in
|
||||
Multiplicative Update (MU) rule to solve Non-negative
|
||||
Matrix Factorization (NMF) in training. Defaults: 6.
|
||||
eval_steps (int): The number of iteration steps in
|
||||
Multiplicative Update (MU) rule to solve Non-negative
|
||||
Matrix Factorization (NMF) in evaluation. Defaults: 7.
|
||||
inv_t (int): Inverted multiple number to make coefficient
|
||||
smaller in softmax. Defaults: 100.
|
||||
rand_init (bool): Whether to initialize randomly.
|
||||
Defaults: True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
MD_S=1,
|
||||
MD_R=64,
|
||||
train_steps=6,
|
||||
eval_steps=7,
|
||||
inv_t=100,
|
||||
rand_init=True):
|
||||
super().__init__()
|
||||
|
||||
self.S = MD_S
|
||||
self.R = MD_R
|
||||
|
||||
self.train_steps = train_steps
|
||||
self.eval_steps = eval_steps
|
||||
|
||||
self.inv_t = inv_t
|
||||
|
||||
self.rand_init = rand_init
|
||||
|
||||
def _build_bases(self, B, S, D, R, device=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def local_step(self, x, bases, coef):
|
||||
raise NotImplementedError
|
||||
|
||||
def local_inference(self, x, bases):
|
||||
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
coef = torch.bmm(x.transpose(1, 2), bases)
|
||||
coef = F.softmax(self.inv_t * coef, dim=-1)
|
||||
|
||||
steps = self.train_steps if self.training else self.eval_steps
|
||||
for _ in range(steps):
|
||||
bases, coef = self.local_step(x, bases, coef)
|
||||
|
||||
return bases, coef
|
||||
|
||||
def compute_coef(self, x, bases, coef):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x, return_bases=False):
|
||||
"""Forward Function."""
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# (B, C, H, W) -> (B * S, D, N)
|
||||
D = C // self.S
|
||||
N = H * W
|
||||
x = x.view(B * self.S, D, N)
|
||||
if not self.rand_init and not hasattr(self, 'bases'):
|
||||
bases = self._build_bases(1, self.S, D, self.R, device=x.device)
|
||||
self.register_buffer('bases', bases)
|
||||
|
||||
# (S, D, R) -> (B * S, D, R)
|
||||
if self.rand_init:
|
||||
bases = self._build_bases(B, self.S, D, self.R, device=x.device)
|
||||
else:
|
||||
bases = self.bases.repeat(B, 1, 1)
|
||||
|
||||
bases, coef = self.local_inference(x, bases)
|
||||
|
||||
# (B * S, N, R)
|
||||
coef = self.compute_coef(x, bases, coef)
|
||||
|
||||
# (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
|
||||
x = torch.bmm(bases, coef.transpose(1, 2))
|
||||
|
||||
# (B * S, D, N) -> (B, C, H, W)
|
||||
x = x.view(B, C, H, W)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class NMF2D(Matrix_Decomposition_2D_Base):
|
||||
"""Non-negative Matrix Factorization (NMF) module.
|
||||
|
||||
It is inherited from ``Matrix_Decomposition_2D_Base`` module.
|
||||
"""
|
||||
|
||||
def __init__(self, args=dict()):
|
||||
super().__init__(**args)
|
||||
|
||||
self.inv_t = 1
|
||||
|
||||
def _build_bases(self, B, S, D, R, device=None):
|
||||
"""Build bases in initialization."""
|
||||
if device is None:
|
||||
device = get_device()
|
||||
bases = torch.rand((B * S, D, R)).to(device)
|
||||
bases = F.normalize(bases, dim=1)
|
||||
|
||||
return bases
|
||||
|
||||
def local_step(self, x, bases, coef):
|
||||
"""Local step in iteration to renew bases and coefficient."""
|
||||
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
numerator = torch.bmm(x.transpose(1, 2), bases)
|
||||
# (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
|
||||
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
|
||||
# Multiplicative Update
|
||||
coef = coef * numerator / (denominator + 1e-6)
|
||||
|
||||
# (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
|
||||
numerator = torch.bmm(x, coef)
|
||||
# (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
|
||||
denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
|
||||
# Multiplicative Update
|
||||
bases = bases * numerator / (denominator + 1e-6)
|
||||
|
||||
return bases, coef
|
||||
|
||||
def compute_coef(self, x, bases, coef):
|
||||
"""Compute coefficient."""
|
||||
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
numerator = torch.bmm(x.transpose(1, 2), bases)
|
||||
# (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
|
||||
# multiplication update
|
||||
coef = coef * numerator / (denominator + 1e-6)
|
||||
|
||||
return coef
|
||||
|
||||
|
||||
class Hamburger(nn.Module):
|
||||
"""Hamburger Module. It consists of one slice of "ham" (matrix
|
||||
decomposition) and two slices of "bread" (linear transformation).
|
||||
|
||||
Args:
|
||||
ham_channels (int): Input and output channels of feature.
|
||||
ham_kwargs (dict): Config of matrix decomposition module.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ham_channels=512,
|
||||
ham_kwargs=dict(),
|
||||
norm_cfg=None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.ham_in = ConvModule(
|
||||
ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None)
|
||||
|
||||
self.ham = NMF2D(ham_kwargs)
|
||||
|
||||
self.ham_out = ConvModule(
|
||||
ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
|
||||
def forward(self, x):
|
||||
enjoy = self.ham_in(x)
|
||||
enjoy = F.relu(enjoy, inplace=True)
|
||||
enjoy = self.ham(enjoy)
|
||||
enjoy = self.ham_out(enjoy)
|
||||
ham = F.relu(x + enjoy, inplace=True)
|
||||
|
||||
return ham
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class LightHamHead(BaseDecodeHead):
|
||||
"""SegNeXt decode head.
|
||||
|
||||
This decode head is the implementation of `SegNeXt: Rethinking
|
||||
Convolutional Attention Design for Semantic
|
||||
Segmentation <https://arxiv.org/abs/2209.08575>`_.
|
||||
Inspiration from https://github.com/visual-attention-network/segnext.
|
||||
|
||||
Specifically, LightHamHead is inspired by HamNet from
|
||||
`Is Attention Better Than Matrix Decomposition?
|
||||
<https://arxiv.org/abs/2109.04553>`.
|
||||
|
||||
Args:
|
||||
ham_channels (int): input channels for Hamburger.
|
||||
Defaults: 512.
|
||||
ham_kwargs (int): kwagrs for Ham. Defaults: dict().
|
||||
"""
|
||||
|
||||
def __init__(self, ham_channels=512, ham_kwargs=dict(), **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
self.ham_channels = ham_channels
|
||||
|
||||
self.squeeze = ConvModule(
|
||||
sum(self.in_channels),
|
||||
self.ham_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.hamburger = Hamburger(ham_channels, ham_kwargs, **kwargs)
|
||||
|
||||
self.align = ConvModule(
|
||||
self.ham_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
inputs = self._transform_inputs(inputs)
|
||||
|
||||
inputs = [
|
||||
resize(
|
||||
level,
|
||||
size=inputs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners) for level in inputs
|
||||
]
|
||||
|
||||
inputs = torch.cat(inputs, dim=1)
|
||||
# apply a conv block to squeeze feature map
|
||||
x = self.squeeze(inputs)
|
||||
# apply hamburger module
|
||||
x = self.hamburger(x)
|
||||
|
||||
# apply a conv block to align feature map
|
||||
output = self.align(x)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
143
finetune/mmseg/models/decode_heads/isa_head.py
Normal file
143
finetune/mmseg/models/decode_heads/isa_head.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class SelfAttentionBlock(_SelfAttentionBlock):
|
||||
"""Self-Attention Module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of key/query feature.
|
||||
channels (int): Output channels of key/query transform.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
act_cfg (dict | None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, channels, conv_cfg, norm_cfg, act_cfg):
|
||||
super().__init__(
|
||||
key_in_channels=in_channels,
|
||||
query_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=in_channels,
|
||||
share_key_query=False,
|
||||
query_downsample=None,
|
||||
key_downsample=None,
|
||||
key_query_num_convs=2,
|
||||
key_query_norm=True,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=False,
|
||||
matmul_norm=True,
|
||||
with_out=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.output_project = self.build_project(
|
||||
in_channels,
|
||||
in_channels,
|
||||
num_convs=1,
|
||||
use_conv_module=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
context = super().forward(x, x)
|
||||
return self.output_project(context)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ISAHead(BaseDecodeHead):
|
||||
"""Interlaced Sparse Self-Attention for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `ISA
|
||||
<https://arxiv.org/abs/1907.12273>`_.
|
||||
|
||||
Args:
|
||||
isa_channels (int): The channels of ISA Module.
|
||||
down_factor (tuple[int]): The local group size of ISA.
|
||||
"""
|
||||
|
||||
def __init__(self, isa_channels, down_factor=(8, 8), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.down_factor = down_factor
|
||||
|
||||
self.in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.global_relation = SelfAttentionBlock(
|
||||
self.channels,
|
||||
isa_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.local_relation = SelfAttentionBlock(
|
||||
self.channels,
|
||||
isa_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.out_conv = ConvModule(
|
||||
self.channels * 2,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x_ = self._transform_inputs(inputs)
|
||||
x = self.in_conv(x_)
|
||||
residual = x
|
||||
|
||||
n, c, h, w = x.size()
|
||||
loc_h, loc_w = self.down_factor # size of local group in H- and W-axes
|
||||
glb_h, glb_w = math.ceil(h / loc_h), math.ceil(w / loc_w)
|
||||
pad_h, pad_w = glb_h * loc_h - h, glb_w * loc_w - w
|
||||
if pad_h > 0 or pad_w > 0: # pad if the size is not divisible
|
||||
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
|
||||
pad_h - pad_h // 2)
|
||||
x = F.pad(x, padding)
|
||||
|
||||
# global relation
|
||||
x = x.view(n, c, glb_h, loc_h, glb_w, loc_w)
|
||||
# do permutation to gather global group
|
||||
x = x.permute(0, 3, 5, 1, 2, 4) # (n, loc_h, loc_w, c, glb_h, glb_w)
|
||||
x = x.reshape(-1, c, glb_h, glb_w)
|
||||
# apply attention within each global group
|
||||
x = self.global_relation(x) # (n * loc_h * loc_w, c, glb_h, glb_w)
|
||||
|
||||
# local relation
|
||||
x = x.view(n, loc_h, loc_w, c, glb_h, glb_w)
|
||||
# do permutation to gather local group
|
||||
x = x.permute(0, 4, 5, 3, 1, 2) # (n, glb_h, glb_w, c, loc_h, loc_w)
|
||||
x = x.reshape(-1, c, loc_h, loc_w)
|
||||
# apply attention within each local group
|
||||
x = self.local_relation(x) # (n * glb_h * glb_w, c, loc_h, loc_w)
|
||||
|
||||
# permute each pixel back to its original position
|
||||
x = x.view(n, glb_h, glb_w, c, loc_h, loc_w)
|
||||
x = x.permute(0, 3, 1, 4, 2, 5) # (n, c, glb_h, loc_h, glb_w, loc_w)
|
||||
x = x.reshape(n, c, glb_h * loc_h, glb_w * loc_w)
|
||||
if pad_h > 0 or pad_w > 0: # remove padding
|
||||
x = x[:, :, pad_h // 2:pad_h // 2 + h, pad_w // 2:pad_w // 2 + w]
|
||||
|
||||
x = self.out_conv(torch.cat([x, residual], dim=1))
|
||||
out = self.cls_seg(x)
|
||||
|
||||
return out
|
||||
461
finetune/mmseg/models/decode_heads/knet_head.py
Normal file
461
finetune/mmseg/models/decode_heads/knet_head.py
Normal file
@@ -0,0 +1,461 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import (FFN, MultiheadAttention,
|
||||
build_transformer_layer)
|
||||
from mmengine.logging import print_log
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class KernelUpdator(nn.Module):
|
||||
"""Dynamic Kernel Updator in Kernel Update Head.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels of input feature map.
|
||||
Default: 256.
|
||||
feat_channels (int): The number of middle-stage channels in
|
||||
the kernel updator. Default: 64.
|
||||
out_channels (int): The number of output channels.
|
||||
gate_sigmoid (bool): Whether use sigmoid function in gate
|
||||
mechanism. Default: True.
|
||||
gate_norm_act (bool): Whether add normalization and activation
|
||||
layer in gate mechanism. Default: False.
|
||||
activate_out: Whether add activation after gate mechanism.
|
||||
Default: False.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='LN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=256,
|
||||
feat_channels=64,
|
||||
out_channels=None,
|
||||
gate_sigmoid=True,
|
||||
gate_norm_act=False,
|
||||
activate_out=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.feat_channels = feat_channels
|
||||
self.out_channels_raw = out_channels
|
||||
self.gate_sigmoid = gate_sigmoid
|
||||
self.gate_norm_act = gate_norm_act
|
||||
self.activate_out = activate_out
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.out_channels = out_channels if out_channels else in_channels
|
||||
|
||||
self.num_params_in = self.feat_channels
|
||||
self.num_params_out = self.feat_channels
|
||||
self.dynamic_layer = nn.Linear(
|
||||
self.in_channels, self.num_params_in + self.num_params_out)
|
||||
self.input_layer = nn.Linear(self.in_channels,
|
||||
self.num_params_in + self.num_params_out,
|
||||
1)
|
||||
self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
|
||||
self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
|
||||
if self.gate_norm_act:
|
||||
self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
|
||||
self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
|
||||
self.activation = build_activation_layer(act_cfg)
|
||||
|
||||
self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1)
|
||||
self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
|
||||
|
||||
def forward(self, update_feature, input_feature):
|
||||
"""Forward function of KernelUpdator.
|
||||
|
||||
Args:
|
||||
update_feature (torch.Tensor): Feature map assembled from
|
||||
each group. It would be reshaped with last dimension
|
||||
shape: `self.in_channels`.
|
||||
input_feature (torch.Tensor): Intermediate feature
|
||||
with shape: (N, num_classes, conv_kernel_size**2, channels).
|
||||
Returns:
|
||||
Tensor: The output tensor of shape (N*C1/C2, K*K, C2), where N is
|
||||
the number of classes, C1 and C2 are the feature map channels of
|
||||
KernelUpdateHead and KernelUpdator, respectively.
|
||||
"""
|
||||
|
||||
update_feature = update_feature.reshape(-1, self.in_channels)
|
||||
num_proposals = update_feature.size(0)
|
||||
# dynamic_layer works for
|
||||
# phi_1 and psi_3 in Eq.(4) and (5) of K-Net paper
|
||||
parameters = self.dynamic_layer(update_feature)
|
||||
param_in = parameters[:, :self.num_params_in].view(
|
||||
-1, self.feat_channels)
|
||||
param_out = parameters[:, -self.num_params_out:].view(
|
||||
-1, self.feat_channels)
|
||||
|
||||
# input_layer works for
|
||||
# phi_2 and psi_4 in Eq.(4) and (5) of K-Net paper
|
||||
input_feats = self.input_layer(
|
||||
input_feature.reshape(num_proposals, -1, self.feat_channels))
|
||||
input_in = input_feats[..., :self.num_params_in]
|
||||
input_out = input_feats[..., -self.num_params_out:]
|
||||
|
||||
# `gate_feats` is F^G in K-Net paper
|
||||
gate_feats = input_in * param_in.unsqueeze(-2)
|
||||
if self.gate_norm_act:
|
||||
gate_feats = self.activation(self.gate_norm(gate_feats))
|
||||
|
||||
input_gate = self.input_norm_in(self.input_gate(gate_feats))
|
||||
update_gate = self.norm_in(self.update_gate(gate_feats))
|
||||
if self.gate_sigmoid:
|
||||
input_gate = input_gate.sigmoid()
|
||||
update_gate = update_gate.sigmoid()
|
||||
param_out = self.norm_out(param_out)
|
||||
input_out = self.input_norm_out(input_out)
|
||||
|
||||
if self.activate_out:
|
||||
param_out = self.activation(param_out)
|
||||
input_out = self.activation(input_out)
|
||||
|
||||
# Gate mechanism. Eq.(5) in original paper.
|
||||
# param_out has shape (batch_size, feat_channels, out_channels)
|
||||
features = update_gate * param_out.unsqueeze(
|
||||
-2) + input_gate * input_out
|
||||
|
||||
features = self.fc_layer(features)
|
||||
features = self.fc_norm(features)
|
||||
features = self.activation(features)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class KernelUpdateHead(nn.Module):
|
||||
"""Kernel Update Head in K-Net.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes. Default: 150.
|
||||
num_ffn_fcs (int): The number of fully-connected layers in
|
||||
FFNs. Default: 2.
|
||||
num_heads (int): The number of parallel attention heads.
|
||||
Default: 8.
|
||||
num_mask_fcs (int): The number of fully connected layers for
|
||||
mask prediction. Default: 3.
|
||||
feedforward_channels (int): The hidden dimension of FFNs.
|
||||
Defaults: 2048.
|
||||
in_channels (int): The number of channels of input feature map.
|
||||
Default: 256.
|
||||
out_channels (int): The number of output channels.
|
||||
Default: 256.
|
||||
dropout (float): The Probability of an element to be
|
||||
zeroed in MultiheadAttention and FFN. Default 0.0.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
ffn_act_cfg (dict): Config of activation layers in FFN.
|
||||
Default: dict(type='ReLU').
|
||||
conv_kernel_size (int): The kernel size of convolution in
|
||||
Kernel Update Head for dynamic kernel updation.
|
||||
Default: 1.
|
||||
feat_transform_cfg (dict | None): Config of feature transform.
|
||||
Default: None.
|
||||
kernel_init (bool): Whether initiate mask kernel in mask head.
|
||||
Default: False.
|
||||
with_ffn (bool): Whether add FFN in kernel update head.
|
||||
Default: True.
|
||||
feat_gather_stride (int): Stride of convolution in feature transform.
|
||||
Default: 1.
|
||||
mask_transform_stride (int): Stride of mask transform.
|
||||
Default: 1.
|
||||
kernel_updator_cfg (dict): Config of kernel updator.
|
||||
Default: dict(
|
||||
type='DynamicConv',
|
||||
in_channels=256,
|
||||
feat_channels=64,
|
||||
out_channels=256,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
norm_cfg=dict(type='LN')).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes=150,
|
||||
num_ffn_fcs=2,
|
||||
num_heads=8,
|
||||
num_mask_fcs=3,
|
||||
feedforward_channels=2048,
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
dropout=0.0,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
||||
conv_kernel_size=1,
|
||||
feat_transform_cfg=None,
|
||||
kernel_init=False,
|
||||
with_ffn=True,
|
||||
feat_gather_stride=1,
|
||||
mask_transform_stride=1,
|
||||
kernel_updator_cfg=dict(
|
||||
type='DynamicConv',
|
||||
in_channels=256,
|
||||
feat_channels=64,
|
||||
out_channels=256,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
norm_cfg=dict(type='LN'))):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.fp16_enabled = False
|
||||
self.dropout = dropout
|
||||
self.num_heads = num_heads
|
||||
self.kernel_init = kernel_init
|
||||
self.with_ffn = with_ffn
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.feat_gather_stride = feat_gather_stride
|
||||
self.mask_transform_stride = mask_transform_stride
|
||||
|
||||
self.attention = MultiheadAttention(in_channels * conv_kernel_size**2,
|
||||
num_heads, dropout)
|
||||
self.attention_norm = build_norm_layer(
|
||||
dict(type='LN'), in_channels * conv_kernel_size**2)[1]
|
||||
self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg)
|
||||
|
||||
if feat_transform_cfg is not None:
|
||||
kernel_size = feat_transform_cfg.pop('kernel_size', 1)
|
||||
transform_channels = in_channels
|
||||
self.feat_transform = ConvModule(
|
||||
transform_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
stride=feat_gather_stride,
|
||||
padding=int(feat_gather_stride // 2),
|
||||
**feat_transform_cfg)
|
||||
else:
|
||||
self.feat_transform = None
|
||||
|
||||
if self.with_ffn:
|
||||
self.ffn = FFN(
|
||||
in_channels,
|
||||
feedforward_channels,
|
||||
num_ffn_fcs,
|
||||
act_cfg=ffn_act_cfg,
|
||||
dropout=dropout)
|
||||
self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
|
||||
|
||||
self.mask_fcs = nn.ModuleList()
|
||||
for _ in range(num_mask_fcs):
|
||||
self.mask_fcs.append(
|
||||
nn.Linear(in_channels, in_channels, bias=False))
|
||||
self.mask_fcs.append(
|
||||
build_norm_layer(dict(type='LN'), in_channels)[1])
|
||||
self.mask_fcs.append(build_activation_layer(act_cfg))
|
||||
|
||||
self.fc_mask = nn.Linear(in_channels, out_channels)
|
||||
|
||||
def init_weights(self):
|
||||
"""Use xavier initialization for all weight parameter and set
|
||||
classification head bias as a specific value when use focal loss."""
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
else:
|
||||
# adopt the default initialization for
|
||||
# the weight and bias of the layer norm
|
||||
pass
|
||||
if self.kernel_init:
|
||||
print_log(
|
||||
'mask kernel in mask head is normal initialized by std 0.01')
|
||||
nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01)
|
||||
|
||||
def forward(self, x, proposal_feat, mask_preds, mask_shape=None):
|
||||
"""Forward function of Dynamic Instance Interactive Head.
|
||||
|
||||
Args:
|
||||
x (Tensor): Feature map from FPN with shape
|
||||
(batch_size, feature_dimensions, H , W).
|
||||
proposal_feat (Tensor): Intermediate feature get from
|
||||
diihead in last stage, has shape
|
||||
(batch_size, num_proposals, feature_dimensions)
|
||||
mask_preds (Tensor): mask prediction from the former stage in shape
|
||||
(batch_size, num_proposals, H, W).
|
||||
|
||||
Returns:
|
||||
Tuple: The first tensor is predicted mask with shape
|
||||
(N, num_classes, H, W), the second tensor is dynamic kernel
|
||||
with shape (N, num_classes, channels, K, K).
|
||||
"""
|
||||
N, num_proposals = proposal_feat.shape[:2]
|
||||
if self.feat_transform is not None:
|
||||
x = self.feat_transform(x)
|
||||
|
||||
C, H, W = x.shape[-3:]
|
||||
|
||||
mask_h, mask_w = mask_preds.shape[-2:]
|
||||
if mask_h != H or mask_w != W:
|
||||
gather_mask = F.interpolate(
|
||||
mask_preds, (H, W), align_corners=False, mode='bilinear')
|
||||
else:
|
||||
gather_mask = mask_preds
|
||||
|
||||
sigmoid_masks = gather_mask.softmax(dim=1)
|
||||
|
||||
# Group Feature Assembling. Eq.(3) in original paper.
|
||||
# einsum is faster than bmm by 30%
|
||||
x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x)
|
||||
|
||||
# obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C]
|
||||
proposal_feat = proposal_feat.reshape(N, num_proposals,
|
||||
self.in_channels,
|
||||
-1).permute(0, 1, 3, 2)
|
||||
obj_feat = self.kernel_update_conv(x_feat, proposal_feat)
|
||||
|
||||
# [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C]
|
||||
obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2)
|
||||
obj_feat = self.attention_norm(self.attention(obj_feat))
|
||||
# [N, B, K*K*C] -> [B, N, K*K*C]
|
||||
obj_feat = obj_feat.permute(1, 0, 2)
|
||||
|
||||
# obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]
|
||||
obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels)
|
||||
|
||||
# FFN
|
||||
if self.with_ffn:
|
||||
obj_feat = self.ffn_norm(self.ffn(obj_feat))
|
||||
|
||||
mask_feat = obj_feat
|
||||
|
||||
for reg_layer in self.mask_fcs:
|
||||
mask_feat = reg_layer(mask_feat)
|
||||
|
||||
# [B, N, K*K, C] -> [B, N, C, K*K]
|
||||
mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2)
|
||||
|
||||
if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1):
|
||||
mask_x = F.interpolate(
|
||||
x, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||
H, W = mask_x.shape[-2:]
|
||||
else:
|
||||
mask_x = x
|
||||
# group conv is 5x faster than unfold and uses about 1/5 memory
|
||||
# Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms
|
||||
# Group conv vs. unfold vs. concat batch, 278 : 1420 : 369
|
||||
# but in real training group conv is slower than concat batch
|
||||
# so we keep using concat batch.
|
||||
# fold_x = F.unfold(
|
||||
# mask_x,
|
||||
# self.conv_kernel_size,
|
||||
# padding=int(self.conv_kernel_size // 2))
|
||||
# mask_feat = mask_feat.reshape(N, num_proposals, -1)
|
||||
# new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x)
|
||||
# [B, N, C, K*K] -> [B*N, C, K, K]
|
||||
mask_feat = mask_feat.reshape(N, num_proposals, C,
|
||||
self.conv_kernel_size,
|
||||
self.conv_kernel_size)
|
||||
# [B, C, H, W] -> [1, B*C, H, W]
|
||||
new_mask_preds = []
|
||||
for i in range(N):
|
||||
new_mask_preds.append(
|
||||
F.conv2d(
|
||||
mask_x[i:i + 1],
|
||||
mask_feat[i],
|
||||
padding=int(self.conv_kernel_size // 2)))
|
||||
|
||||
new_mask_preds = torch.cat(new_mask_preds, dim=0)
|
||||
new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W)
|
||||
if self.mask_transform_stride == 2:
|
||||
new_mask_preds = F.interpolate(
|
||||
new_mask_preds,
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
if mask_shape is not None and mask_shape[0] != H:
|
||||
new_mask_preds = F.interpolate(
|
||||
new_mask_preds,
|
||||
mask_shape,
|
||||
align_corners=False,
|
||||
mode='bilinear')
|
||||
|
||||
return new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape(
|
||||
N, num_proposals, self.in_channels, self.conv_kernel_size,
|
||||
self.conv_kernel_size)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class IterativeDecodeHead(BaseDecodeHead):
|
||||
"""K-Net: Towards Unified Image Segmentation.
|
||||
|
||||
This head is the implementation of
|
||||
`K-Net: <https://arxiv.org/abs/2106.14855>`_.
|
||||
|
||||
Args:
|
||||
num_stages (int): The number of stages (kernel update heads)
|
||||
in IterativeDecodeHead. Default: 3.
|
||||
kernel_generate_head:(dict): Config of kernel generate head which
|
||||
generate mask predictions, dynamic kernels and class predictions
|
||||
for next kernel update heads.
|
||||
kernel_update_head (dict): Config of kernel update head which refine
|
||||
dynamic kernels and class predictions iteratively.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_stages, kernel_generate_head, kernel_update_head,
|
||||
**kwargs):
|
||||
# ``IterativeDecodeHead`` would skip initialization of
|
||||
# ``BaseDecodeHead`` which would be called when building
|
||||
# ``self.kernel_generate_head``.
|
||||
super(BaseDecodeHead, self).__init__(**kwargs)
|
||||
assert num_stages == len(kernel_update_head)
|
||||
self.num_stages = num_stages
|
||||
self.kernel_generate_head = MODELS.build(kernel_generate_head)
|
||||
self.kernel_update_head = nn.ModuleList()
|
||||
self.align_corners = self.kernel_generate_head.align_corners
|
||||
self.num_classes = self.kernel_generate_head.num_classes
|
||||
self.input_transform = self.kernel_generate_head.input_transform
|
||||
self.ignore_index = self.kernel_generate_head.ignore_index
|
||||
self.out_channels = self.num_classes
|
||||
|
||||
for head_cfg in kernel_update_head:
|
||||
self.kernel_update_head.append(MODELS.build(head_cfg))
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
feats = self.kernel_generate_head._forward_feature(inputs)
|
||||
sem_seg = self.kernel_generate_head.cls_seg(feats)
|
||||
seg_kernels = self.kernel_generate_head.conv_seg.weight.clone()
|
||||
seg_kernels = seg_kernels[None].expand(
|
||||
feats.size(0), *seg_kernels.size())
|
||||
|
||||
stage_segs = [sem_seg]
|
||||
for i in range(self.num_stages):
|
||||
sem_seg, seg_kernels = self.kernel_update_head[i](feats,
|
||||
seg_kernels,
|
||||
sem_seg)
|
||||
stage_segs.append(sem_seg)
|
||||
if self.training:
|
||||
return stage_segs
|
||||
# only return the prediction of the last stage during testing
|
||||
return stage_segs[-1]
|
||||
|
||||
def loss_by_feat(self, seg_logits: List[Tensor],
|
||||
batch_data_samples: SampleList, **kwargs) -> dict:
|
||||
losses = dict()
|
||||
for i, logit in enumerate(seg_logits):
|
||||
loss = self.kernel_generate_head.loss_by_feat(
|
||||
logit, batch_data_samples)
|
||||
for k, v in loss.items():
|
||||
losses[f'{k}.s{i}'] = v
|
||||
|
||||
return losses
|
||||
91
finetune/mmseg/models/decode_heads/lraspp_head.py
Normal file
91
finetune/mmseg/models/decode_heads/lraspp_head.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.utils import is_tuple_of
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class LRASPPHead(BaseDecodeHead):
|
||||
"""Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.
|
||||
|
||||
This head is the improved implementation of `Searching for MobileNetV3
|
||||
<https://ieeexplore.ieee.org/document/9008835>`_.
|
||||
|
||||
Args:
|
||||
branch_channels (tuple[int]): The number of output channels in every
|
||||
each branch. Default: (32, 64).
|
||||
"""
|
||||
|
||||
def __init__(self, branch_channels=(32, 64), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if self.input_transform != 'multiple_select':
|
||||
raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
|
||||
f'must be \'multiple_select\'. But received '
|
||||
f'\'{self.input_transform}\'')
|
||||
assert is_tuple_of(branch_channels, int)
|
||||
assert len(branch_channels) == len(self.in_channels) - 1
|
||||
self.branch_channels = branch_channels
|
||||
|
||||
self.convs = nn.Sequential()
|
||||
self.conv_ups = nn.Sequential()
|
||||
for i in range(len(branch_channels)):
|
||||
self.convs.add_module(
|
||||
f'conv{i}',
|
||||
nn.Conv2d(
|
||||
self.in_channels[i], branch_channels[i], 1, bias=False))
|
||||
self.conv_ups.add_module(
|
||||
f'conv_up{i}',
|
||||
ConvModule(
|
||||
self.channels + branch_channels[i],
|
||||
self.channels,
|
||||
1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
bias=False))
|
||||
|
||||
self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)
|
||||
|
||||
self.aspp_conv = ConvModule(
|
||||
self.in_channels[-1],
|
||||
self.channels,
|
||||
1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
bias=False)
|
||||
self.image_pool = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=49, stride=(16, 20)),
|
||||
ConvModule(
|
||||
self.in_channels[2],
|
||||
self.channels,
|
||||
1,
|
||||
act_cfg=dict(type='Sigmoid'),
|
||||
bias=False))
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
inputs = self._transform_inputs(inputs)
|
||||
|
||||
x = inputs[-1]
|
||||
|
||||
x = self.aspp_conv(x) * resize(
|
||||
self.image_pool(x),
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.conv_up_input(x)
|
||||
|
||||
for i in range(len(self.branch_channels) - 1, -1, -1):
|
||||
x = resize(
|
||||
x,
|
||||
size=inputs[i].size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = torch.cat([x, self.convs[i](inputs[i])], 1)
|
||||
x = self.conv_ups[i](x)
|
||||
|
||||
return self.cls_seg(x)
|
||||
163
finetune/mmseg/models/decode_heads/mask2former_head.py
Normal file
163
finetune/mmseg/models/decode_heads/mask2former_head.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
try:
|
||||
from mmdet.models.dense_heads import \
|
||||
Mask2FormerHead as MMDET_Mask2FormerHead
|
||||
except ModuleNotFoundError:
|
||||
MMDET_Mask2FormerHead = BaseModule
|
||||
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures.seg_data_sample import SegDataSample
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class Mask2FormerHead(MMDET_Mask2FormerHead):
|
||||
"""Implements the Mask2Former head.
|
||||
|
||||
See `Mask2Former: Masked-attention Mask Transformer for Universal Image
|
||||
Segmentation <https://arxiv.org/abs/2112.01527>`_ for details.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes. Default: 150.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
ignore_index (int): The label index to be ignored. Default: 255.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
align_corners=False,
|
||||
ignore_index=255,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.align_corners = align_corners
|
||||
self.out_channels = num_classes
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
feat_channels = kwargs['feat_channels']
|
||||
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
|
||||
|
||||
def _seg_data_to_instance_data(self, batch_data_samples: SampleList):
|
||||
"""Perform forward propagation to convert paradigm from MMSegmentation
|
||||
to MMDetection to ensure ``MMDET_Mask2FormerHead`` could be called
|
||||
normally. Specifically, ``batch_gt_instances`` would be added.
|
||||
|
||||
Args:
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: A tuple contains two lists.
|
||||
|
||||
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
||||
gt_instance. It usually includes ``labels``, each is
|
||||
unique ground truth label id of images, with
|
||||
shape (num_gt, ) and ``masks``, each is ground truth
|
||||
masks of each instances of a image, shape (num_gt, h, w).
|
||||
- batch_img_metas (list[dict]): List of image meta information.
|
||||
"""
|
||||
batch_img_metas = []
|
||||
batch_gt_instances = []
|
||||
|
||||
for data_sample in batch_data_samples:
|
||||
batch_img_metas.append(data_sample.metainfo)
|
||||
gt_sem_seg = data_sample.gt_sem_seg.data
|
||||
classes = torch.unique(
|
||||
gt_sem_seg,
|
||||
sorted=False,
|
||||
return_inverse=False,
|
||||
return_counts=False)
|
||||
|
||||
# remove ignored region
|
||||
gt_labels = classes[classes != self.ignore_index]
|
||||
|
||||
masks = []
|
||||
for class_id in gt_labels:
|
||||
masks.append(gt_sem_seg == class_id)
|
||||
|
||||
if len(masks) == 0:
|
||||
gt_masks = torch.zeros(
|
||||
(0, gt_sem_seg.shape[-2],
|
||||
gt_sem_seg.shape[-1])).to(gt_sem_seg).long()
|
||||
else:
|
||||
gt_masks = torch.stack(masks).squeeze(1).long()
|
||||
|
||||
instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
|
||||
batch_gt_instances.append(instance_data)
|
||||
return batch_gt_instances, batch_img_metas
|
||||
|
||||
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
|
||||
train_cfg: ConfigType) -> dict:
|
||||
"""Perform forward propagation and loss calculation of the decoder head
|
||||
on the features of the upstream network.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the upstream
|
||||
network, each is a 4D-tensor.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
train_cfg (ConfigType): Training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components.
|
||||
"""
|
||||
# batch SegDataSample to InstanceDataSample
|
||||
batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data(
|
||||
batch_data_samples)
|
||||
|
||||
# forward
|
||||
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
|
||||
|
||||
# loss
|
||||
losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
|
||||
batch_gt_instances, batch_img_metas)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType) -> Tuple[Tensor]:
|
||||
"""Test without augmentaton.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the
|
||||
upstream network, each is a 4D-tensor.
|
||||
batch_img_metas (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
test_cfg (ConfigType): Test config.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor of segmentation mask.
|
||||
"""
|
||||
batch_data_samples = [
|
||||
SegDataSample(metainfo=metainfo) for metainfo in batch_img_metas
|
||||
]
|
||||
|
||||
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
|
||||
mask_cls_results = all_cls_scores[-1]
|
||||
mask_pred_results = all_mask_preds[-1]
|
||||
if 'pad_shape' in batch_img_metas[0]:
|
||||
size = batch_img_metas[0]['pad_shape']
|
||||
else:
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
# upsample mask
|
||||
mask_pred_results = F.interpolate(
|
||||
mask_pred_results, size=size, mode='bilinear', align_corners=False)
|
||||
cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1]
|
||||
mask_pred = mask_pred_results.sigmoid()
|
||||
seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred)
|
||||
return seg_logits
|
||||
174
finetune/mmseg/models/decode_heads/maskformer_head.py
Normal file
174
finetune/mmseg/models/decode_heads/maskformer_head.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
try:
|
||||
from mmdet.models.dense_heads import MaskFormerHead as MMDET_MaskFormerHead
|
||||
except ModuleNotFoundError:
|
||||
MMDET_MaskFormerHead = BaseModule
|
||||
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures.seg_data_sample import SegDataSample
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MaskFormerHead(MMDET_MaskFormerHead):
|
||||
"""Implements the MaskFormer head.
|
||||
|
||||
See `Per-Pixel Classification is Not All You Need for Semantic Segmentation
|
||||
<https://arxiv.org/pdf/2107.06278>`_ for details.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes. Default: 150.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
ignore_index (int): The label index to be ignored. Default: 255.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes: int = 150,
|
||||
align_corners: bool = False,
|
||||
ignore_index: int = 255,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.out_channels = kwargs['out_channels']
|
||||
self.align_corners = True
|
||||
self.num_classes = num_classes
|
||||
self.align_corners = align_corners
|
||||
self.out_channels = num_classes
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
feat_channels = kwargs['feat_channels']
|
||||
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
|
||||
|
||||
def _seg_data_to_instance_data(self, batch_data_samples: SampleList):
|
||||
"""Perform forward propagation to convert paradigm from MMSegmentation
|
||||
to MMDetection to ensure ``MMDET_MaskFormerHead`` could be called
|
||||
normally. Specifically, ``batch_gt_instances`` would be added.
|
||||
|
||||
Args:
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: A tuple contains two lists.
|
||||
|
||||
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
||||
gt_instance. It usually includes ``labels``, each is
|
||||
unique ground truth label id of images, with
|
||||
shape (num_gt, ) and ``masks``, each is ground truth
|
||||
masks of each instances of a image, shape (num_gt, h, w).
|
||||
- batch_img_metas (list[dict]): List of image meta information.
|
||||
"""
|
||||
batch_img_metas = []
|
||||
batch_gt_instances = []
|
||||
for data_sample in batch_data_samples:
|
||||
# Add `batch_input_shape` in metainfo of data_sample, which would
|
||||
# be used in MaskFormerHead of MMDetection.
|
||||
metainfo = data_sample.metainfo
|
||||
metainfo['batch_input_shape'] = metainfo['img_shape']
|
||||
data_sample.set_metainfo(metainfo)
|
||||
batch_img_metas.append(data_sample.metainfo)
|
||||
gt_sem_seg = data_sample.gt_sem_seg.data
|
||||
classes = torch.unique(
|
||||
gt_sem_seg,
|
||||
sorted=False,
|
||||
return_inverse=False,
|
||||
return_counts=False)
|
||||
|
||||
# remove ignored region
|
||||
gt_labels = classes[classes != self.ignore_index]
|
||||
|
||||
masks = []
|
||||
for class_id in gt_labels:
|
||||
masks.append(gt_sem_seg == class_id)
|
||||
|
||||
if len(masks) == 0:
|
||||
gt_masks = torch.zeros((0, gt_sem_seg.shape[-2],
|
||||
gt_sem_seg.shape[-1])).to(gt_sem_seg)
|
||||
else:
|
||||
gt_masks = torch.stack(masks).squeeze(1)
|
||||
|
||||
instance_data = InstanceData(
|
||||
labels=gt_labels, masks=gt_masks.long())
|
||||
batch_gt_instances.append(instance_data)
|
||||
return batch_gt_instances, batch_img_metas
|
||||
|
||||
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
|
||||
train_cfg: ConfigType) -> dict:
|
||||
"""Perform forward propagation and loss calculation of the decoder head
|
||||
on the features of the upstream network.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the upstream
|
||||
network, each is a 4D-tensor.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
train_cfg (ConfigType): Training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components.
|
||||
"""
|
||||
# batch SegDataSample to InstanceDataSample
|
||||
batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data(
|
||||
batch_data_samples)
|
||||
|
||||
# forward
|
||||
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
|
||||
|
||||
# loss
|
||||
losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
|
||||
batch_gt_instances, batch_img_metas)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType) -> Tuple[Tensor]:
|
||||
"""Test without augmentaton.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the
|
||||
upstream network, each is a 4D-tensor.
|
||||
batch_img_metas (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
test_cfg (ConfigType): Test config.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor of segmentation mask.
|
||||
"""
|
||||
|
||||
batch_data_samples = []
|
||||
for metainfo in batch_img_metas:
|
||||
metainfo['batch_input_shape'] = metainfo['img_shape']
|
||||
batch_data_samples.append(SegDataSample(metainfo=metainfo))
|
||||
# Forward function of MaskFormerHead from MMDetection needs
|
||||
# 'batch_data_samples' as inputs, which is image shape actually.
|
||||
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
|
||||
mask_cls_results = all_cls_scores[-1]
|
||||
mask_pred_results = all_mask_preds[-1]
|
||||
|
||||
# upsample masks
|
||||
img_shape = batch_img_metas[0]['batch_input_shape']
|
||||
mask_pred_results = F.interpolate(
|
||||
mask_pred_results,
|
||||
size=img_shape,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
# semantic inference
|
||||
cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1]
|
||||
mask_pred = mask_pred_results.sigmoid()
|
||||
seg_logits = torch.einsum('bqc,bqhw->bchw', cls_score, mask_pred)
|
||||
return seg_logits
|
||||
50
finetune/mmseg/models/decode_heads/nl_head.py
Normal file
50
finetune/mmseg/models/decode_heads/nl_head.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.cnn import NonLocal2d
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class NLHead(FCNHead):
|
||||
"""Non-local Neural Networks.
|
||||
|
||||
This head is the implementation of `NLNet
|
||||
<https://arxiv.org/abs/1711.07971>`_.
|
||||
|
||||
Args:
|
||||
reduction (int): Reduction factor of projection transform. Default: 2.
|
||||
use_scale (bool): Whether to scale pairwise_weight by
|
||||
sqrt(1/inter_channels). Default: True.
|
||||
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
|
||||
'dot_product'. Default: 'embedded_gaussian.'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reduction=2,
|
||||
use_scale=True,
|
||||
mode='embedded_gaussian',
|
||||
**kwargs):
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.reduction = reduction
|
||||
self.use_scale = use_scale
|
||||
self.mode = mode
|
||||
self.nl_block = NonLocal2d(
|
||||
in_channels=self.channels,
|
||||
reduction=self.reduction,
|
||||
use_scale=self.use_scale,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
mode=self.mode)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
output = self.nl_block(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
127
finetune/mmseg/models/decode_heads/ocr_head.py
Normal file
127
finetune/mmseg/models/decode_heads/ocr_head.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from ..utils import resize
|
||||
from .cascade_decode_head import BaseCascadeDecodeHead
|
||||
|
||||
|
||||
class SpatialGatherModule(nn.Module):
|
||||
"""Aggregate the context features according to the initial predicted
|
||||
probability distribution.
|
||||
|
||||
Employ the soft-weighted method to aggregate the context.
|
||||
"""
|
||||
|
||||
def __init__(self, scale):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, feats, probs):
|
||||
"""Forward function."""
|
||||
batch_size, num_classes, height, width = probs.size()
|
||||
channels = feats.size(1)
|
||||
probs = probs.view(batch_size, num_classes, -1)
|
||||
feats = feats.view(batch_size, channels, -1)
|
||||
# [batch_size, height*width, num_classes]
|
||||
feats = feats.permute(0, 2, 1)
|
||||
# [batch_size, channels, height*width]
|
||||
probs = F.softmax(self.scale * probs, dim=2)
|
||||
# [batch_size, channels, num_classes]
|
||||
ocr_context = torch.matmul(probs, feats)
|
||||
ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3)
|
||||
return ocr_context
|
||||
|
||||
|
||||
class ObjectAttentionBlock(_SelfAttentionBlock):
|
||||
"""Make a OCR used SelfAttentionBlock."""
|
||||
|
||||
def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg,
|
||||
act_cfg):
|
||||
if scale > 1:
|
||||
query_downsample = nn.MaxPool2d(kernel_size=scale)
|
||||
else:
|
||||
query_downsample = None
|
||||
super().__init__(
|
||||
key_in_channels=in_channels,
|
||||
query_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=in_channels,
|
||||
share_key_query=False,
|
||||
query_downsample=query_downsample,
|
||||
key_downsample=None,
|
||||
key_query_num_convs=2,
|
||||
key_query_norm=True,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=True,
|
||||
matmul_norm=True,
|
||||
with_out=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.bottleneck = ConvModule(
|
||||
in_channels * 2,
|
||||
in_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, query_feats, key_feats):
|
||||
"""Forward function."""
|
||||
context = super().forward(query_feats, key_feats)
|
||||
output = self.bottleneck(torch.cat([context, query_feats], dim=1))
|
||||
if self.query_downsample is not None:
|
||||
output = resize(query_feats)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class OCRHead(BaseCascadeDecodeHead):
|
||||
"""Object-Contextual Representations for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `OCRNet
|
||||
<https://arxiv.org/abs/1909.11065>`_.
|
||||
|
||||
Args:
|
||||
ocr_channels (int): The intermediate channels of OCR block.
|
||||
scale (int): The scale of probability map in SpatialGatherModule in
|
||||
Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, ocr_channels, scale=1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.ocr_channels = ocr_channels
|
||||
self.scale = scale
|
||||
self.object_context_block = ObjectAttentionBlock(
|
||||
self.channels,
|
||||
self.ocr_channels,
|
||||
self.scale,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.spatial_gather_module = SpatialGatherModule(self.scale)
|
||||
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs, prev_output):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
feats = self.bottleneck(x)
|
||||
context = self.spatial_gather_module(feats, prev_output)
|
||||
object_context = self.object_context_block(feats, context)
|
||||
output = self.cls_seg(object_context)
|
||||
|
||||
return output
|
||||
183
finetune/mmseg/models/decode_heads/pid_head.py
Normal file
183
finetune/mmseg/models/decode_heads/pid_head.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.models.losses import accuracy
|
||||
from mmseg.models.utils import resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType, SampleList
|
||||
|
||||
|
||||
class BasePIDHead(BaseModule):
|
||||
"""Base class for PID head.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
init_cfg (dict or list[dict], optional): Init config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
self.conv = ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
order=('norm', 'act', 'conv'))
|
||||
_, self.norm = build_norm_layer(norm_cfg, num_features=channels)
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
|
||||
def forward(self, x: Tensor, cls_seg: Optional[nn.Module]) -> Tensor:
|
||||
"""Forward function.
|
||||
Args:
|
||||
x (Tensor): Input tensor.
|
||||
cls_seg (nn.Module, optional): The classification head.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor.
|
||||
"""
|
||||
x = self.conv(x)
|
||||
x = self.norm(x)
|
||||
x = self.act(x)
|
||||
if cls_seg is not None:
|
||||
x = cls_seg(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PIDHead(BaseDecodeHead):
|
||||
"""Decode head for PIDNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
num_classes (int): Number of classes.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
num_classes: int,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
channels,
|
||||
num_classes=num_classes,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**kwargs)
|
||||
self.i_head = BasePIDHead(in_channels, channels, norm_cfg, act_cfg)
|
||||
self.p_head = BasePIDHead(in_channels // 2, channels, norm_cfg,
|
||||
act_cfg)
|
||||
self.d_head = BasePIDHead(
|
||||
in_channels // 2,
|
||||
in_channels // 4,
|
||||
norm_cfg,
|
||||
)
|
||||
self.p_cls_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
|
||||
self.d_cls_seg = nn.Conv2d(in_channels // 4, 1, kernel_size=1)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: Union[Tensor,
|
||||
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
|
||||
"""Forward function.
|
||||
Args:
|
||||
inputs (Tensor | tuple[Tensor]): Input tensor or tuple of
|
||||
Tensor. When training, the input is a tuple of three tensors,
|
||||
(p_feat, i_feat, d_feat), and the output is a tuple of three
|
||||
tensors, (p_seg_logit, i_seg_logit, d_seg_logit).
|
||||
When inference, only the head of integral branch is used, and
|
||||
input is a tensor of integral feature map, and the output is
|
||||
the segmentation logit.
|
||||
|
||||
Returns:
|
||||
Tensor | tuple[Tensor]: Output tensor or tuple of tensors.
|
||||
"""
|
||||
if self.training:
|
||||
x_p, x_i, x_d = inputs
|
||||
x_p = self.p_head(x_p, self.p_cls_seg)
|
||||
x_i = self.i_head(x_i, self.cls_seg)
|
||||
x_d = self.d_head(x_d, self.d_cls_seg)
|
||||
return x_p, x_i, x_d
|
||||
else:
|
||||
return self.i_head(inputs, self.cls_seg)
|
||||
|
||||
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tuple[Tensor]:
|
||||
gt_semantic_segs = [
|
||||
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
|
||||
]
|
||||
gt_edge_segs = [
|
||||
data_sample.gt_edge_map.data for data_sample in batch_data_samples
|
||||
]
|
||||
gt_sem_segs = torch.stack(gt_semantic_segs, dim=0)
|
||||
gt_edge_segs = torch.stack(gt_edge_segs, dim=0)
|
||||
return gt_sem_segs, gt_edge_segs
|
||||
|
||||
def loss_by_feat(self, seg_logits: Tuple[Tensor],
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
loss = dict()
|
||||
p_logit, i_logit, d_logit = seg_logits
|
||||
sem_label, bd_label = self._stack_batch_gt(batch_data_samples)
|
||||
p_logit = resize(
|
||||
input=p_logit,
|
||||
size=sem_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
i_logit = resize(
|
||||
input=i_logit,
|
||||
size=sem_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
d_logit = resize(
|
||||
input=d_logit,
|
||||
size=bd_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
sem_label = sem_label.squeeze(1)
|
||||
bd_label = bd_label.squeeze(1)
|
||||
loss['loss_sem_p'] = self.loss_decode[0](
|
||||
p_logit, sem_label, ignore_index=self.ignore_index)
|
||||
loss['loss_sem_i'] = self.loss_decode[1](i_logit, sem_label)
|
||||
loss['loss_bd'] = self.loss_decode[2](d_logit, bd_label)
|
||||
filler = torch.ones_like(sem_label) * self.ignore_index
|
||||
sem_bd_label = torch.where(
|
||||
torch.sigmoid(d_logit[:, 0, :, :]) > 0.8, sem_label, filler)
|
||||
loss['loss_sem_bd'] = self.loss_decode[3](i_logit, sem_bd_label)
|
||||
loss['acc_seg'] = accuracy(
|
||||
i_logit, sem_label, ignore_index=self.ignore_index)
|
||||
return loss
|
||||
367
finetune/mmseg/models/decode_heads/point_head.py
Normal file
367
finetune/mmseg/models/decode_heads/point_head.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
try:
|
||||
from mmcv.ops import point_sample
|
||||
except ModuleNotFoundError:
|
||||
point_sample = None
|
||||
|
||||
from typing import List
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList
|
||||
from ..losses import accuracy
|
||||
from ..utils import resize
|
||||
from .cascade_decode_head import BaseCascadeDecodeHead
|
||||
|
||||
|
||||
def calculate_uncertainty(seg_logits):
|
||||
"""Estimate uncertainty based on seg logits.
|
||||
|
||||
For each location of the prediction ``seg_logits`` we estimate
|
||||
uncertainty as the difference between top first and top second
|
||||
predicted logits.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): Semantic segmentation logits,
|
||||
shape (batch_size, num_classes, height, width).
|
||||
|
||||
Returns:
|
||||
scores (Tensor): T uncertainty scores with the most uncertain
|
||||
locations having the highest uncertainty score, shape (
|
||||
batch_size, 1, height, width)
|
||||
"""
|
||||
top2_scores = torch.topk(seg_logits, k=2, dim=1)[0]
|
||||
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PointHead(BaseCascadeDecodeHead):
|
||||
"""A mask point head use in PointRend.
|
||||
|
||||
This head is implemented of `PointRend: Image Segmentation as
|
||||
Rendering <https://arxiv.org/abs/1912.08193>`_.
|
||||
``PointHead`` use shared multi-layer perceptron (equivalent to
|
||||
nn.Conv1d) to predict the logit of input points. The fine-grained feature
|
||||
and coarse feature will be concatenate together for predication.
|
||||
|
||||
Args:
|
||||
num_fcs (int): Number of fc layers in the head. Default: 3.
|
||||
in_channels (int): Number of input channels. Default: 256.
|
||||
fc_channels (int): Number of fc channels. Default: 256.
|
||||
num_classes (int): Number of classes for logits. Default: 80.
|
||||
class_agnostic (bool): Whether use class agnostic classification.
|
||||
If so, the output channels of logits will be 1. Default: False.
|
||||
coarse_pred_each_layer (bool): Whether concatenate coarse feature with
|
||||
the output of each fc layer. Default: True.
|
||||
conv_cfg (dict|None): Dictionary to construct and config conv layer.
|
||||
Default: dict(type='Conv1d'))
|
||||
norm_cfg (dict|None): Dictionary to construct and config norm layer.
|
||||
Default: None.
|
||||
loss_point (dict): Dictionary to construct and config loss layer of
|
||||
point head. Default: dict(type='CrossEntropyLoss', use_mask=True,
|
||||
loss_weight=1.0).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_fcs=3,
|
||||
coarse_pred_each_layer=True,
|
||||
conv_cfg=dict(type='Conv1d'),
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU', inplace=False),
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
input_transform='multiple_select',
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
init_cfg=dict(
|
||||
type='Normal', std=0.01, override=dict(name='fc_seg')),
|
||||
**kwargs)
|
||||
if point_sample is None:
|
||||
raise RuntimeError('Please install mmcv-full for '
|
||||
'point_sample ops')
|
||||
|
||||
self.num_fcs = num_fcs
|
||||
self.coarse_pred_each_layer = coarse_pred_each_layer
|
||||
|
||||
fc_in_channels = sum(self.in_channels) + self.num_classes
|
||||
fc_channels = self.channels
|
||||
self.fcs = nn.ModuleList()
|
||||
for k in range(num_fcs):
|
||||
fc = ConvModule(
|
||||
fc_in_channels,
|
||||
fc_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.fcs.append(fc)
|
||||
fc_in_channels = fc_channels
|
||||
fc_in_channels += self.num_classes if self.coarse_pred_each_layer \
|
||||
else 0
|
||||
self.fc_seg = nn.Conv1d(
|
||||
fc_in_channels,
|
||||
self.num_classes,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
if self.dropout_ratio > 0:
|
||||
self.dropout = nn.Dropout(self.dropout_ratio)
|
||||
delattr(self, 'conv_seg')
|
||||
|
||||
def cls_seg(self, feat):
|
||||
"""Classify each pixel with fc."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.fc_seg(feat)
|
||||
return output
|
||||
|
||||
def forward(self, fine_grained_point_feats, coarse_point_feats):
|
||||
x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1)
|
||||
for fc in self.fcs:
|
||||
x = fc(x)
|
||||
if self.coarse_pred_each_layer:
|
||||
x = torch.cat((x, coarse_point_feats), dim=1)
|
||||
return self.cls_seg(x)
|
||||
|
||||
def _get_fine_grained_point_feats(self, x, points):
|
||||
"""Sample from fine grained features.
|
||||
|
||||
Args:
|
||||
x (list[Tensor]): Feature pyramid from by neck or backbone.
|
||||
points (Tensor): Point coordinates, shape (batch_size,
|
||||
num_points, 2).
|
||||
|
||||
Returns:
|
||||
fine_grained_feats (Tensor): Sampled fine grained feature,
|
||||
shape (batch_size, sum(channels of x), num_points).
|
||||
"""
|
||||
|
||||
fine_grained_feats_list = [
|
||||
point_sample(_, points, align_corners=self.align_corners)
|
||||
for _ in x
|
||||
]
|
||||
if len(fine_grained_feats_list) > 1:
|
||||
fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1)
|
||||
else:
|
||||
fine_grained_feats = fine_grained_feats_list[0]
|
||||
|
||||
return fine_grained_feats
|
||||
|
||||
def _get_coarse_point_feats(self, prev_output, points):
|
||||
"""Sample from fine grained features.
|
||||
|
||||
Args:
|
||||
prev_output (list[Tensor]): Prediction of previous decode head.
|
||||
points (Tensor): Point coordinates, shape (batch_size,
|
||||
num_points, 2).
|
||||
|
||||
Returns:
|
||||
coarse_feats (Tensor): Sampled coarse feature, shape (batch_size,
|
||||
num_classes, num_points).
|
||||
"""
|
||||
|
||||
coarse_feats = point_sample(
|
||||
prev_output, points, align_corners=self.align_corners)
|
||||
|
||||
return coarse_feats
|
||||
|
||||
def loss(self, inputs, prev_output, batch_data_samples: SampleList,
|
||||
train_cfg, **kwargs):
|
||||
"""Forward function for training.
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
prev_output (Tensor): The output of previous decode head.
|
||||
batch_data_samples (list[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `img_metas` or `gt_semantic_seg`.
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
x = self._transform_inputs(inputs)
|
||||
with torch.no_grad():
|
||||
points = self.get_points_train(
|
||||
prev_output, calculate_uncertainty, cfg=train_cfg)
|
||||
fine_grained_point_feats = self._get_fine_grained_point_feats(
|
||||
x, points)
|
||||
coarse_point_feats = self._get_coarse_point_feats(prev_output, points)
|
||||
point_logits = self.forward(fine_grained_point_feats,
|
||||
coarse_point_feats)
|
||||
|
||||
losses = self.loss_by_feat(point_logits, points, batch_data_samples)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, inputs, prev_output, batch_img_metas: List[dict],
|
||||
test_cfg, **kwargs):
|
||||
"""Forward function for testing.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
prev_output (Tensor): The output of previous decode head.
|
||||
img_metas (list[dict]): List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Output segmentation map.
|
||||
"""
|
||||
|
||||
x = self._transform_inputs(inputs)
|
||||
refined_seg_logits = prev_output.clone()
|
||||
for _ in range(test_cfg.subdivision_steps):
|
||||
refined_seg_logits = resize(
|
||||
refined_seg_logits,
|
||||
scale_factor=test_cfg.scale_factor,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
batch_size, channels, height, width = refined_seg_logits.shape
|
||||
point_indices, points = self.get_points_test(
|
||||
refined_seg_logits, calculate_uncertainty, cfg=test_cfg)
|
||||
fine_grained_point_feats = self._get_fine_grained_point_feats(
|
||||
x, points)
|
||||
coarse_point_feats = self._get_coarse_point_feats(
|
||||
prev_output, points)
|
||||
point_logits = self.forward(fine_grained_point_feats,
|
||||
coarse_point_feats)
|
||||
|
||||
point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
|
||||
refined_seg_logits = refined_seg_logits.reshape(
|
||||
batch_size, channels, height * width)
|
||||
refined_seg_logits = refined_seg_logits.scatter_(
|
||||
2, point_indices, point_logits)
|
||||
refined_seg_logits = refined_seg_logits.view(
|
||||
batch_size, channels, height, width)
|
||||
|
||||
return self.predict_by_feat(refined_seg_logits, batch_img_metas,
|
||||
**kwargs)
|
||||
|
||||
def loss_by_feat(self, point_logits, points, batch_data_samples, **kwargs):
|
||||
"""Compute segmentation loss."""
|
||||
gt_semantic_seg = self._stack_batch_gt(batch_data_samples)
|
||||
point_label = point_sample(
|
||||
gt_semantic_seg.float(),
|
||||
points,
|
||||
mode='nearest',
|
||||
align_corners=self.align_corners)
|
||||
point_label = point_label.squeeze(1).long()
|
||||
|
||||
loss = dict()
|
||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.loss_decode]
|
||||
else:
|
||||
losses_decode = self.loss_decode
|
||||
for loss_module in losses_decode:
|
||||
loss['point' + loss_module.loss_name] = loss_module(
|
||||
point_logits, point_label, ignore_index=self.ignore_index)
|
||||
|
||||
loss['acc_point'] = accuracy(
|
||||
point_logits, point_label, ignore_index=self.ignore_index)
|
||||
return loss
|
||||
|
||||
def get_points_train(self, seg_logits, uncertainty_func, cfg):
|
||||
"""Sample points for training.
|
||||
|
||||
Sample points in [0, 1] x [0, 1] coordinate space based on their
|
||||
uncertainty. The uncertainties are calculated for each point using
|
||||
'uncertainty_func' function that takes point's logit prediction as
|
||||
input.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): Semantic segmentation logits, shape (
|
||||
batch_size, num_classes, height, width).
|
||||
uncertainty_func (func): uncertainty calculation function.
|
||||
cfg (dict): Training config of point head.
|
||||
|
||||
Returns:
|
||||
point_coords (Tensor): A tensor of shape (batch_size, num_points,
|
||||
2) that contains the coordinates of ``num_points`` sampled
|
||||
points.
|
||||
"""
|
||||
num_points = cfg.num_points
|
||||
oversample_ratio = cfg.oversample_ratio
|
||||
importance_sample_ratio = cfg.importance_sample_ratio
|
||||
assert oversample_ratio >= 1
|
||||
assert 0 <= importance_sample_ratio <= 1
|
||||
batch_size = seg_logits.shape[0]
|
||||
num_sampled = int(num_points * oversample_ratio)
|
||||
point_coords = torch.rand(
|
||||
batch_size, num_sampled, 2, device=seg_logits.device)
|
||||
point_logits = point_sample(seg_logits, point_coords)
|
||||
# It is crucial to calculate uncertainty based on the sampled
|
||||
# prediction value for the points. Calculating uncertainties of the
|
||||
# coarse predictions first and sampling them for points leads to
|
||||
# incorrect results. To illustrate this: assume uncertainty func(
|
||||
# logits)=-abs(logits), a sampled point between two coarse
|
||||
# predictions with -1 and 1 logits has 0 logits, and therefore 0
|
||||
# uncertainty value. However, if we calculate uncertainties for the
|
||||
# coarse predictions first, both will have -1 uncertainty,
|
||||
# and sampled point will get -1 uncertainty.
|
||||
point_uncertainties = uncertainty_func(point_logits)
|
||||
num_uncertain_points = int(importance_sample_ratio * num_points)
|
||||
num_random_points = num_points - num_uncertain_points
|
||||
idx = torch.topk(
|
||||
point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
|
||||
shift = num_sampled * torch.arange(
|
||||
batch_size, dtype=torch.long, device=seg_logits.device)
|
||||
idx += shift[:, None]
|
||||
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
|
||||
batch_size, num_uncertain_points, 2)
|
||||
if num_random_points > 0:
|
||||
rand_point_coords = torch.rand(
|
||||
batch_size, num_random_points, 2, device=seg_logits.device)
|
||||
point_coords = torch.cat((point_coords, rand_point_coords), dim=1)
|
||||
return point_coords
|
||||
|
||||
def get_points_test(self, seg_logits, uncertainty_func, cfg):
|
||||
"""Sample points for testing.
|
||||
|
||||
Find ``num_points`` most uncertain points from ``uncertainty_map``.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): A tensor of shape (batch_size, num_classes,
|
||||
height, width) for class-specific or class-agnostic prediction.
|
||||
uncertainty_func (func): uncertainty calculation function.
|
||||
cfg (dict): Testing config of point head.
|
||||
|
||||
Returns:
|
||||
point_indices (Tensor): A tensor of shape (batch_size, num_points)
|
||||
that contains indices from [0, height x width) of the most
|
||||
uncertain points.
|
||||
point_coords (Tensor): A tensor of shape (batch_size, num_points,
|
||||
2) that contains [0, 1] x [0, 1] normalized coordinates of the
|
||||
most uncertain points from the ``height x width`` grid .
|
||||
"""
|
||||
|
||||
num_points = cfg.subdivision_num_points
|
||||
uncertainty_map = uncertainty_func(seg_logits)
|
||||
batch_size, _, height, width = uncertainty_map.shape
|
||||
h_step = 1.0 / height
|
||||
w_step = 1.0 / width
|
||||
|
||||
uncertainty_map = uncertainty_map.view(batch_size, height * width)
|
||||
num_points = min(height * width, num_points)
|
||||
point_indices = uncertainty_map.topk(num_points, dim=1)[1]
|
||||
point_coords = torch.zeros(
|
||||
batch_size,
|
||||
num_points,
|
||||
2,
|
||||
dtype=torch.float,
|
||||
device=seg_logits.device)
|
||||
point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
|
||||
width).float() * w_step
|
||||
point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
|
||||
width).float() * h_step
|
||||
return point_indices, point_coords
|
||||
197
finetune/mmseg/models/decode_heads/psa_head.py
Normal file
197
finetune/mmseg/models/decode_heads/psa_head.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
try:
|
||||
from mmcv.ops import PSAMask
|
||||
except ModuleNotFoundError:
|
||||
PSAMask = None
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PSAHead(BaseDecodeHead):
|
||||
"""Point-wise Spatial Attention Network for Scene Parsing.
|
||||
|
||||
This head is the implementation of `PSANet
|
||||
<https://hszhao.github.io/papers/eccv18_psanet.pdf>`_.
|
||||
|
||||
Args:
|
||||
mask_size (tuple[int]): The PSA mask size. It usually equals input
|
||||
size.
|
||||
psa_type (str): The type of psa module. Options are 'collect',
|
||||
'distribute', 'bi-direction'. Default: 'bi-direction'
|
||||
compact (bool): Whether use compact map for 'collect' mode.
|
||||
Default: True.
|
||||
shrink_factor (int): The downsample factors of psa mask. Default: 2.
|
||||
normalization_factor (float): The normalize factor of attention.
|
||||
psa_softmax (bool): Whether use softmax for attention.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mask_size,
|
||||
psa_type='bi-direction',
|
||||
compact=False,
|
||||
shrink_factor=2,
|
||||
normalization_factor=1.0,
|
||||
psa_softmax=True,
|
||||
**kwargs):
|
||||
if PSAMask is None:
|
||||
raise RuntimeError('Please install mmcv-full for PSAMask ops')
|
||||
super().__init__(**kwargs)
|
||||
assert psa_type in ['collect', 'distribute', 'bi-direction']
|
||||
self.psa_type = psa_type
|
||||
self.compact = compact
|
||||
self.shrink_factor = shrink_factor
|
||||
self.mask_size = mask_size
|
||||
mask_h, mask_w = mask_size
|
||||
self.psa_softmax = psa_softmax
|
||||
if normalization_factor is None:
|
||||
normalization_factor = mask_h * mask_w
|
||||
self.normalization_factor = normalization_factor
|
||||
|
||||
self.reduce = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.attention = nn.Sequential(
|
||||
ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
nn.Conv2d(
|
||||
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
|
||||
if psa_type == 'bi-direction':
|
||||
self.reduce_p = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.attention_p = nn.Sequential(
|
||||
ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
nn.Conv2d(
|
||||
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
|
||||
self.psamask_collect = PSAMask('collect', mask_size)
|
||||
self.psamask_distribute = PSAMask('distribute', mask_size)
|
||||
else:
|
||||
self.psamask = PSAMask(psa_type, mask_size)
|
||||
self.proj = ConvModule(
|
||||
self.channels * (2 if psa_type == 'bi-direction' else 1),
|
||||
self.in_channels,
|
||||
kernel_size=1,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels * 2,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
identity = x
|
||||
align_corners = self.align_corners
|
||||
if self.psa_type in ['collect', 'distribute']:
|
||||
out = self.reduce(x)
|
||||
n, c, h, w = out.size()
|
||||
if self.shrink_factor != 1:
|
||||
if h % self.shrink_factor and w % self.shrink_factor:
|
||||
h = (h - 1) // self.shrink_factor + 1
|
||||
w = (w - 1) // self.shrink_factor + 1
|
||||
align_corners = True
|
||||
else:
|
||||
h = h // self.shrink_factor
|
||||
w = w // self.shrink_factor
|
||||
align_corners = False
|
||||
out = resize(
|
||||
out,
|
||||
size=(h, w),
|
||||
mode='bilinear',
|
||||
align_corners=align_corners)
|
||||
y = self.attention(out)
|
||||
if self.compact:
|
||||
if self.psa_type == 'collect':
|
||||
y = y.view(n, h * w,
|
||||
h * w).transpose(1, 2).view(n, h * w, h, w)
|
||||
else:
|
||||
y = self.psamask(y)
|
||||
if self.psa_softmax:
|
||||
y = F.softmax(y, dim=1)
|
||||
out = torch.bmm(
|
||||
out.view(n, c, h * w), y.view(n, h * w, h * w)).view(
|
||||
n, c, h, w) * (1.0 / self.normalization_factor)
|
||||
else:
|
||||
x_col = self.reduce(x)
|
||||
x_dis = self.reduce_p(x)
|
||||
n, c, h, w = x_col.size()
|
||||
if self.shrink_factor != 1:
|
||||
if h % self.shrink_factor and w % self.shrink_factor:
|
||||
h = (h - 1) // self.shrink_factor + 1
|
||||
w = (w - 1) // self.shrink_factor + 1
|
||||
align_corners = True
|
||||
else:
|
||||
h = h // self.shrink_factor
|
||||
w = w // self.shrink_factor
|
||||
align_corners = False
|
||||
x_col = resize(
|
||||
x_col,
|
||||
size=(h, w),
|
||||
mode='bilinear',
|
||||
align_corners=align_corners)
|
||||
x_dis = resize(
|
||||
x_dis,
|
||||
size=(h, w),
|
||||
mode='bilinear',
|
||||
align_corners=align_corners)
|
||||
y_col = self.attention(x_col)
|
||||
y_dis = self.attention_p(x_dis)
|
||||
if self.compact:
|
||||
y_dis = y_dis.view(n, h * w,
|
||||
h * w).transpose(1, 2).view(n, h * w, h, w)
|
||||
else:
|
||||
y_col = self.psamask_collect(y_col)
|
||||
y_dis = self.psamask_distribute(y_dis)
|
||||
if self.psa_softmax:
|
||||
y_col = F.softmax(y_col, dim=1)
|
||||
y_dis = F.softmax(y_dis, dim=1)
|
||||
x_col = torch.bmm(
|
||||
x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(
|
||||
n, c, h, w) * (1.0 / self.normalization_factor)
|
||||
x_dis = torch.bmm(
|
||||
x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(
|
||||
n, c, h, w) * (1.0 / self.normalization_factor)
|
||||
out = torch.cat([x_col, x_dis], 1)
|
||||
out = self.proj(out)
|
||||
out = resize(
|
||||
out,
|
||||
size=identity.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=align_corners)
|
||||
out = self.bottleneck(torch.cat((identity, out), dim=1))
|
||||
out = self.cls_seg(out)
|
||||
return out
|
||||
117
finetune/mmseg/models/decode_heads/psp_head.py
Normal file
117
finetune/mmseg/models/decode_heads/psp_head.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class PPM(nn.ModuleList):
|
||||
"""Pooling Pyramid Module used in PSPNet.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
|
||||
act_cfg, align_corners, **kwargs):
|
||||
super().__init__()
|
||||
self.pool_scales = pool_scales
|
||||
self.align_corners = align_corners
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
for pool_scale in pool_scales:
|
||||
self.append(
|
||||
nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(pool_scale),
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
**kwargs)))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
ppm_outs = []
|
||||
for ppm in self:
|
||||
ppm_out = ppm(x)
|
||||
upsampled_ppm_out = resize(
|
||||
ppm_out,
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
ppm_outs.append(upsampled_ppm_out)
|
||||
return ppm_outs
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PSPHead(BaseDecodeHead):
|
||||
"""Pyramid Scene Parsing Network.
|
||||
|
||||
This head is the implementation of
|
||||
`PSPNet <https://arxiv.org/abs/1612.01105>`_.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module. Default: (1, 2, 3, 6).
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(pool_scales, (list, tuple))
|
||||
self.pool_scales = pool_scales
|
||||
self.psp_modules = PPM(
|
||||
self.pool_scales,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels + len(pool_scales) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def _forward_feature(self, inputs):
|
||||
"""Forward function for feature maps before classifying each pixel with
|
||||
``self.cls_seg`` fc.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||
H, W) which is feature map for last layer of decoder head.
|
||||
"""
|
||||
x = self._transform_inputs(inputs)
|
||||
psp_outs = [x]
|
||||
psp_outs.extend(self.psp_modules(x))
|
||||
psp_outs = torch.cat(psp_outs, dim=1)
|
||||
feats = self.bottleneck(psp_outs)
|
||||
return feats
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
output = self._forward_feature(inputs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
736
finetune/mmseg/models/decode_heads/san_head.py
Normal file
736
finetune/mmseg/models/decode_heads/san_head.py
Normal file
@@ -0,0 +1,736 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from functools import partial
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
||||
from mmcv.ops import point_sample
|
||||
from mmengine.dist import all_reduce
|
||||
from mmengine.model.weight_init import (caffe2_xavier_init, normal_init,
|
||||
trunc_normal_)
|
||||
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmseg.models.backbones.vit import TransformerEncoderLayer
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import (ConfigType, MatchMasks, SampleList,
|
||||
seg_data_to_instance_data)
|
||||
from ..utils import (MLP, LayerNorm2d, PatchEmbed, cross_attn_layer,
|
||||
get_uncertain_point_coords_with_randomness, resize)
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class MLPMaskDecoder(nn.Module):
|
||||
"""Module for decoding query and visual features with MLP layers to
|
||||
generate the attention biases and the mask proposals."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels: int,
|
||||
total_heads: int = 1,
|
||||
total_layers: int = 1,
|
||||
embed_channels: int = 256,
|
||||
mlp_channels: int = 256,
|
||||
mlp_num_layers: int = 3,
|
||||
rescale_attn_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.total_heads = total_heads
|
||||
self.total_layers = total_layers
|
||||
|
||||
dense_affine_func = partial(nn.Conv2d, kernel_size=1)
|
||||
# Query Branch
|
||||
self.query_mlp = MLP(in_channels, mlp_channels, embed_channels,
|
||||
mlp_num_layers)
|
||||
# Pixel Branch
|
||||
self.pix_mlp = MLP(
|
||||
in_channels,
|
||||
mlp_channels,
|
||||
embed_channels,
|
||||
mlp_num_layers,
|
||||
affine_func=dense_affine_func,
|
||||
)
|
||||
# Attention Bias Branch
|
||||
self.attn_mlp = MLP(
|
||||
in_channels,
|
||||
mlp_channels,
|
||||
embed_channels * self.total_heads * self.total_layers,
|
||||
mlp_num_layers,
|
||||
affine_func=dense_affine_func,
|
||||
)
|
||||
if rescale_attn_bias:
|
||||
self.bias_scaling = nn.Linear(1, 1)
|
||||
else:
|
||||
self.bias_scaling = nn.Identity()
|
||||
|
||||
def forward(self, query: torch.Tensor,
|
||||
x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Forward function.
|
||||
Args:
|
||||
query (Tensor): Query Tokens [B,N,C].
|
||||
x (Tensor): Visual features [B,C,H,W]
|
||||
|
||||
Return:
|
||||
mask_preds (Tensor): Mask proposals.
|
||||
attn_bias (List[Tensor]): List of attention bias.
|
||||
"""
|
||||
query = self.query_mlp(query)
|
||||
pix = self.pix_mlp(x)
|
||||
b, c, h, w = pix.shape
|
||||
# preidict mask
|
||||
mask_preds = torch.einsum('bqc,bchw->bqhw', query, pix)
|
||||
# generate attn bias
|
||||
attn = self.attn_mlp(x)
|
||||
attn = attn.reshape(b, self.total_layers, self.total_heads, c, h, w)
|
||||
attn_bias = torch.einsum('bqc,blnchw->blnqhw', query, attn)
|
||||
attn_bias = self.bias_scaling(attn_bias[..., None]).squeeze(-1)
|
||||
attn_bias = attn_bias.chunk(self.total_layers, dim=1)
|
||||
attn_bias = [attn.squeeze(1) for attn in attn_bias]
|
||||
return mask_preds, attn_bias
|
||||
|
||||
|
||||
class SideAdapterNetwork(nn.Module):
|
||||
"""Side Adapter Network for predicting mask proposals and attention bias.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
clip_channels (int): Number of channels of visual features.
|
||||
Default: 768.
|
||||
embed_dims (int): embedding dimension. Default: 240.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
patch_bias (bool): Whether use bias in patch embedding.
|
||||
Default: True.
|
||||
num_queries (int): Number of queries for mask proposals.
|
||||
Default: 100.
|
||||
fusion_index (List[int]): The layer number of the encode
|
||||
transformer to fuse with the CLIP feature.
|
||||
Default: [0, 1, 2, 3].
|
||||
cfg_encoder (ConfigType): Configs for the encode layers.
|
||||
cfg_decoder (ConfigType): Configs for the decode layers.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
clip_channels: int = 768,
|
||||
embed_dims: int = 240,
|
||||
patch_size: int = 16,
|
||||
patch_bias: bool = True,
|
||||
num_queries: int = 100,
|
||||
fusion_index: list = [0, 1, 2, 3],
|
||||
cfg_encoder: ConfigType = ...,
|
||||
cfg_decoder: ConfigType = ...,
|
||||
norm_cfg: dict = dict(type='LN'),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
padding=0,
|
||||
input_size=(640, 640),
|
||||
bias=patch_bias,
|
||||
norm_cfg=None,
|
||||
init_cfg=None,
|
||||
)
|
||||
ori_h, ori_w = self.patch_embed.init_out_size
|
||||
num_patches = ori_h * ori_w
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.randn(1, num_patches, embed_dims) * .02)
|
||||
self.query_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_queries, embed_dims))
|
||||
self.query_embed = nn.Parameter(
|
||||
torch.zeros(1, num_queries, embed_dims))
|
||||
encode_layers = []
|
||||
for i in range(cfg_encoder.num_encode_layer):
|
||||
encode_layers.append(
|
||||
TransformerEncoderLayer(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=cfg_encoder.num_heads,
|
||||
feedforward_channels=cfg_encoder.mlp_ratio * embed_dims,
|
||||
norm_cfg=norm_cfg))
|
||||
self.encode_layers = nn.ModuleList(encode_layers)
|
||||
conv_clips = []
|
||||
for i in range(len(fusion_index)):
|
||||
conv_clips.append(
|
||||
nn.Sequential(
|
||||
LayerNorm2d(clip_channels),
|
||||
ConvModule(
|
||||
clip_channels,
|
||||
embed_dims,
|
||||
kernel_size=1,
|
||||
norm_cfg=None,
|
||||
act_cfg=None)))
|
||||
self.conv_clips = nn.ModuleList(conv_clips)
|
||||
self.fusion_index = fusion_index
|
||||
self.mask_decoder = MLPMaskDecoder(
|
||||
in_channels=embed_dims,
|
||||
total_heads=cfg_decoder.num_heads,
|
||||
total_layers=cfg_decoder.num_layers,
|
||||
embed_channels=cfg_decoder.embed_channels,
|
||||
mlp_channels=cfg_decoder.mlp_channels,
|
||||
mlp_num_layers=cfg_decoder.num_mlp,
|
||||
rescale_attn_bias=cfg_decoder.rescale)
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
nn.init.normal_(self.query_embed, std=0.02)
|
||||
nn.init.normal_(self.query_pos_embed, std=0.02)
|
||||
for i in range(len(self.conv_clips)):
|
||||
caffe2_xavier_init(self.conv_clips[i][1].conv)
|
||||
|
||||
def fuse_clip(self, fused_index: int, x: torch.Tensor,
|
||||
clip_feature: torch.Tensor, hwshape: Tuple[int,
|
||||
int], L: int):
|
||||
"""Fuse CLIP feature and visual tokens."""
|
||||
fused_clip = (resize(
|
||||
self.conv_clips[fused_index](clip_feature.contiguous()),
|
||||
size=hwshape,
|
||||
mode='bilinear',
|
||||
align_corners=False)).permute(0, 2, 3, 1).reshape(x[:, -L:,
|
||||
...].shape)
|
||||
x = torch.cat([x[:, :-L, ...], x[:, -L:, ...] + fused_clip], dim=1)
|
||||
return x
|
||||
|
||||
def encode_feature(self, image: torch.Tensor,
|
||||
clip_features: List[torch.Tensor],
|
||||
deep_supervision_idxs: List[int]) -> List[List]:
|
||||
"""Encode images by a lightweight vision transformer."""
|
||||
assert len(self.fusion_index) == len(clip_features)
|
||||
x, hwshape = self.patch_embed(image)
|
||||
ori_h, ori_w = self.patch_embed.init_out_size
|
||||
pos_embed = self.pos_embed
|
||||
if self.pos_embed.shape[1] != x.shape[1]:
|
||||
# resize the position embedding
|
||||
pos_embed = (
|
||||
resize(
|
||||
self.pos_embed.reshape(1, ori_h, ori_w,
|
||||
-1).permute(0, 3, 1, 2),
|
||||
size=hwshape,
|
||||
mode='bicubic',
|
||||
align_corners=False,
|
||||
).flatten(2).permute(0, 2, 1))
|
||||
pos_embed = torch.cat([
|
||||
self.query_pos_embed.expand(pos_embed.shape[0], -1, -1), pos_embed
|
||||
],
|
||||
dim=1)
|
||||
x = torch.cat([self.query_embed.expand(x.shape[0], -1, -1), x], dim=1)
|
||||
x = x + pos_embed
|
||||
L = hwshape[0] * hwshape[1]
|
||||
fused_index = 0
|
||||
if self.fusion_index[fused_index] == 0:
|
||||
x = self.fuse_clip(fused_index, x, clip_features[0][0], hwshape, L)
|
||||
fused_index += 1
|
||||
outs = []
|
||||
for index, block in enumerate(self.encode_layers, start=1):
|
||||
x = block(x)
|
||||
if index < len(self.fusion_index
|
||||
) and index == self.fusion_index[fused_index]:
|
||||
x = self.fuse_clip(fused_index, x,
|
||||
clip_features[fused_index][0], hwshape, L)
|
||||
fused_index += 1
|
||||
x_query = x[:, :-L, ...]
|
||||
x_feat = x[:, -L:, ...].permute(0, 2, 1)\
|
||||
.reshape(x.shape[0], x.shape[-1], hwshape[0], hwshape[1])
|
||||
|
||||
if index in deep_supervision_idxs or index == len(
|
||||
self.encode_layers):
|
||||
outs.append({'query': x_query, 'x': x_feat})
|
||||
|
||||
if index < len(self.encode_layers):
|
||||
x = x + pos_embed
|
||||
return outs
|
||||
|
||||
def decode_feature(self, features):
|
||||
mask_embeds = []
|
||||
attn_biases = []
|
||||
for feature in features:
|
||||
mask_embed, attn_bias = self.mask_decoder(**feature)
|
||||
mask_embeds.append(mask_embed)
|
||||
attn_biases.append(attn_bias)
|
||||
return mask_embeds, attn_biases
|
||||
|
||||
def forward(
|
||||
self, image: torch.Tensor, clip_features: List[torch.Tensor],
|
||||
deep_supervision_idxs: List[int]
|
||||
) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
|
||||
"""Forward function."""
|
||||
features = self.encode_feature(image, clip_features,
|
||||
deep_supervision_idxs)
|
||||
mask_embeds, attn_biases = self.decode_feature(features)
|
||||
return mask_embeds, attn_biases
|
||||
|
||||
|
||||
class RecWithAttnbias(nn.Module):
|
||||
"""Mask recognition module by applying the attention biases to rest deeper
|
||||
CLIP layers.
|
||||
|
||||
Args:
|
||||
sos_token_format (str): The format of sos token. It should be
|
||||
chosen from ["cls_token", "learnable_token", "pos_embedding"].
|
||||
Default: 'cls_token'.
|
||||
sos_token_num (int): Number of sos token. It should be equal to
|
||||
the number of quries. Default: 100.
|
||||
num_layers (int): Number of rest CLIP layers for mask recognition.
|
||||
Default: 3.
|
||||
cross_attn (bool): Whether use cross attention to update sos token.
|
||||
Default: False.
|
||||
embed_dims (int): The feature dimension of CLIP layers.
|
||||
Default: 768.
|
||||
num_heads (int): Parallel attention heads of CLIP layers.
|
||||
Default: 768.
|
||||
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
qkv_bias (bool): Whether to use bias in multihead-attention.
|
||||
Default: True.
|
||||
out_dims (int): Number of channels of the output mask proposals.
|
||||
It should be equal to the out_dims of text_encoder.
|
||||
Default: 512.
|
||||
final_norm (True): Whether use norm layer for sos token.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
frozen_exclude (List): List of parameters that are not to be frozen.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
sos_token_format: str = 'cls_token',
|
||||
sos_token_num: int = 100,
|
||||
num_layers: int = 3,
|
||||
cross_attn: bool = False,
|
||||
embed_dims: int = 768,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: int = 4,
|
||||
num_fcs: int = 2,
|
||||
qkv_bias: bool = True,
|
||||
out_dims: int = 512,
|
||||
final_norm: bool = True,
|
||||
act_cfg: dict = dict(type='GELU'),
|
||||
norm_cfg: dict = dict(type='LN'),
|
||||
frozen_exclude: List = []):
|
||||
super().__init__()
|
||||
|
||||
assert sos_token_format in [
|
||||
'cls_token', 'learnable_token', 'pos_embedding'
|
||||
]
|
||||
self.sos_token_format = sos_token_format
|
||||
self.sos_token_num = sos_token_num
|
||||
self.frozen_exclude = frozen_exclude
|
||||
self.cross_attn = cross_attn
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
if sos_token_format in ['learnable_token', 'pos_embedding']:
|
||||
self.sos_token = nn.Parameter(
|
||||
torch.randn(sos_token_num, 1, self.proj.shape[0]))
|
||||
self.frozen.append('sos_token')
|
||||
|
||||
layers = []
|
||||
for i in range(num_layers):
|
||||
layers.append(
|
||||
BaseTransformerLayer(
|
||||
attn_cfgs=dict(
|
||||
type='MultiheadAttention',
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
batch_first=False,
|
||||
bias=qkv_bias),
|
||||
ffn_cfgs=dict(
|
||||
type='FFN',
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=mlp_ratio * embed_dims,
|
||||
act_cfg=act_cfg),
|
||||
operation_order=('norm', 'self_attn', 'norm', 'ffn')))
|
||||
self.layers = nn.ModuleList(layers)
|
||||
|
||||
self.ln_post = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.proj = nn.Linear(embed_dims, out_dims, bias=False)
|
||||
|
||||
self.final_norm = final_norm
|
||||
self._freeze()
|
||||
|
||||
def init_weights(self, rec_state_dict):
|
||||
if hasattr(self, 'sos_token'):
|
||||
normal_init(self.sos_token, std=0.02)
|
||||
if rec_state_dict is not None:
|
||||
load_state_dict(self, rec_state_dict, strict=False, logger=None)
|
||||
else:
|
||||
super().init_weights()
|
||||
|
||||
def _freeze(self):
|
||||
if 'all' in self.frozen_exclude:
|
||||
return
|
||||
for name, param in self.named_parameters():
|
||||
if not any([exclude in name for exclude in self.frozen_exclude]):
|
||||
param.requires_grad = False
|
||||
|
||||
def _build_attn_biases(self, attn_biases, target_shape):
|
||||
formatted_attn_biases = []
|
||||
for attn_bias in attn_biases:
|
||||
# convert it to proper format: N*num_head,L,L
|
||||
# attn_bias: [N, num_head/1, num_sos,H,W]
|
||||
n, num_head, num_sos, h, w = attn_bias.shape
|
||||
# reshape and downsample
|
||||
attn_bias = F.adaptive_max_pool2d(
|
||||
attn_bias.reshape(n, num_head * num_sos, h, w),
|
||||
output_size=target_shape)
|
||||
attn_bias = attn_bias.reshape(n, num_head, num_sos, *target_shape)
|
||||
|
||||
true_num_head = self.num_heads
|
||||
assert (num_head == 1 or num_head
|
||||
== true_num_head), f'num_head={num_head} is not supported.'
|
||||
if num_head == 1:
|
||||
attn_bias = attn_bias.repeat(1, true_num_head, 1, 1, 1)
|
||||
attn_bias = attn_bias.reshape(n * true_num_head, num_sos, -1)
|
||||
L = attn_bias.shape[-1]
|
||||
if self.cross_attn:
|
||||
# [n*num_head, num_sos, L]
|
||||
formatted_attn_biases.append(attn_bias)
|
||||
else:
|
||||
# [n*num_head, num_sos+1+L, num_sos+1+L]
|
||||
new_attn_bias = attn_bias.new_zeros(num_sos + 1 + L,
|
||||
num_sos + 1 + L)
|
||||
new_attn_bias[:, :num_sos] = -100
|
||||
new_attn_bias[torch.arange(num_sos), torch.arange(num_sos)] = 0
|
||||
new_attn_bias[:num_sos, num_sos] = -100
|
||||
new_attn_bias = (
|
||||
new_attn_bias[None, ...].expand(n * true_num_head, -1,
|
||||
-1).clone())
|
||||
new_attn_bias[..., :num_sos, -L:] = attn_bias
|
||||
formatted_attn_biases.append(new_attn_bias)
|
||||
|
||||
if len(formatted_attn_biases) == 1:
|
||||
formatted_attn_biases = [
|
||||
formatted_attn_biases[0] for _ in range(self.num_layers)
|
||||
]
|
||||
return formatted_attn_biases
|
||||
|
||||
def forward(self, bias: List[Tensor], feature: List[Tensor]):
|
||||
"""Forward function to recognize the category of masks
|
||||
Args:
|
||||
bias (List[Tensor]): Attention bias for transformer layers
|
||||
feature (List[Tensor]): Output of the image encoder,
|
||||
including cls_token and img_feature.
|
||||
"""
|
||||
cls_token = feature[1].unsqueeze(0)
|
||||
img_feature = feature[0]
|
||||
b, c, h, w = img_feature.shape
|
||||
# construct clip shadow features
|
||||
x = torch.cat(
|
||||
[cls_token,
|
||||
img_feature.reshape(b, c, -1).permute(2, 0, 1)])
|
||||
|
||||
# construct sos token
|
||||
if self.sos_token_format == 'cls_token':
|
||||
sos_token = cls_token.repeat(self.sos_token_num, 1, 1)
|
||||
elif self.sos_token_format == 'learnable_token':
|
||||
sos_token = self.sos_token.expand(-1, b, -1)
|
||||
elif self.sos_token_format == 'pos_embedding':
|
||||
sos_token = self.sos_token.expand(-1, b, -1) + cls_token
|
||||
|
||||
# construct attn bias
|
||||
attn_biases = self._build_attn_biases(bias, target_shape=(h, w))
|
||||
|
||||
if self.cross_attn:
|
||||
for i, block in enumerate(self.layers):
|
||||
if self.cross_attn:
|
||||
sos_token = cross_attn_layer(
|
||||
block,
|
||||
sos_token,
|
||||
x[1:, ],
|
||||
attn_biases[i],
|
||||
)
|
||||
if i < len(self.layers) - 1:
|
||||
x = block(x)
|
||||
else:
|
||||
x = torch.cat([sos_token, x], dim=0)
|
||||
for i, block in enumerate(self.layers):
|
||||
x = block(x, attn_masks=[attn_biases[i]])
|
||||
sos_token = x[:self.sos_token_num]
|
||||
|
||||
sos_token = sos_token.permute(1, 0, 2) # LND -> NLD
|
||||
sos_token = self.ln_post(sos_token)
|
||||
sos_token = self.proj(sos_token)
|
||||
if self.final_norm:
|
||||
sos_token = F.normalize(sos_token, dim=-1)
|
||||
return sos_token
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SideAdapterCLIPHead(BaseDecodeHead):
|
||||
"""Side Adapter Network (SAN) for open-vocabulary semantic segmentation
|
||||
with pre-trained vision-language model.
|
||||
|
||||
This decode head is the implementation of `Side Adapter Network
|
||||
for Open-Vocabulary Semantic Segmentation`
|
||||
<https://arxiv.org/abs/2302.12242>.
|
||||
Modified from https://github.com/MendelXu/SAN/blob/main/san/model/side_adapter/side_adapter.py # noqa:E501
|
||||
Copyright (c) 2023 MendelXu.
|
||||
Licensed under the MIT License
|
||||
|
||||
Args:
|
||||
num_classes (int): the number of classes.
|
||||
san_cfg (ConfigType): Configs for SideAdapterNetwork module
|
||||
maskgen_cfg (ConfigType): Configs for RecWithAttnbias module
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes: int, san_cfg: ConfigType,
|
||||
maskgen_cfg: ConfigType, deep_supervision_idxs: List[int],
|
||||
train_cfg: ConfigType, **kwargs):
|
||||
super().__init__(
|
||||
in_channels=san_cfg.in_channels,
|
||||
channels=san_cfg.embed_dims,
|
||||
num_classes=num_classes,
|
||||
**kwargs)
|
||||
assert san_cfg.num_queries == maskgen_cfg.sos_token_num, \
|
||||
'num_queries in san_cfg should be equal to sos_token_num ' \
|
||||
'in maskgen_cfg'
|
||||
del self.conv_seg
|
||||
self.side_adapter_network = SideAdapterNetwork(**san_cfg)
|
||||
self.rec_with_attnbias = RecWithAttnbias(**maskgen_cfg)
|
||||
self.deep_supervision_idxs = deep_supervision_idxs
|
||||
self.train_cfg = train_cfg
|
||||
if train_cfg:
|
||||
self.match_masks = MatchMasks(
|
||||
num_points=train_cfg.num_points,
|
||||
num_queries=san_cfg.num_queries,
|
||||
num_classes=num_classes,
|
||||
assigner=train_cfg.assigner)
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
rec_state_dict = None
|
||||
if isinstance(self.init_cfg, dict) and \
|
||||
self.init_cfg.get('type') == 'Pretrained_Part':
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
|
||||
rec_state_dict = checkpoint.copy()
|
||||
para_prefix = 'decode_head.rec_with_attnbias'
|
||||
prefix_len = len(para_prefix) + 1
|
||||
for k, v in checkpoint.items():
|
||||
rec_state_dict.pop(k)
|
||||
if para_prefix in k:
|
||||
rec_state_dict[k[prefix_len:]] = v
|
||||
|
||||
self.side_adapter_network.init_weights()
|
||||
self.rec_with_attnbias.init_weights(rec_state_dict)
|
||||
|
||||
def forward(self, inputs: Tuple[Tensor],
|
||||
deep_supervision_idxs) -> Tuple[List]:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
inputs (Tuple[Tensor]): A triplet including images,
|
||||
list of multi-level visual features from image encoder and
|
||||
class embeddings from text_encoder.
|
||||
|
||||
Returns:
|
||||
mask_props (List[Tensor]): Mask proposals predicted by SAN.
|
||||
mask_logits (List[Tensor]): Class logits of mask proposals.
|
||||
"""
|
||||
imgs, clip_feature, class_embeds = inputs
|
||||
# predict mask proposals and attention bias
|
||||
mask_props, attn_biases = self.side_adapter_network(
|
||||
imgs, clip_feature, deep_supervision_idxs)
|
||||
|
||||
# mask recognition with attention bias
|
||||
mask_embeds = [
|
||||
self.rec_with_attnbias(att_bias, clip_feature[-1])
|
||||
for att_bias in attn_biases
|
||||
]
|
||||
# Obtain class prediction of masks by comparing the similarity
|
||||
# between the image token and the text embedding of class names.
|
||||
mask_logits = [
|
||||
torch.einsum('bqc,nc->bqn', mask_embed, class_embeds)
|
||||
for mask_embed in mask_embeds
|
||||
]
|
||||
return mask_props, mask_logits
|
||||
|
||||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType) -> Tensor:
|
||||
"""Forward function for prediction.
|
||||
|
||||
Args:
|
||||
inputs (Tuple[Tensor]): Images, visual features from image encoder
|
||||
and class embedding from text encoder.
|
||||
batch_img_metas (dict): List Image info where each dict may also
|
||||
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Outputs segmentation logits map.
|
||||
"""
|
||||
mask_props, mask_logits = self.forward(inputs, [])
|
||||
|
||||
return self.predict_by_feat([mask_props[-1], mask_logits[-1]],
|
||||
batch_img_metas)
|
||||
|
||||
def predict_by_feat(self, seg_logits: List[Tensor],
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""1. Transform a batch of mask proposals to the input shape.
|
||||
2. Generate segmentation map with mask proposals and class logits.
|
||||
"""
|
||||
mask_pred = seg_logits[0]
|
||||
cls_score = seg_logits[1]
|
||||
if isinstance(batch_img_metas[0]['img_shape'], torch.Size):
|
||||
# slide inference
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
elif 'pad_shape' in batch_img_metas[0]:
|
||||
size = batch_img_metas[0]['pad_shape'][:2]
|
||||
else:
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
# upsample mask
|
||||
mask_pred = F.interpolate(
|
||||
mask_pred, size=size, mode='bilinear', align_corners=False)
|
||||
|
||||
mask_cls = F.softmax(cls_score, dim=-1)[..., :-1]
|
||||
mask_pred = mask_pred.sigmoid()
|
||||
seg_logits = torch.einsum('bqc,bqhw->bchw', mask_cls, mask_pred)
|
||||
return seg_logits
|
||||
|
||||
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
|
||||
train_cfg: ConfigType) -> dict:
|
||||
"""Perform forward propagation and loss calculation of the decoder head
|
||||
on the features of the upstream network.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the upstream
|
||||
network, each is a 4D-tensor.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
train_cfg (ConfigType): Training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components.
|
||||
"""
|
||||
# batch SegDataSample to InstanceDataSample
|
||||
batch_gt_instances = seg_data_to_instance_data(self.ignore_index,
|
||||
batch_data_samples)
|
||||
|
||||
# forward
|
||||
all_mask_props, all_mask_logits = self.forward(
|
||||
x, self.deep_supervision_idxs)
|
||||
|
||||
# loss
|
||||
losses = self.loss_by_feat(all_mask_logits, all_mask_props,
|
||||
batch_gt_instances)
|
||||
|
||||
return losses
|
||||
|
||||
def loss_by_feat(
|
||||
self, all_cls_scores: Tensor, all_mask_preds: Tensor,
|
||||
batch_gt_instances: List[InstanceData]) -> Dict[str, Tensor]:
|
||||
"""Loss function.
|
||||
|
||||
Args:
|
||||
all_cls_scores (Tensor): Classification scores for all decoder
|
||||
layers with shape (num_decoder, batch_size, num_queries,
|
||||
cls_out_channels). Note `cls_out_channels` should includes
|
||||
background.
|
||||
all_mask_preds (Tensor): Mask scores for all decoder layers with
|
||||
shape (num_decoder, batch_size, num_queries, h, w).
|
||||
batch_gt_instances (list[obj:`InstanceData`]): each contains
|
||||
``labels`` and ``masks``.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
num_dec_layers = len(all_cls_scores)
|
||||
batch_gt_instances_list = [
|
||||
batch_gt_instances for _ in range(num_dec_layers)
|
||||
]
|
||||
|
||||
losses = []
|
||||
for i in range(num_dec_layers):
|
||||
cls_scores = all_cls_scores[i]
|
||||
mask_preds = all_mask_preds[i]
|
||||
# matching N mask predictions to K category labels
|
||||
(labels, mask_targets, mask_weights,
|
||||
avg_factor) = self.match_masks.get_targets(
|
||||
cls_scores, mask_preds, batch_gt_instances_list[i])
|
||||
cls_scores = cls_scores.flatten(0, 1)
|
||||
labels = labels.flatten(0, 1)
|
||||
num_total_masks = cls_scores.new_tensor([avg_factor],
|
||||
dtype=torch.float)
|
||||
all_reduce(num_total_masks, op='mean')
|
||||
num_total_masks = max(num_total_masks, 1)
|
||||
|
||||
# extract positive ones
|
||||
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
|
||||
mask_preds = mask_preds[mask_weights > 0]
|
||||
|
||||
if mask_targets.shape[0] != 0:
|
||||
with torch.no_grad():
|
||||
points_coords = get_uncertain_point_coords_with_randomness(
|
||||
mask_preds.unsqueeze(1), None,
|
||||
self.train_cfg.num_points,
|
||||
self.train_cfg.oversample_ratio,
|
||||
self.train_cfg.importance_sample_ratio)
|
||||
# shape (num_total_gts, h, w)
|
||||
# -> (num_total_gts, num_points)
|
||||
mask_point_targets = point_sample(
|
||||
mask_targets.unsqueeze(1).float(),
|
||||
points_coords).squeeze(1)
|
||||
# shape (num_queries, h, w) -> (num_queries, num_points)
|
||||
mask_point_preds = point_sample(
|
||||
mask_preds.unsqueeze(1), points_coords).squeeze(1)
|
||||
|
||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.loss_decode]
|
||||
else:
|
||||
losses_decode = self.loss_decode
|
||||
loss = dict()
|
||||
for loss_decode in losses_decode:
|
||||
if 'loss_cls' in loss_decode.loss_name:
|
||||
if loss_decode.loss_name == 'loss_cls_ce':
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
cls_scores, labels)
|
||||
else:
|
||||
assert False, "Only support 'CrossEntropyLoss' in" \
|
||||
' classification loss'
|
||||
|
||||
elif 'loss_mask' in loss_decode.loss_name:
|
||||
if mask_targets.shape[0] == 0:
|
||||
loss[loss_decode.loss_name] = mask_preds.sum()
|
||||
elif loss_decode.loss_name == 'loss_mask_ce':
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
mask_point_preds,
|
||||
mask_point_targets,
|
||||
avg_factor=num_total_masks *
|
||||
self.train_cfg.num_points)
|
||||
elif loss_decode.loss_name == 'loss_mask_dice':
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
mask_point_preds,
|
||||
mask_point_targets,
|
||||
avg_factor=num_total_masks)
|
||||
else:
|
||||
assert False, "Only support 'CrossEntropyLoss' and" \
|
||||
" 'DiceLoss' in mask loss"
|
||||
else:
|
||||
assert False, "Only support for 'loss_cls' and 'loss_mask'"
|
||||
|
||||
losses.append(loss)
|
||||
|
||||
loss_dict = dict()
|
||||
# loss from the last decoder layer
|
||||
loss_dict.update(losses[-1])
|
||||
# loss from other decoder layers
|
||||
for i, loss in enumerate(losses[:-1]):
|
||||
for k, v in loss.items():
|
||||
loss_dict[f'd{self.deep_supervision_idxs[i]}.{k}'] = v
|
||||
return loss_dict
|
||||
66
finetune/mmseg/models/decode_heads/segformer_head.py
Normal file
66
finetune/mmseg/models/decode_heads/segformer_head.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SegformerHead(BaseDecodeHead):
|
||||
"""The all mlp Head of segformer.
|
||||
|
||||
This head is the implementation of
|
||||
`Segformer <https://arxiv.org/abs/2105.15203>` _.
|
||||
|
||||
Args:
|
||||
interpolate_mode: The interpolate mode of MLP head upsample operation.
|
||||
Default: 'bilinear'.
|
||||
"""
|
||||
|
||||
def __init__(self, interpolate_mode='bilinear', **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
|
||||
self.interpolate_mode = interpolate_mode
|
||||
num_inputs = len(self.in_channels)
|
||||
|
||||
assert num_inputs == len(self.in_index)
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(num_inputs):
|
||||
self.convs.append(
|
||||
ConvModule(
|
||||
in_channels=self.in_channels[i],
|
||||
out_channels=self.channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
|
||||
self.fusion_conv = ConvModule(
|
||||
in_channels=self.channels * num_inputs,
|
||||
out_channels=self.channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=self.norm_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
# Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
|
||||
inputs = self._transform_inputs(inputs)
|
||||
outs = []
|
||||
for idx in range(len(inputs)):
|
||||
x = inputs[idx]
|
||||
conv = self.convs[idx]
|
||||
outs.append(
|
||||
resize(
|
||||
input=conv(x),
|
||||
size=inputs[0].shape[2:],
|
||||
mode=self.interpolate_mode,
|
||||
align_corners=self.align_corners))
|
||||
|
||||
out = self.fusion_conv(torch.cat(outs, dim=1))
|
||||
|
||||
out = self.cls_seg(out)
|
||||
|
||||
return out
|
||||
132
finetune/mmseg/models/decode_heads/segmenter_mask_head.py
Normal file
132
finetune/mmseg/models/decode_heads/segmenter_mask_head.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmengine.model import ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, trunc_normal_,
|
||||
trunc_normal_init)
|
||||
|
||||
from mmseg.models.backbones.vit import TransformerEncoderLayer
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SegmenterMaskTransformerHead(BaseDecodeHead):
|
||||
"""Segmenter: Transformer for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of
|
||||
`Segmenter: <https://arxiv.org/abs/2105.05633>`_.
|
||||
|
||||
Args:
|
||||
backbone_cfg:(dict): Config of backbone of
|
||||
Context Path.
|
||||
in_channels (int): The number of channels of input image.
|
||||
num_layers (int): The depth of transformer.
|
||||
num_heads (int): The number of attention heads.
|
||||
embed_dims (int): The number of embedding dimension.
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.1.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.0
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: True.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
init_std (float): The value of std in weight initialization.
|
||||
Default: 0.02.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
num_layers,
|
||||
num_heads,
|
||||
embed_dims,
|
||||
mlp_ratio=4,
|
||||
drop_path_rate=0.1,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
num_fcs=2,
|
||||
qkv_bias=True,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
init_std=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(in_channels=in_channels, **kwargs)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
|
||||
self.layers = ModuleList()
|
||||
for i in range(num_layers):
|
||||
self.layers.append(
|
||||
TransformerEncoderLayer(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=mlp_ratio * embed_dims,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
batch_first=True,
|
||||
))
|
||||
|
||||
self.dec_proj = nn.Linear(in_channels, embed_dims)
|
||||
|
||||
self.cls_emb = nn.Parameter(
|
||||
torch.randn(1, self.num_classes, embed_dims))
|
||||
self.patch_proj = nn.Linear(embed_dims, embed_dims, bias=False)
|
||||
self.classes_proj = nn.Linear(embed_dims, embed_dims, bias=False)
|
||||
|
||||
self.decoder_norm = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=1)[1]
|
||||
self.mask_norm = build_norm_layer(
|
||||
norm_cfg, self.num_classes, postfix=2)[1]
|
||||
|
||||
self.init_std = init_std
|
||||
|
||||
delattr(self, 'conv_seg')
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.cls_emb, std=self.init_std)
|
||||
trunc_normal_init(self.patch_proj, std=self.init_std)
|
||||
trunc_normal_init(self.classes_proj, std=self.init_std)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=self.init_std, bias=0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
constant_init(m, val=1.0, bias=0.0)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._transform_inputs(inputs)
|
||||
b, c, h, w = x.shape
|
||||
x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c)
|
||||
|
||||
x = self.dec_proj(x)
|
||||
cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
|
||||
x = torch.cat((x, cls_emb), 1)
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
patches = self.patch_proj(x[:, :-self.num_classes])
|
||||
cls_seg_feat = self.classes_proj(x[:, -self.num_classes:])
|
||||
|
||||
patches = F.normalize(patches, dim=2, p=2)
|
||||
cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2)
|
||||
|
||||
masks = patches @ cls_seg_feat.transpose(1, 2)
|
||||
masks = self.mask_norm(masks)
|
||||
masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w)
|
||||
|
||||
return masks
|
||||
102
finetune/mmseg/models/decode_heads/sep_aspp_head.py
Normal file
102
finetune/mmseg/models/decode_heads/sep_aspp_head.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .aspp_head import ASPPHead, ASPPModule
|
||||
|
||||
|
||||
class DepthwiseSeparableASPPModule(ASPPModule):
|
||||
"""Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable
|
||||
conv."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for i, dilation in enumerate(self.dilations):
|
||||
if dilation > 1:
|
||||
self[i] = DepthwiseSeparableConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
dilation=dilation,
|
||||
padding=dilation,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DepthwiseSeparableASPPHead(ASPPHead):
|
||||
"""Encoder-Decoder with Atrous Separable Convolution for Semantic Image
|
||||
Segmentation.
|
||||
|
||||
This head is the implementation of `DeepLabV3+
|
||||
<https://arxiv.org/abs/1802.02611>`_.
|
||||
|
||||
Args:
|
||||
c1_in_channels (int): The input channels of c1 decoder. If is 0,
|
||||
the no decoder will be used.
|
||||
c1_channels (int): The intermediate channels of c1 decoder.
|
||||
"""
|
||||
|
||||
def __init__(self, c1_in_channels, c1_channels, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert c1_in_channels >= 0
|
||||
self.aspp_modules = DepthwiseSeparableASPPModule(
|
||||
dilations=self.dilations,
|
||||
in_channels=self.in_channels,
|
||||
channels=self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
if c1_in_channels > 0:
|
||||
self.c1_bottleneck = ConvModule(
|
||||
c1_in_channels,
|
||||
c1_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
else:
|
||||
self.c1_bottleneck = None
|
||||
self.sep_bottleneck = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
self.channels + c1_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
DepthwiseSeparableConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
aspp_outs = [
|
||||
resize(
|
||||
self.image_pool(x),
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
]
|
||||
aspp_outs.extend(self.aspp_modules(x))
|
||||
aspp_outs = torch.cat(aspp_outs, dim=1)
|
||||
output = self.bottleneck(aspp_outs)
|
||||
if self.c1_bottleneck is not None:
|
||||
c1_output = self.c1_bottleneck(inputs[0])
|
||||
output = resize(
|
||||
input=output,
|
||||
size=c1_output.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
output = torch.cat([output, c1_output], dim=1)
|
||||
output = self.sep_bottleneck(output)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
60
finetune/mmseg/models/decode_heads/sep_fcn_head.py
Normal file
60
finetune/mmseg/models/decode_heads/sep_fcn_head.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.cnn import DepthwiseSeparableConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DepthwiseSeparableFCNHead(FCNHead):
|
||||
"""Depthwise-Separable Fully Convolutional Network for Semantic
|
||||
Segmentation.
|
||||
|
||||
This head is implemented according to `Fast-SCNN: Fast Semantic
|
||||
Segmentation Network <https://arxiv.org/abs/1902.04502>`_.
|
||||
|
||||
Args:
|
||||
in_channels(int): Number of output channels of FFM.
|
||||
channels(int): Number of middle-stage channels in the decode head.
|
||||
concat_input(bool): Whether to concatenate original decode input into
|
||||
the result of several consecutive convolution layers.
|
||||
Default: True.
|
||||
num_classes(int): Used to determine the dimension of
|
||||
final prediction tensor.
|
||||
in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
loss_decode(dict): Config of loss type and some
|
||||
relevant additional options.
|
||||
dw_act_cfg (dict):Activation config of depthwise ConvModule. If it is
|
||||
'default', it will be the same as `act_cfg`. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, dw_act_cfg=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.convs[0] = DepthwiseSeparableConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.kernel_size // 2,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dw_act_cfg=dw_act_cfg)
|
||||
|
||||
for i in range(1, self.num_convs):
|
||||
self.convs[i] = DepthwiseSeparableConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.kernel_size // 2,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dw_act_cfg=dw_act_cfg)
|
||||
|
||||
if self.concat_input:
|
||||
self.conv_cat = DepthwiseSeparableConvModule(
|
||||
self.in_channels + self.channels,
|
||||
self.channels,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.kernel_size // 2,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dw_act_cfg=dw_act_cfg)
|
||||
62
finetune/mmseg/models/decode_heads/setr_mla_head.py
Normal file
62
finetune/mmseg/models/decode_heads/setr_mla_head.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SETRMLAHead(BaseDecodeHead):
|
||||
"""Multi level feature aggretation head of SETR.
|
||||
|
||||
MLA head of `SETR <https://arxiv.org/pdf/2012.15840.pdf>`_.
|
||||
|
||||
Args:
|
||||
mlahead_channels (int): Channels of conv-conv-4x of multi-level feature
|
||||
aggregation. Default: 128.
|
||||
up_scale (int): The scale factor of interpolate. Default:4.
|
||||
"""
|
||||
|
||||
def __init__(self, mla_channels=128, up_scale=4, **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
self.mla_channels = mla_channels
|
||||
|
||||
num_inputs = len(self.in_channels)
|
||||
|
||||
# Refer to self.cls_seg settings of BaseDecodeHead
|
||||
assert self.channels == num_inputs * mla_channels
|
||||
|
||||
self.up_convs = nn.ModuleList()
|
||||
for i in range(num_inputs):
|
||||
self.up_convs.append(
|
||||
nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.in_channels[i],
|
||||
out_channels=mla_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
in_channels=mla_channels,
|
||||
out_channels=mla_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
Upsample(
|
||||
scale_factor=up_scale,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)))
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs = self._transform_inputs(inputs)
|
||||
outs = []
|
||||
for x, up_conv in zip(inputs, self.up_convs):
|
||||
outs.append(up_conv(x))
|
||||
out = torch.cat(outs, dim=1)
|
||||
out = self.cls_seg(out)
|
||||
return out
|
||||
81
finetune/mmseg/models/decode_heads/setr_up_head.py
Normal file
81
finetune/mmseg/models/decode_heads/setr_up_head.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SETRUPHead(BaseDecodeHead):
|
||||
"""Naive upsampling head and Progressive upsampling head of SETR.
|
||||
|
||||
Naive or PUP head of `SETR <https://arxiv.org/pdf/2012.15840.pdf>`_.
|
||||
|
||||
Args:
|
||||
norm_layer (dict): Config dict for input normalization.
|
||||
Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True).
|
||||
num_convs (int): Number of decoder convolutions. Default: 1.
|
||||
up_scale (int): The scale factor of interpolate. Default:4.
|
||||
kernel_size (int): The kernel size of convolution when decoding
|
||||
feature information from backbone. Default: 3.
|
||||
init_cfg (dict | list[dict] | None): Initialization config dict.
|
||||
Default: dict(
|
||||
type='Constant', val=1.0, bias=0, layer='LayerNorm').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
norm_layer=dict(type='LN', eps=1e-6, requires_grad=True),
|
||||
num_convs=1,
|
||||
up_scale=4,
|
||||
kernel_size=3,
|
||||
init_cfg=[
|
||||
dict(type='Constant', val=1.0, bias=0, layer='LayerNorm'),
|
||||
dict(
|
||||
type='Normal',
|
||||
std=0.01,
|
||||
override=dict(name='conv_seg'))
|
||||
],
|
||||
**kwargs):
|
||||
|
||||
assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.'
|
||||
|
||||
super().__init__(init_cfg=init_cfg, **kwargs)
|
||||
|
||||
assert isinstance(self.in_channels, int)
|
||||
|
||||
_, self.norm = build_norm_layer(norm_layer, self.in_channels)
|
||||
|
||||
self.up_convs = nn.ModuleList()
|
||||
in_channels = self.in_channels
|
||||
out_channels = self.channels
|
||||
for _ in range(num_convs):
|
||||
self.up_convs.append(
|
||||
nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=int(kernel_size - 1) // 2,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
Upsample(
|
||||
scale_factor=up_scale,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)))
|
||||
in_channels = out_channels
|
||||
|
||||
def forward(self, x):
|
||||
x = self._transform_inputs(x)
|
||||
|
||||
n, c, h, w = x.shape
|
||||
x = x.reshape(n, c, h * w).transpose(2, 1).contiguous()
|
||||
x = self.norm(x)
|
||||
x = x.transpose(1, 2).reshape(n, c, h, w).contiguous()
|
||||
|
||||
for up_conv in self.up_convs:
|
||||
x = up_conv(x)
|
||||
out = self.cls_seg(x)
|
||||
return out
|
||||
97
finetune/mmseg/models/decode_heads/stdc_head.py
Normal file
97
finetune/mmseg/models/decode_heads/stdc_head.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.structures import PixelData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import SampleList
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class STDCHead(FCNHead):
|
||||
"""This head is the implementation of `Rethinking BiSeNet For Real-time
|
||||
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
|
||||
|
||||
Args:
|
||||
boundary_threshold (float): The threshold of calculating boundary.
|
||||
Default: 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self, boundary_threshold=0.1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.boundary_threshold = boundary_threshold
|
||||
# Using register buffer to make laplacian kernel on the same
|
||||
# device of `seg_label`.
|
||||
self.register_buffer(
|
||||
'laplacian_kernel',
|
||||
torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1],
|
||||
dtype=torch.float32,
|
||||
requires_grad=False).reshape((1, 1, 3, 3)))
|
||||
self.fusion_kernel = torch.nn.Parameter(
|
||||
torch.tensor([[6. / 10], [3. / 10], [1. / 10]],
|
||||
dtype=torch.float32).reshape(1, 3, 1, 1),
|
||||
requires_grad=False)
|
||||
|
||||
def loss_by_feat(self, seg_logits: Tensor,
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
"""Compute Detail Aggregation Loss."""
|
||||
# Note: The paper claims `fusion_kernel` is a trainable 1x1 conv
|
||||
# parameters. However, it is a constant in original repo and other
|
||||
# codebase because it would not be added into computation graph
|
||||
# after threshold operation.
|
||||
seg_label = self._stack_batch_gt(batch_data_samples).to(
|
||||
self.laplacian_kernel)
|
||||
boundary_targets = F.conv2d(
|
||||
seg_label, self.laplacian_kernel, padding=1)
|
||||
boundary_targets = boundary_targets.clamp(min=0)
|
||||
boundary_targets[boundary_targets > self.boundary_threshold] = 1
|
||||
boundary_targets[boundary_targets <= self.boundary_threshold] = 0
|
||||
|
||||
boundary_targets_x2 = F.conv2d(
|
||||
seg_label, self.laplacian_kernel, stride=2, padding=1)
|
||||
boundary_targets_x2 = boundary_targets_x2.clamp(min=0)
|
||||
|
||||
boundary_targets_x4 = F.conv2d(
|
||||
seg_label, self.laplacian_kernel, stride=4, padding=1)
|
||||
boundary_targets_x4 = boundary_targets_x4.clamp(min=0)
|
||||
|
||||
boundary_targets_x4_up = F.interpolate(
|
||||
boundary_targets_x4, boundary_targets.shape[2:], mode='nearest')
|
||||
boundary_targets_x2_up = F.interpolate(
|
||||
boundary_targets_x2, boundary_targets.shape[2:], mode='nearest')
|
||||
|
||||
boundary_targets_x2_up[
|
||||
boundary_targets_x2_up > self.boundary_threshold] = 1
|
||||
boundary_targets_x2_up[
|
||||
boundary_targets_x2_up <= self.boundary_threshold] = 0
|
||||
|
||||
boundary_targets_x4_up[
|
||||
boundary_targets_x4_up > self.boundary_threshold] = 1
|
||||
boundary_targets_x4_up[
|
||||
boundary_targets_x4_up <= self.boundary_threshold] = 0
|
||||
|
||||
boundary_targets_pyramids = torch.stack(
|
||||
(boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up),
|
||||
dim=1)
|
||||
|
||||
boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2)
|
||||
boudary_targets_pyramid = F.conv2d(boundary_targets_pyramids,
|
||||
self.fusion_kernel)
|
||||
|
||||
boudary_targets_pyramid[
|
||||
boudary_targets_pyramid > self.boundary_threshold] = 1
|
||||
boudary_targets_pyramid[
|
||||
boudary_targets_pyramid <= self.boundary_threshold] = 0
|
||||
|
||||
seg_labels = boudary_targets_pyramid.long()
|
||||
batch_sample_list = []
|
||||
for label in seg_labels:
|
||||
seg_data_sample = SegDataSample()
|
||||
seg_data_sample.gt_sem_seg = PixelData(data=label)
|
||||
batch_sample_list.append(seg_data_sample)
|
||||
|
||||
loss = super().loss_by_feat(seg_logits, batch_sample_list)
|
||||
return loss
|
||||
139
finetune/mmseg/models/decode_heads/uper_head.py
Normal file
139
finetune/mmseg/models/decode_heads/uper_head.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
from .psp_head import PPM
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class UPerHead(BaseDecodeHead):
|
||||
"""Unified Perceptual Parsing for Scene Understanding.
|
||||
|
||||
This head is the implementation of `UPerNet
|
||||
<https://arxiv.org/abs/1807.10221>`_.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module applied on the last feature. Default: (1, 2, 3, 6).
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
# PSP Module
|
||||
self.psp_modules = PPM(
|
||||
pool_scales,
|
||||
self.in_channels[-1],
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels[-1] + len(pool_scales) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
# FPN Module
|
||||
self.lateral_convs = nn.ModuleList()
|
||||
self.fpn_convs = nn.ModuleList()
|
||||
for in_channels in self.in_channels[:-1]: # skip the top layer
|
||||
l_conv = ConvModule(
|
||||
in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
inplace=False)
|
||||
fpn_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
inplace=False)
|
||||
self.lateral_convs.append(l_conv)
|
||||
self.fpn_convs.append(fpn_conv)
|
||||
|
||||
self.fpn_bottleneck = ConvModule(
|
||||
len(self.in_channels) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def psp_forward(self, inputs):
|
||||
"""Forward function of PSP module."""
|
||||
x = inputs[-1]
|
||||
psp_outs = [x]
|
||||
psp_outs.extend(self.psp_modules(x))
|
||||
psp_outs = torch.cat(psp_outs, dim=1)
|
||||
output = self.bottleneck(psp_outs)
|
||||
|
||||
return output
|
||||
|
||||
def _forward_feature(self, inputs):
|
||||
"""Forward function for feature maps before classifying each pixel with
|
||||
``self.cls_seg`` fc.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||
H, W) which is feature map for last layer of decoder head.
|
||||
"""
|
||||
inputs = self._transform_inputs(inputs)
|
||||
|
||||
# build laterals
|
||||
laterals = [
|
||||
lateral_conv(inputs[i])
|
||||
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||
]
|
||||
|
||||
laterals.append(self.psp_forward(inputs))
|
||||
|
||||
# build top-down path
|
||||
used_backbone_levels = len(laterals)
|
||||
for i in range(used_backbone_levels - 1, 0, -1):
|
||||
prev_shape = laterals[i - 1].shape[2:]
|
||||
laterals[i - 1] = laterals[i - 1] + resize(
|
||||
laterals[i],
|
||||
size=prev_shape,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
# build outputs
|
||||
fpn_outs = [
|
||||
self.fpn_convs[i](laterals[i])
|
||||
for i in range(used_backbone_levels - 1)
|
||||
]
|
||||
# append psp feature
|
||||
fpn_outs.append(laterals[-1])
|
||||
|
||||
for i in range(used_backbone_levels - 1, 0, -1):
|
||||
fpn_outs[i] = resize(
|
||||
fpn_outs[i],
|
||||
size=fpn_outs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
fpn_outs = torch.cat(fpn_outs, dim=1)
|
||||
feats = self.fpn_bottleneck(fpn_outs)
|
||||
return feats
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
output = self._forward_feature(inputs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
253
finetune/mmseg/models/decode_heads/vpd_depth_head.py
Normal file
253
finetune/mmseg/models/decode_heads/vpd_depth_head.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer
|
||||
from mmengine.model import BaseModule
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class VPDDepthDecoder(BaseModule):
|
||||
"""VPD Depth Decoder class.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
num_deconv_layers (int): Number of deconvolution layers.
|
||||
num_deconv_filters (List[int]): List of output channels for
|
||||
deconvolution layers.
|
||||
init_cfg (Optional[Union[Dict, List[Dict]]], optional): Configuration
|
||||
for weight initialization. Defaults to Normal for Conv2d and
|
||||
ConvTranspose2d layers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_deconv_layers: int,
|
||||
num_deconv_filters: List[int],
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
|
||||
type='Normal',
|
||||
std=0.001,
|
||||
layer=['Conv2d', 'ConvTranspose2d'])):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.deconv_layers = self._make_deconv_layer(
|
||||
num_deconv_layers,
|
||||
num_deconv_filters,
|
||||
)
|
||||
|
||||
conv_layers = []
|
||||
conv_layers.append(
|
||||
build_conv_layer(
|
||||
dict(type='Conv2d'),
|
||||
in_channels=num_deconv_filters[-1],
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1))
|
||||
conv_layers.append(build_norm_layer(dict(type='BN'), out_channels)[1])
|
||||
conv_layers.append(nn.ReLU(inplace=True))
|
||||
self.conv_layers = nn.Sequential(*conv_layers)
|
||||
|
||||
self.up_sample = nn.Upsample(
|
||||
scale_factor=2, mode='bilinear', align_corners=False)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass through the decoder network."""
|
||||
out = self.deconv_layers(x)
|
||||
out = self.conv_layers(out)
|
||||
|
||||
out = self.up_sample(out)
|
||||
out = self.up_sample(out)
|
||||
|
||||
return out
|
||||
|
||||
def _make_deconv_layer(self, num_layers, num_deconv_filters):
|
||||
"""Make deconv layers."""
|
||||
|
||||
layers = []
|
||||
in_channels = self.in_channels
|
||||
for i in range(num_layers):
|
||||
|
||||
num_channels = num_deconv_filters[i]
|
||||
layers.append(
|
||||
build_upsample_layer(
|
||||
dict(type='deconv'),
|
||||
in_channels=in_channels,
|
||||
out_channels=num_channels,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
output_padding=0,
|
||||
bias=False))
|
||||
layers.append(nn.BatchNorm2d(num_channels))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
in_channels = num_channels
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VPDDepthHead(BaseDecodeHead):
|
||||
"""Depth Prediction Head for VPD.
|
||||
|
||||
.. _`VPD`: https://arxiv.org/abs/2303.02153
|
||||
|
||||
Args:
|
||||
max_depth (float): Maximum depth value. Defaults to 10.0.
|
||||
in_channels (Sequence[int]): Number of input channels for each
|
||||
convolutional layer.
|
||||
embed_dim (int): Dimension of embedding. Defaults to 192.
|
||||
feature_dim (int): Dimension of aggregated feature. Defaults to 1536.
|
||||
num_deconv_layers (int): Number of deconvolution layers in the
|
||||
decoder. Defaults to 3.
|
||||
num_deconv_filters (Sequence[int]): Number of filters for each deconv
|
||||
layer. Defaults to (32, 32, 32).
|
||||
fmap_border (Union[int, Sequence[int]]): Feature map border for
|
||||
cropping. Defaults to 0.
|
||||
align_corners (bool): Flag for align_corners in interpolation.
|
||||
Defaults to False.
|
||||
loss_decode (dict): Configurations for the loss function. Defaults to
|
||||
dict(type='SiLogLoss').
|
||||
init_cfg (dict): Initialization configurations. Defaults to
|
||||
dict(type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']).
|
||||
"""
|
||||
|
||||
num_classes = 1
|
||||
out_channels = 1
|
||||
input_transform = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_depth: float = 10.0,
|
||||
in_channels: Sequence[int] = [320, 640, 1280, 1280],
|
||||
embed_dim: int = 192,
|
||||
feature_dim: int = 1536,
|
||||
num_deconv_layers: int = 3,
|
||||
num_deconv_filters: Sequence[int] = (32, 32, 32),
|
||||
fmap_border: Union[int, Sequence[int]] = 0,
|
||||
align_corners: bool = False,
|
||||
loss_decode: dict = dict(type='SiLogLoss'),
|
||||
init_cfg=dict(
|
||||
type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']),
|
||||
):
|
||||
|
||||
super(BaseDecodeHead, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
# initialize parameters
|
||||
self.in_channels = in_channels
|
||||
self.max_depth = max_depth
|
||||
self.align_corners = align_corners
|
||||
|
||||
# feature map border
|
||||
if isinstance(fmap_border, int):
|
||||
fmap_border = (fmap_border, fmap_border)
|
||||
self.fmap_border = fmap_border
|
||||
|
||||
# define network layers
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1),
|
||||
nn.GroupNorm(16, in_channels[0]),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1),
|
||||
)
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_channels[1], in_channels[1], 3, stride=2, padding=1)
|
||||
|
||||
self.conv_aggregation = nn.Sequential(
|
||||
nn.Conv2d(sum(in_channels), feature_dim, 1),
|
||||
nn.GroupNorm(16, feature_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.decoder = VPDDepthDecoder(
|
||||
in_channels=embed_dim * 8,
|
||||
out_channels=embed_dim,
|
||||
num_deconv_layers=num_deconv_layers,
|
||||
num_deconv_filters=num_deconv_filters)
|
||||
|
||||
self.depth_pred_layer = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
embed_dim, embed_dim, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(embed_dim, 1, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# build loss
|
||||
if isinstance(loss_decode, dict):
|
||||
self.loss_decode = MODELS.build(loss_decode)
|
||||
elif isinstance(loss_decode, (list, tuple)):
|
||||
self.loss_decode = nn.ModuleList()
|
||||
for loss in loss_decode:
|
||||
self.loss_decode.append(MODELS.build(loss))
|
||||
else:
|
||||
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
|
||||
but got {type(loss_decode)}')
|
||||
|
||||
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
|
||||
gt_depth_maps = [
|
||||
data_sample.gt_depth_map.data for data_sample in batch_data_samples
|
||||
]
|
||||
return torch.stack(gt_depth_maps, dim=0)
|
||||
|
||||
def forward(self, x):
|
||||
x = [
|
||||
x[0], x[1],
|
||||
torch.cat([x[2], F.interpolate(x[3], scale_factor=2)], dim=1)
|
||||
]
|
||||
x = torch.cat([self.conv1(x[0]), self.conv2(x[1]), x[2]], dim=1)
|
||||
x = self.conv_aggregation(x)
|
||||
|
||||
x = x[:, :, :x.size(2) - self.fmap_border[0], :x.size(3) -
|
||||
self.fmap_border[1]].contiguous()
|
||||
x = self.decoder(x)
|
||||
out = self.depth_pred_layer(x)
|
||||
|
||||
depth = torch.sigmoid(out) * self.max_depth
|
||||
|
||||
return depth
|
||||
|
||||
def loss_by_feat(self, pred_depth_map: Tensor,
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
"""Compute depth estimation loss.
|
||||
|
||||
Args:
|
||||
pred_depth_map (Tensor): The output from decode head forward
|
||||
function.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_dpeth_map`.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
|
||||
gt_depth_map = self._stack_batch_gt(batch_data_samples)
|
||||
loss = dict()
|
||||
pred_depth_map = resize(
|
||||
input=pred_depth_map,
|
||||
size=gt_depth_map.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.loss_decode]
|
||||
else:
|
||||
losses_decode = self.loss_decode
|
||||
for loss_decode in losses_decode:
|
||||
if loss_decode.loss_name not in loss:
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
pred_depth_map, gt_depth_map)
|
||||
else:
|
||||
loss[loss_decode.loss_name] += loss_decode(
|
||||
pred_depth_map, gt_depth_map)
|
||||
|
||||
return loss
|
||||
Reference in New Issue
Block a user