init
This commit is contained in:
422
finetune/mmseg/models/backbones/stdc.py
Normal file
422
finetune/mmseg/models/backbones/stdc.py
Normal file
@@ -0,0 +1,422 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Modified from https://github.com/MichaelFan01/STDC-Seg."""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .bisenetv1 import AttentionRefinementModule
|
||||
|
||||
|
||||
class STDCModule(BaseModule):
|
||||
"""STDCModule.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels before scaling.
|
||||
stride (int): The number of stride for the first conv layer.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
act_cfg (dict): The activation config for conv layers.
|
||||
num_convs (int): Numbers of conv layers.
|
||||
fusion_type (str): Type of fusion operation. Default: 'add'.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
norm_cfg=None,
|
||||
act_cfg=None,
|
||||
num_convs=4,
|
||||
fusion_type='add',
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert num_convs > 1
|
||||
assert fusion_type in ['add', 'cat']
|
||||
self.stride = stride
|
||||
self.with_downsample = True if self.stride == 2 else False
|
||||
self.fusion_type = fusion_type
|
||||
|
||||
self.layers = ModuleList()
|
||||
conv_0 = ConvModule(
|
||||
in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg)
|
||||
|
||||
if self.with_downsample:
|
||||
self.downsample = ConvModule(
|
||||
out_channels // 2,
|
||||
out_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=out_channels // 2,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
if self.fusion_type == 'add':
|
||||
self.layers.append(nn.Sequential(conv_0, self.downsample))
|
||||
self.skip = Sequential(
|
||||
ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None))
|
||||
else:
|
||||
self.layers.append(conv_0)
|
||||
self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
|
||||
else:
|
||||
self.layers.append(conv_0)
|
||||
|
||||
for i in range(1, num_convs):
|
||||
out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i
|
||||
self.layers.append(
|
||||
ConvModule(
|
||||
out_channels // 2**i,
|
||||
out_channels // out_factor,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.fusion_type == 'add':
|
||||
out = self.forward_add(inputs)
|
||||
else:
|
||||
out = self.forward_cat(inputs)
|
||||
return out
|
||||
|
||||
def forward_add(self, inputs):
|
||||
layer_outputs = []
|
||||
x = inputs.clone()
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
layer_outputs.append(x)
|
||||
if self.with_downsample:
|
||||
inputs = self.skip(inputs)
|
||||
|
||||
return torch.cat(layer_outputs, dim=1) + inputs
|
||||
|
||||
def forward_cat(self, inputs):
|
||||
x0 = self.layers[0](inputs)
|
||||
layer_outputs = [x0]
|
||||
for i, layer in enumerate(self.layers[1:]):
|
||||
if i == 0:
|
||||
if self.with_downsample:
|
||||
x = layer(self.downsample(x0))
|
||||
else:
|
||||
x = layer(x0)
|
||||
else:
|
||||
x = layer(x)
|
||||
layer_outputs.append(x)
|
||||
if self.with_downsample:
|
||||
layer_outputs[0] = self.skip(x0)
|
||||
return torch.cat(layer_outputs, dim=1)
|
||||
|
||||
|
||||
class FeatureFusionModule(BaseModule):
|
||||
"""Feature Fusion Module. This module is different from FeatureFusionModule
|
||||
in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter
|
||||
channel number is calculated by given `scale_factor`, while
|
||||
FeatureFusionModule in BiSeNetV1 only uses one ConvModule in
|
||||
`self.conv_atten`.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
scale_factor (int): The number of channel scale factor.
|
||||
Default: 4.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): The activation config for conv layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
scale_factor=4,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
channels = out_channels // scale_factor
|
||||
self.conv0 = ConvModule(
|
||||
in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
self.attention = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
ConvModule(
|
||||
out_channels,
|
||||
channels,
|
||||
1,
|
||||
norm_cfg=None,
|
||||
bias=False,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
channels,
|
||||
out_channels,
|
||||
1,
|
||||
norm_cfg=None,
|
||||
bias=False,
|
||||
act_cfg=None), nn.Sigmoid())
|
||||
|
||||
def forward(self, spatial_inputs, context_inputs):
|
||||
inputs = torch.cat([spatial_inputs, context_inputs], dim=1)
|
||||
x = self.conv0(inputs)
|
||||
attn = self.attention(x)
|
||||
x_attn = x * attn
|
||||
return x_attn + x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class STDCNet(BaseModule):
|
||||
"""This backbone is the implementation of `Rethinking BiSeNet For Real-time
|
||||
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
|
||||
|
||||
Args:
|
||||
stdc_type (int): The type of backbone structure,
|
||||
`STDCNet1` and`STDCNet2` denotes two main backbones in paper,
|
||||
whose FLOPs is 813M and 1446M, respectively.
|
||||
in_channels (int): The num of input_channels.
|
||||
channels (tuple[int]): The output channels for each stage.
|
||||
bottleneck_type (str): The type of STDC Module type, the value must
|
||||
be 'add' or 'cat'.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
act_cfg (dict): The activation config for conv layers.
|
||||
num_convs (int): Numbers of conv layer at each STDC Module.
|
||||
Default: 4.
|
||||
with_final_conv (bool): Whether add a conv layer at the Module output.
|
||||
Default: True.
|
||||
pretrained (str, optional): Model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> stdc_type = 'STDCNet1'
|
||||
>>> in_channels = 3
|
||||
>>> channels = (32, 64, 256, 512, 1024)
|
||||
>>> bottleneck_type = 'cat'
|
||||
>>> inputs = torch.rand(1, 3, 1024, 2048)
|
||||
>>> self = STDCNet(stdc_type, in_channels,
|
||||
... channels, bottleneck_type).eval()
|
||||
>>> outputs = self.forward(inputs)
|
||||
>>> for i in range(len(outputs)):
|
||||
... print(f'outputs[{i}].shape = {outputs[i].shape}')
|
||||
outputs[0].shape = torch.Size([1, 256, 128, 256])
|
||||
outputs[1].shape = torch.Size([1, 512, 64, 128])
|
||||
outputs[2].shape = torch.Size([1, 1024, 32, 64])
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
'STDCNet1': [(2, 1), (2, 1), (2, 1)],
|
||||
'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)]
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
stdc_type,
|
||||
in_channels,
|
||||
channels,
|
||||
bottleneck_type,
|
||||
norm_cfg,
|
||||
act_cfg,
|
||||
num_convs=4,
|
||||
with_final_conv=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert stdc_type in self.arch_settings, \
|
||||
f'invalid structure {stdc_type} for STDCNet.'
|
||||
assert bottleneck_type in ['add', 'cat'],\
|
||||
f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}'
|
||||
|
||||
assert len(channels) == 5,\
|
||||
f'invalid channels length {len(channels)} for STDCNet.'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.stage_strides = self.arch_settings[stdc_type]
|
||||
self.prtrained = pretrained
|
||||
self.num_convs = num_convs
|
||||
self.with_final_conv = with_final_conv
|
||||
|
||||
self.stages = ModuleList([
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels[0],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
self.channels[0],
|
||||
self.channels[1],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
])
|
||||
# `self.num_shallow_features` is the number of shallow modules in
|
||||
# `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper.
|
||||
# They are both not used for following modules like Attention
|
||||
# Refinement Module and Feature Fusion Module.
|
||||
# Thus they would be cut from `outs`. Please refer to Figure 4
|
||||
# of original paper for more details.
|
||||
self.num_shallow_features = len(self.stages)
|
||||
|
||||
for strides in self.stage_strides:
|
||||
idx = len(self.stages) - 1
|
||||
self.stages.append(
|
||||
self._make_stage(self.channels[idx], self.channels[idx + 1],
|
||||
strides, norm_cfg, act_cfg, bottleneck_type))
|
||||
# After appending, `self.stages` is a ModuleList including several
|
||||
# shallow modules and STDCModules.
|
||||
# (len(self.stages) ==
|
||||
# self.num_shallow_features + len(self.stage_strides))
|
||||
if self.with_final_conv:
|
||||
self.final_conv = ConvModule(
|
||||
self.channels[-1],
|
||||
max(1024, self.channels[-1]),
|
||||
1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def _make_stage(self, in_channels, out_channels, strides, norm_cfg,
|
||||
act_cfg, bottleneck_type):
|
||||
layers = []
|
||||
for i, stride in enumerate(strides):
|
||||
layers.append(
|
||||
STDCModule(
|
||||
in_channels if i == 0 else out_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
norm_cfg,
|
||||
act_cfg,
|
||||
num_convs=self.num_convs,
|
||||
fusion_type=bottleneck_type))
|
||||
return Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
for stage in self.stages:
|
||||
x = stage(x)
|
||||
outs.append(x)
|
||||
if self.with_final_conv:
|
||||
outs[-1] = self.final_conv(outs[-1])
|
||||
outs = outs[self.num_shallow_features:]
|
||||
return tuple(outs)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class STDCContextPathNet(BaseModule):
|
||||
"""STDCNet with Context Path. The `outs` below is a list of three feature
|
||||
maps from deep to shallow, whose height and width is from small to big,
|
||||
respectively. The biggest feature map of `outs` is outputted for
|
||||
`STDCHead`, where Detail Loss would be calculated by Detail Ground-truth.
|
||||
The other two feature maps are used for Attention Refinement Module,
|
||||
respectively. Besides, the biggest feature map of `outs` and the last
|
||||
output of Attention Refinement Module are concatenated for Feature Fusion
|
||||
Module. Then, this fusion feature map `feat_fuse` would be outputted for
|
||||
`decode_head`. More details please refer to Figure 4 of original paper.
|
||||
|
||||
Args:
|
||||
backbone_cfg (dict): Config dict for stdc backbone.
|
||||
last_in_channels (tuple(int)), The number of channels of last
|
||||
two feature maps from stdc backbone. Default: (1024, 512).
|
||||
out_channels (int): The channels of output feature maps.
|
||||
Default: 128.
|
||||
ffm_cfg (dict): Config dict for Feature Fusion Module. Default:
|
||||
`dict(in_channels=512, out_channels=256, scale_factor=4)`.
|
||||
upsample_mode (str): Algorithm used for upsampling:
|
||||
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||
``'trilinear'``. Default: ``'nearest'``.
|
||||
align_corners (str): align_corners argument of F.interpolate. It
|
||||
must be `None` if upsample_mode is ``'nearest'``. Default: None.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Return:
|
||||
outputs (tuple): The tuple of list of output feature map for
|
||||
auxiliary heads and decoder head.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_cfg,
|
||||
last_in_channels=(1024, 512),
|
||||
out_channels=128,
|
||||
ffm_cfg=dict(
|
||||
in_channels=512, out_channels=256, scale_factor=4),
|
||||
upsample_mode='nearest',
|
||||
align_corners=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.backbone = MODELS.build(backbone_cfg)
|
||||
self.arms = ModuleList()
|
||||
self.convs = ModuleList()
|
||||
for channels in last_in_channels:
|
||||
self.arms.append(AttentionRefinementModule(channels, out_channels))
|
||||
self.convs.append(
|
||||
ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg))
|
||||
self.conv_avg = ConvModule(
|
||||
last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg)
|
||||
|
||||
self.ffm = FeatureFusionModule(**ffm_cfg)
|
||||
|
||||
self.upsample_mode = upsample_mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
outs = list(self.backbone(x))
|
||||
avg = F.adaptive_avg_pool2d(outs[-1], 1)
|
||||
avg_feat = self.conv_avg(avg)
|
||||
|
||||
feature_up = resize(
|
||||
avg_feat,
|
||||
size=outs[-1].shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=self.align_corners)
|
||||
arms_out = []
|
||||
for i in range(len(self.arms)):
|
||||
x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up
|
||||
feature_up = resize(
|
||||
x_arm,
|
||||
size=outs[len(outs) - 1 - i - 1].shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=self.align_corners)
|
||||
feature_up = self.convs[i](feature_up)
|
||||
arms_out.append(feature_up)
|
||||
|
||||
feat_fuse = self.ffm(outs[0], arms_out[1])
|
||||
|
||||
# The `outputs` has four feature maps.
|
||||
# `outs[0]` is outputted for `STDCHead` auxiliary head.
|
||||
# Two feature maps of `arms_out` are outputted for auxiliary head.
|
||||
# `feat_fuse` is outputted for decoder head.
|
||||
outputs = [outs[0]] + list(arms_out) + [feat_fuse]
|
||||
return tuple(outputs)
|
||||
Reference in New Issue
Block a user