init
This commit is contained in:
14
finetune/mmseg/models/necks/__init__.py
Normal file
14
finetune/mmseg/models/necks/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .featurepyramid import Feature2Pyramid
|
||||
from .fpn import FPN
|
||||
from .ic_neck import ICNeck
|
||||
from .jpu import JPU
|
||||
from .mla_neck import MLANeck
|
||||
from .multilevel_neck import MultiLevelNeck
|
||||
from .fusion_transformer import FusionTransformer
|
||||
from .fusion_multilevel_neck import FusionMultiLevelNeck
|
||||
|
||||
__all__ = [
|
||||
'FPN', 'MultiLevelNeck', 'MLANeck', 'ICNeck', 'JPU', 'Feature2Pyramid',
|
||||
'FusionTransformer', 'FusionMultiLevelNeck'
|
||||
]
|
||||
67
finetune/mmseg/models/necks/featurepyramid.py
Normal file
67
finetune/mmseg/models/necks/featurepyramid.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class Feature2Pyramid(nn.Module):
|
||||
"""Feature2Pyramid.
|
||||
|
||||
A neck structure connect ViT backbone and decoder_heads.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Embedding dimension.
|
||||
rescales (list[float]): Different sampling multiples were
|
||||
used to obtain pyramid features. Default: [4, 2, 1, 0.5].
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='SyncBN', requires_grad=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dim,
|
||||
rescales=[4, 2, 1, 0.5],
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||||
super().__init__()
|
||||
self.rescales = rescales
|
||||
self.upsample_4x = None
|
||||
for k in self.rescales:
|
||||
if k == 4:
|
||||
self.upsample_4x = nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
embed_dim, embed_dim, kernel_size=2, stride=2),
|
||||
build_norm_layer(norm_cfg, embed_dim)[1],
|
||||
nn.GELU(),
|
||||
nn.ConvTranspose2d(
|
||||
embed_dim, embed_dim, kernel_size=2, stride=2),
|
||||
)
|
||||
elif k == 2:
|
||||
self.upsample_2x = nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
embed_dim, embed_dim, kernel_size=2, stride=2))
|
||||
elif k == 1:
|
||||
self.identity = nn.Identity()
|
||||
elif k == 0.5:
|
||||
self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
elif k == 0.25:
|
||||
self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4)
|
||||
else:
|
||||
raise KeyError(f'invalid {k} for feature2pyramid')
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == len(self.rescales)
|
||||
outputs = []
|
||||
if self.upsample_4x is not None:
|
||||
ops = [
|
||||
self.upsample_4x, self.upsample_2x, self.identity,
|
||||
self.downsample_2x
|
||||
]
|
||||
else:
|
||||
ops = [
|
||||
self.upsample_2x, self.identity, self.downsample_2x,
|
||||
self.downsample_4x
|
||||
]
|
||||
for i in range(len(inputs)):
|
||||
outputs.append(ops[i](inputs[i]))
|
||||
return tuple(outputs)
|
||||
212
finetune/mmseg/models/necks/fpn.py
Normal file
212
finetune/mmseg/models/necks/fpn.py
Normal file
@@ -0,0 +1,212 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FPN(BaseModule):
|
||||
"""Feature Pyramid Network.
|
||||
|
||||
This neck is the implementation of `Feature Pyramid Networks for Object
|
||||
Detection <https://arxiv.org/abs/1612.03144>`_.
|
||||
|
||||
Args:
|
||||
in_channels (list[int]): Number of input channels per scale.
|
||||
out_channels (int): Number of output channels (used at each scale).
|
||||
num_outs (int): Number of output scales.
|
||||
start_level (int): Index of the start input backbone level used to
|
||||
build the feature pyramid. Default: 0.
|
||||
end_level (int): Index of the end input backbone level (exclusive) to
|
||||
build the feature pyramid. Default: -1, which means the last level.
|
||||
add_extra_convs (bool | str): If bool, it decides whether to add conv
|
||||
layers on top of the original feature maps. Default to False.
|
||||
If True, its actual mode is specified by `extra_convs_on_inputs`.
|
||||
If str, it specifies the source feature map of the extra convs.
|
||||
Only the following options are allowed
|
||||
|
||||
- 'on_input': Last feat map of neck inputs (i.e. backbone feature).
|
||||
- 'on_lateral': Last feature map after lateral convs.
|
||||
- 'on_output': The last output feature map after fpn convs.
|
||||
extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
|
||||
on the original feature from the backbone. If True,
|
||||
it is equivalent to `add_extra_convs='on_input'`. If False, it is
|
||||
equivalent to set `add_extra_convs='on_output'`. Default to True.
|
||||
relu_before_extra_convs (bool): Whether to apply relu before the extra
|
||||
conv. Default: False.
|
||||
no_norm_on_lateral (bool): Whether to apply norm on lateral.
|
||||
Default: False.
|
||||
conv_cfg (dict): Config dict for convolution layer. Default: None.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
act_cfg (dict): Config dict for activation layer in ConvModule.
|
||||
Default: None.
|
||||
upsample_cfg (dict): Config dict for interpolate layer.
|
||||
Default: dict(mode='nearest').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> in_channels = [2, 3, 5, 7]
|
||||
>>> scales = [340, 170, 84, 43]
|
||||
>>> inputs = [torch.rand(1, c, s, s)
|
||||
... for c, s in zip(in_channels, scales)]
|
||||
>>> self = FPN(in_channels, 11, len(in_channels)).eval()
|
||||
>>> outputs = self.forward(inputs)
|
||||
>>> for i in range(len(outputs)):
|
||||
... print(f'outputs[{i}].shape = {outputs[i].shape}')
|
||||
outputs[0].shape = torch.Size([1, 11, 340, 340])
|
||||
outputs[1].shape = torch.Size([1, 11, 170, 170])
|
||||
outputs[2].shape = torch.Size([1, 11, 84, 84])
|
||||
outputs[3].shape = torch.Size([1, 11, 43, 43])
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_outs,
|
||||
start_level=0,
|
||||
end_level=-1,
|
||||
add_extra_convs=False,
|
||||
extra_convs_on_inputs=False,
|
||||
relu_before_extra_convs=False,
|
||||
no_norm_on_lateral=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=None,
|
||||
upsample_cfg=dict(mode='nearest'),
|
||||
init_cfg=dict(
|
||||
type='Xavier', layer='Conv2d', distribution='uniform')):
|
||||
super().__init__(init_cfg)
|
||||
assert isinstance(in_channels, list)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_ins = len(in_channels)
|
||||
self.num_outs = num_outs
|
||||
self.relu_before_extra_convs = relu_before_extra_convs
|
||||
self.no_norm_on_lateral = no_norm_on_lateral
|
||||
self.fp16_enabled = False
|
||||
self.upsample_cfg = upsample_cfg.copy()
|
||||
|
||||
if end_level == -1:
|
||||
self.backbone_end_level = self.num_ins
|
||||
assert num_outs >= self.num_ins - start_level
|
||||
else:
|
||||
# if end_level < inputs, no extra level is allowed
|
||||
self.backbone_end_level = end_level
|
||||
assert end_level <= len(in_channels)
|
||||
assert num_outs == end_level - start_level
|
||||
self.start_level = start_level
|
||||
self.end_level = end_level
|
||||
self.add_extra_convs = add_extra_convs
|
||||
assert isinstance(add_extra_convs, (str, bool))
|
||||
if isinstance(add_extra_convs, str):
|
||||
# Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
|
||||
assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
|
||||
elif add_extra_convs: # True
|
||||
if extra_convs_on_inputs:
|
||||
# For compatibility with previous release
|
||||
# TODO: deprecate `extra_convs_on_inputs`
|
||||
self.add_extra_convs = 'on_input'
|
||||
else:
|
||||
self.add_extra_convs = 'on_output'
|
||||
|
||||
self.lateral_convs = nn.ModuleList()
|
||||
self.fpn_convs = nn.ModuleList()
|
||||
|
||||
for i in range(self.start_level, self.backbone_end_level):
|
||||
l_conv = ConvModule(
|
||||
in_channels[i],
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
|
||||
act_cfg=act_cfg,
|
||||
inplace=False)
|
||||
fpn_conv = ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
inplace=False)
|
||||
|
||||
self.lateral_convs.append(l_conv)
|
||||
self.fpn_convs.append(fpn_conv)
|
||||
|
||||
# add extra conv layers (e.g., RetinaNet)
|
||||
extra_levels = num_outs - self.backbone_end_level + self.start_level
|
||||
if self.add_extra_convs and extra_levels >= 1:
|
||||
for i in range(extra_levels):
|
||||
if i == 0 and self.add_extra_convs == 'on_input':
|
||||
in_channels = self.in_channels[self.backbone_end_level - 1]
|
||||
else:
|
||||
in_channels = out_channels
|
||||
extra_fpn_conv = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
inplace=False)
|
||||
self.fpn_convs.append(extra_fpn_conv)
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
|
||||
# build laterals
|
||||
laterals = [
|
||||
lateral_conv(inputs[i + self.start_level])
|
||||
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||
]
|
||||
|
||||
# build top-down path
|
||||
used_backbone_levels = len(laterals)
|
||||
for i in range(used_backbone_levels - 1, 0, -1):
|
||||
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
|
||||
# it cannot co-exist with `size` in `F.interpolate`.
|
||||
if 'scale_factor' in self.upsample_cfg:
|
||||
laterals[i - 1] = laterals[i - 1] + resize(
|
||||
laterals[i], **self.upsample_cfg)
|
||||
else:
|
||||
prev_shape = laterals[i - 1].shape[2:]
|
||||
laterals[i - 1] = laterals[i - 1] + resize(
|
||||
laterals[i], size=prev_shape, **self.upsample_cfg)
|
||||
|
||||
# build outputs
|
||||
# part 1: from original levels
|
||||
outs = [
|
||||
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
|
||||
]
|
||||
# part 2: add extra levels
|
||||
if self.num_outs > len(outs):
|
||||
# use max pool to get more levels on top of outputs
|
||||
# (e.g., Faster R-CNN, Mask R-CNN)
|
||||
if not self.add_extra_convs:
|
||||
for i in range(self.num_outs - used_backbone_levels):
|
||||
outs.append(F.max_pool2d(outs[-1], 1, stride=2))
|
||||
# add conv layers on top of original feature maps (RetinaNet)
|
||||
else:
|
||||
if self.add_extra_convs == 'on_input':
|
||||
extra_source = inputs[self.backbone_end_level - 1]
|
||||
elif self.add_extra_convs == 'on_lateral':
|
||||
extra_source = laterals[-1]
|
||||
elif self.add_extra_convs == 'on_output':
|
||||
extra_source = outs[-1]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
outs.append(self.fpn_convs[used_backbone_levels](extra_source))
|
||||
for i in range(used_backbone_levels + 1, self.num_outs):
|
||||
if self.relu_before_extra_convs:
|
||||
outs.append(self.fpn_convs[i](F.relu(outs[-1])))
|
||||
else:
|
||||
outs.append(self.fpn_convs[i](outs[-1]))
|
||||
return tuple(outs)
|
||||
90
finetune/mmseg/models/necks/fusion_multilevel_neck.py
Normal file
90
finetune/mmseg/models/necks/fusion_multilevel_neck.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .multilevel_neck import MultiLevelNeck
|
||||
from .fusion_transformer import FusionTransformer
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FusionMultiLevelNeck(nn.Module):
|
||||
def __init__(self,
|
||||
ts_size=10,
|
||||
in_channels_ml=[768, 768, 768, 768],
|
||||
out_channels_ml=768,
|
||||
scales_ml=[0.5, 1, 2, 4],
|
||||
norm_cfg_ml=None,
|
||||
act_cfg_ml=None,
|
||||
input_dims=768,
|
||||
embed_dims=768,
|
||||
num_layers=4,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
with_cls_token=True,
|
||||
output_cls_token=True,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
init_cfg=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super(FusionMultiLevelNeck, self).__init__()
|
||||
self.in_channels = in_channels_ml
|
||||
self.ts_size = ts_size
|
||||
self.multilevel_neck = MultiLevelNeck(
|
||||
in_channels_ml,
|
||||
out_channels_ml,
|
||||
scales_ml,
|
||||
norm_cfg_ml,
|
||||
act_cfg_ml
|
||||
)
|
||||
# self.up_head = UPHead(1024, 2816, 4)
|
||||
|
||||
self.fusion_transformer = FusionTransformer(
|
||||
input_dims,
|
||||
embed_dims,
|
||||
num_layers,
|
||||
num_heads,
|
||||
mlp_ratio,
|
||||
qkv_bias,
|
||||
drop_rate,
|
||||
attn_drop_rate,
|
||||
drop_path_rate,
|
||||
with_cls_token,
|
||||
output_cls_token,
|
||||
norm_cfg,
|
||||
act_cfg,
|
||||
num_fcs,
|
||||
norm_eval,
|
||||
with_cp,
|
||||
init_cfg,
|
||||
)
|
||||
|
||||
def init_weights(self):
|
||||
self.fusion_transformer.init_weights()
|
||||
|
||||
def forward(self, inputs, require_feat: bool = False, require_two: bool = False):
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
|
||||
inputs = self.multilevel_neck(inputs)
|
||||
|
||||
ts = self.ts_size
|
||||
b_total, c, h, w = inputs[-1].shape
|
||||
b = int(b_total / ts)
|
||||
outs = []
|
||||
for idx in range(len(inputs)):
|
||||
|
||||
input_feat = inputs[idx]
|
||||
b_total, c, h, w = inputs[idx].shape
|
||||
input_feat = input_feat.reshape(b, ts, c, h, w).permute(0, 3, 4, 1, 2).reshape(b*h*w, ts, c) # b*ts, c, h, w转换为b*h*w, ts, c
|
||||
feat_fusion = self.fusion_transformer(input_feat, require_feat, require_two)
|
||||
c_fusion = feat_fusion.shape[-1]
|
||||
feat_fusion = feat_fusion.reshape(b, h, w, c_fusion).permute(0, 3, 1, 2) # b*h*w, c -> b, c, h, w
|
||||
outs.append(feat_fusion)
|
||||
|
||||
return tuple(outs)
|
||||
166
finetune/mmseg/models/necks/fusion_transformer.py
Normal file
166
finetune/mmseg/models/necks/fusion_transformer.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright (c) Ant Group. All rights reserved.
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from mmseg.models.backbones.vit import TransformerEncoderLayer
|
||||
|
||||
# from mmseg.utils import get_root_logger
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
# @MODELS.register_module()
|
||||
class FusionTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
input_dims=768,
|
||||
embed_dims=768,
|
||||
num_layers=4,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
with_cls_token=True,
|
||||
output_cls_token=True,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
init_cfg=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super(FusionTransformer, self).__init__()
|
||||
|
||||
self.porj_linear = nn.Linear(input_dims, embed_dims)
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
|
||||
self.init_cfg = init_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for i in range(num_layers):
|
||||
self.layers.append(
|
||||
TransformerEncoderLayer(embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=mlp_ratio *
|
||||
embed_dims,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
batch_first=True))
|
||||
|
||||
def init_weights(self):
|
||||
if isinstance(self.init_cfg, dict) and \
|
||||
self.init_cfg.get('type') in ['Pretrained', 'Pretrained_Part']:
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
|
||||
if self.init_cfg.get('type') == 'Pretrained':
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
elif self.init_cfg.get('type') == 'Pretrained_Part':
|
||||
state_dict = checkpoint.copy()
|
||||
para_prefix = 'image_encoder'
|
||||
prefix_len = len(para_prefix) + 1
|
||||
for k, v in checkpoint.items():
|
||||
state_dict.pop(k)
|
||||
if para_prefix in k:
|
||||
state_dict[k[prefix_len:]] = v
|
||||
|
||||
# if 'pos_embed' in state_dict.keys():
|
||||
# if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
||||
# print_log(msg=f'Resize the pos_embed shape from '
|
||||
# f'{state_dict["pos_embed"].shape} to '
|
||||
# f'{self.pos_embed.shape}')
|
||||
# h, w = self.img_size
|
||||
# pos_size = int(
|
||||
# math.sqrt(state_dict['pos_embed'].shape[1] - 1))
|
||||
# state_dict['pos_embed'] = self.resize_pos_embed(
|
||||
# state_dict['pos_embed'],
|
||||
# (h // self.patch_size, w // self.patch_size),
|
||||
# (pos_size, pos_size), self.interpolate_mode)
|
||||
|
||||
load_state_dict(self, state_dict, strict=False, logger=None)
|
||||
elif self.init_cfg is not None:
|
||||
super().init_weights()
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
# trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def forward(self, inputs, require_feat: bool = False, require_two: bool = False):
|
||||
inputs = self.porj_linear(inputs)
|
||||
B, N, C = inputs.shape
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, inputs), dim=1)
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
# add hidden and atten state
|
||||
block_outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if require_feat:
|
||||
block_outs.append(x)
|
||||
|
||||
if self.output_cls_token:
|
||||
if require_two:
|
||||
x = x[:, :2]
|
||||
else:
|
||||
x = x[:, 0]
|
||||
elif not self.output_cls_token and self.with_cls_token:
|
||||
x = x[:, 1:]
|
||||
|
||||
if require_feat:
|
||||
return x, block_outs
|
||||
else:
|
||||
return x
|
||||
|
||||
def train(self, mode=True):
|
||||
super(FusionTransformer, self).train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.LayerNorm):
|
||||
m.eval()
|
||||
|
||||
if __name__ == '__main__':
|
||||
fusion_transformer = FusionTransformer()
|
||||
print(fusion_transformer)
|
||||
148
finetune/mmseg/models/necks/ic_neck.py
Normal file
148
finetune/mmseg/models/necks/ic_neck.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class CascadeFeatureFusion(BaseModule):
|
||||
"""Cascade Feature Fusion Unit in ICNet.
|
||||
|
||||
Args:
|
||||
low_channels (int): The number of input channels for
|
||||
low resolution feature map.
|
||||
high_channels (int): The number of input channels for
|
||||
high resolution feature map.
|
||||
out_channels (int): The number of output channels.
|
||||
conv_cfg (dict): Dictionary to construct and config conv layer.
|
||||
Default: None.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Dictionary to construct and config act layer.
|
||||
Default: dict(type='ReLU').
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
x (Tensor): The output tensor of shape (N, out_channels, H, W).
|
||||
x_low (Tensor): The output tensor of shape (N, out_channels, H, W)
|
||||
for Cascade Label Guidance in auxiliary heads.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
low_channels,
|
||||
high_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.align_corners = align_corners
|
||||
self.conv_low = ConvModule(
|
||||
low_channels,
|
||||
out_channels,
|
||||
3,
|
||||
padding=2,
|
||||
dilation=2,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.conv_high = ConvModule(
|
||||
high_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x_low, x_high):
|
||||
x_low = resize(
|
||||
x_low,
|
||||
size=x_high.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
# Note: Different from original paper, `x_low` is underwent
|
||||
# `self.conv_low` rather than another 1x1 conv classifier
|
||||
# before being used for auxiliary head.
|
||||
x_low = self.conv_low(x_low)
|
||||
x_high = self.conv_high(x_high)
|
||||
x = x_low + x_high
|
||||
x = F.relu(x, inplace=True)
|
||||
return x, x_low
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ICNeck(BaseModule):
|
||||
"""ICNet for Real-Time Semantic Segmentation on High-Resolution Images.
|
||||
|
||||
This head is the implementation of `ICHead
|
||||
<https://arxiv.org/abs/1704.08545>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input image channels. Default: 3.
|
||||
out_channels (int): The numbers of output feature channels.
|
||||
Default: 128.
|
||||
conv_cfg (dict): Dictionary to construct and config conv layer.
|
||||
Default: None.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Dictionary to construct and config act layer.
|
||||
Default: dict(type='ReLU').
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=(64, 256, 256),
|
||||
out_channels=128,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(in_channels) == 3, 'Length of input channels \
|
||||
must be 3!'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.align_corners = align_corners
|
||||
self.cff_24 = CascadeFeatureFusion(
|
||||
self.in_channels[2],
|
||||
self.in_channels[1],
|
||||
self.out_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
|
||||
self.cff_12 = CascadeFeatureFusion(
|
||||
self.out_channels,
|
||||
self.in_channels[0],
|
||||
self.out_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == 3, 'Length of input feature \
|
||||
maps must be 3!'
|
||||
|
||||
x_sub1, x_sub2, x_sub4 = inputs
|
||||
x_cff_24, x_24 = self.cff_24(x_sub4, x_sub2)
|
||||
x_cff_12, x_12 = self.cff_12(x_cff_24, x_sub1)
|
||||
# Note: `x_cff_12` is used for decode_head,
|
||||
# `x_24` and `x_12` are used for auxiliary head.
|
||||
return x_24, x_12, x_cff_12
|
||||
131
finetune/mmseg/models/necks/jpu.py
Normal file
131
finetune/mmseg/models/necks/jpu.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class JPU(BaseModule):
|
||||
"""FastFCN: Rethinking Dilated Convolution in the Backbone
|
||||
for Semantic Segmentation.
|
||||
|
||||
This Joint Pyramid Upsampling (JPU) neck is the implementation of
|
||||
`FastFCN <https://arxiv.org/abs/1903.11816>`_.
|
||||
|
||||
Args:
|
||||
in_channels (Tuple[int], optional): The number of input channels
|
||||
for each convolution operations before upsampling.
|
||||
Default: (512, 1024, 2048).
|
||||
mid_channels (int): The number of output channels of JPU.
|
||||
Default: 512.
|
||||
start_level (int): Index of the start input backbone level used to
|
||||
build the feature pyramid. Default: 0.
|
||||
end_level (int): Index of the end input backbone level (exclusive) to
|
||||
build the feature pyramid. Default: -1, which means the last level.
|
||||
dilations (tuple[int]): Dilation rate of each Depthwise
|
||||
Separable ConvModule. Default: (1, 2, 4, 8).
|
||||
align_corners (bool, optional): The align_corners argument of
|
||||
resize operation. Default: False.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=(512, 1024, 2048),
|
||||
mid_channels=512,
|
||||
start_level=0,
|
||||
end_level=-1,
|
||||
dilations=(1, 2, 4, 8),
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert isinstance(in_channels, tuple)
|
||||
assert isinstance(dilations, tuple)
|
||||
self.in_channels = in_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.start_level = start_level
|
||||
self.num_ins = len(in_channels)
|
||||
if end_level == -1:
|
||||
self.backbone_end_level = self.num_ins
|
||||
else:
|
||||
self.backbone_end_level = end_level
|
||||
assert end_level <= len(in_channels)
|
||||
|
||||
self.dilations = dilations
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.dilation_layers = nn.ModuleList()
|
||||
for i in range(self.start_level, self.backbone_end_level):
|
||||
conv_layer = nn.Sequential(
|
||||
ConvModule(
|
||||
self.in_channels[i],
|
||||
self.mid_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.conv_layers.append(conv_layer)
|
||||
for i in range(len(dilations)):
|
||||
dilation_layer = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=(self.backbone_end_level - self.start_level) *
|
||||
self.mid_channels,
|
||||
out_channels=self.mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=dilations[i],
|
||||
dilation=dilations[i],
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=norm_cfg,
|
||||
pw_act_cfg=act_cfg))
|
||||
self.dilation_layers.append(dilation_layer)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
assert len(inputs) == len(self.in_channels), 'Length of inputs must \
|
||||
be the same with self.in_channels!'
|
||||
|
||||
feats = [
|
||||
self.conv_layers[i - self.start_level](inputs[i])
|
||||
for i in range(self.start_level, self.backbone_end_level)
|
||||
]
|
||||
|
||||
h, w = feats[0].shape[2:]
|
||||
for i in range(1, len(feats)):
|
||||
feats[i] = resize(
|
||||
feats[i],
|
||||
size=(h, w),
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
feat = torch.cat(feats, dim=1)
|
||||
concat_feat = torch.cat([
|
||||
self.dilation_layers[i](feat) for i in range(len(self.dilations))
|
||||
],
|
||||
dim=1)
|
||||
|
||||
outs = []
|
||||
|
||||
# Default: outs[2] is the output of JPU for decoder head, outs[1] is
|
||||
# the feature map from backbone for auxiliary head. Additionally,
|
||||
# outs[0] can also be used for auxiliary head.
|
||||
for i in range(self.start_level, self.backbone_end_level - 1):
|
||||
outs.append(inputs[i])
|
||||
outs.append(concat_feat)
|
||||
return tuple(outs)
|
||||
118
finetune/mmseg/models/necks/mla_neck.py
Normal file
118
finetune/mmseg/models/necks/mla_neck.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class MLAModule(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels=[1024, 1024, 1024, 1024],
|
||||
out_channels=256,
|
||||
norm_cfg=None,
|
||||
act_cfg=None):
|
||||
super().__init__()
|
||||
self.channel_proj = nn.ModuleList()
|
||||
for i in range(len(in_channels)):
|
||||
self.channel_proj.append(
|
||||
ConvModule(
|
||||
in_channels=in_channels[i],
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.feat_extract = nn.ModuleList()
|
||||
for i in range(len(in_channels)):
|
||||
self.feat_extract.append(
|
||||
ConvModule(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
# feat_list -> [p2, p3, p4, p5]
|
||||
feat_list = []
|
||||
for x, conv in zip(inputs, self.channel_proj):
|
||||
feat_list.append(conv(x))
|
||||
|
||||
# feat_list -> [p5, p4, p3, p2]
|
||||
# mid_list -> [m5, m4, m3, m2]
|
||||
feat_list = feat_list[::-1]
|
||||
mid_list = []
|
||||
for feat in feat_list:
|
||||
if len(mid_list) == 0:
|
||||
mid_list.append(feat)
|
||||
else:
|
||||
mid_list.append(mid_list[-1] + feat)
|
||||
|
||||
# mid_list -> [m5, m4, m3, m2]
|
||||
# out_list -> [o2, o3, o4, o5]
|
||||
out_list = []
|
||||
for mid, conv in zip(mid_list, self.feat_extract):
|
||||
out_list.append(conv(mid))
|
||||
|
||||
return tuple(out_list)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MLANeck(nn.Module):
|
||||
"""Multi-level Feature Aggregation.
|
||||
|
||||
This neck is `The Multi-level Feature Aggregation construction of
|
||||
SETR <https://arxiv.org/abs/2012.15840>`_.
|
||||
|
||||
|
||||
Args:
|
||||
in_channels (List[int]): Number of input channels per scale.
|
||||
out_channels (int): Number of output channels (used at each scale).
|
||||
norm_layer (dict): Config dict for input normalization.
|
||||
Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True).
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
act_cfg (dict): Config dict for activation layer in ConvModule.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
norm_layer=dict(type='LN', eps=1e-6, requires_grad=True),
|
||||
norm_cfg=None,
|
||||
act_cfg=None):
|
||||
super().__init__()
|
||||
assert isinstance(in_channels, list)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
# In order to build general vision transformer backbone, we have to
|
||||
# move MLA to neck.
|
||||
self.norm = nn.ModuleList([
|
||||
build_norm_layer(norm_layer, in_channels[i])[1]
|
||||
for i in range(len(in_channels))
|
||||
])
|
||||
|
||||
self.mla = MLAModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
|
||||
# Convert from nchw to nlc
|
||||
outs = []
|
||||
for i in range(len(inputs)):
|
||||
x = inputs[i]
|
||||
n, c, h, w = x.shape
|
||||
x = x.reshape(n, c, h * w).transpose(2, 1).contiguous()
|
||||
x = self.norm[i](x)
|
||||
x = x.transpose(1, 2).reshape(n, c, h, w).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
outs = self.mla(outs)
|
||||
return tuple(outs)
|
||||
79
finetune/mmseg/models/necks/multilevel_neck.py
Normal file
79
finetune/mmseg/models/necks/multilevel_neck.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model.weight_init import xavier_init
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MultiLevelNeck(nn.Module):
|
||||
"""MultiLevelNeck.
|
||||
|
||||
A neck structure connect vit backbone and decoder_heads.
|
||||
|
||||
Args:
|
||||
in_channels (List[int]): Number of input channels per scale.
|
||||
out_channels (int): Number of output channels (used at each scale).
|
||||
scales (List[float]): Scale factors for each input feature map.
|
||||
Default: [0.5, 1, 2, 4]
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
act_cfg (dict): Config dict for activation layer in ConvModule.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
scales=[0.5, 1, 2, 4],
|
||||
norm_cfg=None,
|
||||
act_cfg=None):
|
||||
super().__init__()
|
||||
assert isinstance(in_channels, list)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.scales = scales
|
||||
self.num_outs = len(scales)
|
||||
self.lateral_convs = nn.ModuleList()
|
||||
self.convs = nn.ModuleList()
|
||||
for in_channel in in_channels:
|
||||
self.lateral_convs.append(
|
||||
ConvModule(
|
||||
in_channel,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
for _ in range(self.num_outs):
|
||||
self.convs.append(
|
||||
ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
# default init_weights for conv(msra) and norm in ConvModule
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
xavier_init(m, distribution='uniform')
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
inputs = [
|
||||
lateral_conv(inputs[i])
|
||||
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||
]
|
||||
# for len(inputs) not equal to self.num_outs
|
||||
if len(inputs) == 1:
|
||||
inputs = [inputs[0] for _ in range(self.num_outs)]
|
||||
outs = []
|
||||
for i in range(self.num_outs):
|
||||
x_resize = resize(
|
||||
inputs[i], scale_factor=self.scales[i], mode='bilinear')
|
||||
outs.append(self.convs[i](x_resize))
|
||||
return tuple(outs)
|
||||
Reference in New Issue
Block a user