This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

View File

@@ -0,0 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .featurepyramid import Feature2Pyramid
from .fpn import FPN
from .ic_neck import ICNeck
from .jpu import JPU
from .mla_neck import MLANeck
from .multilevel_neck import MultiLevelNeck
from .fusion_transformer import FusionTransformer
from .fusion_multilevel_neck import FusionMultiLevelNeck
__all__ = [
'FPN', 'MultiLevelNeck', 'MLANeck', 'ICNeck', 'JPU', 'Feature2Pyramid',
'FusionTransformer', 'FusionMultiLevelNeck'
]

View File

@@ -0,0 +1,67 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmseg.registry import MODELS
@MODELS.register_module()
class Feature2Pyramid(nn.Module):
"""Feature2Pyramid.
A neck structure connect ViT backbone and decoder_heads.
Args:
embed_dims (int): Embedding dimension.
rescales (list[float]): Different sampling multiples were
used to obtain pyramid features. Default: [4, 2, 1, 0.5].
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='SyncBN', requires_grad=True).
"""
def __init__(self,
embed_dim,
rescales=[4, 2, 1, 0.5],
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super().__init__()
self.rescales = rescales
self.upsample_4x = None
for k in self.rescales:
if k == 4:
self.upsample_4x = nn.Sequential(
nn.ConvTranspose2d(
embed_dim, embed_dim, kernel_size=2, stride=2),
build_norm_layer(norm_cfg, embed_dim)[1],
nn.GELU(),
nn.ConvTranspose2d(
embed_dim, embed_dim, kernel_size=2, stride=2),
)
elif k == 2:
self.upsample_2x = nn.Sequential(
nn.ConvTranspose2d(
embed_dim, embed_dim, kernel_size=2, stride=2))
elif k == 1:
self.identity = nn.Identity()
elif k == 0.5:
self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2)
elif k == 0.25:
self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4)
else:
raise KeyError(f'invalid {k} for feature2pyramid')
def forward(self, inputs):
assert len(inputs) == len(self.rescales)
outputs = []
if self.upsample_4x is not None:
ops = [
self.upsample_4x, self.upsample_2x, self.identity,
self.downsample_2x
]
else:
ops = [
self.upsample_2x, self.identity, self.downsample_2x,
self.downsample_4x
]
for i in range(len(inputs)):
outputs.append(ops[i](inputs[i]))
return tuple(outputs)

View File

@@ -0,0 +1,212 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmseg.registry import MODELS
from ..utils import resize
@MODELS.register_module()
class FPN(BaseModule):
"""Feature Pyramid Network.
This neck is the implementation of `Feature Pyramid Networks for Object
Detection <https://arxiv.org/abs/1612.03144>`_.
Args:
in_channels (list[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale).
num_outs (int): Number of output scales.
start_level (int): Index of the start input backbone level used to
build the feature pyramid. Default: 0.
end_level (int): Index of the end input backbone level (exclusive) to
build the feature pyramid. Default: -1, which means the last level.
add_extra_convs (bool | str): If bool, it decides whether to add conv
layers on top of the original feature maps. Default to False.
If True, its actual mode is specified by `extra_convs_on_inputs`.
If str, it specifies the source feature map of the extra convs.
Only the following options are allowed
- 'on_input': Last feat map of neck inputs (i.e. backbone feature).
- 'on_lateral': Last feature map after lateral convs.
- 'on_output': The last output feature map after fpn convs.
extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
on the original feature from the backbone. If True,
it is equivalent to `add_extra_convs='on_input'`. If False, it is
equivalent to set `add_extra_convs='on_output'`. Default to True.
relu_before_extra_convs (bool): Whether to apply relu before the extra
conv. Default: False.
no_norm_on_lateral (bool): Whether to apply norm on lateral.
Default: False.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): Config dict for activation layer in ConvModule.
Default: None.
upsample_cfg (dict): Config dict for interpolate layer.
Default: dict(mode='nearest').
init_cfg (dict or list[dict], optional): Initialization config dict.
Example:
>>> import torch
>>> in_channels = [2, 3, 5, 7]
>>> scales = [340, 170, 84, 43]
>>> inputs = [torch.rand(1, c, s, s)
... for c, s in zip(in_channels, scales)]
>>> self = FPN(in_channels, 11, len(in_channels)).eval()
>>> outputs = self.forward(inputs)
>>> for i in range(len(outputs)):
... print(f'outputs[{i}].shape = {outputs[i].shape}')
outputs[0].shape = torch.Size([1, 11, 340, 340])
outputs[1].shape = torch.Size([1, 11, 170, 170])
outputs[2].shape = torch.Size([1, 11, 84, 84])
outputs[3].shape = torch.Size([1, 11, 43, 43])
"""
def __init__(self,
in_channels,
out_channels,
num_outs,
start_level=0,
end_level=-1,
add_extra_convs=False,
extra_convs_on_inputs=False,
relu_before_extra_convs=False,
no_norm_on_lateral=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=None,
upsample_cfg=dict(mode='nearest'),
init_cfg=dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
super().__init__(init_cfg)
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_ins = len(in_channels)
self.num_outs = num_outs
self.relu_before_extra_convs = relu_before_extra_convs
self.no_norm_on_lateral = no_norm_on_lateral
self.fp16_enabled = False
self.upsample_cfg = upsample_cfg.copy()
if end_level == -1:
self.backbone_end_level = self.num_ins
assert num_outs >= self.num_ins - start_level
else:
# if end_level < inputs, no extra level is allowed
self.backbone_end_level = end_level
assert end_level <= len(in_channels)
assert num_outs == end_level - start_level
self.start_level = start_level
self.end_level = end_level
self.add_extra_convs = add_extra_convs
assert isinstance(add_extra_convs, (str, bool))
if isinstance(add_extra_convs, str):
# Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
elif add_extra_convs: # True
if extra_convs_on_inputs:
# For compatibility with previous release
# TODO: deprecate `extra_convs_on_inputs`
self.add_extra_convs = 'on_input'
else:
self.add_extra_convs = 'on_output'
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
for i in range(self.start_level, self.backbone_end_level):
l_conv = ConvModule(
in_channels[i],
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
act_cfg=act_cfg,
inplace=False)
fpn_conv = ConvModule(
out_channels,
out_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
# add extra conv layers (e.g., RetinaNet)
extra_levels = num_outs - self.backbone_end_level + self.start_level
if self.add_extra_convs and extra_levels >= 1:
for i in range(extra_levels):
if i == 0 and self.add_extra_convs == 'on_input':
in_channels = self.in_channels[self.backbone_end_level - 1]
else:
in_channels = out_channels
extra_fpn_conv = ConvModule(
in_channels,
out_channels,
3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False)
self.fpn_convs.append(extra_fpn_conv)
def forward(self, inputs):
assert len(inputs) == len(self.in_channels)
# build laterals
laterals = [
lateral_conv(inputs[i + self.start_level])
for i, lateral_conv in enumerate(self.lateral_convs)
]
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
# it cannot co-exist with `size` in `F.interpolate`.
if 'scale_factor' in self.upsample_cfg:
laterals[i - 1] = laterals[i - 1] + resize(
laterals[i], **self.upsample_cfg)
else:
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] = laterals[i - 1] + resize(
laterals[i], size=prev_shape, **self.upsample_cfg)
# build outputs
# part 1: from original levels
outs = [
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
]
# part 2: add extra levels
if self.num_outs > len(outs):
# use max pool to get more levels on top of outputs
# (e.g., Faster R-CNN, Mask R-CNN)
if not self.add_extra_convs:
for i in range(self.num_outs - used_backbone_levels):
outs.append(F.max_pool2d(outs[-1], 1, stride=2))
# add conv layers on top of original feature maps (RetinaNet)
else:
if self.add_extra_convs == 'on_input':
extra_source = inputs[self.backbone_end_level - 1]
elif self.add_extra_convs == 'on_lateral':
extra_source = laterals[-1]
elif self.add_extra_convs == 'on_output':
extra_source = outs[-1]
else:
raise NotImplementedError
outs.append(self.fpn_convs[used_backbone_levels](extra_source))
for i in range(used_backbone_levels + 1, self.num_outs):
if self.relu_before_extra_convs:
outs.append(self.fpn_convs[i](F.relu(outs[-1])))
else:
outs.append(self.fpn_convs[i](outs[-1]))
return tuple(outs)

View File

@@ -0,0 +1,90 @@
import torch
import torch.nn as nn
from .multilevel_neck import MultiLevelNeck
from .fusion_transformer import FusionTransformer
from mmseg.registry import MODELS
@MODELS.register_module()
class FusionMultiLevelNeck(nn.Module):
def __init__(self,
ts_size=10,
in_channels_ml=[768, 768, 768, 768],
out_channels_ml=768,
scales_ml=[0.5, 1, 2, 4],
norm_cfg_ml=None,
act_cfg_ml=None,
input_dims=768,
embed_dims=768,
num_layers=4,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
with_cls_token=True,
output_cls_token=True,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
num_fcs=2,
norm_eval=False,
with_cp=False,
init_cfg=None,
*args,
**kwargs):
super(FusionMultiLevelNeck, self).__init__()
self.in_channels = in_channels_ml
self.ts_size = ts_size
self.multilevel_neck = MultiLevelNeck(
in_channels_ml,
out_channels_ml,
scales_ml,
norm_cfg_ml,
act_cfg_ml
)
# self.up_head = UPHead(1024, 2816, 4)
self.fusion_transformer = FusionTransformer(
input_dims,
embed_dims,
num_layers,
num_heads,
mlp_ratio,
qkv_bias,
drop_rate,
attn_drop_rate,
drop_path_rate,
with_cls_token,
output_cls_token,
norm_cfg,
act_cfg,
num_fcs,
norm_eval,
with_cp,
init_cfg,
)
def init_weights(self):
self.fusion_transformer.init_weights()
def forward(self, inputs, require_feat: bool = False, require_two: bool = False):
assert len(inputs) == len(self.in_channels)
inputs = self.multilevel_neck(inputs)
ts = self.ts_size
b_total, c, h, w = inputs[-1].shape
b = int(b_total / ts)
outs = []
for idx in range(len(inputs)):
input_feat = inputs[idx]
b_total, c, h, w = inputs[idx].shape
input_feat = input_feat.reshape(b, ts, c, h, w).permute(0, 3, 4, 1, 2).reshape(b*h*w, ts, c) # b*ts, c, h, w转换为b*h*w, ts, c
feat_fusion = self.fusion_transformer(input_feat, require_feat, require_two)
c_fusion = feat_fusion.shape[-1]
feat_fusion = feat_fusion.reshape(b, h, w, c_fusion).permute(0, 3, 1, 2) # b*h*w, c -> b, c, h, w
outs.append(feat_fusion)
return tuple(outs)

View File

@@ -0,0 +1,166 @@
# Copyright (c) Ant Group. All rights reserved.
from collections import OrderedDict
import torch
import torch.nn as nn
from mmengine.model.weight_init import (constant_init, kaiming_init,
trunc_normal_)
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
from torch.nn.modules.batchnorm import _BatchNorm
from mmseg.models.backbones.vit import TransformerEncoderLayer
# from mmseg.utils import get_root_logger
from mmseg.registry import MODELS
# @MODELS.register_module()
class FusionTransformer(nn.Module):
def __init__(self,
input_dims=768,
embed_dims=768,
num_layers=4,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
with_cls_token=True,
output_cls_token=True,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
num_fcs=2,
norm_eval=False,
with_cp=False,
init_cfg=None,
*args,
**kwargs):
super(FusionTransformer, self).__init__()
self.porj_linear = nn.Linear(input_dims, embed_dims)
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
self.init_cfg = init_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self.drop_after_pos = nn.Dropout(p=drop_rate)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
] # stochastic depth decay rule
self.layers = nn.ModuleList()
for i in range(num_layers):
self.layers.append(
TransformerEncoderLayer(embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=mlp_ratio *
embed_dims,
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
num_fcs=num_fcs,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
batch_first=True))
def init_weights(self):
if isinstance(self.init_cfg, dict) and \
self.init_cfg.get('type') in ['Pretrained', 'Pretrained_Part']:
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
if self.init_cfg.get('type') == 'Pretrained':
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
elif self.init_cfg.get('type') == 'Pretrained_Part':
state_dict = checkpoint.copy()
para_prefix = 'image_encoder'
prefix_len = len(para_prefix) + 1
for k, v in checkpoint.items():
state_dict.pop(k)
if para_prefix in k:
state_dict[k[prefix_len:]] = v
# if 'pos_embed' in state_dict.keys():
# if self.pos_embed.shape != state_dict['pos_embed'].shape:
# print_log(msg=f'Resize the pos_embed shape from '
# f'{state_dict["pos_embed"].shape} to '
# f'{self.pos_embed.shape}')
# h, w = self.img_size
# pos_size = int(
# math.sqrt(state_dict['pos_embed'].shape[1] - 1))
# state_dict['pos_embed'] = self.resize_pos_embed(
# state_dict['pos_embed'],
# (h // self.patch_size, w // self.patch_size),
# (pos_size, pos_size), self.interpolate_mode)
load_state_dict(self, state_dict, strict=False, logger=None)
elif self.init_cfg is not None:
super().init_weights()
else:
# We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
# trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
if 'ffn' in n:
nn.init.normal_(m.bias, mean=0., std=1e-6)
else:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
kaiming_init(m, mode='fan_in', bias=0.)
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m, val=1.0, bias=0.)
def forward(self, inputs, require_feat: bool = False, require_two: bool = False):
inputs = self.porj_linear(inputs)
B, N, C = inputs.shape
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, inputs), dim=1)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
# add hidden and atten state
block_outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if require_feat:
block_outs.append(x)
if self.output_cls_token:
if require_two:
x = x[:, :2]
else:
x = x[:, 0]
elif not self.output_cls_token and self.with_cls_token:
x = x[:, 1:]
if require_feat:
return x, block_outs
else:
return x
def train(self, mode=True):
super(FusionTransformer, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, nn.LayerNorm):
m.eval()
if __name__ == '__main__':
fusion_transformer = FusionTransformer()
print(fusion_transformer)

View File

@@ -0,0 +1,148 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmseg.registry import MODELS
from ..utils import resize
class CascadeFeatureFusion(BaseModule):
"""Cascade Feature Fusion Unit in ICNet.
Args:
low_channels (int): The number of input channels for
low resolution feature map.
high_channels (int): The number of input channels for
high resolution feature map.
out_channels (int): The number of output channels.
conv_cfg (dict): Dictionary to construct and config conv layer.
Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN').
act_cfg (dict): Dictionary to construct and config act layer.
Default: dict(type='ReLU').
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
x (Tensor): The output tensor of shape (N, out_channels, H, W).
x_low (Tensor): The output tensor of shape (N, out_channels, H, W)
for Cascade Label Guidance in auxiliary heads.
"""
def __init__(self,
low_channels,
high_channels,
out_channels,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.align_corners = align_corners
self.conv_low = ConvModule(
low_channels,
out_channels,
3,
padding=2,
dilation=2,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv_high = ConvModule(
high_channels,
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x_low, x_high):
x_low = resize(
x_low,
size=x_high.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
# Note: Different from original paper, `x_low` is underwent
# `self.conv_low` rather than another 1x1 conv classifier
# before being used for auxiliary head.
x_low = self.conv_low(x_low)
x_high = self.conv_high(x_high)
x = x_low + x_high
x = F.relu(x, inplace=True)
return x, x_low
@MODELS.register_module()
class ICNeck(BaseModule):
"""ICNet for Real-Time Semantic Segmentation on High-Resolution Images.
This head is the implementation of `ICHead
<https://arxiv.org/abs/1704.08545>`_.
Args:
in_channels (int): The number of input image channels. Default: 3.
out_channels (int): The numbers of output feature channels.
Default: 128.
conv_cfg (dict): Dictionary to construct and config conv layer.
Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN').
act_cfg (dict): Dictionary to construct and config act layer.
Default: dict(type='ReLU').
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels=(64, 256, 256),
out_channels=128,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
assert len(in_channels) == 3, 'Length of input channels \
must be 3!'
self.in_channels = in_channels
self.out_channels = out_channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.align_corners = align_corners
self.cff_24 = CascadeFeatureFusion(
self.in_channels[2],
self.in_channels[1],
self.out_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.cff_12 = CascadeFeatureFusion(
self.out_channels,
self.in_channels[0],
self.out_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
def forward(self, inputs):
assert len(inputs) == 3, 'Length of input feature \
maps must be 3!'
x_sub1, x_sub2, x_sub4 = inputs
x_cff_24, x_24 = self.cff_24(x_sub4, x_sub2)
x_cff_12, x_12 = self.cff_12(x_cff_24, x_sub1)
# Note: `x_cff_12` is used for decode_head,
# `x_24` and `x_12` are used for auxiliary head.
return x_24, x_12, x_cff_12

View File

@@ -0,0 +1,131 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmengine.model import BaseModule
from mmseg.registry import MODELS
from ..utils import resize
@MODELS.register_module()
class JPU(BaseModule):
"""FastFCN: Rethinking Dilated Convolution in the Backbone
for Semantic Segmentation.
This Joint Pyramid Upsampling (JPU) neck is the implementation of
`FastFCN <https://arxiv.org/abs/1903.11816>`_.
Args:
in_channels (Tuple[int], optional): The number of input channels
for each convolution operations before upsampling.
Default: (512, 1024, 2048).
mid_channels (int): The number of output channels of JPU.
Default: 512.
start_level (int): Index of the start input backbone level used to
build the feature pyramid. Default: 0.
end_level (int): Index of the end input backbone level (exclusive) to
build the feature pyramid. Default: -1, which means the last level.
dilations (tuple[int]): Dilation rate of each Depthwise
Separable ConvModule. Default: (1, 2, 4, 8).
align_corners (bool, optional): The align_corners argument of
resize operation. Default: False.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels=(512, 1024, 2048),
mid_channels=512,
start_level=0,
end_level=-1,
dilations=(1, 2, 4, 8),
align_corners=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
assert isinstance(in_channels, tuple)
assert isinstance(dilations, tuple)
self.in_channels = in_channels
self.mid_channels = mid_channels
self.start_level = start_level
self.num_ins = len(in_channels)
if end_level == -1:
self.backbone_end_level = self.num_ins
else:
self.backbone_end_level = end_level
assert end_level <= len(in_channels)
self.dilations = dilations
self.align_corners = align_corners
self.conv_layers = nn.ModuleList()
self.dilation_layers = nn.ModuleList()
for i in range(self.start_level, self.backbone_end_level):
conv_layer = nn.Sequential(
ConvModule(
self.in_channels[i],
self.mid_channels,
kernel_size=3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.conv_layers.append(conv_layer)
for i in range(len(dilations)):
dilation_layer = nn.Sequential(
DepthwiseSeparableConvModule(
in_channels=(self.backbone_end_level - self.start_level) *
self.mid_channels,
out_channels=self.mid_channels,
kernel_size=3,
stride=1,
padding=dilations[i],
dilation=dilations[i],
dw_norm_cfg=norm_cfg,
dw_act_cfg=None,
pw_norm_cfg=norm_cfg,
pw_act_cfg=act_cfg))
self.dilation_layers.append(dilation_layer)
def forward(self, inputs):
"""Forward function."""
assert len(inputs) == len(self.in_channels), 'Length of inputs must \
be the same with self.in_channels!'
feats = [
self.conv_layers[i - self.start_level](inputs[i])
for i in range(self.start_level, self.backbone_end_level)
]
h, w = feats[0].shape[2:]
for i in range(1, len(feats)):
feats[i] = resize(
feats[i],
size=(h, w),
mode='bilinear',
align_corners=self.align_corners)
feat = torch.cat(feats, dim=1)
concat_feat = torch.cat([
self.dilation_layers[i](feat) for i in range(len(self.dilations))
],
dim=1)
outs = []
# Default: outs[2] is the output of JPU for decoder head, outs[1] is
# the feature map from backbone for auxiliary head. Additionally,
# outs[0] can also be used for auxiliary head.
for i in range(self.start_level, self.backbone_end_level - 1):
outs.append(inputs[i])
outs.append(concat_feat)
return tuple(outs)

View File

@@ -0,0 +1,118 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer
from mmseg.registry import MODELS
class MLAModule(nn.Module):
def __init__(self,
in_channels=[1024, 1024, 1024, 1024],
out_channels=256,
norm_cfg=None,
act_cfg=None):
super().__init__()
self.channel_proj = nn.ModuleList()
for i in range(len(in_channels)):
self.channel_proj.append(
ConvModule(
in_channels=in_channels[i],
out_channels=out_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.feat_extract = nn.ModuleList()
for i in range(len(in_channels)):
self.feat_extract.append(
ConvModule(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def forward(self, inputs):
# feat_list -> [p2, p3, p4, p5]
feat_list = []
for x, conv in zip(inputs, self.channel_proj):
feat_list.append(conv(x))
# feat_list -> [p5, p4, p3, p2]
# mid_list -> [m5, m4, m3, m2]
feat_list = feat_list[::-1]
mid_list = []
for feat in feat_list:
if len(mid_list) == 0:
mid_list.append(feat)
else:
mid_list.append(mid_list[-1] + feat)
# mid_list -> [m5, m4, m3, m2]
# out_list -> [o2, o3, o4, o5]
out_list = []
for mid, conv in zip(mid_list, self.feat_extract):
out_list.append(conv(mid))
return tuple(out_list)
@MODELS.register_module()
class MLANeck(nn.Module):
"""Multi-level Feature Aggregation.
This neck is `The Multi-level Feature Aggregation construction of
SETR <https://arxiv.org/abs/2012.15840>`_.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale).
norm_layer (dict): Config dict for input normalization.
Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True).
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): Config dict for activation layer in ConvModule.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
norm_layer=dict(type='LN', eps=1e-6, requires_grad=True),
norm_cfg=None,
act_cfg=None):
super().__init__()
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
# In order to build general vision transformer backbone, we have to
# move MLA to neck.
self.norm = nn.ModuleList([
build_norm_layer(norm_layer, in_channels[i])[1]
for i in range(len(in_channels))
])
self.mla = MLAModule(
in_channels=in_channels,
out_channels=out_channels,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, inputs):
assert len(inputs) == len(self.in_channels)
# Convert from nchw to nlc
outs = []
for i in range(len(inputs)):
x = inputs[i]
n, c, h, w = x.shape
x = x.reshape(n, c, h * w).transpose(2, 1).contiguous()
x = self.norm[i](x)
x = x.transpose(1, 2).reshape(n, c, h, w).contiguous()
outs.append(x)
outs = self.mla(outs)
return tuple(outs)

View File

@@ -0,0 +1,79 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model.weight_init import xavier_init
from mmseg.registry import MODELS
from ..utils import resize
@MODELS.register_module()
class MultiLevelNeck(nn.Module):
"""MultiLevelNeck.
A neck structure connect vit backbone and decoder_heads.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale).
scales (List[float]): Scale factors for each input feature map.
Default: [0.5, 1, 2, 4]
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): Config dict for activation layer in ConvModule.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
scales=[0.5, 1, 2, 4],
norm_cfg=None,
act_cfg=None):
super().__init__()
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
self.scales = scales
self.num_outs = len(scales)
self.lateral_convs = nn.ModuleList()
self.convs = nn.ModuleList()
for in_channel in in_channels:
self.lateral_convs.append(
ConvModule(
in_channel,
out_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
for _ in range(self.num_outs):
self.convs.append(
ConvModule(
out_channels,
out_channels,
kernel_size=3,
padding=1,
stride=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
# default init_weights for conv(msra) and norm in ConvModule
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')
def forward(self, inputs):
assert len(inputs) == len(self.in_channels)
inputs = [
lateral_conv(inputs[i])
for i, lateral_conv in enumerate(self.lateral_convs)
]
# for len(inputs) not equal to self.num_outs
if len(inputs) == 1:
inputs = [inputs[0] for _ in range(self.num_outs)]
outs = []
for i in range(self.num_outs):
x_resize = resize(
inputs[i], scale_factor=self.scales[i], mode='bilinear')
outs.append(self.convs[i](x_resize))
return tuple(outs)