init
This commit is contained in:
408
finetune/mmseg/models/backbones/fast_scnn.py
Normal file
408
finetune/mmseg/models/backbones/fast_scnn.py
Normal file
@@ -0,0 +1,408 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.models.decode_heads.psp_head import PPM
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import InvertedResidual, resize
|
||||
|
||||
|
||||
class LearningToDownsample(nn.Module):
|
||||
"""Learning to downsample module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
dw_channels (tuple[int]): Number of output channels of the first and
|
||||
the second depthwise conv (dwconv) layers.
|
||||
out_channels (int): Number of output channels of the whole
|
||||
'learning to downsample' module.
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
act_cfg (dict): Config of activation layers. Default:
|
||||
dict(type='ReLU')
|
||||
dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config
|
||||
of depthwise ConvModule. If it is 'default', it will be the same
|
||||
as `act_cfg`. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
dw_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
dw_act_cfg=None):
|
||||
super().__init__()
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.dw_act_cfg = dw_act_cfg
|
||||
dw_channels1 = dw_channels[0]
|
||||
dw_channels2 = dw_channels[1]
|
||||
|
||||
self.conv = ConvModule(
|
||||
in_channels,
|
||||
dw_channels1,
|
||||
3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.dsconv1 = DepthwiseSeparableConvModule(
|
||||
dw_channels1,
|
||||
dw_channels2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dw_act_cfg=self.dw_act_cfg)
|
||||
|
||||
self.dsconv2 = DepthwiseSeparableConvModule(
|
||||
dw_channels2,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dw_act_cfg=self.dw_act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.dsconv1(x)
|
||||
x = self.dsconv2(x)
|
||||
return x
|
||||
|
||||
|
||||
class GlobalFeatureExtractor(nn.Module):
|
||||
"""Global feature extractor module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels of the GFE module.
|
||||
Default: 64
|
||||
block_channels (tuple[int]): Tuple of ints. Each int specifies the
|
||||
number of output channels of each Inverted Residual module.
|
||||
Default: (64, 96, 128)
|
||||
out_channels(int): Number of output channels of the GFE module.
|
||||
Default: 128
|
||||
expand_ratio (int): Adjusts number of channels of the hidden layer
|
||||
in InvertedResidual by this amount.
|
||||
Default: 6
|
||||
num_blocks (tuple[int]): Tuple of ints. Each int specifies the
|
||||
number of times each Inverted Residual module is repeated.
|
||||
The repeated Inverted Residual modules are called a 'group'.
|
||||
Default: (3, 3, 3)
|
||||
strides (tuple[int]): Tuple of ints. Each int specifies
|
||||
the downsampling factor of each 'group'.
|
||||
Default: (2, 2, 1)
|
||||
pool_scales (tuple[int]): Tuple of ints. Each int specifies
|
||||
the parameter required in 'global average pooling' within PPM.
|
||||
Default: (1, 2, 3, 6)
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
act_cfg (dict): Config of activation layers. Default:
|
||||
dict(type='ReLU')
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=64,
|
||||
block_channels=(64, 96, 128),
|
||||
out_channels=128,
|
||||
expand_ratio=6,
|
||||
num_blocks=(3, 3, 3),
|
||||
strides=(2, 2, 1),
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False):
|
||||
super().__init__()
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
assert len(block_channels) == len(num_blocks) == 3
|
||||
self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
|
||||
num_blocks[0], strides[0],
|
||||
expand_ratio)
|
||||
self.bottleneck2 = self._make_layer(block_channels[0],
|
||||
block_channels[1], num_blocks[1],
|
||||
strides[1], expand_ratio)
|
||||
self.bottleneck3 = self._make_layer(block_channels[1],
|
||||
block_channels[2], num_blocks[2],
|
||||
strides[2], expand_ratio)
|
||||
self.ppm = PPM(
|
||||
pool_scales,
|
||||
block_channels[2],
|
||||
block_channels[2] // 4,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=align_corners)
|
||||
|
||||
self.out = ConvModule(
|
||||
block_channels[2] * 2,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def _make_layer(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
blocks,
|
||||
stride=1,
|
||||
expand_ratio=6):
|
||||
layers = [
|
||||
InvertedResidual(
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
expand_ratio,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
]
|
||||
for i in range(1, blocks):
|
||||
layers.append(
|
||||
InvertedResidual(
|
||||
out_channels,
|
||||
out_channels,
|
||||
1,
|
||||
expand_ratio,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.bottleneck1(x)
|
||||
x = self.bottleneck2(x)
|
||||
x = self.bottleneck3(x)
|
||||
x = torch.cat([x, *self.ppm(x)], dim=1)
|
||||
x = self.out(x)
|
||||
return x
|
||||
|
||||
|
||||
class FeatureFusionModule(nn.Module):
|
||||
"""Feature fusion module.
|
||||
|
||||
Args:
|
||||
higher_in_channels (int): Number of input channels of the
|
||||
higher-resolution branch.
|
||||
lower_in_channels (int): Number of input channels of the
|
||||
lower-resolution branch.
|
||||
out_channels (int): Number of output channels.
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
dwconv_act_cfg (dict): Config of activation layers in 3x3 conv.
|
||||
Default: dict(type='ReLU').
|
||||
conv_act_cfg (dict): Config of activation layers in the two 1x1 conv.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
higher_in_channels,
|
||||
lower_in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dwconv_act_cfg=dict(type='ReLU'),
|
||||
conv_act_cfg=None,
|
||||
align_corners=False):
|
||||
super().__init__()
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.dwconv_act_cfg = dwconv_act_cfg
|
||||
self.conv_act_cfg = conv_act_cfg
|
||||
self.align_corners = align_corners
|
||||
self.dwconv = ConvModule(
|
||||
lower_in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
groups=out_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.dwconv_act_cfg)
|
||||
self.conv_lower_res = ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.conv_act_cfg)
|
||||
|
||||
self.conv_higher_res = ConvModule(
|
||||
higher_in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.conv_act_cfg)
|
||||
|
||||
self.relu = nn.ReLU(True)
|
||||
|
||||
def forward(self, higher_res_feature, lower_res_feature):
|
||||
lower_res_feature = resize(
|
||||
lower_res_feature,
|
||||
size=higher_res_feature.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
lower_res_feature = self.dwconv(lower_res_feature)
|
||||
lower_res_feature = self.conv_lower_res(lower_res_feature)
|
||||
|
||||
higher_res_feature = self.conv_higher_res(higher_res_feature)
|
||||
out = higher_res_feature + lower_res_feature
|
||||
return self.relu(out)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FastSCNN(BaseModule):
|
||||
"""Fast-SCNN Backbone.
|
||||
|
||||
This backbone is the implementation of `Fast-SCNN: Fast Semantic
|
||||
Segmentation Network <https://arxiv.org/abs/1902.04502>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
downsample_dw_channels (tuple[int]): Number of output channels after
|
||||
the first conv layer & the second conv layer in
|
||||
Learning-To-Downsample (LTD) module.
|
||||
Default: (32, 48).
|
||||
global_in_channels (int): Number of input channels of
|
||||
Global Feature Extractor(GFE).
|
||||
Equal to number of output channels of LTD.
|
||||
Default: 64.
|
||||
global_block_channels (tuple[int]): Tuple of integers that describe
|
||||
the output channels for each of the MobileNet-v2 bottleneck
|
||||
residual blocks in GFE.
|
||||
Default: (64, 96, 128).
|
||||
global_block_strides (tuple[int]): Tuple of integers
|
||||
that describe the strides (downsampling factors) for each of the
|
||||
MobileNet-v2 bottleneck residual blocks in GFE.
|
||||
Default: (2, 2, 1).
|
||||
global_out_channels (int): Number of output channels of GFE.
|
||||
Default: 128.
|
||||
higher_in_channels (int): Number of input channels of the higher
|
||||
resolution branch in FFM.
|
||||
Equal to global_in_channels.
|
||||
Default: 64.
|
||||
lower_in_channels (int): Number of input channels of the lower
|
||||
resolution branch in FFM.
|
||||
Equal to global_out_channels.
|
||||
Default: 128.
|
||||
fusion_out_channels (int): Number of output channels of FFM.
|
||||
Default: 128.
|
||||
out_indices (tuple): Tuple of indices of list
|
||||
[higher_res_features, lower_res_features, fusion_output].
|
||||
Often set to (0,1,2) to enable aux. heads.
|
||||
Default: (0, 1, 2).
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
act_cfg (dict): Config of activation layers. Default:
|
||||
dict(type='ReLU')
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False
|
||||
dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config
|
||||
of depthwise ConvModule. If it is 'default', it will be the same
|
||||
as `act_cfg`. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
downsample_dw_channels=(32, 48),
|
||||
global_in_channels=64,
|
||||
global_block_channels=(64, 96, 128),
|
||||
global_block_strides=(2, 2, 1),
|
||||
global_out_channels=128,
|
||||
higher_in_channels=64,
|
||||
lower_in_channels=128,
|
||||
fusion_out_channels=128,
|
||||
out_indices=(0, 1, 2),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
dw_act_cfg=None,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg)
|
||||
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
|
||||
if global_in_channels != higher_in_channels:
|
||||
raise AssertionError('Global Input Channels must be the same \
|
||||
with Higher Input Channels!')
|
||||
elif global_out_channels != lower_in_channels:
|
||||
raise AssertionError('Global Output Channels must be the same \
|
||||
with Lower Input Channels!')
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.downsample_dw_channels1 = downsample_dw_channels[0]
|
||||
self.downsample_dw_channels2 = downsample_dw_channels[1]
|
||||
self.global_in_channels = global_in_channels
|
||||
self.global_block_channels = global_block_channels
|
||||
self.global_block_strides = global_block_strides
|
||||
self.global_out_channels = global_out_channels
|
||||
self.higher_in_channels = higher_in_channels
|
||||
self.lower_in_channels = lower_in_channels
|
||||
self.fusion_out_channels = fusion_out_channels
|
||||
self.out_indices = out_indices
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.align_corners = align_corners
|
||||
self.learning_to_downsample = LearningToDownsample(
|
||||
in_channels,
|
||||
downsample_dw_channels,
|
||||
global_in_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
dw_act_cfg=dw_act_cfg)
|
||||
self.global_feature_extractor = GlobalFeatureExtractor(
|
||||
global_in_channels,
|
||||
global_block_channels,
|
||||
global_out_channels,
|
||||
strides=self.global_block_strides,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
self.feature_fusion = FeatureFusionModule(
|
||||
higher_in_channels,
|
||||
lower_in_channels,
|
||||
fusion_out_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dwconv_act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
|
||||
def forward(self, x):
|
||||
higher_res_features = self.learning_to_downsample(x)
|
||||
lower_res_features = self.global_feature_extractor(higher_res_features)
|
||||
fusion_output = self.feature_fusion(higher_res_features,
|
||||
lower_res_features)
|
||||
|
||||
outs = [higher_res_features, lower_res_features, fusion_output]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
||||
Reference in New Issue
Block a user