This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

View File

@@ -0,0 +1,35 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .beit import BEiT
from .bisenetv1 import BiSeNetV1
from .bisenetv2 import BiSeNetV2
from .cgnet import CGNet
from .ddrnet import DDRNet
from .erfnet import ERFNet
from .fast_scnn import FastSCNN
from .hrnet import HRNet
from .icnet import ICNet
from .mae import MAE
from .mit import MixVisionTransformer
from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetV3
from .mscan import MSCAN
from .pidnet import PIDNet
from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1c, ResNetV1d
from .resnext import ResNeXt
from .stdc import STDCContextPathNet, STDCNet
from .swin import SwinTransformer
from .timm_backbone import TIMMBackbone
from .twins import PCPVT, SVT
from .unet import UNet
from .vit import VisionTransformer
from .vpd import VPD
__all__ = [
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN',
'DDRNet', 'VPD'
]

View File

@@ -0,0 +1,554 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import (constant_init, kaiming_init,
trunc_normal_)
from mmengine.runner.checkpoint import _load_checkpoint
from scipy import interpolate
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.registry import MODELS
from ..utils import PatchEmbed
from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer
class BEiTAttention(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative
position bias.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (tuple[int]): The height and width of the window.
bias (bool): The option to add leanable bias for q, k, v. If bias is
True, it will add leanable bias. If bias is 'qv_bias', it will only
add leanable bias for q, v. If bias is False, it will not add bias
for q, k, v. Default to 'qv_bias'.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float): Dropout ratio of output. Default: 0.
init_cfg (dict | None, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size,
bias='qv_bias',
qk_scale=None,
attn_drop_rate=0.,
proj_drop_rate=0.,
init_cfg=None,
**kwargs):
super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads
self.bias = bias
self.scale = qk_scale or head_embed_dims**-0.5
qkv_bias = bias
if bias == 'qv_bias':
self._init_qv_bias()
qkv_bias = False
self.window_size = window_size
self._init_rel_pos_embedding()
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_rate)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop_rate)
def _init_qv_bias(self):
self.q_bias = nn.Parameter(torch.zeros(self.embed_dims))
self.v_bias = nn.Parameter(torch.zeros(self.embed_dims))
def _init_rel_pos_embedding(self):
Wh, Ww = self.window_size
# cls to token & token 2 cls & cls to cls
self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3
# relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH)
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, self.num_heads))
# get pair-wise relative position index for
# each token inside the window
coords_h = torch.arange(Wh)
coords_w = torch.arange(Ww)
# coords shape is (2, Wh, Ww)
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
# coords_flatten shape is (2, Wh*Ww)
coords_flatten = torch.flatten(coords, 1)
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :])
# relative_coords shape is (Wh*Ww, Wh*Ww, 2)
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
# shift to start from 0
relative_coords[:, :, 0] += Wh - 1
relative_coords[:, :, 1] += Ww - 1
relative_coords[:, :, 0] *= 2 * Ww - 1
relative_position_index = torch.zeros(
size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype)
# relative_position_index shape is (Wh*Ww, Wh*Ww)
relative_position_index[1:, 1:] = relative_coords.sum(-1)
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer('relative_position_index',
relative_position_index)
def init_weights(self):
trunc_normal_(self.relative_position_bias_table, std=0.02)
def forward(self, x):
"""
Args:
x (tensor): input features with shape of (num_windows*B, N, C).
"""
B, N, C = x.shape
if self.bias == 'qv_bias':
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
else:
qkv = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
Wh = self.window_size[0]
Ww = self.window_size[1]
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
Wh * Ww + 1, Wh * Ww + 1, -1)
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class BEiTTransformerEncoderLayer(VisionTransformerEncoderLayer):
"""Implements one encoder layer in Vision Transformer.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
attn_drop_rate (float): The drop out rate for attention layer.
Default: 0.0.
drop_path_rate (float): Stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
bias (bool): The option to add leanable bias for q, k, v. If bias is
True, it will add leanable bias. If bias is 'qv_bias', it will only
add leanable bias for q, v. If bias is False, it will not add bias
for q, k, v. Default to 'qv_bias'.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
window_size (tuple[int], optional): The height and width of the window.
Default: None.
init_values (float, optional): Initialize the values of BEiTAttention
and FFN with learnable scaling. Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
bias='qv_bias',
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
window_size=None,
attn_cfg=dict(),
ffn_cfg=dict(add_identity=False),
init_values=None):
attn_cfg.update(dict(window_size=window_size, qk_scale=None))
super().__init__(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=feedforward_channels,
attn_drop_rate=attn_drop_rate,
drop_path_rate=0.,
drop_rate=0.,
num_fcs=num_fcs,
qkv_bias=bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
attn_cfg=attn_cfg,
ffn_cfg=ffn_cfg)
# NOTE: drop path for stochastic depth, we shall see if
# this is better than dropout here
dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate)
self.drop_path = build_dropout(
dropout_layer) if dropout_layer else nn.Identity()
self.gamma_1 = nn.Parameter(
init_values * torch.ones(embed_dims), requires_grad=True)
self.gamma_2 = nn.Parameter(
init_values * torch.ones(embed_dims), requires_grad=True)
def build_attn(self, attn_cfg):
self.attn = BEiTAttention(**attn_cfg)
def forward(self, x):
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x)))
return x
@MODELS.register_module()
class BEiT(BaseModule):
"""BERT Pre-Training of Image Transformers.
Args:
img_size (int | tuple): Input image size. Default: 224.
patch_size (int): The patch size. Default: 16.
in_channels (int): Number of input channels. Default: 3.
embed_dims (int): Embedding dimension. Default: 768.
num_layers (int): Depth of transformer. Default: 12.
num_heads (int): Number of attention heads. Default: 12.
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
Default: 4.
out_indices (list | tuple | int): Output from which stages.
Default: -1.
qv_bias (bool): Enable bias for qv if True. Default: True.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
drop_path_rate (float): Stochastic depth rate. Default 0.0.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
Default: False.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Default: False.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
pretrained (str, optional): Model pretrained path. Default: None.
init_values (float): Initialize the values of BEiTAttention and FFN
with learnable scaling.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
img_size=224,
patch_size=16,
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_indices=-1,
qv_bias=True,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
patch_norm=False,
final_norm=False,
num_fcs=2,
norm_eval=False,
pretrained=None,
init_values=0.1,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, tuple):
if len(img_size) == 1:
img_size = to_2tuple(img_size[0])
assert len(img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.in_channels = in_channels
self.img_size = img_size
self.patch_size = patch_size
self.norm_eval = norm_eval
self.pretrained = pretrained
self.num_layers = num_layers
self.embed_dims = embed_dims
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.attn_drop_rate = attn_drop_rate
self.drop_path_rate = drop_path_rate
self.num_fcs = num_fcs
self.qv_bias = qv_bias
self.act_cfg = act_cfg
self.norm_cfg = norm_cfg
self.patch_norm = patch_norm
self.init_values = init_values
self.window_size = (img_size[0] // patch_size,
img_size[1] // patch_size)
self.patch_shape = self.window_size
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self._build_patch_embedding()
self._build_layers()
if isinstance(out_indices, int):
if out_indices == -1:
out_indices = num_layers - 1
self.out_indices = [out_indices]
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
self.out_indices = out_indices
else:
raise TypeError('out_indices must be type of int, list or tuple')
self.final_norm = final_norm
if final_norm:
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
def _build_patch_embedding(self):
"""Build patch embedding layer."""
self.patch_embed = PatchEmbed(
in_channels=self.in_channels,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=self.patch_size,
stride=self.patch_size,
padding=0,
norm_cfg=self.norm_cfg if self.patch_norm else None,
init_cfg=None)
def _build_layers(self):
"""Build transformer encoding layers."""
dpr = [
x.item()
for x in torch.linspace(0, self.drop_path_rate, self.num_layers)
]
self.layers = ModuleList()
for i in range(self.num_layers):
self.layers.append(
BEiTTransformerEncoderLayer(
embed_dims=self.embed_dims,
num_heads=self.num_heads,
feedforward_channels=self.mlp_ratio * self.embed_dims,
attn_drop_rate=self.attn_drop_rate,
drop_path_rate=dpr[i],
num_fcs=self.num_fcs,
bias='qv_bias' if self.qv_bias else False,
act_cfg=self.act_cfg,
norm_cfg=self.norm_cfg,
window_size=self.window_size,
init_values=self.init_values))
@property
def norm1(self):
return getattr(self, self.norm1_name)
def _geometric_sequence_interpolation(self, src_size, dst_size, sequence,
num):
"""Get new sequence via geometric sequence interpolation.
Args:
src_size (int): Pos_embedding size in pre-trained model.
dst_size (int): Pos_embedding size in the current model.
sequence (tensor): The relative position bias of the pretrain
model after removing the extra tokens.
num (int): Number of attention heads.
Returns:
new_sequence (tensor): Geometric sequence interpolate the
pre-trained relative position bias to the size of
the current model.
"""
def geometric_progression(a, r, n):
return a * (1.0 - r**n) / (1.0 - r)
# Here is a binary function.
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
# The position of each interpolated point is determined
# by the ratio obtained by dichotomy.
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q**(i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
# Interpolation functions are being executed and called.
new_sequence = []
for i in range(num):
z = sequence[:, i].view(src_size, src_size).float().numpy()
f = interpolate.interp2d(x, y, z, kind='cubic')
new_sequence.append(
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(sequence))
new_sequence = torch.cat(new_sequence, dim=-1)
return new_sequence
def resize_rel_pos_embed(self, checkpoint):
"""Resize relative pos_embed weights.
This function is modified from
https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/checkpoint.py. # noqa: E501
Copyright (c) Microsoft Corporation
Licensed under the MIT License
Args:
checkpoint (dict): Key and value of the pretrain model.
Returns:
state_dict (dict): Interpolate the relative pos_embed weights
in the pre-train model to the current model size.
"""
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
all_keys = list(state_dict.keys())
for key in all_keys:
if 'relative_position_index' in key:
state_dict.pop(key)
# In order to keep the center of pos_bias as consistent as
# possible after interpolation, and vice versa in the edge
# area, the geometric sequence interpolation method is adopted.
if 'relative_position_bias_table' in key:
rel_pos_bias = state_dict[key]
src_num_pos, num_attn_heads = rel_pos_bias.size()
dst_num_pos, _ = self.state_dict()[key].size()
dst_patch_shape = self.patch_shape
if dst_patch_shape[0] != dst_patch_shape[1]:
raise NotImplementedError()
# Count the number of extra tokens.
num_extra_tokens = dst_num_pos - (
dst_patch_shape[0] * 2 - 1) * (
dst_patch_shape[1] * 2 - 1)
src_size = int((src_num_pos - num_extra_tokens)**0.5)
dst_size = int((dst_num_pos - num_extra_tokens)**0.5)
if src_size != dst_size:
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
new_rel_pos_bias = self._geometric_sequence_interpolation(
src_size, dst_size, rel_pos_bias, num_attn_heads)
new_rel_pos_bias = torch.cat(
(new_rel_pos_bias, extra_tokens), dim=0)
state_dict[key] = new_rel_pos_bias
return state_dict
def init_weights(self):
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
self.apply(_init_weights)
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
checkpoint = _load_checkpoint(
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
state_dict = self.resize_rel_pos_embed(checkpoint)
self.load_state_dict(state_dict, False)
elif self.init_cfg is not None:
super().init_weights()
else:
# We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
# Copyright 2019 Ross Wightman
# Licensed under the Apache License, Version 2.0 (the "License")
trunc_normal_(self.cls_token, std=.02)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
if 'ffn' in n:
nn.init.normal_(m.bias, mean=0., std=1e-6)
else:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
kaiming_init(m, mode='fan_in', bias=0.)
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m, val=1.0, bias=0.)
def forward(self, inputs):
B = inputs.shape[0]
x, hw_shape = self.patch_embed(inputs)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1:
if self.final_norm:
x = self.norm1(x)
if i in self.out_indices:
# Remove class token and reshape token for decoder head
out = x[:, 1:]
B, _, C = out.shape
out = out.reshape(B, hw_shape[0], hw_shape[1],
C).permute(0, 3, 1, 2).contiguous()
outs.append(out)
return tuple(outs)
def train(self, mode=True):
super().train(mode)
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, nn.LayerNorm):
m.eval()

View File

@@ -0,0 +1,332 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmseg.registry import MODELS
from ..utils import resize
class SpatialPath(BaseModule):
"""Spatial Path to preserve the spatial size of the original input image
and encode affluent spatial information.
Args:
in_channels(int): The number of channels of input
image. Default: 3.
num_channels (Tuple[int]): The number of channels of
each layers in Spatial Path.
Default: (64, 64, 64, 128).
Returns:
x (torch.Tensor): Feature map for Feature Fusion Module.
"""
def __init__(self,
in_channels=3,
num_channels=(64, 64, 64, 128),
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
assert len(num_channels) == 4, 'Length of input channels \
of Spatial Path must be 4!'
self.layers = []
for i in range(len(num_channels)):
layer_name = f'layer{i + 1}'
self.layers.append(layer_name)
if i == 0:
self.add_module(
layer_name,
ConvModule(
in_channels=in_channels,
out_channels=num_channels[i],
kernel_size=7,
stride=2,
padding=3,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
elif i == len(num_channels) - 1:
self.add_module(
layer_name,
ConvModule(
in_channels=num_channels[i - 1],
out_channels=num_channels[i],
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
else:
self.add_module(
layer_name,
ConvModule(
in_channels=num_channels[i - 1],
out_channels=num_channels[i],
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def forward(self, x):
for i, layer_name in enumerate(self.layers):
layer_stage = getattr(self, layer_name)
x = layer_stage(x)
return x
class AttentionRefinementModule(BaseModule):
"""Attention Refinement Module (ARM) to refine the features of each stage.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
Returns:
x_out (torch.Tensor): Feature map of Attention Refinement Module.
"""
def __init__(self,
in_channels,
out_channel,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.conv_layer = ConvModule(
in_channels=in_channels,
out_channels=out_channel,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.atten_conv_layer = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
ConvModule(
in_channels=out_channel,
out_channels=out_channel,
kernel_size=1,
bias=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None), nn.Sigmoid())
def forward(self, x):
x = self.conv_layer(x)
x_atten = self.atten_conv_layer(x)
x_out = x * x_atten
return x_out
class ContextPath(BaseModule):
"""Context Path to provide sufficient receptive field.
Args:
backbone_cfg:(dict): Config of backbone of
Context Path.
context_channels (Tuple[int]): The number of channel numbers
of various modules in Context Path.
Default: (128, 256, 512).
align_corners (bool, optional): The align_corners argument of
resize operation. Default: False.
Returns:
x_16_up, x_32_up (torch.Tensor, torch.Tensor): Two feature maps
undergoing upsampling from 1/16 and 1/32 downsampling
feature maps. These two feature maps are used for Feature
Fusion Module and Auxiliary Head.
"""
def __init__(self,
backbone_cfg,
context_channels=(128, 256, 512),
align_corners=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
assert len(context_channels) == 3, 'Length of input channels \
of Context Path must be 3!'
self.backbone = MODELS.build(backbone_cfg)
self.align_corners = align_corners
self.arm16 = AttentionRefinementModule(context_channels[1],
context_channels[0])
self.arm32 = AttentionRefinementModule(context_channels[2],
context_channels[0])
self.conv_head32 = ConvModule(
in_channels=context_channels[0],
out_channels=context_channels[0],
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv_head16 = ConvModule(
in_channels=context_channels[0],
out_channels=context_channels[0],
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.gap_conv = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
ConvModule(
in_channels=context_channels[2],
out_channels=context_channels[0],
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def forward(self, x):
x_4, x_8, x_16, x_32 = self.backbone(x)
x_gap = self.gap_conv(x_32)
x_32_arm = self.arm32(x_32)
x_32_sum = x_32_arm + x_gap
x_32_up = resize(input=x_32_sum, size=x_16.shape[2:], mode='nearest')
x_32_up = self.conv_head32(x_32_up)
x_16_arm = self.arm16(x_16)
x_16_sum = x_16_arm + x_32_up
x_16_up = resize(input=x_16_sum, size=x_8.shape[2:], mode='nearest')
x_16_up = self.conv_head16(x_16_up)
return x_16_up, x_32_up
class FeatureFusionModule(BaseModule):
"""Feature Fusion Module to fuse low level output feature of Spatial Path
and high level output feature of Context Path.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
Returns:
x_out (torch.Tensor): Feature map of Feature Fusion Module.
"""
def __init__(self,
in_channels,
out_channels,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.conv1 = ConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.gap = nn.AdaptiveAvgPool2d((1, 1))
self.conv_atten = nn.Sequential(
ConvModule(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg), nn.Sigmoid())
def forward(self, x_sp, x_cp):
x_concat = torch.cat([x_sp, x_cp], dim=1)
x_fuse = self.conv1(x_concat)
x_atten = self.gap(x_fuse)
# Note: No BN and more 1x1 conv in paper.
x_atten = self.conv_atten(x_atten)
x_atten = x_fuse * x_atten
x_out = x_atten + x_fuse
return x_out
@MODELS.register_module()
class BiSeNetV1(BaseModule):
"""BiSeNetV1 backbone.
This backbone is the implementation of `BiSeNet: Bilateral
Segmentation Network for Real-time Semantic
Segmentation <https://arxiv.org/abs/1808.00897>`_.
Args:
backbone_cfg:(dict): Config of backbone of
Context Path.
in_channels (int): The number of channels of input
image. Default: 3.
spatial_channels (Tuple[int]): Size of channel numbers of
various layers in Spatial Path.
Default: (64, 64, 64, 128).
context_channels (Tuple[int]): Size of channel numbers of
various modules in Context Path.
Default: (128, 256, 512).
out_indices (Tuple[int] | int, optional): Output from which stages.
Default: (0, 1, 2).
align_corners (bool, optional): The align_corners argument of
resize operation in Bilateral Guided Aggregation Layer.
Default: False.
out_channels(int): The number of channels of output.
It must be the same with `in_channels` of decode_head.
Default: 256.
"""
def __init__(self,
backbone_cfg,
in_channels=3,
spatial_channels=(64, 64, 64, 128),
context_channels=(128, 256, 512),
out_indices=(0, 1, 2),
align_corners=False,
out_channels=256,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
assert len(spatial_channels) == 4, 'Length of input channels \
of Spatial Path must be 4!'
assert len(context_channels) == 3, 'Length of input channels \
of Context Path must be 3!'
self.out_indices = out_indices
self.align_corners = align_corners
self.context_path = ContextPath(backbone_cfg, context_channels,
self.align_corners)
self.spatial_path = SpatialPath(in_channels, spatial_channels)
self.ffm = FeatureFusionModule(context_channels[1], out_channels)
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
def forward(self, x):
# stole refactoring code from Coin Cheung, thanks
x_context8, x_context16 = self.context_path(x)
x_spatial = self.spatial_path(x)
x_fuse = self.ffm(x_spatial, x_context8)
outs = [x_fuse, x_context8, x_context16]
outs = [outs[i] for i in self.out_indices]
return tuple(outs)

View File

@@ -0,0 +1,622 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
build_activation_layer, build_norm_layer)
from mmengine.model import BaseModule
from mmseg.registry import MODELS
from ..utils import resize
class DetailBranch(BaseModule):
"""Detail Branch with wide channels and shallow layers to capture low-level
details and generate high-resolution feature representation.
Args:
detail_channels (Tuple[int]): Size of channel numbers of each stage
in Detail Branch, in paper it has 3 stages.
Default: (64, 64, 128).
in_channels (int): Number of channels of input image. Default: 3.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
x (torch.Tensor): Feature map of Detail Branch.
"""
def __init__(self,
detail_channels=(64, 64, 128),
in_channels=3,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
detail_branch = []
for i in range(len(detail_channels)):
if i == 0:
detail_branch.append(
nn.Sequential(
ConvModule(
in_channels=in_channels,
out_channels=detail_channels[i],
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
in_channels=detail_channels[i],
out_channels=detail_channels[i],
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)))
else:
detail_branch.append(
nn.Sequential(
ConvModule(
in_channels=detail_channels[i - 1],
out_channels=detail_channels[i],
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
in_channels=detail_channels[i],
out_channels=detail_channels[i],
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
in_channels=detail_channels[i],
out_channels=detail_channels[i],
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)))
self.detail_branch = nn.ModuleList(detail_branch)
def forward(self, x):
for stage in self.detail_branch:
x = stage(x)
return x
class StemBlock(BaseModule):
"""Stem Block at the beginning of Semantic Branch.
Args:
in_channels (int): Number of input channels.
Default: 3.
out_channels (int): Number of output channels.
Default: 16.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
x (torch.Tensor): First feature map in Semantic Branch.
"""
def __init__(self,
in_channels=3,
out_channels=16,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.conv_first = ConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.convs = nn.Sequential(
ConvModule(
in_channels=out_channels,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
in_channels=out_channels // 2,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.pool = nn.MaxPool2d(
kernel_size=3, stride=2, padding=1, ceil_mode=False)
self.fuse_last = ConvModule(
in_channels=out_channels * 2,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x):
x = self.conv_first(x)
x_left = self.convs(x)
x_right = self.pool(x)
x = self.fuse_last(torch.cat([x_left, x_right], dim=1))
return x
class GELayer(BaseModule):
"""Gather-and-Expansion Layer.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
exp_ratio (int): Expansion ratio for middle channels.
Default: 6.
stride (int): Stride of GELayer. Default: 1
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
x (torch.Tensor): Intermediate feature map in
Semantic Branch.
"""
def __init__(self,
in_channels,
out_channels,
exp_ratio=6,
stride=1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
mid_channel = in_channels * exp_ratio
self.conv1 = ConvModule(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if stride == 1:
self.dwconv = nn.Sequential(
# ReLU in ConvModule not shown in paper
ConvModule(
in_channels=in_channels,
out_channels=mid_channel,
kernel_size=3,
stride=stride,
padding=1,
groups=in_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.shortcut = None
else:
self.dwconv = nn.Sequential(
ConvModule(
in_channels=in_channels,
out_channels=mid_channel,
kernel_size=3,
stride=stride,
padding=1,
groups=in_channels,
bias=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None),
# ReLU in ConvModule not shown in paper
ConvModule(
in_channels=mid_channel,
out_channels=mid_channel,
kernel_size=3,
stride=1,
padding=1,
groups=mid_channel,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
)
self.shortcut = nn.Sequential(
DepthwiseSeparableConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
padding=1,
dw_norm_cfg=norm_cfg,
dw_act_cfg=None,
pw_norm_cfg=norm_cfg,
pw_act_cfg=None,
))
self.conv2 = nn.Sequential(
ConvModule(
in_channels=mid_channel,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None,
))
self.act = build_activation_layer(act_cfg)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.dwconv(x)
x = self.conv2(x)
if self.shortcut is not None:
shortcut = self.shortcut(identity)
x = x + shortcut
else:
x = x + identity
x = self.act(x)
return x
class CEBlock(BaseModule):
"""Context Embedding Block for large receptive filed in Semantic Branch.
Args:
in_channels (int): Number of input channels.
Default: 3.
out_channels (int): Number of output channels.
Default: 16.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
x (torch.Tensor): Last feature map in Semantic Branch.
"""
def __init__(self,
in_channels=3,
out_channels=16,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.gap = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
build_norm_layer(norm_cfg, self.in_channels)[1])
self.conv_gap = ConvModule(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
# Note: in paper here is naive conv2d, no bn-relu
self.conv_last = ConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x):
identity = x
x = self.gap(x)
x = self.conv_gap(x)
x = identity + x
x = self.conv_last(x)
return x
class SemanticBranch(BaseModule):
"""Semantic Branch which is lightweight with narrow channels and deep
layers to obtain high-level semantic context.
Args:
semantic_channels(Tuple[int]): Size of channel numbers of
various stages in Semantic Branch.
Default: (16, 32, 64, 128).
in_channels (int): Number of channels of input image. Default: 3.
exp_ratio (int): Expansion ratio for middle channels.
Default: 6.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
semantic_outs (List[torch.Tensor]): List of several feature maps
for auxiliary heads (Booster) and Bilateral
Guided Aggregation Layer.
"""
def __init__(self,
semantic_channels=(16, 32, 64, 128),
in_channels=3,
exp_ratio=6,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.semantic_channels = semantic_channels
self.semantic_stages = []
for i in range(len(semantic_channels)):
stage_name = f'stage{i + 1}'
self.semantic_stages.append(stage_name)
if i == 0:
self.add_module(
stage_name,
StemBlock(self.in_channels, semantic_channels[i]))
elif i == (len(semantic_channels) - 1):
self.add_module(
stage_name,
nn.Sequential(
GELayer(semantic_channels[i - 1], semantic_channels[i],
exp_ratio, 2),
GELayer(semantic_channels[i], semantic_channels[i],
exp_ratio, 1),
GELayer(semantic_channels[i], semantic_channels[i],
exp_ratio, 1),
GELayer(semantic_channels[i], semantic_channels[i],
exp_ratio, 1)))
else:
self.add_module(
stage_name,
nn.Sequential(
GELayer(semantic_channels[i - 1], semantic_channels[i],
exp_ratio, 2),
GELayer(semantic_channels[i], semantic_channels[i],
exp_ratio, 1)))
self.add_module(f'stage{len(semantic_channels)}_CEBlock',
CEBlock(semantic_channels[-1], semantic_channels[-1]))
self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock')
def forward(self, x):
semantic_outs = []
for stage_name in self.semantic_stages:
semantic_stage = getattr(self, stage_name)
x = semantic_stage(x)
semantic_outs.append(x)
return semantic_outs
class BGALayer(BaseModule):
"""Bilateral Guided Aggregation Layer to fuse the complementary information
from both Detail Branch and Semantic Branch.
Args:
out_channels (int): Number of output channels.
Default: 128.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
output (torch.Tensor): Output feature map for Segment heads.
"""
def __init__(self,
out_channels=128,
align_corners=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.out_channels = out_channels
self.align_corners = align_corners
self.detail_dwconv = nn.Sequential(
DepthwiseSeparableConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
dw_norm_cfg=norm_cfg,
dw_act_cfg=None,
pw_norm_cfg=None,
pw_act_cfg=None,
))
self.detail_down = nn.Sequential(
ConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None),
nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False))
self.semantic_conv = nn.Sequential(
ConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None))
self.semantic_dwconv = nn.Sequential(
DepthwiseSeparableConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
dw_norm_cfg=norm_cfg,
dw_act_cfg=None,
pw_norm_cfg=None,
pw_act_cfg=None,
))
self.conv = ConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
inplace=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
def forward(self, x_d, x_s):
detail_dwconv = self.detail_dwconv(x_d)
detail_down = self.detail_down(x_d)
semantic_conv = self.semantic_conv(x_s)
semantic_dwconv = self.semantic_dwconv(x_s)
semantic_conv = resize(
input=semantic_conv,
size=detail_dwconv.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv)
fuse_2 = detail_down * torch.sigmoid(semantic_dwconv)
fuse_2 = resize(
input=fuse_2,
size=fuse_1.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
output = self.conv(fuse_1 + fuse_2)
return output
@MODELS.register_module()
class BiSeNetV2(BaseModule):
"""BiSeNetV2: Bilateral Network with Guided Aggregation for
Real-time Semantic Segmentation.
This backbone is the implementation of
`BiSeNetV2 <https://arxiv.org/abs/2004.02147>`_.
Args:
in_channels (int): Number of channel of input image. Default: 3.
detail_channels (Tuple[int], optional): Channels of each stage
in Detail Branch. Default: (64, 64, 128).
semantic_channels (Tuple[int], optional): Channels of each stage
in Semantic Branch. Default: (16, 32, 64, 128).
See Table 1 and Figure 3 of paper for more details.
semantic_expansion_ratio (int, optional): The expansion factor
expanding channel number of middle channels in Semantic Branch.
Default: 6.
bga_channels (int, optional): Number of middle channels in
Bilateral Guided Aggregation Layer. Default: 128.
out_indices (Tuple[int] | int, optional): Output from which stages.
Default: (0, 1, 2, 3, 4).
align_corners (bool, optional): The align_corners argument of
resize operation in Bilateral Guided Aggregation Layer.
Default: False.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels=3,
detail_channels=(64, 64, 128),
semantic_channels=(16, 32, 64, 128),
semantic_expansion_ratio=6,
bga_channels=128,
out_indices=(0, 1, 2, 3, 4),
align_corners=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
if init_cfg is None:
init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
]
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_indices = out_indices
self.detail_channels = detail_channels
self.semantic_channels = semantic_channels
self.semantic_expansion_ratio = semantic_expansion_ratio
self.bga_channels = bga_channels
self.align_corners = align_corners
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.detail = DetailBranch(self.detail_channels, self.in_channels)
self.semantic = SemanticBranch(self.semantic_channels,
self.in_channels,
self.semantic_expansion_ratio)
self.bga = BGALayer(self.bga_channels, self.align_corners)
def forward(self, x):
# stole refactoring code from Coin Cheung, thanks
x_detail = self.detail(x)
x_semantic_lst = self.semantic(x)
x_head = self.bga(x_detail, x_semantic_lst[-1])
outs = [x_head] + x_semantic_lst[:-1]
outs = [outs[i] for i in self.out_indices]
return tuple(outs)

View File

@@ -0,0 +1,372 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
from mmengine.model import BaseModule
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmseg.registry import MODELS
class GlobalContextExtractor(nn.Module):
"""Global Context Extractor for CGNet.
This class is employed to refine the joint feature of both local feature
and surrounding context.
Args:
channel (int): Number of input feature channels.
reduction (int): Reductions for global context extractor. Default: 16.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self, channel, reduction=16, with_cp=False):
super().__init__()
self.channel = channel
self.reduction = reduction
assert reduction >= 1 and channel >= reduction
self.with_cp = with_cp
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel), nn.Sigmoid())
def forward(self, x):
def _inner_forward(x):
num_batch, num_channel = x.size()[:2]
y = self.avg_pool(x).view(num_batch, num_channel)
y = self.fc(y).view(num_batch, num_channel, 1, 1)
return x * y
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
class ContextGuidedBlock(nn.Module):
"""Context Guided Block for CGNet.
This class consists of four components: local feature extractor,
surrounding feature extractor, joint feature extractor and global
context extractor.
Args:
in_channels (int): Number of input feature channels.
out_channels (int): Number of output feature channels.
dilation (int): Dilation rate for surrounding context extractor.
Default: 2.
reduction (int): Reduction for global context extractor. Default: 16.
skip_connect (bool): Add input to output or not. Default: True.
downsample (bool): Downsample the input to 1/2 or not. Default: False.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='PReLU').
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self,
in_channels,
out_channels,
dilation=2,
reduction=16,
skip_connect=True,
downsample=False,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='PReLU'),
with_cp=False):
super().__init__()
self.with_cp = with_cp
self.downsample = downsample
channels = out_channels if downsample else out_channels // 2
if 'type' in act_cfg and act_cfg['type'] == 'PReLU':
act_cfg['num_parameters'] = channels
kernel_size = 3 if downsample else 1
stride = 2 if downsample else 1
padding = (kernel_size - 1) // 2
self.conv1x1 = ConvModule(
in_channels,
channels,
kernel_size,
stride,
padding,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.f_loc = build_conv_layer(
conv_cfg,
channels,
channels,
kernel_size=3,
padding=1,
groups=channels,
bias=False)
self.f_sur = build_conv_layer(
conv_cfg,
channels,
channels,
kernel_size=3,
padding=dilation,
groups=channels,
dilation=dilation,
bias=False)
self.bn = build_norm_layer(norm_cfg, 2 * channels)[1]
self.activate = nn.PReLU(2 * channels)
if downsample:
self.bottleneck = build_conv_layer(
conv_cfg,
2 * channels,
out_channels,
kernel_size=1,
bias=False)
self.skip_connect = skip_connect and not downsample
self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp)
def forward(self, x):
def _inner_forward(x):
out = self.conv1x1(x)
loc = self.f_loc(out)
sur = self.f_sur(out)
joi_feat = torch.cat([loc, sur], 1) # the joint feature
joi_feat = self.bn(joi_feat)
joi_feat = self.activate(joi_feat)
if self.downsample:
joi_feat = self.bottleneck(joi_feat) # channel = out_channels
# f_glo is employed to refine the joint feature
out = self.f_glo(joi_feat)
if self.skip_connect:
return x + out
else:
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
class InputInjection(nn.Module):
"""Downsampling module for CGNet."""
def __init__(self, num_downsampling):
super().__init__()
self.pool = nn.ModuleList()
for i in range(num_downsampling):
self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
def forward(self, x):
for pool in self.pool:
x = pool(x)
return x
@MODELS.register_module()
class CGNet(BaseModule):
"""CGNet backbone.
This backbone is the implementation of `A Light-weight Context Guided
Network for Semantic Segmentation <https://arxiv.org/abs/1811.08201>`_.
Args:
in_channels (int): Number of input image channels. Normally 3.
num_channels (tuple[int]): Numbers of feature channels at each stages.
Default: (32, 64, 128).
num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2.
Default: (3, 21).
dilations (tuple[int]): Dilation rate for surrounding context
extractors at stage 1 and stage 2. Default: (2, 4).
reductions (tuple[int]): Reductions for global context extractors at
stage 1 and stage 2. Default: (8, 16).
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='PReLU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
def __init__(self,
in_channels=3,
num_channels=(32, 64, 128),
num_blocks=(3, 21),
dilations=(2, 4),
reductions=(8, 16),
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='PReLU'),
norm_eval=False,
with_cp=False,
pretrained=None,
init_cfg=None):
super().__init__(init_cfg)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer=['Conv2d', 'Linear']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm']),
dict(type='Constant', val=0, layer='PReLU')
]
else:
raise TypeError('pretrained must be a str or None')
self.in_channels = in_channels
self.num_channels = num_channels
assert isinstance(self.num_channels, tuple) and len(
self.num_channels) == 3
self.num_blocks = num_blocks
assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2
self.dilations = dilations
assert isinstance(self.dilations, tuple) and len(self.dilations) == 2
self.reductions = reductions
assert isinstance(self.reductions, tuple) and len(self.reductions) == 2
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU':
self.act_cfg['num_parameters'] = num_channels[0]
self.norm_eval = norm_eval
self.with_cp = with_cp
cur_channels = in_channels
self.stem = nn.ModuleList()
for i in range(3):
self.stem.append(
ConvModule(
cur_channels,
num_channels[0],
3,
2 if i == 0 else 1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
cur_channels = num_channels[0]
self.inject_2x = InputInjection(1) # down-sample for Input, factor=2
self.inject_4x = InputInjection(2) # down-sample for Input, factor=4
cur_channels += in_channels
self.norm_prelu_0 = nn.Sequential(
build_norm_layer(norm_cfg, cur_channels)[1],
nn.PReLU(cur_channels))
# stage 1
self.level1 = nn.ModuleList()
for i in range(num_blocks[0]):
self.level1.append(
ContextGuidedBlock(
cur_channels if i == 0 else num_channels[1],
num_channels[1],
dilations[0],
reductions[0],
downsample=(i == 0),
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
with_cp=with_cp)) # CG block
cur_channels = 2 * num_channels[1] + in_channels
self.norm_prelu_1 = nn.Sequential(
build_norm_layer(norm_cfg, cur_channels)[1],
nn.PReLU(cur_channels))
# stage 2
self.level2 = nn.ModuleList()
for i in range(num_blocks[1]):
self.level2.append(
ContextGuidedBlock(
cur_channels if i == 0 else num_channels[2],
num_channels[2],
dilations[1],
reductions[1],
downsample=(i == 0),
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
with_cp=with_cp)) # CG block
cur_channels = 2 * num_channels[2]
self.norm_prelu_2 = nn.Sequential(
build_norm_layer(norm_cfg, cur_channels)[1],
nn.PReLU(cur_channels))
def forward(self, x):
output = []
# stage 0
inp_2x = self.inject_2x(x)
inp_4x = self.inject_4x(x)
for layer in self.stem:
x = layer(x)
x = self.norm_prelu_0(torch.cat([x, inp_2x], 1))
output.append(x)
# stage 1
for i, layer in enumerate(self.level1):
x = layer(x)
if i == 0:
down1 = x
x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1))
output.append(x)
# stage 2
for i, layer in enumerate(self.level2):
x = layer(x)
if i == 0:
down2 = x
x = self.norm_prelu_2(torch.cat([down2, x], 1))
output.append(x)
return output
def train(self, mode=True):
"""Convert the model into training mode will keeping the normalization
layer freezed."""
super().train(mode)
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()

View File

@@ -0,0 +1,222 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer
from mmengine.model import BaseModule
from mmseg.models.utils import DAPPM, BasicBlock, Bottleneck, resize
from mmseg.registry import MODELS
from mmseg.utils import OptConfigType
@MODELS.register_module()
class DDRNet(BaseModule):
"""DDRNet backbone.
This backbone is the implementation of `Deep Dual-resolution Networks for
Real-time and Accurate Semantic Segmentation of Road Scenes
<http://arxiv.org/abs/2101.06085>`_.
Modified from https://github.com/ydhongHIT/DDRNet.
Args:
in_channels (int): Number of input image channels. Default: 3.
channels: (int): The base channels of DDRNet. Default: 32.
ppm_channels (int): The channels of PPM module. Default: 128.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
norm_cfg (dict): Config dict to build norm layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU', inplace=True).
init_cfg (dict, optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels: int = 3,
channels: int = 32,
ppm_channels: int = 128,
align_corners: bool = False,
norm_cfg: OptConfigType = dict(type='BN', requires_grad=True),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
init_cfg: OptConfigType = None):
super().__init__(init_cfg)
self.in_channels = in_channels
self.ppm_channels = ppm_channels
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.align_corners = align_corners
# stage 0-2
self.stem = self._make_stem_layer(in_channels, channels, num_blocks=2)
self.relu = nn.ReLU()
# low resolution(context) branch
self.context_branch_layers = nn.ModuleList()
for i in range(3):
self.context_branch_layers.append(
self._make_layer(
block=BasicBlock if i < 2 else Bottleneck,
inplanes=channels * 2**(i + 1),
planes=channels * 8 if i > 0 else channels * 4,
num_blocks=2 if i < 2 else 1,
stride=2))
# bilateral fusion
self.compression_1 = ConvModule(
channels * 4,
channels * 2,
kernel_size=1,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.down_1 = ConvModule(
channels * 2,
channels * 4,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.compression_2 = ConvModule(
channels * 8,
channels * 2,
kernel_size=1,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.down_2 = nn.Sequential(
ConvModule(
channels * 2,
channels * 4,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
channels * 4,
channels * 8,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=None))
# high resolution(spatial) branch
self.spatial_branch_layers = nn.ModuleList()
for i in range(3):
self.spatial_branch_layers.append(
self._make_layer(
block=BasicBlock if i < 2 else Bottleneck,
inplanes=channels * 2,
planes=channels * 2,
num_blocks=2 if i < 2 else 1,
))
self.spp = DAPPM(
channels * 16, ppm_channels, channels * 4, num_scales=5)
def _make_stem_layer(self, in_channels, channels, num_blocks):
layers = [
ConvModule(
in_channels,
channels,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
channels,
channels,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
]
layers.extend([
self._make_layer(BasicBlock, channels, channels, num_blocks),
nn.ReLU(),
self._make_layer(
BasicBlock, channels, channels * 2, num_blocks, stride=2),
nn.ReLU(),
])
return nn.Sequential(*layers)
def _make_layer(self, block, inplanes, planes, num_blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
layers = [
block(
in_channels=inplanes,
channels=planes,
stride=stride,
downsample=downsample)
]
inplanes = planes * block.expansion
for i in range(1, num_blocks):
layers.append(
block(
in_channels=inplanes,
channels=planes,
stride=1,
norm_cfg=self.norm_cfg,
act_cfg_out=None if i == num_blocks - 1 else self.act_cfg))
return nn.Sequential(*layers)
def forward(self, x):
"""Forward function."""
out_size = (x.shape[-2] // 8, x.shape[-1] // 8)
# stage 0-2
x = self.stem(x)
# stage3
x_c = self.context_branch_layers[0](x)
x_s = self.spatial_branch_layers[0](x)
comp_c = self.compression_1(self.relu(x_c))
x_c += self.down_1(self.relu(x_s))
x_s += resize(
comp_c,
size=out_size,
mode='bilinear',
align_corners=self.align_corners)
if self.training:
temp_context = x_s.clone()
# stage4
x_c = self.context_branch_layers[1](self.relu(x_c))
x_s = self.spatial_branch_layers[1](self.relu(x_s))
comp_c = self.compression_2(self.relu(x_c))
x_c += self.down_2(self.relu(x_s))
x_s += resize(
comp_c,
size=out_size,
mode='bilinear',
align_corners=self.align_corners)
# stage5
x_s = self.spatial_branch_layers[2](self.relu(x_s))
x_c = self.context_branch_layers[2](self.relu(x_c))
x_c = self.spp(x_c)
x_c = resize(
x_c,
size=out_size,
mode='bilinear',
align_corners=self.align_corners)
return (temp_context, x_s + x_c) if self.training else x_s + x_c

View File

@@ -0,0 +1,329 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
from mmengine.model import BaseModule
from mmseg.registry import MODELS
from ..utils import resize
class DownsamplerBlock(BaseModule):
"""Downsampler block of ERFNet.
This module is a little different from basical ConvModule.
The features from Conv and MaxPool layers are
concatenated before BatchNorm.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
conv_cfg=None,
norm_cfg=dict(type='BN', eps=1e-3),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.conv = build_conv_layer(
self.conv_cfg,
in_channels,
out_channels - in_channels,
kernel_size=3,
stride=2,
padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.bn = build_norm_layer(self.norm_cfg, out_channels)[1]
self.act = build_activation_layer(self.act_cfg)
def forward(self, input):
conv_out = self.conv(input)
pool_out = self.pool(input)
pool_out = resize(
input=pool_out,
size=conv_out.size()[2:],
mode='bilinear',
align_corners=False)
output = torch.cat([conv_out, pool_out], 1)
output = self.bn(output)
output = self.act(output)
return output
class NonBottleneck1d(BaseModule):
"""Non-bottleneck block of ERFNet.
Args:
channels (int): Number of channels in Non-bottleneck block.
drop_rate (float): Probability of an element to be zeroed.
Default 0.
dilation (int): Dilation rate for last two conv layers.
Default 1.
num_conv_layer (int): Number of 3x1 and 1x3 convolution layers.
Default 2.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
channels,
drop_rate=0,
dilation=1,
num_conv_layer=2,
conv_cfg=None,
norm_cfg=dict(type='BN', eps=1e-3),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.act = build_activation_layer(self.act_cfg)
self.convs_layers = nn.ModuleList()
for conv_layer in range(num_conv_layer):
first_conv_padding = (1, 0) if conv_layer == 0 else (dilation, 0)
first_conv_dilation = 1 if conv_layer == 0 else (dilation, 1)
second_conv_padding = (0, 1) if conv_layer == 0 else (0, dilation)
second_conv_dilation = 1 if conv_layer == 0 else (1, dilation)
self.convs_layers.append(
build_conv_layer(
self.conv_cfg,
channels,
channels,
kernel_size=(3, 1),
stride=1,
padding=first_conv_padding,
bias=True,
dilation=first_conv_dilation))
self.convs_layers.append(self.act)
self.convs_layers.append(
build_conv_layer(
self.conv_cfg,
channels,
channels,
kernel_size=(1, 3),
stride=1,
padding=second_conv_padding,
bias=True,
dilation=second_conv_dilation))
self.convs_layers.append(
build_norm_layer(self.norm_cfg, channels)[1])
if conv_layer == 0:
self.convs_layers.append(self.act)
else:
self.convs_layers.append(nn.Dropout(p=drop_rate))
def forward(self, input):
output = input
for conv in self.convs_layers:
output = conv(output)
output = self.act(output + input)
return output
class UpsamplerBlock(BaseModule):
"""Upsampler block of ERFNet.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
conv_cfg=None,
norm_cfg=dict(type='BN', eps=1e-3),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.conv = nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
bias=True)
self.bn = build_norm_layer(self.norm_cfg, out_channels)[1]
self.act = build_activation_layer(self.act_cfg)
def forward(self, input):
output = self.conv(input)
output = self.bn(output)
output = self.act(output)
return output
@MODELS.register_module()
class ERFNet(BaseModule):
"""ERFNet backbone.
This backbone is the implementation of `ERFNet: Efficient Residual
Factorized ConvNet for Real-time SemanticSegmentation
<https://ieeexplore.ieee.org/document/8063438>`_.
Args:
in_channels (int): The number of channels of input
image. Default: 3.
enc_downsample_channels (Tuple[int]): Size of channel
numbers of various Downsampler block in encoder.
Default: (16, 64, 128).
enc_stage_non_bottlenecks (Tuple[int]): Number of stages of
Non-bottleneck block in encoder.
Default: (5, 8).
enc_non_bottleneck_dilations (Tuple[int]): Dilation rate of each
stage of Non-bottleneck block of encoder.
Default: (2, 4, 8, 16).
enc_non_bottleneck_channels (Tuple[int]): Size of channel
numbers of various Non-bottleneck block in encoder.
Default: (64, 128).
dec_upsample_channels (Tuple[int]): Size of channel numbers of
various Deconvolution block in decoder.
Default: (64, 16).
dec_stages_non_bottleneck (Tuple[int]): Number of stages of
Non-bottleneck block in decoder.
Default: (2, 2).
dec_non_bottleneck_channels (Tuple[int]): Size of channel
numbers of various Non-bottleneck block in decoder.
Default: (64, 16).
drop_rate (float): Probability of an element to be zeroed.
Default 0.1.
"""
def __init__(self,
in_channels=3,
enc_downsample_channels=(16, 64, 128),
enc_stage_non_bottlenecks=(5, 8),
enc_non_bottleneck_dilations=(2, 4, 8, 16),
enc_non_bottleneck_channels=(64, 128),
dec_upsample_channels=(64, 16),
dec_stages_non_bottleneck=(2, 2),
dec_non_bottleneck_channels=(64, 16),
dropout_ratio=0.1,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
assert len(enc_downsample_channels) \
== len(dec_upsample_channels)+1, 'Number of downsample\
block of encoder does not \
match number of upsample block of decoder!'
assert len(enc_downsample_channels) \
== len(enc_stage_non_bottlenecks)+1, 'Number of \
downsample block of encoder does not match \
number of Non-bottleneck block of encoder!'
assert len(enc_downsample_channels) \
== len(enc_non_bottleneck_channels)+1, 'Number of \
downsample block of encoder does not match \
number of channels of Non-bottleneck block of encoder!'
assert enc_stage_non_bottlenecks[-1] \
% len(enc_non_bottleneck_dilations) == 0, 'Number of \
Non-bottleneck block of encoder does not match \
number of Non-bottleneck block of encoder!'
assert len(dec_upsample_channels) \
== len(dec_stages_non_bottleneck), 'Number of \
upsample block of decoder does not match \
number of Non-bottleneck block of decoder!'
assert len(dec_stages_non_bottleneck) \
== len(dec_non_bottleneck_channels), 'Number of \
Non-bottleneck block of decoder does not match \
number of channels of Non-bottleneck block of decoder!'
self.in_channels = in_channels
self.enc_downsample_channels = enc_downsample_channels
self.enc_stage_non_bottlenecks = enc_stage_non_bottlenecks
self.enc_non_bottleneck_dilations = enc_non_bottleneck_dilations
self.enc_non_bottleneck_channels = enc_non_bottleneck_channels
self.dec_upsample_channels = dec_upsample_channels
self.dec_stages_non_bottleneck = dec_stages_non_bottleneck
self.dec_non_bottleneck_channels = dec_non_bottleneck_channels
self.dropout_ratio = dropout_ratio
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.encoder.append(
DownsamplerBlock(self.in_channels, enc_downsample_channels[0]))
for i in range(len(enc_downsample_channels) - 1):
self.encoder.append(
DownsamplerBlock(enc_downsample_channels[i],
enc_downsample_channels[i + 1]))
# Last part of encoder is some dilated NonBottleneck1d blocks.
if i == len(enc_downsample_channels) - 2:
iteration_times = int(enc_stage_non_bottlenecks[-1] /
len(enc_non_bottleneck_dilations))
for j in range(iteration_times):
for k in range(len(enc_non_bottleneck_dilations)):
self.encoder.append(
NonBottleneck1d(enc_downsample_channels[-1],
self.dropout_ratio,
enc_non_bottleneck_dilations[k]))
else:
for j in range(enc_stage_non_bottlenecks[i]):
self.encoder.append(
NonBottleneck1d(enc_downsample_channels[i + 1],
self.dropout_ratio))
for i in range(len(dec_upsample_channels)):
if i == 0:
self.decoder.append(
UpsamplerBlock(enc_downsample_channels[-1],
dec_non_bottleneck_channels[i]))
else:
self.decoder.append(
UpsamplerBlock(dec_non_bottleneck_channels[i - 1],
dec_non_bottleneck_channels[i]))
for j in range(dec_stages_non_bottleneck[i]):
self.decoder.append(
NonBottleneck1d(dec_non_bottleneck_channels[i]))
def forward(self, x):
for enc in self.encoder:
x = enc(x)
for dec in self.decoder:
x = dec(x)
return [x]

View File

@@ -0,0 +1,408 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmengine.model import BaseModule
from mmseg.models.decode_heads.psp_head import PPM
from mmseg.registry import MODELS
from ..utils import InvertedResidual, resize
class LearningToDownsample(nn.Module):
"""Learning to downsample module.
Args:
in_channels (int): Number of input channels.
dw_channels (tuple[int]): Number of output channels of the first and
the second depthwise conv (dwconv) layers.
out_channels (int): Number of output channels of the whole
'learning to downsample' module.
conv_cfg (dict | None): Config of conv layers. Default: None
norm_cfg (dict | None): Config of norm layers. Default:
dict(type='BN')
act_cfg (dict): Config of activation layers. Default:
dict(type='ReLU')
dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config
of depthwise ConvModule. If it is 'default', it will be the same
as `act_cfg`. Default: None.
"""
def __init__(self,
in_channels,
dw_channels,
out_channels,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
dw_act_cfg=None):
super().__init__()
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.dw_act_cfg = dw_act_cfg
dw_channels1 = dw_channels[0]
dw_channels2 = dw_channels[1]
self.conv = ConvModule(
in_channels,
dw_channels1,
3,
stride=2,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.dsconv1 = DepthwiseSeparableConvModule(
dw_channels1,
dw_channels2,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
dw_act_cfg=self.dw_act_cfg)
self.dsconv2 = DepthwiseSeparableConvModule(
dw_channels2,
out_channels,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
dw_act_cfg=self.dw_act_cfg)
def forward(self, x):
x = self.conv(x)
x = self.dsconv1(x)
x = self.dsconv2(x)
return x
class GlobalFeatureExtractor(nn.Module):
"""Global feature extractor module.
Args:
in_channels (int): Number of input channels of the GFE module.
Default: 64
block_channels (tuple[int]): Tuple of ints. Each int specifies the
number of output channels of each Inverted Residual module.
Default: (64, 96, 128)
out_channels(int): Number of output channels of the GFE module.
Default: 128
expand_ratio (int): Adjusts number of channels of the hidden layer
in InvertedResidual by this amount.
Default: 6
num_blocks (tuple[int]): Tuple of ints. Each int specifies the
number of times each Inverted Residual module is repeated.
The repeated Inverted Residual modules are called a 'group'.
Default: (3, 3, 3)
strides (tuple[int]): Tuple of ints. Each int specifies
the downsampling factor of each 'group'.
Default: (2, 2, 1)
pool_scales (tuple[int]): Tuple of ints. Each int specifies
the parameter required in 'global average pooling' within PPM.
Default: (1, 2, 3, 6)
conv_cfg (dict | None): Config of conv layers. Default: None
norm_cfg (dict | None): Config of norm layers. Default:
dict(type='BN')
act_cfg (dict): Config of activation layers. Default:
dict(type='ReLU')
align_corners (bool): align_corners argument of F.interpolate.
Default: False
"""
def __init__(self,
in_channels=64,
block_channels=(64, 96, 128),
out_channels=128,
expand_ratio=6,
num_blocks=(3, 3, 3),
strides=(2, 2, 1),
pool_scales=(1, 2, 3, 6),
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False):
super().__init__()
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
assert len(block_channels) == len(num_blocks) == 3
self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
num_blocks[0], strides[0],
expand_ratio)
self.bottleneck2 = self._make_layer(block_channels[0],
block_channels[1], num_blocks[1],
strides[1], expand_ratio)
self.bottleneck3 = self._make_layer(block_channels[1],
block_channels[2], num_blocks[2],
strides[2], expand_ratio)
self.ppm = PPM(
pool_scales,
block_channels[2],
block_channels[2] // 4,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=align_corners)
self.out = ConvModule(
block_channels[2] * 2,
out_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def _make_layer(self,
in_channels,
out_channels,
blocks,
stride=1,
expand_ratio=6):
layers = [
InvertedResidual(
in_channels,
out_channels,
stride,
expand_ratio,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
]
for i in range(1, blocks):
layers.append(
InvertedResidual(
out_channels,
out_channels,
1,
expand_ratio,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
return nn.Sequential(*layers)
def forward(self, x):
x = self.bottleneck1(x)
x = self.bottleneck2(x)
x = self.bottleneck3(x)
x = torch.cat([x, *self.ppm(x)], dim=1)
x = self.out(x)
return x
class FeatureFusionModule(nn.Module):
"""Feature fusion module.
Args:
higher_in_channels (int): Number of input channels of the
higher-resolution branch.
lower_in_channels (int): Number of input channels of the
lower-resolution branch.
out_channels (int): Number of output channels.
conv_cfg (dict | None): Config of conv layers. Default: None
norm_cfg (dict | None): Config of norm layers. Default:
dict(type='BN')
dwconv_act_cfg (dict): Config of activation layers in 3x3 conv.
Default: dict(type='ReLU').
conv_act_cfg (dict): Config of activation layers in the two 1x1 conv.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
"""
def __init__(self,
higher_in_channels,
lower_in_channels,
out_channels,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dwconv_act_cfg=dict(type='ReLU'),
conv_act_cfg=None,
align_corners=False):
super().__init__()
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.dwconv_act_cfg = dwconv_act_cfg
self.conv_act_cfg = conv_act_cfg
self.align_corners = align_corners
self.dwconv = ConvModule(
lower_in_channels,
out_channels,
3,
padding=1,
groups=out_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.dwconv_act_cfg)
self.conv_lower_res = ConvModule(
out_channels,
out_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.conv_act_cfg)
self.conv_higher_res = ConvModule(
higher_in_channels,
out_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.conv_act_cfg)
self.relu = nn.ReLU(True)
def forward(self, higher_res_feature, lower_res_feature):
lower_res_feature = resize(
lower_res_feature,
size=higher_res_feature.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
lower_res_feature = self.dwconv(lower_res_feature)
lower_res_feature = self.conv_lower_res(lower_res_feature)
higher_res_feature = self.conv_higher_res(higher_res_feature)
out = higher_res_feature + lower_res_feature
return self.relu(out)
@MODELS.register_module()
class FastSCNN(BaseModule):
"""Fast-SCNN Backbone.
This backbone is the implementation of `Fast-SCNN: Fast Semantic
Segmentation Network <https://arxiv.org/abs/1902.04502>`_.
Args:
in_channels (int): Number of input image channels. Default: 3.
downsample_dw_channels (tuple[int]): Number of output channels after
the first conv layer & the second conv layer in
Learning-To-Downsample (LTD) module.
Default: (32, 48).
global_in_channels (int): Number of input channels of
Global Feature Extractor(GFE).
Equal to number of output channels of LTD.
Default: 64.
global_block_channels (tuple[int]): Tuple of integers that describe
the output channels for each of the MobileNet-v2 bottleneck
residual blocks in GFE.
Default: (64, 96, 128).
global_block_strides (tuple[int]): Tuple of integers
that describe the strides (downsampling factors) for each of the
MobileNet-v2 bottleneck residual blocks in GFE.
Default: (2, 2, 1).
global_out_channels (int): Number of output channels of GFE.
Default: 128.
higher_in_channels (int): Number of input channels of the higher
resolution branch in FFM.
Equal to global_in_channels.
Default: 64.
lower_in_channels (int): Number of input channels of the lower
resolution branch in FFM.
Equal to global_out_channels.
Default: 128.
fusion_out_channels (int): Number of output channels of FFM.
Default: 128.
out_indices (tuple): Tuple of indices of list
[higher_res_features, lower_res_features, fusion_output].
Often set to (0,1,2) to enable aux. heads.
Default: (0, 1, 2).
conv_cfg (dict | None): Config of conv layers. Default: None
norm_cfg (dict | None): Config of norm layers. Default:
dict(type='BN')
act_cfg (dict): Config of activation layers. Default:
dict(type='ReLU')
align_corners (bool): align_corners argument of F.interpolate.
Default: False
dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config
of depthwise ConvModule. If it is 'default', it will be the same
as `act_cfg`. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
def __init__(self,
in_channels=3,
downsample_dw_channels=(32, 48),
global_in_channels=64,
global_block_channels=(64, 96, 128),
global_block_strides=(2, 2, 1),
global_out_channels=128,
higher_in_channels=64,
lower_in_channels=128,
fusion_out_channels=128,
out_indices=(0, 1, 2),
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False,
dw_act_cfg=None,
init_cfg=None):
super().__init__(init_cfg)
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
]
if global_in_channels != higher_in_channels:
raise AssertionError('Global Input Channels must be the same \
with Higher Input Channels!')
elif global_out_channels != lower_in_channels:
raise AssertionError('Global Output Channels must be the same \
with Lower Input Channels!')
self.in_channels = in_channels
self.downsample_dw_channels1 = downsample_dw_channels[0]
self.downsample_dw_channels2 = downsample_dw_channels[1]
self.global_in_channels = global_in_channels
self.global_block_channels = global_block_channels
self.global_block_strides = global_block_strides
self.global_out_channels = global_out_channels
self.higher_in_channels = higher_in_channels
self.lower_in_channels = lower_in_channels
self.fusion_out_channels = fusion_out_channels
self.out_indices = out_indices
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.align_corners = align_corners
self.learning_to_downsample = LearningToDownsample(
in_channels,
downsample_dw_channels,
global_in_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
dw_act_cfg=dw_act_cfg)
self.global_feature_extractor = GlobalFeatureExtractor(
global_in_channels,
global_block_channels,
global_out_channels,
strides=self.global_block_strides,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.feature_fusion = FeatureFusionModule(
higher_in_channels,
lower_in_channels,
fusion_out_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
dwconv_act_cfg=self.act_cfg,
align_corners=self.align_corners)
def forward(self, x):
higher_res_features = self.learning_to_downsample(x)
lower_res_features = self.global_feature_extractor(higher_res_features)
fusion_output = self.feature_fusion(higher_res_features,
lower_res_features)
outs = [higher_res_features, lower_res_features, fusion_output]
outs = [outs[i] for i in self.out_indices]
return tuple(outs)

View File

@@ -0,0 +1,642 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch.nn as nn
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmengine.model import BaseModule, ModuleList, Sequential
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmseg.registry import MODELS
from ..utils import Upsample, resize
from .resnet import BasicBlock, Bottleneck
class HRModule(BaseModule):
"""High-Resolution Module for HRNet.
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
is in this module.
"""
def __init__(self,
num_branches,
blocks,
num_blocks,
in_channels,
num_channels,
multiscale_output=True,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
block_init_cfg=None,
init_cfg=None):
super().__init__(init_cfg)
self.block_init_cfg = block_init_cfg
self._check_branches(num_branches, num_blocks, in_channels,
num_channels)
self.in_channels = in_channels
self.num_branches = num_branches
self.multiscale_output = multiscale_output
self.norm_cfg = norm_cfg
self.conv_cfg = conv_cfg
self.with_cp = with_cp
self.branches = self._make_branches(num_branches, blocks, num_blocks,
num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(inplace=False)
def _check_branches(self, num_branches, num_blocks, in_channels,
num_channels):
"""Check branches configuration."""
if num_branches != len(num_blocks):
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \
f'{len(num_blocks)})'
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \
f'{len(num_channels)})'
raise ValueError(error_msg)
if num_branches != len(in_channels):
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \
f'{len(in_channels)})'
raise ValueError(error_msg)
def _make_one_branch(self,
branch_index,
block,
num_blocks,
num_channels,
stride=1):
"""Build one branch."""
downsample = None
if stride != 1 or \
self.in_channels[branch_index] != \
num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
build_conv_layer(
self.conv_cfg,
self.in_channels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(self.norm_cfg, num_channels[branch_index] *
block.expansion)[1])
layers = []
layers.append(
block(
self.in_channels[branch_index],
num_channels[branch_index],
stride,
downsample=downsample,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg,
init_cfg=self.block_init_cfg))
self.in_channels[branch_index] = \
num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(
block(
self.in_channels[branch_index],
num_channels[branch_index],
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg,
init_cfg=self.block_init_cfg))
return Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
"""Build multiple branch."""
branches = []
for i in range(num_branches):
branches.append(
self._make_one_branch(i, block, num_blocks, num_channels))
return ModuleList(branches)
def _make_fuse_layers(self):
"""Build fuse layer."""
if self.num_branches == 1:
return None
num_branches = self.num_branches
in_channels = self.in_channels
fuse_layers = []
num_out_branches = num_branches if self.multiscale_output else 1
for i in range(num_out_branches):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=1,
stride=1,
padding=0,
bias=False),
build_norm_layer(self.norm_cfg, in_channels[i])[1],
# we set align_corners=False for HRNet
Upsample(
scale_factor=2**(j - i),
mode='bilinear',
align_corners=False)))
elif j == i:
fuse_layer.append(None)
else:
conv_downsamples = []
for k in range(i - j):
if k == i - j - 1:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[i])[1]))
else:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[j],
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[j])[1],
nn.ReLU(inplace=False)))
fuse_layer.append(nn.Sequential(*conv_downsamples))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def forward(self, x):
"""Forward function."""
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = 0
for j in range(self.num_branches):
if i == j:
y += x[j]
elif j > i:
y = y + resize(
self.fuse_layers[i][j](x[j]),
size=x[i].shape[2:],
mode='bilinear',
align_corners=False)
else:
y += self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
return x_fuse
@MODELS.register_module()
class HRNet(BaseModule):
"""HRNet backbone.
This backbone is the implementation of `High-Resolution Representations
for Labeling Pixels and Regions <https://arxiv.org/abs/1904.04514>`_.
Args:
extra (dict): Detailed configuration for each stage of HRNet.
There must be 4 stages, the configuration for each stage must have
5 keys:
- num_modules (int): The number of HRModule in this stage.
- num_branches (int): The number of branches in the HRModule.
- block (str): The type of convolution block.
- num_blocks (tuple): The number of blocks in each branch.
The length must be equal to num_branches.
- num_channels (tuple): The number of channels in each branch.
The length must be equal to num_branches.
in_channels (int): Number of input image channels. Normally 3.
conv_cfg (dict): Dictionary to construct and config conv layer.
Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Use `BN` by default.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: False.
multiscale_output (bool): Whether to output multi-level features
produced by multiple branches. If False, only the first level
feature will be output. Default: True.
pretrained (str, optional): Model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Example:
>>> from mmseg.models import HRNet
>>> import torch
>>> extra = dict(
>>> stage1=dict(
>>> num_modules=1,
>>> num_branches=1,
>>> block='BOTTLENECK',
>>> num_blocks=(4, ),
>>> num_channels=(64, )),
>>> stage2=dict(
>>> num_modules=1,
>>> num_branches=2,
>>> block='BASIC',
>>> num_blocks=(4, 4),
>>> num_channels=(32, 64)),
>>> stage3=dict(
>>> num_modules=4,
>>> num_branches=3,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4),
>>> num_channels=(32, 64, 128)),
>>> stage4=dict(
>>> num_modules=3,
>>> num_branches=4,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4, 4),
>>> num_channels=(32, 64, 128, 256)))
>>> self = HRNet(extra, in_channels=1)
>>> self.eval()
>>> inputs = torch.rand(1, 1, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 32, 8, 8)
(1, 64, 4, 4)
(1, 128, 2, 2)
(1, 256, 1, 1)
"""
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
def __init__(self,
extra,
in_channels=3,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
with_cp=False,
frozen_stages=-1,
zero_init_residual=False,
multiscale_output=True,
pretrained=None,
init_cfg=None):
super().__init__(init_cfg)
self.pretrained = pretrained
self.zero_init_residual = zero_init_residual
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')
# Assert configurations of 4 stages are in extra
assert 'stage1' in extra and 'stage2' in extra \
and 'stage3' in extra and 'stage4' in extra
# Assert whether the length of `num_blocks` and `num_channels` are
# equal to `num_branches`
for i in range(4):
cfg = extra[f'stage{i + 1}']
assert len(cfg['num_blocks']) == cfg['num_branches'] and \
len(cfg['num_channels']) == cfg['num_branches']
self.extra = extra
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.frozen_stages = frozen_stages
# stem net
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
64,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
self.conv_cfg,
64,
64,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True)
# stage 1
self.stage1_cfg = self.extra['stage1']
num_channels = self.stage1_cfg['num_channels'][0]
block_type = self.stage1_cfg['block']
num_blocks = self.stage1_cfg['num_blocks'][0]
block = self.blocks_dict[block_type]
stage1_out_channels = num_channels * block.expansion
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
# stage 2
self.stage2_cfg = self.extra['stage2']
num_channels = self.stage2_cfg['num_channels']
block_type = self.stage2_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition1 = self._make_transition_layer([stage1_out_channels],
num_channels)
self.stage2, pre_stage_channels = self._make_stage(
self.stage2_cfg, num_channels)
# stage 3
self.stage3_cfg = self.extra['stage3']
num_channels = self.stage3_cfg['num_channels']
block_type = self.stage3_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition2 = self._make_transition_layer(pre_stage_channels,
num_channels)
self.stage3, pre_stage_channels = self._make_stage(
self.stage3_cfg, num_channels)
# stage 4
self.stage4_cfg = self.extra['stage4']
num_channels = self.stage4_cfg['num_channels']
block_type = self.stage4_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition3 = self._make_transition_layer(pre_stage_channels,
num_channels)
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels, multiscale_output=multiscale_output)
self._freeze_stages()
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)
@property
def norm2(self):
"""nn.Module: the normalization layer named "norm2" """
return getattr(self, self.norm2_name)
def _make_transition_layer(self, num_channels_pre_layer,
num_channels_cur_layer):
"""Make transition layer."""
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
num_channels_pre_layer[i],
num_channels_cur_layer[i],
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
num_channels_cur_layer[i])[1],
nn.ReLU(inplace=True)))
else:
transition_layers.append(None)
else:
conv_downsamples = []
for j in range(i + 1 - num_branches_pre):
in_channels = num_channels_pre_layer[-1]
out_channels = num_channels_cur_layer[i] \
if j == i - num_branches_pre else in_channels
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, out_channels)[1],
nn.ReLU(inplace=True)))
transition_layers.append(nn.Sequential(*conv_downsamples))
return nn.ModuleList(transition_layers)
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
"""Make each layer."""
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
build_conv_layer(
self.conv_cfg,
inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
layers = []
block_init_cfg = None
if self.pretrained is None and not hasattr(
self, 'init_cfg') and self.zero_init_residual:
if block is BasicBlock:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm2'))
elif block is Bottleneck:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm3'))
layers.append(
block(
inplanes,
planes,
stride,
downsample=downsample,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg,
init_cfg=block_init_cfg))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
inplanes,
planes,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg,
init_cfg=block_init_cfg))
return Sequential(*layers)
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
"""Make each stage."""
num_modules = layer_config['num_modules']
num_branches = layer_config['num_branches']
num_blocks = layer_config['num_blocks']
num_channels = layer_config['num_channels']
block = self.blocks_dict[layer_config['block']]
hr_modules = []
block_init_cfg = None
if self.pretrained is None and not hasattr(
self, 'init_cfg') and self.zero_init_residual:
if block is BasicBlock:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm2'))
elif block is Bottleneck:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm3'))
for i in range(num_modules):
# multi_scale_output is only used for the last module
if not multiscale_output and i == num_modules - 1:
reset_multiscale_output = False
else:
reset_multiscale_output = True
hr_modules.append(
HRModule(
num_branches,
block,
num_blocks,
in_channels,
num_channels,
reset_multiscale_output,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg,
block_init_cfg=block_init_cfg))
return Sequential(*hr_modules), in_channels
def _freeze_stages(self):
"""Freeze stages param and norm stats."""
if self.frozen_stages >= 0:
self.norm1.eval()
self.norm2.eval()
for m in [self.conv1, self.norm1, self.conv2, self.norm2]:
for param in m.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
if i == 1:
m = getattr(self, f'layer{i}')
t = getattr(self, f'transition{i}')
elif i == 4:
m = getattr(self, f'stage{i}')
else:
m = getattr(self, f'stage{i}')
t = getattr(self, f'transition{i}')
m.eval()
for param in m.parameters():
param.requires_grad = False
t.eval()
for param in t.parameters():
param.requires_grad = False
def forward(self, x):
"""Forward function."""
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.relu(x)
x = self.layer1(x)
x_list = []
for i in range(self.stage2_cfg['num_branches']):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_cfg['num_branches']):
if self.transition2[i] is not None:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_cfg['num_branches']):
if self.transition3[i] is not None:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage4(x_list)
return y_list
def train(self, mode=True):
"""Convert the model into training mode will keeping the normalization
layer freezed."""
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()

View File

@@ -0,0 +1,166 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmseg.registry import MODELS
from ..decode_heads.psp_head import PPM
from ..utils import resize
@MODELS.register_module()
class ICNet(BaseModule):
"""ICNet for Real-Time Semantic Segmentation on High-Resolution Images.
This backbone is the implementation of
`ICNet <https://arxiv.org/abs/1704.08545>`_.
Args:
backbone_cfg (dict): Config dict to build backbone. Usually it is
ResNet but it can also be other backbones.
in_channels (int): The number of input image channels. Default: 3.
layer_channels (Sequence[int]): The numbers of feature channels at
layer 2 and layer 4 in ResNet. It can also be other backbones.
Default: (512, 2048).
light_branch_middle_channels (int): The number of channels of the
middle layer in light branch. Default: 32.
psp_out_channels (int): The number of channels of the output of PSP
module. Default: 512.
out_channels (Sequence[int]): The numbers of output feature channels
at each branches. Default: (64, 256, 256).
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module. Default: (1, 2, 3, 6).
conv_cfg (dict): Dictionary to construct and config conv layer.
Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN').
act_cfg (dict): Dictionary to construct and config act layer.
Default: dict(type='ReLU').
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
backbone_cfg,
in_channels=3,
layer_channels=(512, 2048),
light_branch_middle_channels=32,
psp_out_channels=512,
out_channels=(64, 256, 256),
pool_scales=(1, 2, 3, 6),
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='ReLU'),
align_corners=False,
init_cfg=None):
if backbone_cfg is None:
raise TypeError('backbone_cfg must be passed from config file!')
if init_cfg is None:
init_cfg = [
dict(type='Kaiming', mode='fan_out', layer='Conv2d'),
dict(type='Constant', val=1, layer='_BatchNorm'),
dict(type='Normal', mean=0.01, layer='Linear')
]
super().__init__(init_cfg=init_cfg)
self.align_corners = align_corners
self.backbone = MODELS.build(backbone_cfg)
# Note: Default `ceil_mode` is false in nn.MaxPool2d, set
# `ceil_mode=True` to keep information in the corner of feature map.
self.backbone.maxpool = nn.MaxPool2d(
kernel_size=3, stride=2, padding=1, ceil_mode=True)
self.psp_modules = PPM(
pool_scales=pool_scales,
in_channels=layer_channels[1],
channels=psp_out_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
align_corners=align_corners)
self.psp_bottleneck = ConvModule(
layer_channels[1] + len(pool_scales) * psp_out_channels,
psp_out_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv_sub1 = nn.Sequential(
ConvModule(
in_channels=in_channels,
out_channels=light_branch_middle_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg),
ConvModule(
in_channels=light_branch_middle_channels,
out_channels=light_branch_middle_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg),
ConvModule(
in_channels=light_branch_middle_channels,
out_channels=out_channels[0],
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg))
self.conv_sub2 = ConvModule(
layer_channels[0],
out_channels[1],
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)
self.conv_sub4 = ConvModule(
psp_out_channels,
out_channels[2],
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)
def forward(self, x):
output = []
# sub 1
output.append(self.conv_sub1(x))
# sub 2
x = resize(
x,
scale_factor=0.5,
mode='bilinear',
align_corners=self.align_corners)
x = self.backbone.stem(x)
x = self.backbone.maxpool(x)
x = self.backbone.layer1(x)
x = self.backbone.layer2(x)
output.append(self.conv_sub2(x))
# sub 4
x = resize(
x,
scale_factor=0.5,
mode='bilinear',
align_corners=self.align_corners)
x = self.backbone.layer3(x)
x = self.backbone.layer4(x)
psp_outs = self.psp_modules(x) + [x]
psp_outs = torch.cat(psp_outs, dim=1)
x = self.psp_bottleneck(psp_outs)
output.append(self.conv_sub4(x))
return output

View File

@@ -0,0 +1,260 @@
# Copyright (c) OpenMMLab. All rights reserved.import math
import math
import torch
import torch.nn as nn
from mmengine.model import ModuleList
from mmengine.model.weight_init import (constant_init, kaiming_init,
trunc_normal_)
from mmengine.runner.checkpoint import _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from mmseg.registry import MODELS
from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer
class MAEAttention(BEiTAttention):
"""Multi-head self-attention with relative position bias used in MAE.
This module is different from ``BEiTAttention`` by initializing the
relative bias table with zeros.
"""
def init_weights(self):
"""Initialize relative position bias with zeros."""
# As MAE initializes relative position bias as zeros and this class
# inherited from BEiT which initializes relative position bias
# with `trunc_normal`, `init_weights` here does
# nothing and just passes directly
pass
class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer):
"""Implements one encoder layer in Vision Transformer.
This module is different from ``BEiTTransformerEncoderLayer`` by replacing
``BEiTAttention`` with ``MAEAttention``.
"""
def build_attn(self, attn_cfg):
self.attn = MAEAttention(**attn_cfg)
@MODELS.register_module()
class MAE(BEiT):
"""VisionTransformer with support for patch.
Args:
img_size (int | tuple): Input image size. Default: 224.
patch_size (int): The patch size. Default: 16.
in_channels (int): Number of input channels. Default: 3.
embed_dims (int): embedding dimension. Default: 768.
num_layers (int): depth of transformer. Default: 12.
num_heads (int): number of attention heads. Default: 12.
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
Default: 4.
out_indices (list | tuple | int): Output from which stages.
Default: -1.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
drop_path_rate (float): stochastic depth rate. Default 0.0.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
Default: False.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Default: False.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
pretrained (str, optional): model pretrained path. Default: None.
init_values (float): Initialize the values of Attention and FFN
with learnable scaling. Defaults to 0.1.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
img_size=224,
patch_size=16,
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_indices=-1,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
patch_norm=False,
final_norm=False,
num_fcs=2,
norm_eval=False,
pretrained=None,
init_values=0.1,
init_cfg=None):
super().__init__(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dims=embed_dims,
num_layers=num_layers,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
out_indices=out_indices,
qv_bias=False,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
patch_norm=patch_norm,
final_norm=final_norm,
num_fcs=num_fcs,
norm_eval=norm_eval,
pretrained=pretrained,
init_values=init_values,
init_cfg=init_cfg)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self.num_patches = self.patch_shape[0] * self.patch_shape[1]
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches + 1, embed_dims))
def _build_layers(self):
dpr = [
x.item()
for x in torch.linspace(0, self.drop_path_rate, self.num_layers)
]
self.layers = ModuleList()
for i in range(self.num_layers):
self.layers.append(
MAETransformerEncoderLayer(
embed_dims=self.embed_dims,
num_heads=self.num_heads,
feedforward_channels=self.mlp_ratio * self.embed_dims,
attn_drop_rate=self.attn_drop_rate,
drop_path_rate=dpr[i],
num_fcs=self.num_fcs,
bias=True,
act_cfg=self.act_cfg,
norm_cfg=self.norm_cfg,
window_size=self.patch_shape,
init_values=self.init_values))
def fix_init_weight(self):
"""Rescale the initialization according to layer id.
This function is copied from https://github.com/microsoft/unilm/blob/master/beit/modeling_pretrain.py. # noqa: E501
Copyright (c) Microsoft Corporation
Licensed under the MIT License
"""
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.layers):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.ffn.layers[1].weight.data, layer_id + 1)
def init_weights(self):
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
self.apply(_init_weights)
self.fix_init_weight()
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
checkpoint = _load_checkpoint(
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
state_dict = self.resize_rel_pos_embed(checkpoint)
state_dict = self.resize_abs_pos_embed(state_dict)
self.load_state_dict(state_dict, False)
elif self.init_cfg is not None:
super().init_weights()
else:
# We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
# Copyright 2019 Ross Wightman
# Licensed under the Apache License, Version 2.0 (the "License")
trunc_normal_(self.cls_token, std=.02)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
if 'ffn' in n:
nn.init.normal_(m.bias, mean=0., std=1e-6)
else:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
kaiming_init(m, mode='fan_in', bias=0.)
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m, val=1.0, bias=0.)
def resize_abs_pos_embed(self, state_dict):
if 'pos_embed' in state_dict:
pos_embed_checkpoint = state_dict['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
# height (== width) for the checkpoint position embedding
orig_size = int(
(pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
# height (== width) for the new position embedding
new_size = int(self.num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
embedding_size).permute(
0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens,
size=(new_size, new_size),
mode='bicubic',
align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
state_dict['pos_embed'] = new_pos_embed
return state_dict
def forward(self, inputs):
B = inputs.shape[0]
x, hw_shape = self.patch_embed(inputs)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1:
if self.final_norm:
x = self.norm1(x)
if i in self.out_indices:
out = x[:, 1:]
B, _, C = out.shape
out = out.reshape(B, hw_shape[0], hw_shape[1],
C).permute(0, 3, 1, 2).contiguous()
outs.append(out)
return tuple(outs)

View File

@@ -0,0 +1,450 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import warnings
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import MultiheadAttention
from mmengine.model import BaseModule, ModuleList, Sequential
from mmengine.model.weight_init import (constant_init, normal_init,
trunc_normal_init)
from mmseg.registry import MODELS
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw
class MixFFN(BaseModule):
"""An implementation of MixFFN of Segformer.
The differences between MixFFN & FFN:
1. Use 1X1 Conv to replace Linear layer.
2. Introduce 3X3 Conv to encode positional information.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`. Defaults: 256.
feedforward_channels (int): The hidden dimension of FFNs.
Defaults: 1024.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='ReLU')
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
feedforward_channels,
act_cfg=dict(type='GELU'),
ffn_drop=0.,
dropout_layer=None,
init_cfg=None):
super().__init__(init_cfg)
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.act_cfg = act_cfg
self.activate = build_activation_layer(act_cfg)
in_channels = embed_dims
fc1 = Conv2d(
in_channels=in_channels,
out_channels=feedforward_channels,
kernel_size=1,
stride=1,
bias=True)
# 3x3 depth wise conv to provide positional encode information
pe_conv = Conv2d(
in_channels=feedforward_channels,
out_channels=feedforward_channels,
kernel_size=3,
stride=1,
padding=(3 - 1) // 2,
bias=True,
groups=feedforward_channels)
fc2 = Conv2d(
in_channels=feedforward_channels,
out_channels=in_channels,
kernel_size=1,
stride=1,
bias=True)
drop = nn.Dropout(ffn_drop)
layers = [fc1, pe_conv, self.activate, drop, fc2, drop]
self.layers = Sequential(*layers)
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else torch.nn.Identity()
def forward(self, x, hw_shape, identity=None):
out = nlc_to_nchw(x, hw_shape)
out = self.layers(out)
out = nchw_to_nlc(out)
if identity is None:
identity = x
return identity + self.dropout_layer(out)
class EfficientMultiheadAttention(MultiheadAttention):
"""An implementation of Efficient Multi-head Attention of Segformer.
This module is modified from MultiheadAttention which is a module from
mmcv.cnn.bricks.transformer.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
Default: 0.0.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut. Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default: False.
qkv_bias (bool): enable bias for qkv if True. Default True.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
Attention of Segformer. Default: 1.
"""
def __init__(self,
embed_dims,
num_heads,
attn_drop=0.,
proj_drop=0.,
dropout_layer=None,
init_cfg=None,
batch_first=True,
qkv_bias=False,
norm_cfg=dict(type='LN'),
sr_ratio=1):
super().__init__(
embed_dims,
num_heads,
attn_drop,
proj_drop,
dropout_layer=dropout_layer,
init_cfg=init_cfg,
batch_first=batch_first,
bias=qkv_bias)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = Conv2d(
in_channels=embed_dims,
out_channels=embed_dims,
kernel_size=sr_ratio,
stride=sr_ratio)
# The ret[0] of build_norm_layer is norm name.
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
# handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
from mmseg import digit_version, mmcv_version
if mmcv_version < digit_version('1.3.17'):
warnings.warn('The legacy version of forward function in'
'EfficientMultiheadAttention is deprecated in'
'mmcv>=1.3.17 and will no longer support in the'
'future. Please upgrade your mmcv.')
self.forward = self.legacy_forward
def forward(self, x, hw_shape, identity=None):
x_q = x
if self.sr_ratio > 1:
x_kv = nlc_to_nchw(x, hw_shape)
x_kv = self.sr(x_kv)
x_kv = nchw_to_nlc(x_kv)
x_kv = self.norm(x_kv)
else:
x_kv = x
if identity is None:
identity = x_q
# Because the dataflow('key', 'query', 'value') of
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
# embed_dims), We should adjust the shape of dataflow from
# batch_first (batch, num_query, embed_dims) to num_query_first
# (num_query ,batch, embed_dims), and recover ``attn_output``
# from num_query_first to batch_first.
if self.batch_first:
x_q = x_q.transpose(0, 1)
x_kv = x_kv.transpose(0, 1)
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
if self.batch_first:
out = out.transpose(0, 1)
return identity + self.dropout_layer(self.proj_drop(out))
def legacy_forward(self, x, hw_shape, identity=None):
"""multi head attention forward in mmcv version < 1.3.17."""
x_q = x
if self.sr_ratio > 1:
x_kv = nlc_to_nchw(x, hw_shape)
x_kv = self.sr(x_kv)
x_kv = nchw_to_nlc(x_kv)
x_kv = self.norm(x_kv)
else:
x_kv = x
if identity is None:
identity = x_q
# `need_weights=True` will let nn.MultiHeadAttention
# `return attn_output, attn_output_weights.sum(dim=1) / num_heads`
# The `attn_output_weights.sum(dim=1)` may cause cuda error. So, we set
# `need_weights=False` to ignore `attn_output_weights.sum(dim=1)`.
# This issue - `https://github.com/pytorch/pytorch/issues/37583` report
# the error that large scale tensor sum operation may cause cuda error.
out = self.attn(query=x_q, key=x_kv, value=x_kv, need_weights=False)[0]
return identity + self.dropout_layer(self.proj_drop(out))
class TransformerEncoderLayer(BaseModule):
"""Implements one encoder layer in Segformer.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed.
after the feed forward layer. Default 0.0.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0.
drop_path_rate (float): stochastic depth rate. Default 0.0.
qkv_bias (bool): enable bias for qkv if True.
Default: True.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default: False.
init_cfg (dict, optional): Initialization config dict.
Default:None.
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
Attention of Segformer. Default: 1.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
batch_first=True,
sr_ratio=1,
with_cp=False):
super().__init__()
# The ret[0] of build_norm_layer is norm name.
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = EfficientMultiheadAttention(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
batch_first=batch_first,
qkv_bias=qkv_bias,
norm_cfg=norm_cfg,
sr_ratio=sr_ratio)
# The ret[0] of build_norm_layer is norm name.
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = MixFFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)
self.with_cp = with_cp
def forward(self, x, hw_shape):
def _inner_forward(x):
x = self.attn(self.norm1(x), hw_shape, identity=x)
x = self.ffn(self.norm2(x), hw_shape, identity=x)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
@MODELS.register_module()
class MixVisionTransformer(BaseModule):
"""The backbone of Segformer.
This backbone is the implementation of `SegFormer: Simple and
Efficient Design for Semantic Segmentation with
Transformers <https://arxiv.org/abs/2105.15203>`_.
Args:
in_channels (int): Number of input channels. Default: 3.
embed_dims (int): Embedding dimension. Default: 768.
num_stags (int): The num of stages. Default: 4.
num_layers (Sequence[int]): The layer number of each transformer encode
layer. Default: [3, 4, 6, 3].
num_heads (Sequence[int]): The attention heads of each transformer
encode layer. Default: [1, 2, 4, 8].
patch_sizes (Sequence[int]): The patch_size of each overlapped patch
embedding. Default: [7, 3, 3, 3].
strides (Sequence[int]): The stride of each overlapped patch embedding.
Default: [4, 2, 2, 2].
sr_ratios (Sequence[int]): The spatial reduction rate of each
transformer encode layer. Default: [8, 4, 2, 1].
out_indices (Sequence[int] | int): Output from which stages.
Default: (0, 1, 2, 3).
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
Default: 4.
qkv_bias (bool): Enable bias for qkv if True. Default: True.
drop_rate (float): Probability of an element to be zeroed.
Default 0.0
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
drop_path_rate (float): stochastic depth rate. Default 0.0
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
"""
def __init__(self,
in_channels=3,
embed_dims=64,
num_stages=4,
num_layers=[3, 4, 6, 3],
num_heads=[1, 2, 4, 8],
patch_sizes=[7, 3, 3, 3],
strides=[4, 2, 2, 2],
sr_ratios=[8, 4, 2, 1],
out_indices=(0, 1, 2, 3),
mlp_ratio=4,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN', eps=1e-6),
pretrained=None,
init_cfg=None,
with_cp=False):
super().__init__(init_cfg=init_cfg)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.embed_dims = embed_dims
self.num_stages = num_stages
self.num_layers = num_layers
self.num_heads = num_heads
self.patch_sizes = patch_sizes
self.strides = strides
self.sr_ratios = sr_ratios
self.with_cp = with_cp
assert num_stages == len(num_layers) == len(num_heads) \
== len(patch_sizes) == len(strides) == len(sr_ratios)
self.out_indices = out_indices
assert max(out_indices) < self.num_stages
# transformer encoder
dpr = [
x.item()
for x in torch.linspace(0, drop_path_rate, sum(num_layers))
] # stochastic num_layer decay rule
cur = 0
self.layers = ModuleList()
for i, num_layer in enumerate(num_layers):
embed_dims_i = embed_dims * num_heads[i]
patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims_i,
kernel_size=patch_sizes[i],
stride=strides[i],
padding=patch_sizes[i] // 2,
norm_cfg=norm_cfg)
layer = ModuleList([
TransformerEncoderLayer(
embed_dims=embed_dims_i,
num_heads=num_heads[i],
feedforward_channels=mlp_ratio * embed_dims_i,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[cur + idx],
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
])
in_channels = embed_dims_i
# The ret[0] of build_norm_layer is norm name.
norm = build_norm_layer(norm_cfg, embed_dims_i)[1]
self.layers.append(ModuleList([patch_embed, layer, norm]))
cur += num_layer
def init_weights(self):
if self.init_cfg is None:
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m, val=1.0, bias=0.)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[
1] * m.out_channels
fan_out //= m.groups
normal_init(
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
else:
super().init_weights()
def forward(self, x):
outs = []
for i, layer in enumerate(self.layers):
x, hw_shape = layer[0](x)
for block in layer[1]:
x = block(x, hw_shape)
x = layer[2](x)
x = nlc_to_nchw(x, hw_shape)
if i in self.out_indices:
outs.append(x)
return outs

View File

@@ -0,0 +1,197 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from mmseg.registry import MODELS
from ..utils import InvertedResidual, make_divisible
@MODELS.register_module()
class MobileNetV2(BaseModule):
"""MobileNetV2 backbone.
This backbone is the implementation of
`MobileNetV2: Inverted Residuals and Linear Bottlenecks
<https://arxiv.org/abs/1801.04381>`_.
Args:
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Default: 1.0.
strides (Sequence[int], optional): Strides of the first block of each
layer. If not specified, default config in ``arch_setting`` will
be used.
dilations (Sequence[int]): Dilation of each layer.
out_indices (None or Sequence[int]): Output from which stages.
Default: (7, ).
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU6').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
# Parameters to build layers. 3 parameters are needed to construct a
# layer, from left to right: expand_ratio, channel, num_blocks.
arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4],
[6, 96, 3], [6, 160, 3], [6, 320, 1]]
def __init__(self,
widen_factor=1.,
strides=(1, 2, 2, 2, 1, 2, 1),
dilations=(1, 1, 1, 1, 1, 1, 1),
out_indices=(1, 2, 4, 6),
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'),
norm_eval=False,
with_cp=False,
pretrained=None,
init_cfg=None):
super().__init__(init_cfg)
self.pretrained = pretrained
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')
self.widen_factor = widen_factor
self.strides = strides
self.dilations = dilations
assert len(strides) == len(dilations) == len(self.arch_settings)
self.out_indices = out_indices
for index in out_indices:
if index not in range(0, 7):
raise ValueError('the item in out_indices must in '
f'range(0, 7). But received {index}')
if frozen_stages not in range(-1, 7):
raise ValueError('frozen_stages must be in range(-1, 7). '
f'But received {frozen_stages}')
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.in_channels = make_divisible(32 * widen_factor, 8)
self.conv1 = ConvModule(
in_channels=3,
out_channels=self.in_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.layers = []
for i, layer_cfg in enumerate(self.arch_settings):
expand_ratio, channel, num_blocks = layer_cfg
stride = self.strides[i]
dilation = self.dilations[i]
out_channels = make_divisible(channel * widen_factor, 8)
inverted_res_layer = self.make_layer(
out_channels=out_channels,
num_blocks=num_blocks,
stride=stride,
dilation=dilation,
expand_ratio=expand_ratio)
layer_name = f'layer{i + 1}'
self.add_module(layer_name, inverted_res_layer)
self.layers.append(layer_name)
def make_layer(self, out_channels, num_blocks, stride, dilation,
expand_ratio):
"""Stack InvertedResidual blocks to build a layer for MobileNetV2.
Args:
out_channels (int): out_channels of block.
num_blocks (int): Number of blocks.
stride (int): Stride of the first block.
dilation (int): Dilation of the first block.
expand_ratio (int): Expand the number of channels of the
hidden layer in InvertedResidual by this ratio.
"""
layers = []
for i in range(num_blocks):
layers.append(
InvertedResidual(
self.in_channels,
out_channels,
stride if i == 0 else 1,
expand_ratio=expand_ratio,
dilation=dilation if i == 0 else 1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
with_cp=self.with_cp))
self.in_channels = out_channels
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
def _freeze_stages(self):
if self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
layer = getattr(self, f'layer{i}')
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()

View File

@@ -0,0 +1,267 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import Conv2dAdaptivePadding
from mmengine.model import BaseModule
from mmengine.utils import is_tuple_of
from torch.nn.modules.batchnorm import _BatchNorm
from mmseg.registry import MODELS
from ..utils import InvertedResidualV3 as InvertedResidual
@MODELS.register_module()
class MobileNetV3(BaseModule):
"""MobileNetV3 backbone.
This backbone is the improved implementation of `Searching for MobileNetV3
<https://ieeexplore.ieee.org/document/9008835>`_.
Args:
arch (str): Architecture of mobilnetv3, from {'small', 'large'}.
Default: 'small'.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
out_indices (tuple[int]): Output from which layer.
Default: (0, 1, 12).
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed.
Default: False.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
# Parameters to build each block:
# [kernel size, mid channels, out channels, with_se, act type, stride]
arch_settings = {
'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4
[3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8
[3, 88, 24, False, 'ReLU', 1],
[5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16
[5, 240, 40, True, 'HSwish', 1],
[5, 240, 40, True, 'HSwish', 1],
[5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16
[5, 144, 48, True, 'HSwish', 1],
[5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32
[5, 576, 96, True, 'HSwish', 1],
[5, 576, 96, True, 'HSwish', 1]],
'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2
[3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4
[3, 72, 24, False, 'ReLU', 1],
[5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8
[5, 120, 40, True, 'ReLU', 1],
[5, 120, 40, True, 'ReLU', 1],
[3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16
[3, 200, 80, False, 'HSwish', 1],
[3, 184, 80, False, 'HSwish', 1],
[3, 184, 80, False, 'HSwish', 1],
[3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16
[3, 672, 112, True, 'HSwish', 1],
[5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32
[5, 960, 160, True, 'HSwish', 1],
[5, 960, 160, True, 'HSwish', 1]]
} # yapf: disable
def __init__(self,
arch='small',
conv_cfg=None,
norm_cfg=dict(type='BN'),
out_indices=(0, 1, 12),
frozen_stages=-1,
reduction_factor=1,
norm_eval=False,
with_cp=False,
pretrained=None,
init_cfg=None):
super().__init__(init_cfg)
self.pretrained = pretrained
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')
assert arch in self.arch_settings
assert isinstance(reduction_factor, int) and reduction_factor > 0
assert is_tuple_of(out_indices, int)
for index in out_indices:
if index not in range(0, len(self.arch_settings[arch]) + 2):
raise ValueError(
'the item in out_indices must in '
f'range(0, {len(self.arch_settings[arch])+2}). '
f'But received {index}')
if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
raise ValueError('frozen_stages must be in range(-1, '
f'{len(self.arch_settings[arch])+2}). '
f'But received {frozen_stages}')
self.arch = arch
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.reduction_factor = reduction_factor
self.norm_eval = norm_eval
self.with_cp = with_cp
self.layers = self._make_layer()
def _make_layer(self):
layers = []
# build the first layer (layer0)
in_channels = 16
layer = ConvModule(
in_channels=3,
out_channels=in_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=dict(type='Conv2dAdaptivePadding'),
norm_cfg=self.norm_cfg,
act_cfg=dict(type='HSwish'))
self.add_module('layer0', layer)
layers.append('layer0')
layer_setting = self.arch_settings[self.arch]
for i, params in enumerate(layer_setting):
(kernel_size, mid_channels, out_channels, with_se, act,
stride) = params
if self.arch == 'large' and i >= 12 or self.arch == 'small' and \
i >= 8:
mid_channels = mid_channels // self.reduction_factor
out_channels = out_channels // self.reduction_factor
if with_se:
se_cfg = dict(
channels=mid_channels,
ratio=4,
act_cfg=(dict(type='ReLU'),
dict(type='HSigmoid', bias=3.0, divisor=6.0)))
else:
se_cfg = None
layer = InvertedResidual(
in_channels=in_channels,
out_channels=out_channels,
mid_channels=mid_channels,
kernel_size=kernel_size,
stride=stride,
se_cfg=se_cfg,
with_expand_conv=(in_channels != mid_channels),
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type=act),
with_cp=self.with_cp)
in_channels = out_channels
layer_name = f'layer{i + 1}'
self.add_module(layer_name, layer)
layers.append(layer_name)
# build the last layer
# block5 layer12 os=32 for small model
# block6 layer16 os=32 for large model
layer = ConvModule(
in_channels=in_channels,
out_channels=576 if self.arch == 'small' else 960,
kernel_size=1,
stride=1,
dilation=4,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type='HSwish'))
layer_name = f'layer{len(layer_setting) + 1}'
self.add_module(layer_name, layer)
layers.append(layer_name)
# next, convert backbone MobileNetV3 to a semantic segmentation version
if self.arch == 'small':
self.layer4.depthwise_conv.conv.stride = (1, 1)
self.layer9.depthwise_conv.conv.stride = (1, 1)
for i in range(4, len(layers)):
layer = getattr(self, layers[i])
if isinstance(layer, InvertedResidual):
modified_module = layer.depthwise_conv.conv
else:
modified_module = layer.conv
if i < 9:
modified_module.dilation = (2, 2)
pad = 2
else:
modified_module.dilation = (4, 4)
pad = 4
if not isinstance(modified_module, Conv2dAdaptivePadding):
# Adjust padding
pad *= (modified_module.kernel_size[0] - 1) // 2
modified_module.padding = (pad, pad)
else:
self.layer7.depthwise_conv.conv.stride = (1, 1)
self.layer13.depthwise_conv.conv.stride = (1, 1)
for i in range(7, len(layers)):
layer = getattr(self, layers[i])
if isinstance(layer, InvertedResidual):
modified_module = layer.depthwise_conv.conv
else:
modified_module = layer.conv
if i < 13:
modified_module.dilation = (2, 2)
pad = 2
else:
modified_module.dilation = (4, 4)
pad = 4
if not isinstance(modified_module, Conv2dAdaptivePadding):
# Adjust padding
pad *= (modified_module.kernel_size[0] - 1) // 2
modified_module.padding = (pad, pad)
return layers
def forward(self, x):
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)
return outs
def _freeze_stages(self):
for i in range(self.frozen_stages + 1):
layer = getattr(self, f'layer{i}')
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()

View File

@@ -0,0 +1,467 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Originally from https://github.com/visual-attention-network/segnext
# Licensed under the Apache License, Version 2.0 (the "License")
import math
import warnings
import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule
from mmengine.model.weight_init import (constant_init, normal_init,
trunc_normal_init)
from mmseg.registry import MODELS
class Mlp(BaseModule):
"""Multi Layer Perceptron (MLP) Module.
Args:
in_features (int): The dimension of input features.
hidden_features (int): The dimension of hidden features.
Defaults: None.
out_features (int): The dimension of output features.
Defaults: None.
act_cfg (dict): Config dict for activation layer in block.
Default: dict(type='GELU').
drop (float): The number of dropout rate in MLP block.
Defaults: 0.0.
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_cfg=dict(type='GELU'),
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.dwconv = nn.Conv2d(
hidden_features,
hidden_features,
3,
1,
1,
bias=True,
groups=hidden_features)
self.act = build_activation_layer(act_cfg)
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
def forward(self, x):
"""Forward function."""
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class StemConv(BaseModule):
"""Stem Block at the beginning of Semantic Branch.
Args:
in_channels (int): The dimension of input channels.
out_channels (int): The dimension of output channels.
act_cfg (dict): Config dict for activation layer in block.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Defaults: dict(type='SyncBN', requires_grad=True).
"""
def __init__(self,
in_channels,
out_channels,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super().__init__()
self.proj = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels // 2,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1)),
build_norm_layer(norm_cfg, out_channels // 2)[1],
build_activation_layer(act_cfg),
nn.Conv2d(
out_channels // 2,
out_channels,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1)),
build_norm_layer(norm_cfg, out_channels)[1],
)
def forward(self, x):
"""Forward function."""
x = self.proj(x)
_, _, H, W = x.size()
x = x.flatten(2).transpose(1, 2)
return x, H, W
class MSCAAttention(BaseModule):
"""Attention Module in Multi-Scale Convolutional Attention Module (MSCA).
Args:
channels (int): The dimension of channels.
kernel_sizes (list): The size of attention
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
paddings (list): The number of
corresponding padding value in attention module.
Defaults: [2, [0, 3], [0, 5], [0, 10]].
"""
def __init__(self,
channels,
kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
paddings=[2, [0, 3], [0, 5], [0, 10]]):
super().__init__()
self.conv0 = nn.Conv2d(
channels,
channels,
kernel_size=kernel_sizes[0],
padding=paddings[0],
groups=channels)
for i, (kernel_size,
padding) in enumerate(zip(kernel_sizes[1:], paddings[1:])):
kernel_size_ = [kernel_size, kernel_size[::-1]]
padding_ = [padding, padding[::-1]]
conv_name = [f'conv{i}_1', f'conv{i}_2']
for i_kernel, i_pad, i_conv in zip(kernel_size_, padding_,
conv_name):
self.add_module(
i_conv,
nn.Conv2d(
channels,
channels,
tuple(i_kernel),
padding=i_pad,
groups=channels))
self.conv3 = nn.Conv2d(channels, channels, 1)
def forward(self, x):
"""Forward function."""
u = x.clone()
attn = self.conv0(x)
# Multi-Scale Feature extraction
attn_0 = self.conv0_1(attn)
attn_0 = self.conv0_2(attn_0)
attn_1 = self.conv1_1(attn)
attn_1 = self.conv1_2(attn_1)
attn_2 = self.conv2_1(attn)
attn_2 = self.conv2_2(attn_2)
attn = attn + attn_0 + attn_1 + attn_2
# Channel Mixing
attn = self.conv3(attn)
# Convolutional Attention
x = attn * u
return x
class MSCASpatialAttention(BaseModule):
"""Spatial Attention Module in Multi-Scale Convolutional Attention Module
(MSCA).
Args:
in_channels (int): The dimension of channels.
attention_kernel_sizes (list): The size of attention
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
attention_kernel_paddings (list): The number of
corresponding padding value in attention module.
Defaults: [2, [0, 3], [0, 5], [0, 10]].
act_cfg (dict): Config dict for activation layer in block.
Default: dict(type='GELU').
"""
def __init__(self,
in_channels,
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
act_cfg=dict(type='GELU')):
super().__init__()
self.proj_1 = nn.Conv2d(in_channels, in_channels, 1)
self.activation = build_activation_layer(act_cfg)
self.spatial_gating_unit = MSCAAttention(in_channels,
attention_kernel_sizes,
attention_kernel_paddings)
self.proj_2 = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x):
"""Forward function."""
shorcut = x.clone()
x = self.proj_1(x)
x = self.activation(x)
x = self.spatial_gating_unit(x)
x = self.proj_2(x)
x = x + shorcut
return x
class MSCABlock(BaseModule):
"""Basic Multi-Scale Convolutional Attention Block. It leverage the large-
kernel attention (LKA) mechanism to build both channel and spatial
attention. In each branch, it uses two depth-wise strip convolutions to
approximate standard depth-wise convolutions with large kernels. The kernel
size for each branch is set to 7, 11, and 21, respectively.
Args:
channels (int): The dimension of channels.
attention_kernel_sizes (list): The size of attention
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
attention_kernel_paddings (list): The number of
corresponding padding value in attention module.
Defaults: [2, [0, 3], [0, 5], [0, 10]].
mlp_ratio (float): The ratio of multiple input dimension to
calculate hidden feature in MLP layer. Defaults: 4.0.
drop (float): The number of dropout rate in MLP block.
Defaults: 0.0.
drop_path (float): The ratio of drop paths.
Defaults: 0.0.
act_cfg (dict): Config dict for activation layer in block.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Defaults: dict(type='SyncBN', requires_grad=True).
"""
def __init__(self,
channels,
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super().__init__()
self.norm1 = build_norm_layer(norm_cfg, channels)[1]
self.attn = MSCASpatialAttention(channels, attention_kernel_sizes,
attention_kernel_paddings, act_cfg)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = build_norm_layer(norm_cfg, channels)[1]
mlp_hidden_channels = int(channels * mlp_ratio)
self.mlp = Mlp(
in_features=channels,
hidden_features=mlp_hidden_channels,
act_cfg=act_cfg,
drop=drop)
layer_scale_init_value = 1e-2
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones(channels), requires_grad=True)
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones(channels), requires_grad=True)
def forward(self, x, H, W):
"""Forward function."""
B, N, C = x.shape
x = x.permute(0, 2, 1).view(B, C, H, W)
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) *
self.attn(self.norm1(x)))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) *
self.mlp(self.norm2(x)))
x = x.view(B, C, N).permute(0, 2, 1)
return x
class OverlapPatchEmbed(BaseModule):
"""Image to Patch Embedding.
Args:
patch_size (int): The patch size.
Defaults: 7.
stride (int): Stride of the convolutional layer.
Default: 4.
in_channels (int): The number of input channels.
Defaults: 3.
embed_dims (int): The dimensions of embedding.
Defaults: 768.
norm_cfg (dict): Config dict for normalization layer.
Defaults: dict(type='SyncBN', requires_grad=True).
"""
def __init__(self,
patch_size=7,
stride=4,
in_channels=3,
embed_dim=768,
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super().__init__()
self.proj = nn.Conv2d(
in_channels,
embed_dim,
kernel_size=patch_size,
stride=stride,
padding=patch_size // 2)
self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
def forward(self, x):
"""Forward function."""
x = self.proj(x)
_, _, H, W = x.shape
x = self.norm(x)
x = x.flatten(2).transpose(1, 2)
return x, H, W
@MODELS.register_module()
class MSCAN(BaseModule):
"""SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone.
This backbone is the implementation of `SegNeXt: Rethinking
Convolutional Attention Design for Semantic
Segmentation <https://arxiv.org/abs/2209.08575>`_.
Inspiration from https://github.com/visual-attention-network/segnext.
Args:
in_channels (int): The number of input channels. Defaults: 3.
embed_dims (list[int]): Embedding dimension.
Defaults: [64, 128, 256, 512].
mlp_ratios (list[int]): Ratio of mlp hidden dim to embedding dim.
Defaults: [4, 4, 4, 4].
drop_rate (float): Dropout rate. Defaults: 0.
drop_path_rate (float): Stochastic depth rate. Defaults: 0.
depths (list[int]): Depths of each Swin Transformer stage.
Default: [3, 4, 6, 3].
num_stages (int): MSCAN stages. Default: 4.
attention_kernel_sizes (list): Size of attention kernel in
Attention Module (Figure 2(b) of original paper).
Defaults: [5, [1, 7], [1, 11], [1, 21]].
attention_kernel_paddings (list): Size of attention paddings
in Attention Module (Figure 2(b) of original paper).
Defaults: [2, [0, 3], [0, 5], [0, 10]].
norm_cfg (dict): Config of norm layers.
Defaults: dict(type='SyncBN', requires_grad=True).
pretrained (str, optional): model pretrained path.
Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels=3,
embed_dims=[64, 128, 256, 512],
mlp_ratios=[4, 4, 4, 4],
drop_rate=0.,
drop_path_rate=0.,
depths=[3, 4, 6, 3],
num_stages=4,
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN', requires_grad=True),
pretrained=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.depths = depths
self.num_stages = num_stages
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
cur = 0
for i in range(num_stages):
if i == 0:
patch_embed = StemConv(3, embed_dims[0], norm_cfg=norm_cfg)
else:
patch_embed = OverlapPatchEmbed(
patch_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2,
in_channels=in_channels if i == 0 else embed_dims[i - 1],
embed_dim=embed_dims[i],
norm_cfg=norm_cfg)
block = nn.ModuleList([
MSCABlock(
channels=embed_dims[i],
attention_kernel_sizes=attention_kernel_sizes,
attention_kernel_paddings=attention_kernel_paddings,
mlp_ratio=mlp_ratios[i],
drop=drop_rate,
drop_path=dpr[cur + j],
act_cfg=act_cfg,
norm_cfg=norm_cfg) for j in range(depths[i])
])
norm = nn.LayerNorm(embed_dims[i])
cur += depths[i]
setattr(self, f'patch_embed{i + 1}', patch_embed)
setattr(self, f'block{i + 1}', block)
setattr(self, f'norm{i + 1}', norm)
def init_weights(self):
"""Initialize modules of MSCAN."""
print('init cfg', self.init_cfg)
if self.init_cfg is None:
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m, val=1.0, bias=0.)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[
1] * m.out_channels
fan_out //= m.groups
normal_init(
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
else:
super().init_weights()
def forward(self, x):
"""Forward function."""
B = x.shape[0]
outs = []
for i in range(self.num_stages):
patch_embed = getattr(self, f'patch_embed{i + 1}')
block = getattr(self, f'block{i + 1}')
norm = getattr(self, f'norm{i + 1}')
x, H, W = patch_embed(x)
for blk in block:
x = blk(x, H, W)
x = norm(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
return outs

View File

@@ -0,0 +1,522 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmengine.runner import CheckpointLoader
from torch import Tensor
from mmseg.registry import MODELS
from mmseg.utils import OptConfigType
from ..utils import DAPPM, PAPPM, BasicBlock, Bottleneck
class PagFM(BaseModule):
"""Pixel-attention-guided fusion module.
Args:
in_channels (int): The number of input channels.
channels (int): The number of channels.
after_relu (bool): Whether to use ReLU before attention.
Default: False.
with_channel (bool): Whether to use channel attention.
Default: False.
upsample_mode (str): The mode of upsample. Default: 'bilinear'.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(typ='ReLU', inplace=True).
init_cfg (dict): Config dict for initialization. Default: None.
"""
def __init__(self,
in_channels: int,
channels: int,
after_relu: bool = False,
with_channel: bool = False,
upsample_mode: str = 'bilinear',
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(typ='ReLU', inplace=True),
init_cfg: OptConfigType = None):
super().__init__(init_cfg)
self.after_relu = after_relu
self.with_channel = with_channel
self.upsample_mode = upsample_mode
self.f_i = ConvModule(
in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None)
self.f_p = ConvModule(
in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None)
if with_channel:
self.up = ConvModule(
channels, in_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
if after_relu:
self.relu = MODELS.build(act_cfg)
def forward(self, x_p: Tensor, x_i: Tensor) -> Tensor:
"""Forward function.
Args:
x_p (Tensor): The featrue map from P branch.
x_i (Tensor): The featrue map from I branch.
Returns:
Tensor: The feature map with pixel-attention-guided fusion.
"""
if self.after_relu:
x_p = self.relu(x_p)
x_i = self.relu(x_i)
f_i = self.f_i(x_i)
f_i = F.interpolate(
f_i,
size=x_p.shape[2:],
mode=self.upsample_mode,
align_corners=False)
f_p = self.f_p(x_p)
if self.with_channel:
sigma = torch.sigmoid(self.up(f_p * f_i))
else:
sigma = torch.sigmoid(torch.sum(f_p * f_i, dim=1).unsqueeze(1))
x_i = F.interpolate(
x_i,
size=x_p.shape[2:],
mode=self.upsample_mode,
align_corners=False)
out = sigma * x_i + (1 - sigma) * x_p
return out
class Bag(BaseModule):
"""Boundary-attention-guided fusion module.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
kernel_size (int): The kernel size of the convolution. Default: 3.
padding (int): The padding of the convolution. Default: 1.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU', inplace=True).
conv_cfg (dict): Config dict for convolution layer.
Default: dict(order=('norm', 'act', 'conv')).
init_cfg (dict): Config dict for initialization. Default: None.
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
padding: int = 1,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
conv_cfg: OptConfigType = dict(order=('norm', 'act', 'conv')),
init_cfg: OptConfigType = None):
super().__init__(init_cfg)
self.conv = ConvModule(
in_channels,
out_channels,
kernel_size,
padding=padding,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg)
def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor:
"""Forward function.
Args:
x_p (Tensor): The featrue map from P branch.
x_i (Tensor): The featrue map from I branch.
x_d (Tensor): The featrue map from D branch.
Returns:
Tensor: The feature map with boundary-attention-guided fusion.
"""
sigma = torch.sigmoid(x_d)
return self.conv(sigma * x_p + (1 - sigma) * x_i)
class LightBag(BaseModule):
"""Light Boundary-attention-guided fusion module.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer. Default: None.
init_cfg (dict): Config dict for initialization. Default: None.
"""
def __init__(self,
in_channels: int,
out_channels: int,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = None,
init_cfg: OptConfigType = None):
super().__init__(init_cfg)
self.f_p = ConvModule(
in_channels,
out_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.f_i = ConvModule(
in_channels,
out_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor:
"""Forward function.
Args:
x_p (Tensor): The featrue map from P branch.
x_i (Tensor): The featrue map from I branch.
x_d (Tensor): The featrue map from D branch.
Returns:
Tensor: The feature map with light boundary-attention-guided
fusion.
"""
sigma = torch.sigmoid(x_d)
f_p = self.f_p((1 - sigma) * x_i + x_p)
f_i = self.f_i(x_i + sigma * x_p)
return f_p + f_i
@MODELS.register_module()
class PIDNet(BaseModule):
"""PIDNet backbone.
This backbone is the implementation of `PIDNet: A Real-time Semantic
Segmentation Network Inspired from PID Controller
<https://arxiv.org/abs/2206.02066>`_.
Modified from https://github.com/XuJiacong/PIDNet.
Licensed under the MIT License.
Args:
in_channels (int): The number of input channels. Default: 3.
channels (int): The number of channels in the stem layer. Default: 64.
ppm_channels (int): The number of channels in the PPM layer.
Default: 96.
num_stem_blocks (int): The number of blocks in the stem layer.
Default: 2.
num_branch_blocks (int): The number of blocks in the branch layer.
Default: 3.
align_corners (bool): The align_corners argument of F.interpolate.
Default: False.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU', inplace=True).
init_cfg (dict): Config dict for initialization. Default: None.
"""
def __init__(self,
in_channels: int = 3,
channels: int = 64,
ppm_channels: int = 96,
num_stem_blocks: int = 2,
num_branch_blocks: int = 3,
align_corners: bool = False,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
init_cfg: OptConfigType = None,
**kwargs):
super().__init__(init_cfg)
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.align_corners = align_corners
# stem layer
self.stem = self._make_stem_layer(in_channels, channels,
num_stem_blocks)
self.relu = nn.ReLU()
# I Branch
self.i_branch_layers = nn.ModuleList()
for i in range(3):
self.i_branch_layers.append(
self._make_layer(
block=BasicBlock if i < 2 else Bottleneck,
in_channels=channels * 2**(i + 1),
channels=channels * 8 if i > 0 else channels * 4,
num_blocks=num_branch_blocks if i < 2 else 2,
stride=2))
# P Branch
self.p_branch_layers = nn.ModuleList()
for i in range(3):
self.p_branch_layers.append(
self._make_layer(
block=BasicBlock if i < 2 else Bottleneck,
in_channels=channels * 2,
channels=channels * 2,
num_blocks=num_stem_blocks if i < 2 else 1))
self.compression_1 = ConvModule(
channels * 4,
channels * 2,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None)
self.compression_2 = ConvModule(
channels * 8,
channels * 2,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None)
self.pag_1 = PagFM(channels * 2, channels)
self.pag_2 = PagFM(channels * 2, channels)
# D Branch
if num_stem_blocks == 2:
self.d_branch_layers = nn.ModuleList([
self._make_single_layer(BasicBlock, channels * 2, channels),
self._make_layer(Bottleneck, channels, channels, 1)
])
channel_expand = 1
spp_module = PAPPM
dfm_module = LightBag
act_cfg_dfm = None
else:
self.d_branch_layers = nn.ModuleList([
self._make_single_layer(BasicBlock, channels * 2,
channels * 2),
self._make_single_layer(BasicBlock, channels * 2, channels * 2)
])
channel_expand = 2
spp_module = DAPPM
dfm_module = Bag
act_cfg_dfm = act_cfg
self.diff_1 = ConvModule(
channels * 4,
channels * channel_expand,
kernel_size=3,
padding=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None)
self.diff_2 = ConvModule(
channels * 8,
channels * 2,
kernel_size=3,
padding=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None)
self.spp = spp_module(
channels * 16, ppm_channels, channels * 4, num_scales=5)
self.dfm = dfm_module(
channels * 4, channels * 4, norm_cfg=norm_cfg, act_cfg=act_cfg_dfm)
self.d_branch_layers.append(
self._make_layer(Bottleneck, channels * 2, channels * 2, 1))
def _make_stem_layer(self, in_channels: int, channels: int,
num_blocks: int) -> nn.Sequential:
"""Make stem layer.
Args:
in_channels (int): Number of input channels.
channels (int): Number of output channels.
num_blocks (int): Number of blocks.
Returns:
nn.Sequential: The stem layer.
"""
layers = [
ConvModule(
in_channels,
channels,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
channels,
channels,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
]
layers.append(
self._make_layer(BasicBlock, channels, channels, num_blocks))
layers.append(nn.ReLU())
layers.append(
self._make_layer(
BasicBlock, channels, channels * 2, num_blocks, stride=2))
layers.append(nn.ReLU())
return nn.Sequential(*layers)
def _make_layer(self,
block: BasicBlock,
in_channels: int,
channels: int,
num_blocks: int,
stride: int = 1) -> nn.Sequential:
"""Make layer for PIDNet backbone.
Args:
block (BasicBlock): Basic block.
in_channels (int): Number of input channels.
channels (int): Number of output channels.
num_blocks (int): Number of blocks.
stride (int): Stride of the first block. Default: 1.
Returns:
nn.Sequential: The Branch Layer.
"""
downsample = None
if stride != 1 or in_channels != channels * block.expansion:
downsample = ConvModule(
in_channels,
channels * block.expansion,
kernel_size=1,
stride=stride,
norm_cfg=self.norm_cfg,
act_cfg=None)
layers = [block(in_channels, channels, stride, downsample)]
in_channels = channels * block.expansion
for i in range(1, num_blocks):
layers.append(
block(
in_channels,
channels,
stride=1,
act_cfg_out=None if i == num_blocks - 1 else self.act_cfg))
return nn.Sequential(*layers)
def _make_single_layer(self,
block: Union[BasicBlock, Bottleneck],
in_channels: int,
channels: int,
stride: int = 1) -> nn.Module:
"""Make single layer for PIDNet backbone.
Args:
block (BasicBlock or Bottleneck): Basic block or Bottleneck.
in_channels (int): Number of input channels.
channels (int): Number of output channels.
stride (int): Stride of the first block. Default: 1.
Returns:
nn.Module
"""
downsample = None
if stride != 1 or in_channels != channels * block.expansion:
downsample = ConvModule(
in_channels,
channels * block.expansion,
kernel_size=1,
stride=stride,
norm_cfg=self.norm_cfg,
act_cfg=None)
return block(
in_channels, channels, stride, downsample, act_cfg_out=None)
def init_weights(self):
"""Initialize the weights in backbone.
Since the D branch is not initialized by the pre-trained model, we
initialize it with the same method as the ResNet.
"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if self.init_cfg is not None:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
ckpt = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], map_location='cpu')
self.load_state_dict(ckpt, strict=False)
def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]:
"""Forward function.
Args:
x (Tensor): Input tensor with shape (B, C, H, W).
Returns:
Tensor or tuple[Tensor]: If self.training is True, return
tuple[Tensor], else return Tensor.
"""
w_out = x.shape[-1] // 8
h_out = x.shape[-2] // 8
# stage 0-2
x = self.stem(x)
# stage 3
x_i = self.relu(self.i_branch_layers[0](x))
x_p = self.p_branch_layers[0](x)
x_d = self.d_branch_layers[0](x)
comp_i = self.compression_1(x_i)
x_p = self.pag_1(x_p, comp_i)
diff_i = self.diff_1(x_i)
x_d += F.interpolate(
diff_i,
size=[h_out, w_out],
mode='bilinear',
align_corners=self.align_corners)
if self.training:
temp_p = x_p.clone()
# stage 4
x_i = self.relu(self.i_branch_layers[1](x_i))
x_p = self.p_branch_layers[1](self.relu(x_p))
x_d = self.d_branch_layers[1](self.relu(x_d))
comp_i = self.compression_2(x_i)
x_p = self.pag_2(x_p, comp_i)
diff_i = self.diff_2(x_i)
x_d += F.interpolate(
diff_i,
size=[h_out, w_out],
mode='bilinear',
align_corners=self.align_corners)
if self.training:
temp_d = x_d.clone()
# stage 5
x_i = self.i_branch_layers[2](x_i)
x_p = self.p_branch_layers[2](self.relu(x_p))
x_d = self.d_branch_layers[2](self.relu(x_d))
x_i = self.spp(x_i)
x_i = F.interpolate(
x_i,
size=[h_out, w_out],
mode='bilinear',
align_corners=self.align_corners)
out = self.dfm(x_p, x_i, x_d)
return (temp_p, out, temp_d) if self.training else out

View File

@@ -0,0 +1,318 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmseg.registry import MODELS
from ..utils import ResLayer
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNetV1d
class RSoftmax(nn.Module):
"""Radix Softmax module in ``SplitAttentionConv2d``.
Args:
radix (int): Radix of input.
groups (int): Groups of input.
"""
def __init__(self, radix, groups):
super().__init__()
self.radix = radix
self.groups = groups
def forward(self, x):
batch = x.size(0)
if self.radix > 1:
x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
x = F.softmax(x, dim=1)
x = x.reshape(batch, -1)
else:
x = torch.sigmoid(x)
return x
class SplitAttentionConv2d(nn.Module):
"""Split-Attention Conv2d in ResNeSt.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int | tuple[int]): Same as nn.Conv2d.
stride (int | tuple[int]): Same as nn.Conv2d.
padding (int | tuple[int]): Same as nn.Conv2d.
dilation (int | tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of inter_channels. Default: 4.
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer. Default: None.
dcn (dict): Config dict for DCN. Default: None.
"""
def __init__(self,
in_channels,
channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
radix=2,
reduction_factor=4,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None):
super().__init__()
inter_channels = max(in_channels * radix // reduction_factor, 32)
self.radix = radix
self.groups = groups
self.channels = channels
self.with_dcn = dcn is not None
self.dcn = dcn
fallback_on_stride = False
if self.with_dcn:
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
if self.with_dcn and not fallback_on_stride:
assert conv_cfg is None, 'conv_cfg must be None for DCN'
conv_cfg = dcn
self.conv = build_conv_layer(
conv_cfg,
in_channels,
channels * radix,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups * radix,
bias=False)
self.norm0_name, norm0 = build_norm_layer(
norm_cfg, channels * radix, postfix=0)
self.add_module(self.norm0_name, norm0)
self.relu = nn.ReLU(inplace=True)
self.fc1 = build_conv_layer(
None, channels, inter_channels, 1, groups=self.groups)
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, inter_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.fc2 = build_conv_layer(
None, inter_channels, channels * radix, 1, groups=self.groups)
self.rsoftmax = RSoftmax(radix, groups)
@property
def norm0(self):
"""nn.Module: the normalization layer named "norm0" """
return getattr(self, self.norm0_name)
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)
def forward(self, x):
x = self.conv(x)
x = self.norm0(x)
x = self.relu(x)
batch, rchannel = x.shape[:2]
batch = x.size(0)
if self.radix > 1:
splits = x.view(batch, self.radix, -1, *x.shape[2:])
gap = splits.sum(dim=1)
else:
gap = x
gap = F.adaptive_avg_pool2d(gap, 1)
gap = self.fc1(gap)
gap = self.norm1(gap)
gap = self.relu(gap)
atten = self.fc2(gap)
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
if self.radix > 1:
attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
out = torch.sum(attens * splits, dim=1)
else:
out = atten * x
return out.contiguous()
class Bottleneck(_Bottleneck):
"""Bottleneck block for ResNeSt.
Args:
inplane (int): Input planes of this block.
planes (int): Middle planes of this block.
groups (int): Groups of conv2.
width_per_group (int): Width per group of conv2. 64x4d indicates
``groups=64, width_per_group=4`` and 32x8d indicates
``groups=32, width_per_group=8``.
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of inter_channels in
SplitAttentionConv2d. Default: 4.
avg_down_stride (bool): Whether to use average pool for stride in
Bottleneck. Default: True.
kwargs (dict): Key word arguments for base class.
"""
expansion = 4
def __init__(self,
inplanes,
planes,
groups=1,
base_width=4,
base_channels=64,
radix=2,
reduction_factor=4,
avg_down_stride=True,
**kwargs):
"""Bottleneck block for ResNeSt."""
super().__init__(inplanes, planes, **kwargs)
if groups == 1:
width = self.planes
else:
width = math.floor(self.planes *
(base_width / base_channels)) * groups
self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, width, postfix=1)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.planes * self.expansion, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.inplanes,
width,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
self.with_modulated_dcn = False
self.conv2 = SplitAttentionConv2d(
width,
width,
kernel_size=3,
stride=1 if self.avg_down_stride else self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
radix=radix,
reduction_factor=reduction_factor,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
dcn=self.dcn)
delattr(self, self.norm2_name)
if self.avg_down_stride:
self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
self.conv3 = build_conv_layer(
self.conv_cfg,
width,
self.planes * self.expansion,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv1_plugin_names)
out = self.conv2(out)
if self.avg_down_stride:
out = self.avd_layer(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv2_plugin_names)
out = self.conv3(out)
out = self.norm3(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv3_plugin_names)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
@MODELS.register_module()
class ResNeSt(ResNetV1d):
"""ResNeSt backbone.
This backbone is the implementation of `ResNeSt:
Split-Attention Networks <https://arxiv.org/abs/2004.08955>`_.
Args:
groups (int): Number of groups of Bottleneck. Default: 1
base_width (int): Base width of Bottleneck. Default: 4
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of inter_channels in
SplitAttentionConv2d. Default: 4.
avg_down_stride (bool): Whether to use average pool for stride in
Bottleneck. Default: True.
kwargs (dict): Keyword arguments for ResNet.
"""
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3)),
200: (Bottleneck, (3, 24, 36, 3))
}
def __init__(self,
groups=1,
base_width=4,
radix=2,
reduction_factor=4,
avg_down_stride=True,
**kwargs):
self.groups = groups
self.base_width = base_width
self.radix = radix
self.reduction_factor = reduction_factor
self.avg_down_stride = avg_down_stride
super().__init__(**kwargs)
def make_res_layer(self, **kwargs):
"""Pack all blocks in a stage into a ``ResLayer``."""
return ResLayer(
groups=self.groups,
base_width=self.base_width,
base_channels=self.base_channels,
radix=self.radix,
reduction_factor=self.reduction_factor,
avg_down_stride=self.avg_down_stride,
**kwargs)

View File

@@ -0,0 +1,712 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer
from mmengine.model import BaseModule
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmseg.registry import MODELS
from ..utils import ResLayer
class BasicBlock(BaseModule):
"""Basic block for ResNet."""
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
plugins=None,
init_cfg=None):
super().__init__(init_cfg)
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
self.conv1 = build_conv_layer(
conv_cfg,
inplanes,
planes,
3,
stride=stride,
padding=dilation,
dilation=dilation,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
conv_cfg, planes, planes, 3, padding=1, bias=False)
self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.with_cp = with_cp
@property
def norm1(self):
"""nn.Module: normalization layer after the first convolution layer"""
return getattr(self, self.norm1_name)
@property
def norm2(self):
"""nn.Module: normalization layer after the second convolution layer"""
return getattr(self, self.norm2_name)
def forward(self, x):
"""Forward function."""
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
class Bottleneck(BaseModule):
"""Bottleneck block for ResNet.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
"caffe", the stride-two layer is the first 1x1 conv layer.
"""
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
plugins=None,
init_cfg=None):
super().__init__(init_cfg)
assert style in ['pytorch', 'caffe']
assert dcn is None or isinstance(dcn, dict)
assert plugins is None or isinstance(plugins, list)
if plugins is not None:
allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
assert all(p['position'] in allowed_position for p in plugins)
self.inplanes = inplanes
self.planes = planes
self.stride = stride
self.dilation = dilation
self.style = style
self.with_cp = with_cp
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.dcn = dcn
self.with_dcn = dcn is not None
self.plugins = plugins
self.with_plugins = plugins is not None
if self.with_plugins:
# collect plugins for conv1/conv2/conv3
self.after_conv1_plugins = [
plugin['cfg'] for plugin in plugins
if plugin['position'] == 'after_conv1'
]
self.after_conv2_plugins = [
plugin['cfg'] for plugin in plugins
if plugin['position'] == 'after_conv2'
]
self.after_conv3_plugins = [
plugin['cfg'] for plugin in plugins
if plugin['position'] == 'after_conv3'
]
if self.style == 'pytorch':
self.conv1_stride = 1
self.conv2_stride = stride
else:
self.conv1_stride = stride
self.conv2_stride = 1
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
norm_cfg, planes * self.expansion, postfix=3)
self.conv1 = build_conv_layer(
conv_cfg,
inplanes,
planes,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
fallback_on_stride = False
if self.with_dcn:
fallback_on_stride = dcn.pop('fallback_on_stride', False)
if not self.with_dcn or fallback_on_stride:
self.conv2 = build_conv_layer(
conv_cfg,
planes,
planes,
kernel_size=3,
stride=self.conv2_stride,
padding=dilation,
dilation=dilation,
bias=False)
else:
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
self.conv2 = build_conv_layer(
dcn,
planes,
planes,
kernel_size=3,
stride=self.conv2_stride,
padding=dilation,
dilation=dilation,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
conv_cfg,
planes,
planes * self.expansion,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
if self.with_plugins:
self.after_conv1_plugin_names = self.make_block_plugins(
planes, self.after_conv1_plugins)
self.after_conv2_plugin_names = self.make_block_plugins(
planes, self.after_conv2_plugins)
self.after_conv3_plugin_names = self.make_block_plugins(
planes * self.expansion, self.after_conv3_plugins)
def make_block_plugins(self, in_channels, plugins):
"""make plugins for block.
Args:
in_channels (int): Input channels of plugin.
plugins (list[dict]): List of plugins cfg to build.
Returns:
list[str]: List of the names of plugin.
"""
assert isinstance(plugins, list)
plugin_names = []
for plugin in plugins:
plugin = plugin.copy()
name, layer = build_plugin_layer(
plugin,
in_channels=in_channels,
postfix=plugin.pop('postfix', ''))
assert not hasattr(self, name), f'duplicate plugin {name}'
self.add_module(name, layer)
plugin_names.append(name)
return plugin_names
def forward_plugin(self, x, plugin_names):
"""Forward function for plugins."""
out = x
for name in plugin_names:
out = getattr(self, name)(x)
return out
@property
def norm1(self):
"""nn.Module: normalization layer after the first convolution layer"""
return getattr(self, self.norm1_name)
@property
def norm2(self):
"""nn.Module: normalization layer after the second convolution layer"""
return getattr(self, self.norm2_name)
@property
def norm3(self):
"""nn.Module: normalization layer after the third convolution layer"""
return getattr(self, self.norm3_name)
def forward(self, x):
"""Forward function."""
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv1_plugin_names)
out = self.conv2(out)
out = self.norm2(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv2_plugin_names)
out = self.conv3(out)
out = self.norm3(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv3_plugin_names)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
@MODELS.register_module()
class ResNet(BaseModule):
"""ResNet backbone.
This backbone is the improved implementation of `Deep Residual Learning
for Image Recognition <https://arxiv.org/abs/1512.03385>`_.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Number of stem channels. Default: 64.
base_channels (int): Number of base channels of res layer. Default: 64.
num_stages (int): Resnet stages, normally 4. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: (1, 2, 2, 2).
dilations (Sequence[int]): Dilation of each stage.
Default: (1, 1, 1, 1).
out_indices (Sequence[int]): Output from which stages.
Default: (0, 1, 2, 3).
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer. Default: 'pytorch'.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): Dictionary to construct and config conv layer.
When conv_cfg is None, cfg will be set to dict(type='Conv2d').
Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True).
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
dcn (dict | None): Dictionary to construct and config DCN conv layer.
When dcn is not None, conv_cfg must be None. Default: None.
stage_with_dcn (Sequence[bool]): Whether to set DCN conv for each
stage. The length of stage_with_dcn is equal to num_stages.
Default: (False, False, False, False).
plugins (list[dict]): List of plugins for stages, each dict contains:
- cfg (dict, required): Cfg dict to build plugin.
- position (str, required): Position inside block to insert plugin,
options: 'after_conv1', 'after_conv2', 'after_conv3'.
- stages (tuple[bool], optional): Stages to apply plugin, length
should be same as 'num_stages'.
Default: None.
multi_grid (Sequence[int]|None): Multi grid dilation rates of last
stage. Default: None.
contract_dilation (bool): Whether contract first dilation of each layer
Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Example:
>>> from mmseg.models import ResNet
>>> import torch
>>> self = ResNet(depth=18)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 64, 8, 8)
(1, 128, 4, 4)
(1, 256, 2, 2)
(1, 512, 1, 1)
"""
arch_settings = {
18: (BasicBlock, (2, 2, 2, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self,
depth,
in_channels=3,
stem_channels=64,
base_channels=64,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3),
style='pytorch',
deep_stem=False,
avg_down=False,
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
dcn=None,
stage_with_dcn=(False, False, False, False),
plugins=None,
multi_grid=None,
contract_dilation=False,
with_cp=False,
zero_init_residual=True,
pretrained=None,
init_cfg=None):
super().__init__(init_cfg)
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
self.pretrained = pretrained
self.zero_init_residual = zero_init_residual
block_init_cfg = None
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
block = self.arch_settings[depth][0]
if self.zero_init_residual:
if block is BasicBlock:
block_init_cfg = dict(
type='Constant',
val=0,
override=dict(name='norm2'))
elif block is Bottleneck:
block_init_cfg = dict(
type='Constant',
val=0,
override=dict(name='norm3'))
else:
raise TypeError('pretrained must be a str or None')
self.depth = depth
self.stem_channels = stem_channels
self.base_channels = base_channels
self.num_stages = num_stages
assert num_stages >= 1 and num_stages <= 4
self.strides = strides
self.dilations = dilations
assert len(strides) == len(dilations) == num_stages
self.out_indices = out_indices
assert max(out_indices) < num_stages
self.style = style
self.deep_stem = deep_stem
self.avg_down = avg_down
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.with_cp = with_cp
self.norm_eval = norm_eval
self.dcn = dcn
self.stage_with_dcn = stage_with_dcn
if dcn is not None:
assert len(stage_with_dcn) == num_stages
self.plugins = plugins
self.multi_grid = multi_grid
self.contract_dilation = contract_dilation
self.block, stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = stem_channels
self._make_stem_layer(in_channels, stem_channels)
self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks):
stride = strides[i]
dilation = dilations[i]
dcn = self.dcn if self.stage_with_dcn[i] else None
if plugins is not None:
stage_plugins = self.make_stage_plugins(plugins, i)
else:
stage_plugins = None
# multi grid is applied to last layer only
stage_multi_grid = multi_grid if i == len(
self.stage_blocks) - 1 else None
planes = base_channels * 2**i
res_layer = self.make_res_layer(
block=self.block,
inplanes=self.inplanes,
planes=planes,
num_blocks=num_blocks,
stride=stride,
dilation=dilation,
style=self.style,
avg_down=self.avg_down,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
dcn=dcn,
plugins=stage_plugins,
multi_grid=stage_multi_grid,
contract_dilation=contract_dilation,
init_cfg=block_init_cfg)
self.inplanes = planes * self.block.expansion
layer_name = f'layer{i+1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self._freeze_stages()
self.feat_dim = self.block.expansion * base_channels * 2**(
len(self.stage_blocks) - 1)
def make_stage_plugins(self, plugins, stage_idx):
"""make plugins for ResNet 'stage_idx'th stage .
Currently we support to insert 'context_block',
'empirical_attention_block', 'nonlocal_block' into the backbone like
ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
Bottleneck.
An example of plugins format could be :
>>> plugins=[
... dict(cfg=dict(type='xxx', arg1='xxx'),
... stages=(False, True, True, True),
... position='after_conv2'),
... dict(cfg=dict(type='yyy'),
... stages=(True, True, True, True),
... position='after_conv3'),
... dict(cfg=dict(type='zzz', postfix='1'),
... stages=(True, True, True, True),
... position='after_conv3'),
... dict(cfg=dict(type='zzz', postfix='2'),
... stages=(True, True, True, True),
... position='after_conv3')
... ]
>>> self = ResNet(depth=18)
>>> stage_plugins = self.make_stage_plugins(plugins, 0)
>>> assert len(stage_plugins) == 3
Suppose 'stage_idx=0', the structure of blocks in the stage would be:
conv1-> conv2->conv3->yyy->zzz1->zzz2
Suppose 'stage_idx=1', the structure of blocks in the stage would be:
conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
If stages is missing, the plugin would be applied to all stages.
Args:
plugins (list[dict]): List of plugins cfg to build. The postfix is
required if multiple same type plugins are inserted.
stage_idx (int): Index of stage to build
Returns:
list[dict]: Plugins for current stage
"""
stage_plugins = []
for plugin in plugins:
plugin = plugin.copy()
stages = plugin.pop('stages', None)
assert stages is None or len(stages) == self.num_stages
# whether to insert plugin into current stage
if stages is None or stages[stage_idx]:
stage_plugins.append(plugin)
return stage_plugins
def make_res_layer(self, **kwargs):
"""Pack all blocks in a stage into a ``ResLayer``."""
return ResLayer(**kwargs)
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)
def _make_stem_layer(self, in_channels, stem_channels):
"""Make stem layer for ResNet."""
if self.deep_stem:
self.stem = nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels,
stem_channels // 2,
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
nn.ReLU(inplace=True),
build_conv_layer(
self.conv_cfg,
stem_channels // 2,
stem_channels // 2,
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
nn.ReLU(inplace=True),
build_conv_layer(
self.conv_cfg,
stem_channels // 2,
stem_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, stem_channels)[1],
nn.ReLU(inplace=True))
else:
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
stem_channels,
kernel_size=7,
stride=2,
padding=3,
bias=False)
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, stem_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def _freeze_stages(self):
"""Freeze stages param and norm stats."""
if self.frozen_stages >= 0:
if self.deep_stem:
self.stem.eval()
for param in self.stem.parameters():
param.requires_grad = False
else:
self.norm1.eval()
for m in [self.conv1, self.norm1]:
for param in m.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = getattr(self, f'layer{i}')
m.eval()
for param in m.parameters():
param.requires_grad = False
def forward(self, x):
"""Forward function."""
if self.deep_stem:
x = self.stem(x)
else:
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.maxpool(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
@MODELS.register_module()
class ResNetV1c(ResNet):
"""ResNetV1c variant described in [1]_.
Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv in
the input stem with three 3x3 convs. For more details please refer to `Bag
of Tricks for Image Classification with Convolutional Neural Networks
<https://arxiv.org/abs/1812.01187>`_.
"""
def __init__(self, **kwargs):
super().__init__(deep_stem=True, avg_down=False, **kwargs)
@MODELS.register_module()
class ResNetV1d(ResNet):
"""ResNetV1d variant described in [1]_.
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
the input stem with three 3x3 convs. And in the downsampling block, a 2x2
avg_pool with stride 2 is added before conv, whose stride is changed to 1.
"""
def __init__(self, **kwargs):
super().__init__(deep_stem=True, avg_down=True, **kwargs)

View File

@@ -0,0 +1,150 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmseg.registry import MODELS
from ..utils import ResLayer
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNet
class Bottleneck(_Bottleneck):
"""Bottleneck block for ResNeXt.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
"caffe", the stride-two layer is the first 1x1 conv layer.
"""
def __init__(self,
inplanes,
planes,
groups=1,
base_width=4,
base_channels=64,
**kwargs):
super().__init__(inplanes, planes, **kwargs)
if groups == 1:
width = self.planes
else:
width = math.floor(self.planes *
(base_width / base_channels)) * groups
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, width, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
self.norm_cfg, width, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.planes * self.expansion, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.inplanes,
width,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
fallback_on_stride = False
self.with_modulated_dcn = False
if self.with_dcn:
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
if not self.with_dcn or fallback_on_stride:
self.conv2 = build_conv_layer(
self.conv_cfg,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
else:
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
self.conv2 = build_conv_layer(
self.dcn,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
self.conv_cfg,
width,
self.planes * self.expansion,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
@MODELS.register_module()
class ResNeXt(ResNet):
"""ResNeXt backbone.
This backbone is the implementation of `Aggregated
Residual Transformations for Deep Neural
Networks <https://arxiv.org/abs/1611.05431>`_.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Normally 3.
num_stages (int): Resnet stages, normally 4.
groups (int): Group of resnext.
base_width (int): Base width of resnext.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from mmseg.models import ResNeXt
>>> import torch
>>> self = ResNeXt(depth=50)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 256, 8, 8)
(1, 512, 4, 4)
(1, 1024, 2, 2)
(1, 2048, 1, 1)
"""
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self, groups=1, base_width=4, **kwargs):
self.groups = groups
self.base_width = base_width
super().__init__(**kwargs)
def make_res_layer(self, **kwargs):
"""Pack all blocks in a stage into a ``ResLayer``"""
return ResLayer(
groups=self.groups,
base_width=self.base_width,
base_channels=self.base_channels,
**kwargs)

View File

@@ -0,0 +1,422 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/MichaelFan01/STDC-Seg."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule, ModuleList, Sequential
from mmseg.registry import MODELS
from ..utils import resize
from .bisenetv1 import AttentionRefinementModule
class STDCModule(BaseModule):
"""STDCModule.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels before scaling.
stride (int): The number of stride for the first conv layer.
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): The activation config for conv layers.
num_convs (int): Numbers of conv layers.
fusion_type (str): Type of fusion operation. Default: 'add'.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
stride,
norm_cfg=None,
act_cfg=None,
num_convs=4,
fusion_type='add',
init_cfg=None):
super().__init__(init_cfg=init_cfg)
assert num_convs > 1
assert fusion_type in ['add', 'cat']
self.stride = stride
self.with_downsample = True if self.stride == 2 else False
self.fusion_type = fusion_type
self.layers = ModuleList()
conv_0 = ConvModule(
in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg)
if self.with_downsample:
self.downsample = ConvModule(
out_channels // 2,
out_channels // 2,
kernel_size=3,
stride=2,
padding=1,
groups=out_channels // 2,
norm_cfg=norm_cfg,
act_cfg=None)
if self.fusion_type == 'add':
self.layers.append(nn.Sequential(conv_0, self.downsample))
self.skip = Sequential(
ConvModule(
in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=1,
groups=in_channels,
norm_cfg=norm_cfg,
act_cfg=None),
ConvModule(
in_channels,
out_channels,
1,
norm_cfg=norm_cfg,
act_cfg=None))
else:
self.layers.append(conv_0)
self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
else:
self.layers.append(conv_0)
for i in range(1, num_convs):
out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i
self.layers.append(
ConvModule(
out_channels // 2**i,
out_channels // out_factor,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def forward(self, inputs):
if self.fusion_type == 'add':
out = self.forward_add(inputs)
else:
out = self.forward_cat(inputs)
return out
def forward_add(self, inputs):
layer_outputs = []
x = inputs.clone()
for layer in self.layers:
x = layer(x)
layer_outputs.append(x)
if self.with_downsample:
inputs = self.skip(inputs)
return torch.cat(layer_outputs, dim=1) + inputs
def forward_cat(self, inputs):
x0 = self.layers[0](inputs)
layer_outputs = [x0]
for i, layer in enumerate(self.layers[1:]):
if i == 0:
if self.with_downsample:
x = layer(self.downsample(x0))
else:
x = layer(x0)
else:
x = layer(x)
layer_outputs.append(x)
if self.with_downsample:
layer_outputs[0] = self.skip(x0)
return torch.cat(layer_outputs, dim=1)
class FeatureFusionModule(BaseModule):
"""Feature Fusion Module. This module is different from FeatureFusionModule
in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter
channel number is calculated by given `scale_factor`, while
FeatureFusionModule in BiSeNetV1 only uses one ConvModule in
`self.conv_atten`.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
scale_factor (int): The number of channel scale factor.
Default: 4.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): The activation config for conv layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
scale_factor=4,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
channels = out_channels // scale_factor
self.conv0 = ConvModule(
in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
self.attention = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
ConvModule(
out_channels,
channels,
1,
norm_cfg=None,
bias=False,
act_cfg=act_cfg),
ConvModule(
channels,
out_channels,
1,
norm_cfg=None,
bias=False,
act_cfg=None), nn.Sigmoid())
def forward(self, spatial_inputs, context_inputs):
inputs = torch.cat([spatial_inputs, context_inputs], dim=1)
x = self.conv0(inputs)
attn = self.attention(x)
x_attn = x * attn
return x_attn + x
@MODELS.register_module()
class STDCNet(BaseModule):
"""This backbone is the implementation of `Rethinking BiSeNet For Real-time
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
Args:
stdc_type (int): The type of backbone structure,
`STDCNet1` and`STDCNet2` denotes two main backbones in paper,
whose FLOPs is 813M and 1446M, respectively.
in_channels (int): The num of input_channels.
channels (tuple[int]): The output channels for each stage.
bottleneck_type (str): The type of STDC Module type, the value must
be 'add' or 'cat'.
norm_cfg (dict): Config dict for normalization layer.
act_cfg (dict): The activation config for conv layers.
num_convs (int): Numbers of conv layer at each STDC Module.
Default: 4.
with_final_conv (bool): Whether add a conv layer at the Module output.
Default: True.
pretrained (str, optional): Model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Example:
>>> import torch
>>> stdc_type = 'STDCNet1'
>>> in_channels = 3
>>> channels = (32, 64, 256, 512, 1024)
>>> bottleneck_type = 'cat'
>>> inputs = torch.rand(1, 3, 1024, 2048)
>>> self = STDCNet(stdc_type, in_channels,
... channels, bottleneck_type).eval()
>>> outputs = self.forward(inputs)
>>> for i in range(len(outputs)):
... print(f'outputs[{i}].shape = {outputs[i].shape}')
outputs[0].shape = torch.Size([1, 256, 128, 256])
outputs[1].shape = torch.Size([1, 512, 64, 128])
outputs[2].shape = torch.Size([1, 1024, 32, 64])
"""
arch_settings = {
'STDCNet1': [(2, 1), (2, 1), (2, 1)],
'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)]
}
def __init__(self,
stdc_type,
in_channels,
channels,
bottleneck_type,
norm_cfg,
act_cfg,
num_convs=4,
with_final_conv=False,
pretrained=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
assert stdc_type in self.arch_settings, \
f'invalid structure {stdc_type} for STDCNet.'
assert bottleneck_type in ['add', 'cat'],\
f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}'
assert len(channels) == 5,\
f'invalid channels length {len(channels)} for STDCNet.'
self.in_channels = in_channels
self.channels = channels
self.stage_strides = self.arch_settings[stdc_type]
self.prtrained = pretrained
self.num_convs = num_convs
self.with_final_conv = with_final_conv
self.stages = ModuleList([
ConvModule(
self.in_channels,
self.channels[0],
kernel_size=3,
stride=2,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
self.channels[0],
self.channels[1],
kernel_size=3,
stride=2,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
])
# `self.num_shallow_features` is the number of shallow modules in
# `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper.
# They are both not used for following modules like Attention
# Refinement Module and Feature Fusion Module.
# Thus they would be cut from `outs`. Please refer to Figure 4
# of original paper for more details.
self.num_shallow_features = len(self.stages)
for strides in self.stage_strides:
idx = len(self.stages) - 1
self.stages.append(
self._make_stage(self.channels[idx], self.channels[idx + 1],
strides, norm_cfg, act_cfg, bottleneck_type))
# After appending, `self.stages` is a ModuleList including several
# shallow modules and STDCModules.
# (len(self.stages) ==
# self.num_shallow_features + len(self.stage_strides))
if self.with_final_conv:
self.final_conv = ConvModule(
self.channels[-1],
max(1024, self.channels[-1]),
1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def _make_stage(self, in_channels, out_channels, strides, norm_cfg,
act_cfg, bottleneck_type):
layers = []
for i, stride in enumerate(strides):
layers.append(
STDCModule(
in_channels if i == 0 else out_channels,
out_channels,
stride,
norm_cfg,
act_cfg,
num_convs=self.num_convs,
fusion_type=bottleneck_type))
return Sequential(*layers)
def forward(self, x):
outs = []
for stage in self.stages:
x = stage(x)
outs.append(x)
if self.with_final_conv:
outs[-1] = self.final_conv(outs[-1])
outs = outs[self.num_shallow_features:]
return tuple(outs)
@MODELS.register_module()
class STDCContextPathNet(BaseModule):
"""STDCNet with Context Path. The `outs` below is a list of three feature
maps from deep to shallow, whose height and width is from small to big,
respectively. The biggest feature map of `outs` is outputted for
`STDCHead`, where Detail Loss would be calculated by Detail Ground-truth.
The other two feature maps are used for Attention Refinement Module,
respectively. Besides, the biggest feature map of `outs` and the last
output of Attention Refinement Module are concatenated for Feature Fusion
Module. Then, this fusion feature map `feat_fuse` would be outputted for
`decode_head`. More details please refer to Figure 4 of original paper.
Args:
backbone_cfg (dict): Config dict for stdc backbone.
last_in_channels (tuple(int)), The number of channels of last
two feature maps from stdc backbone. Default: (1024, 512).
out_channels (int): The channels of output feature maps.
Default: 128.
ffm_cfg (dict): Config dict for Feature Fusion Module. Default:
`dict(in_channels=512, out_channels=256, scale_factor=4)`.
upsample_mode (str): Algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'``. Default: ``'nearest'``.
align_corners (str): align_corners argument of F.interpolate. It
must be `None` if upsample_mode is ``'nearest'``. Default: None.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Return:
outputs (tuple): The tuple of list of output feature map for
auxiliary heads and decoder head.
"""
def __init__(self,
backbone_cfg,
last_in_channels=(1024, 512),
out_channels=128,
ffm_cfg=dict(
in_channels=512, out_channels=256, scale_factor=4),
upsample_mode='nearest',
align_corners=None,
norm_cfg=dict(type='BN'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.backbone = MODELS.build(backbone_cfg)
self.arms = ModuleList()
self.convs = ModuleList()
for channels in last_in_channels:
self.arms.append(AttentionRefinementModule(channels, out_channels))
self.convs.append(
ConvModule(
out_channels,
out_channels,
3,
padding=1,
norm_cfg=norm_cfg))
self.conv_avg = ConvModule(
last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg)
self.ffm = FeatureFusionModule(**ffm_cfg)
self.upsample_mode = upsample_mode
self.align_corners = align_corners
def forward(self, x):
outs = list(self.backbone(x))
avg = F.adaptive_avg_pool2d(outs[-1], 1)
avg_feat = self.conv_avg(avg)
feature_up = resize(
avg_feat,
size=outs[-1].shape[2:],
mode=self.upsample_mode,
align_corners=self.align_corners)
arms_out = []
for i in range(len(self.arms)):
x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up
feature_up = resize(
x_arm,
size=outs[len(outs) - 1 - i - 1].shape[2:],
mode=self.upsample_mode,
align_corners=self.align_corners)
feature_up = self.convs[i](feature_up)
arms_out.append(feature_up)
feat_fuse = self.ffm(outs[0], arms_out[1])
# The `outputs` has four feature maps.
# `outs[0]` is outputted for `STDCHead` auxiliary head.
# Two feature maps of `arms_out` are outputted for auxiliary head.
# `feat_fuse` is outputted for decoder head.
outputs = [outs[0]] + list(arms_out) + [feat_fuse]
return tuple(outputs)

View File

@@ -0,0 +1,757 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, build_dropout
from mmengine.logging import print_log
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import (constant_init, trunc_normal_,
trunc_normal_init)
from mmengine.runner import CheckpointLoader
from mmengine.utils import to_2tuple
from mmseg.registry import MODELS
from ..utils.embed import PatchEmbed, PatchMerging
class WindowMSA(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative
position bias.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (tuple[int]): The height and width of the window.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
init_cfg (dict | None, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size,
qkv_bias=True,
qk_scale=None,
attn_drop_rate=0.,
proj_drop_rate=0.,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads
self.scale = qk_scale or head_embed_dims**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# About 2x faster than original impl
Wh, Ww = self.window_size
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
rel_position_index = rel_index_coords + rel_index_coords.T
rel_position_index = rel_position_index.flip(1).contiguous()
self.register_buffer('relative_position_index', rel_position_index)
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_rate)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop_rate)
self.softmax = nn.Softmax(dim=-1)
def init_weights(self):
trunc_normal_(self.relative_position_bias_table, std=0.02)
def forward(self, x, mask=None):
"""
Args:
x (tensor): input features with shape of (num_windows*B, N, C)
mask (tensor | None, Optional): mask with shape of (num_windows,
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
# make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B // nW, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@staticmethod
def double_step_seq(step1, len1, step2, len2):
seq1 = torch.arange(0, step1 * len1, step1)
seq2 = torch.arange(0, step2 * len2, step2)
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
class ShiftWindowMSA(BaseModule):
"""Shifted Window Multihead Self-Attention Module.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window.
shift_size (int, optional): The shift step of each window towards
right-bottom. If zero, act as regular window-msa. Defaults to 0.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Defaults: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Defaults: 0.
proj_drop_rate (float, optional): Dropout ratio of output.
Defaults: 0.
dropout_layer (dict, optional): The dropout_layer used before output.
Defaults: dict(type='DropPath', drop_prob=0.).
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size,
shift_size=0,
qkv_bias=True,
qk_scale=None,
attn_drop_rate=0,
proj_drop_rate=0,
dropout_layer=dict(type='DropPath', drop_prob=0.),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.window_size = window_size
self.shift_size = shift_size
assert 0 <= self.shift_size < self.window_size
self.w_msa = WindowMSA(
embed_dims=embed_dims,
num_heads=num_heads,
window_size=to_2tuple(window_size),
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop_rate=attn_drop_rate,
proj_drop_rate=proj_drop_rate,
init_cfg=None)
self.drop = build_dropout(dropout_layer)
def forward(self, query, hw_shape):
B, L, C = query.shape
H, W = hw_shape
assert L == H * W, 'input feature has wrong size'
query = query.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
H_pad, W_pad = query.shape[1], query.shape[2]
# cyclic shift
if self.shift_size > 0:
shifted_query = torch.roll(
query,
shifts=(-self.shift_size, -self.shift_size),
dims=(1, 2))
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
h_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# nW, window_size, window_size, 1
mask_windows = self.window_partition(img_mask)
mask_windows = mask_windows.view(
-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-100.0)).masked_fill(
attn_mask == 0, float(0.0))
else:
shifted_query = query
attn_mask = None
# nW*B, window_size, window_size, C
query_windows = self.window_partition(shifted_query)
# nW*B, window_size*window_size, C
query_windows = query_windows.view(-1, self.window_size**2, C)
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
attn_windows = self.w_msa(query_windows, mask=attn_mask)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size,
self.window_size, C)
# B H' W' C
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(
shifted_x,
shifts=(self.shift_size, self.shift_size),
dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = self.drop(x)
return x
def window_reverse(self, windows, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
window_size = self.window_size
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
def window_partition(self, x):
"""
Args:
x: (B, H, W, C)
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
window_size = self.window_size
x = x.view(B, H // window_size, window_size, W // window_size,
window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows
class SwinBlock(BaseModule):
""""
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
window_size (int, optional): The local window scale. Default: 7.
shift (bool, optional): whether to shift window or not. Default False.
qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout rate. Default: 0.
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
act_cfg (dict, optional): The config dict of activation function.
Default: dict(type='GELU').
norm_cfg (dict, optional): The config dict of normalization.
Default: dict(type='LN').
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Default: False.
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
window_size=7,
shift=False,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
with_cp=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.with_cp = with_cp
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = ShiftWindowMSA(
embed_dims=embed_dims,
num_heads=num_heads,
window_size=window_size,
shift_size=window_size // 2 if shift else 0,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop_rate=attn_drop_rate,
proj_drop_rate=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
init_cfg=None)
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=2,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg,
add_identity=True,
init_cfg=None)
def forward(self, x, hw_shape):
def _inner_forward(x):
identity = x
x = self.norm1(x)
x = self.attn(x, hw_shape)
x = x + identity
identity = x
x = self.norm2(x)
x = self.ffn(x, identity=identity)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class SwinBlockSequence(BaseModule):
"""Implements one stage in Swin Transformer.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
depth (int): The number of blocks in this stage.
window_size (int, optional): The local window scale. Default: 7.
qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout rate. Default: 0.
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
drop_path_rate (float | list[float], optional): Stochastic depth
rate. Default: 0.
downsample (BaseModule | None, optional): The downsample operation
module. Default: None.
act_cfg (dict, optional): The config dict of activation function.
Default: dict(type='GELU').
norm_cfg (dict, optional): The config dict of normalization.
Default: dict(type='LN').
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Default: False.
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
depth,
window_size=7,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
downsample=None,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
with_cp=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
if isinstance(drop_path_rate, list):
drop_path_rates = drop_path_rate
assert len(drop_path_rates) == depth
else:
drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
self.blocks = ModuleList()
for i in range(depth):
block = SwinBlock(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=feedforward_channels,
window_size=window_size,
shift=False if i % 2 == 0 else True,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rates[i],
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
init_cfg=None)
self.blocks.append(block)
self.downsample = downsample
def forward(self, x, hw_shape):
for block in self.blocks:
x = block(x, hw_shape)
if self.downsample:
x_down, down_hw_shape = self.downsample(x, hw_shape)
return x_down, down_hw_shape, x, hw_shape
else:
return x, hw_shape, x, hw_shape
@MODELS.register_module()
class SwinTransformer(BaseModule):
"""Swin Transformer backbone.
This backbone is the implementation of `Swin Transformer:
Hierarchical Vision Transformer using Shifted
Windows <https://arxiv.org/abs/2103.14030>`_.
Inspiration from https://github.com/microsoft/Swin-Transformer.
Args:
pretrain_img_size (int | tuple[int]): The size of input image when
pretrain. Defaults: 224.
in_channels (int): The num of input channels.
Defaults: 3.
embed_dims (int): The feature dimension. Default: 96.
patch_size (int | tuple[int]): Patch size. Default: 4.
window_size (int): Window size. Default: 7.
mlp_ratio (int | float): Ratio of mlp hidden dim to embedding dim.
Default: 4.
depths (tuple[int]): Depths of each Swin Transformer stage.
Default: (2, 2, 6, 2).
num_heads (tuple[int]): Parallel attention heads of each Swin
Transformer stage. Default: (3, 6, 12, 24).
strides (tuple[int]): The patch merging or patch embedding stride of
each Swin Transformer stage. (In swin, we set kernel size equal to
stride.) Default: (4, 2, 2, 2).
out_indices (tuple[int]): Output from which stages.
Default: (0, 1, 2, 3).
qkv_bias (bool, optional): If True, add a learnable bias to query, key,
value. Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
patch_norm (bool): If add a norm layer for patch embed and patch
merging. Default: True.
drop_rate (float): Dropout rate. Defaults: 0.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults: False.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LN').
norm_cfg (dict): Config dict for normalization layer at
output of backone. Defaults: dict(type='LN').
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Default: False.
pretrained (str, optional): model pretrained path. Default: None.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
pretrain_img_size=224,
in_channels=3,
embed_dims=96,
patch_size=4,
window_size=7,
mlp_ratio=4,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
strides=(4, 2, 2, 2),
out_indices=(0, 1, 2, 3),
qkv_bias=True,
qk_scale=None,
patch_norm=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
use_abs_pos_embed=False,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
with_cp=False,
pretrained=None,
frozen_stages=-1,
init_cfg=None):
self.frozen_stages = frozen_stages
if isinstance(pretrain_img_size, int):
pretrain_img_size = to_2tuple(pretrain_img_size)
elif isinstance(pretrain_img_size, tuple):
if len(pretrain_img_size) == 1:
pretrain_img_size = to_2tuple(pretrain_img_size[0])
assert len(pretrain_img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(pretrain_img_size)}'
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be specified at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
init_cfg = init_cfg
else:
raise TypeError('pretrained must be a str or None')
super().__init__(init_cfg=init_cfg)
num_layers = len(depths)
self.out_indices = out_indices
self.use_abs_pos_embed = use_abs_pos_embed
assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=strides[0],
padding='corner',
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None)
if self.use_abs_pos_embed:
patch_row = pretrain_img_size[0] // patch_size
patch_col = pretrain_img_size[1] // patch_size
num_patches = patch_row * patch_col
self.absolute_pos_embed = nn.Parameter(
torch.zeros((1, num_patches, embed_dims)))
self.drop_after_pos = nn.Dropout(p=drop_rate)
# set stochastic depth decay rule
total_depth = sum(depths)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
]
self.stages = ModuleList()
in_channels = embed_dims
for i in range(num_layers):
if i < num_layers - 1:
downsample = PatchMerging(
in_channels=in_channels,
out_channels=2 * in_channels,
stride=strides[i + 1],
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None)
else:
downsample = None
stage = SwinBlockSequence(
embed_dims=in_channels,
num_heads=num_heads[i],
feedforward_channels=int(mlp_ratio * in_channels),
depth=depths[i],
window_size=window_size,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],
downsample=downsample,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
init_cfg=None)
self.stages.append(stage)
if downsample:
in_channels = downsample.out_channels
self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)]
# Add a norm layer for each output
for i in out_indices:
layer = build_norm_layer(norm_cfg, self.num_features[i])[1]
layer_name = f'norm{i}'
self.add_module(layer_name, layer)
def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed."""
super().train(mode)
self._freeze_stages()
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
if self.use_abs_pos_embed:
self.absolute_pos_embed.requires_grad = False
self.drop_after_pos.eval()
for i in range(1, self.frozen_stages + 1):
if (i - 1) in self.out_indices:
norm_layer = getattr(self, f'norm{i-1}')
norm_layer.eval()
for param in norm_layer.parameters():
param.requires_grad = False
m = self.stages[i - 1]
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self):
if self.init_cfg is None:
print_log(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
if self.use_abs_pos_embed:
trunc_normal_(self.absolute_pos_embed, std=0.02)
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m, val=1.0, bias=0.)
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
ckpt = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
if 'state_dict' in ckpt:
_state_dict = ckpt['state_dict']
elif 'model' in ckpt:
_state_dict = ckpt['model']
else:
_state_dict = ckpt
state_dict = OrderedDict()
for k, v in _state_dict.items():
if k.startswith('backbone.'):
state_dict[k[9:]] = v
else:
state_dict[k] = v
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
# reshape absolute position embedding
if state_dict.get('absolute_pos_embed') is not None:
absolute_pos_embed = state_dict['absolute_pos_embed']
N1, L, C1 = absolute_pos_embed.size()
N2, C2, H, W = self.absolute_pos_embed.size()
if N1 != N2 or C1 != C2 or L != H * W:
print_log('Error in loading absolute_pos_embed, pass')
else:
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
# interpolate position bias table if needed
relative_position_bias_table_keys = [
k for k in state_dict.keys()
if 'relative_position_bias_table' in k
]
for table_key in relative_position_bias_table_keys:
table_pretrained = state_dict[table_key]
if table_key in self.state_dict():
table_current = self.state_dict()[table_key]
L1, nH1 = table_pretrained.size()
L2, nH2 = table_current.size()
if nH1 != nH2:
print_log(f'Error in loading {table_key}, pass')
elif L1 != L2:
S1 = int(L1**0.5)
S2 = int(L2**0.5)
table_pretrained_resized = F.interpolate(
table_pretrained.permute(1, 0).reshape(
1, nH1, S1, S1),
size=(S2, S2),
mode='bicubic')
state_dict[table_key] = table_pretrained_resized.view(
nH2, L2).permute(1, 0).contiguous()
# load state_dict
self.load_state_dict(state_dict, strict=False)
def forward(self, x):
x, hw_shape = self.patch_embed(x)
if self.use_abs_pos_embed:
x = x + self.absolute_pos_embed
x = self.drop_after_pos(x)
outs = []
for i, stage in enumerate(self.stages):
x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(out)
out = out.view(-1, *out_hw_shape,
self.num_features[i]).permute(0, 3, 1,
2).contiguous()
outs.append(out)
return outs

View File

@@ -0,0 +1,63 @@
# Copyright (c) OpenMMLab. All rights reserved.
try:
import timm
except ImportError:
timm = None
from mmengine.model import BaseModule
from mmengine.registry import MODELS as MMENGINE_MODELS
from mmseg.registry import MODELS
@MODELS.register_module()
class TIMMBackbone(BaseModule):
"""Wrapper to use backbones from timm library. More details can be found in
`timm <https://github.com/rwightman/pytorch-image-models>`_ .
Args:
model_name (str): Name of timm model to instantiate.
pretrained (bool): Load pretrained weights if True.
checkpoint_path (str): Path of checkpoint to load after
model is initialized.
in_channels (int): Number of input image channels. Default: 3.
init_cfg (dict, optional): Initialization config dict
**kwargs: Other timm & model specific arguments.
"""
def __init__(
self,
model_name,
features_only=True,
pretrained=True,
checkpoint_path='',
in_channels=3,
init_cfg=None,
**kwargs,
):
if timm is None:
raise RuntimeError('timm is not installed')
super().__init__(init_cfg)
if 'norm_layer' in kwargs:
kwargs['norm_layer'] = MMENGINE_MODELS.get(kwargs['norm_layer'])
self.timm_model = timm.create_model(
model_name=model_name,
features_only=features_only,
pretrained=pretrained,
in_chans=in_channels,
checkpoint_path=checkpoint_path,
**kwargs,
)
# Make unused parameters None
self.timm_model.global_pool = None
self.timm_model.fc = None
self.timm_model.classifier = None
# Hack to use pretrained weights from timm
if pretrained or checkpoint_path:
self._is_init = True
def forward(self, x):
features = self.timm_model(x)
return features

View File

@@ -0,0 +1,588 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import FFN
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import (constant_init, normal_init,
trunc_normal_init)
from torch.nn.modules.batchnorm import _BatchNorm
from mmseg.models.backbones.mit import EfficientMultiheadAttention
from mmseg.registry import MODELS
from ..utils.embed import PatchEmbed
class GlobalSubsampledAttention(EfficientMultiheadAttention):
"""Global Sub-sampled Attention (Spatial Reduction Attention)
This module is modified from EfficientMultiheadAttention
which is a module from mmseg.models.backbones.mit.py.
Specifically, there is no difference between
`GlobalSubsampledAttention` and `EfficientMultiheadAttention`,
`GlobalSubsampledAttention` is built as a brand new class
because it is renamed as `Global sub-sampled attention (GSA)`
in paper.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
Default: 0.0.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut. Default: None.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dims)
or (n, batch, embed_dims). Default: False.
qkv_bias (bool): enable bias for qkv if True. Default: True.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
sr_ratio (int): The ratio of spatial reduction of GSA of PCPVT.
Default: 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
attn_drop=0.,
proj_drop=0.,
dropout_layer=None,
batch_first=True,
qkv_bias=True,
norm_cfg=dict(type='LN'),
sr_ratio=1,
init_cfg=None):
super().__init__(
embed_dims,
num_heads,
attn_drop=attn_drop,
proj_drop=proj_drop,
dropout_layer=dropout_layer,
batch_first=batch_first,
qkv_bias=qkv_bias,
norm_cfg=norm_cfg,
sr_ratio=sr_ratio,
init_cfg=init_cfg)
class GSAEncoderLayer(BaseModule):
"""Implements one encoder layer with GSA.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default: 0.0.
attn_drop_rate (float): The drop out rate for attention layer.
Default: 0.0.
drop_path_rate (float): Stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): Enable bias for qkv if True. Default: True
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
sr_ratio (float): Kernel_size of conv in Attention modules. Default: 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
sr_ratio=1.,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
self.attn = GlobalSubsampledAttention(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias,
norm_cfg=norm_cfg,
sr_ratio=sr_ratio)
self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1]
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg,
add_identity=False)
self.drop_path = build_dropout(
dict(type='DropPath', drop_prob=drop_path_rate)
) if drop_path_rate > 0. else nn.Identity()
def forward(self, x, hw_shape):
x = x + self.drop_path(self.attn(self.norm1(x), hw_shape, identity=0.))
x = x + self.drop_path(self.ffn(self.norm2(x)))
return x
class LocallyGroupedSelfAttention(BaseModule):
"""Locally-grouped Self Attention (LSA) module.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads. Default: 8
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
window_size(int): Window size of LSA. Default: 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop_rate=0.,
proj_drop_rate=0.,
window_size=1,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
assert embed_dims % num_heads == 0, f'dim {embed_dims} should be ' \
f'divided by num_heads ' \
f'{num_heads}.'
self.embed_dims = embed_dims
self.num_heads = num_heads
head_dim = embed_dims // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_rate)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop_rate)
self.window_size = window_size
def forward(self, x, hw_shape):
b, n, c = x.shape
h, w = hw_shape
x = x.view(b, h, w, c)
# pad feature maps to multiples of Local-groups
pad_l = pad_t = 0
pad_r = (self.window_size - w % self.window_size) % self.window_size
pad_b = (self.window_size - h % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
# calculate attention mask for LSA
Hp, Wp = x.shape[1:-1]
_h, _w = Hp // self.window_size, Wp // self.window_size
mask = torch.zeros((1, Hp, Wp), device=x.device)
mask[:, -pad_b:, :].fill_(1)
mask[:, :, -pad_r:].fill_(1)
# [B, _h, _w, window_size, window_size, C]
x = x.reshape(b, _h, self.window_size, _w, self.window_size,
c).transpose(2, 3)
mask = mask.reshape(1, _h, self.window_size, _w,
self.window_size).transpose(2, 3).reshape(
1, _h * _w,
self.window_size * self.window_size)
# [1, _h*_w, window_size*window_size, window_size*window_size]
attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3)
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-1000.0)).masked_fill(
attn_mask == 0, float(0.0))
# [3, B, _w*_h, nhead, window_size*window_size, dim]
qkv = self.qkv(x).reshape(b, _h * _w,
self.window_size * self.window_size, 3,
self.num_heads, c // self.num_heads).permute(
3, 0, 1, 4, 2, 5)
q, k, v = qkv[0], qkv[1], qkv[2]
# [B, _h*_w, n_head, window_size*window_size, window_size*window_size]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn + attn_mask.unsqueeze(2)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
attn = (attn @ v).transpose(2, 3).reshape(b, _h, _w, self.window_size,
self.window_size, c)
x = attn.transpose(2, 3).reshape(b, _h * self.window_size,
_w * self.window_size, c)
if pad_r > 0 or pad_b > 0:
x = x[:, :h, :w, :].contiguous()
x = x.reshape(b, n, c)
x = self.proj(x)
x = self.proj_drop(x)
return x
class LSAEncoderLayer(BaseModule):
"""Implements one encoder layer in Twins-SVT.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default: 0.0.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
drop_path_rate (float): Stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): Enable bias for qkv if True. Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
window_size (int): Window size of LSA. Default: 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=True,
qk_scale=None,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
window_size=1,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads,
qkv_bias, qk_scale,
attn_drop_rate, drop_rate,
window_size)
self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1]
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg,
add_identity=False)
self.drop_path = build_dropout(
dict(type='DropPath', drop_prob=drop_path_rate)
) if drop_path_rate > 0. else nn.Identity()
def forward(self, x, hw_shape):
x = x + self.drop_path(self.attn(self.norm1(x), hw_shape))
x = x + self.drop_path(self.ffn(self.norm2(x)))
return x
class ConditionalPositionEncoding(BaseModule):
"""The Conditional Position Encoding (CPE) module.
The CPE is the implementation of 'Conditional Positional Encodings
for Vision Transformers <https://arxiv.org/abs/2102.10882>'_.
Args:
in_channels (int): Number of input channels.
embed_dims (int): The feature dimension. Default: 768.
stride (int): Stride of conv layer. Default: 1.
"""
def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.proj = nn.Conv2d(
in_channels,
embed_dims,
kernel_size=3,
stride=stride,
padding=1,
bias=True,
groups=embed_dims)
self.stride = stride
def forward(self, x, hw_shape):
b, n, c = x.shape
h, w = hw_shape
feat_token = x
cnn_feat = feat_token.transpose(1, 2).view(b, c, h, w)
if self.stride == 1:
x = self.proj(cnn_feat) + cnn_feat
else:
x = self.proj(cnn_feat)
x = x.flatten(2).transpose(1, 2)
return x
@MODELS.register_module()
class PCPVT(BaseModule):
"""The backbone of Twins-PCPVT.
This backbone is the implementation of `Twins: Revisiting the Design
of Spatial Attention in Vision Transformers
<https://arxiv.org/abs/1512.03385>`_.
Args:
in_channels (int): Number of input channels. Default: 3.
embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512].
patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2].
strides (list): The strides. Default: [4, 2, 2, 2].
num_heads (int): Number of attention heads. Default: [1, 2, 4, 8].
mlp_ratios (int): Ratio of mlp hidden dim to embedding dim.
Default: [4, 4, 4, 4].
out_indices (tuple[int]): Output from which stages.
Default: (0, 1, 2, 3).
qkv_bias (bool): Enable bias for qkv if True. Default: False.
drop_rate (float): Probability of an element to be zeroed.
Default 0.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
drop_path_rate (float): Stochastic depth rate. Default 0.0
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
depths (list): Depths of each stage. Default [3, 4, 6, 3]
sr_ratios (list): Kernel_size of conv in each Attn module in
Transformer encoder layer. Default: [8, 4, 2, 1].
norm_after_stagebool): Add extra norm. Default False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
in_channels=3,
embed_dims=[64, 128, 256, 512],
patch_sizes=[4, 2, 2, 2],
strides=[4, 2, 2, 2],
num_heads=[1, 2, 4, 8],
mlp_ratios=[4, 4, 4, 4],
out_indices=(0, 1, 2, 3),
qkv_bias=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN'),
depths=[3, 4, 6, 3],
sr_ratios=[8, 4, 2, 1],
norm_after_stage=False,
pretrained=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.depths = depths
# patch_embed
self.patch_embeds = ModuleList()
self.position_encoding_drops = ModuleList()
self.layers = ModuleList()
for i in range(len(depths)):
self.patch_embeds.append(
PatchEmbed(
in_channels=in_channels if i == 0 else embed_dims[i - 1],
embed_dims=embed_dims[i],
conv_type='Conv2d',
kernel_size=patch_sizes[i],
stride=strides[i],
padding='corner',
norm_cfg=norm_cfg))
self.position_encoding_drops.append(nn.Dropout(p=drop_rate))
self.position_encodings = ModuleList([
ConditionalPositionEncoding(embed_dim, embed_dim)
for embed_dim in embed_dims
])
# transformer encoder
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
cur = 0
for k in range(len(depths)):
_block = ModuleList([
GSAEncoderLayer(
embed_dims=embed_dims[k],
num_heads=num_heads[k],
feedforward_channels=mlp_ratios[k] * embed_dims[k],
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
drop_path_rate=dpr[cur + i],
num_fcs=2,
qkv_bias=qkv_bias,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
sr_ratio=sr_ratios[k]) for i in range(depths[k])
])
self.layers.append(_block)
cur += depths[k]
self.norm_name, norm = build_norm_layer(
norm_cfg, embed_dims[-1], postfix=1)
self.out_indices = out_indices
self.norm_after_stage = norm_after_stage
if self.norm_after_stage:
self.norm_list = ModuleList()
for dim in embed_dims:
self.norm_list.append(build_norm_layer(norm_cfg, dim)[1])
def init_weights(self):
if self.init_cfg is not None:
super().init_weights()
else:
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m, val=1.0, bias=0.)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[
1] * m.out_channels
fan_out //= m.groups
normal_init(
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
def forward(self, x):
outputs = list()
b = x.shape[0]
for i in range(len(self.depths)):
x, hw_shape = self.patch_embeds[i](x)
h, w = hw_shape
x = self.position_encoding_drops[i](x)
for j, blk in enumerate(self.layers[i]):
x = blk(x, hw_shape)
if j == 0:
x = self.position_encodings[i](x, hw_shape)
if self.norm_after_stage:
x = self.norm_list[i](x)
x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
if i in self.out_indices:
outputs.append(x)
return tuple(outputs)
@MODELS.register_module()
class SVT(PCPVT):
"""The backbone of Twins-SVT.
This backbone is the implementation of `Twins: Revisiting the Design
of Spatial Attention in Vision Transformers
<https://arxiv.org/abs/1512.03385>`_.
Args:
in_channels (int): Number of input channels. Default: 3.
embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512].
patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2].
strides (list): The strides. Default: [4, 2, 2, 2].
num_heads (int): Number of attention heads. Default: [1, 2, 4].
mlp_ratios (int): Ratio of mlp hidden dim to embedding dim.
Default: [4, 4, 4].
out_indices (tuple[int]): Output from which stages.
Default: (0, 1, 2, 3).
qkv_bias (bool): Enable bias for qkv if True. Default: False.
drop_rate (float): Dropout rate. Default 0.
attn_drop_rate (float): Dropout ratio of attention weight.
Default 0.0
drop_path_rate (float): Stochastic depth rate. Default 0.2.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
depths (list): Depths of each stage. Default [4, 4, 4].
sr_ratios (list): Kernel_size of conv in each Attn module in
Transformer encoder layer. Default: [4, 2, 1].
windiow_sizes (list): Window size of LSA. Default: [7, 7, 7],
input_features_slicebool): Input features need slice. Default: False.
norm_after_stagebool): Add extra norm. Default False.
strides (list): Strides in patch-Embedding modules. Default: (2, 2, 2)
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
in_channels=3,
embed_dims=[64, 128, 256],
patch_sizes=[4, 2, 2, 2],
strides=[4, 2, 2, 2],
num_heads=[1, 2, 4],
mlp_ratios=[4, 4, 4],
out_indices=(0, 1, 2, 3),
qkv_bias=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
norm_cfg=dict(type='LN'),
depths=[4, 4, 4],
sr_ratios=[4, 2, 1],
windiow_sizes=[7, 7, 7],
norm_after_stage=True,
pretrained=None,
init_cfg=None):
super().__init__(in_channels, embed_dims, patch_sizes, strides,
num_heads, mlp_ratios, out_indices, qkv_bias,
drop_rate, attn_drop_rate, drop_path_rate, norm_cfg,
depths, sr_ratios, norm_after_stage, pretrained,
init_cfg)
# transformer encoder
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
for k in range(len(depths)):
for i in range(depths[k]):
if i % 2 == 0:
self.layers[k][i] = \
LSAEncoderLayer(
embed_dims=embed_dims[k],
num_heads=num_heads[k],
feedforward_channels=mlp_ratios[k] * embed_dims[k],
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[sum(depths[:k])+i],
qkv_bias=qkv_bias,
window_size=windiow_sizes[k])

View File

@@ -0,0 +1,436 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from mmengine.model import BaseModule
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmseg.registry import MODELS
from ..utils import UpConvBlock, Upsample
class BasicConvBlock(nn.Module):
"""Basic convolutional block for UNet.
This module consists of several plain convolutional layers.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
num_convs (int): Number of convolutional layers. Default: 2.
stride (int): Whether use stride convolution to downsample
the input feature map. If stride=2, it only uses stride convolution
in the first convolutional layer to downsample the input feature
map. Options are 1 or 2. Default: 1.
dilation (int): Whether use dilated convolution to expand the
receptive field. Set dilation rate of each convolutional layer and
the dilation rate of the first convolutional layer is always 1.
Default: 1.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
conv_cfg (dict | None): Config dict for convolution layer.
Default: None.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
dcn (bool): Use deformable convolution in convolutional layer or not.
Default: None.
plugins (dict): plugins for convolutional layers. Default: None.
"""
def __init__(self,
in_channels,
out_channels,
num_convs=2,
stride=1,
dilation=1,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
dcn=None,
plugins=None):
super().__init__()
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
self.with_cp = with_cp
convs = []
for i in range(num_convs):
convs.append(
ConvModule(
in_channels=in_channels if i == 0 else out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride if i == 0 else 1,
dilation=1 if i == 0 else dilation,
padding=1 if i == 0 else dilation,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.convs = nn.Sequential(*convs)
def forward(self, x):
"""Forward function."""
if self.with_cp and x.requires_grad:
out = cp.checkpoint(self.convs, x)
else:
out = self.convs(x)
return out
@MODELS.register_module()
class DeconvModule(nn.Module):
"""Deconvolution upsample module in decoder for UNet (2X upsample).
This module uses deconvolution to upsample feature map in the decoder
of UNet.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
kernel_size (int): Kernel size of the convolutional layer. Default: 4.
"""
def __init__(self,
in_channels,
out_channels,
with_cp=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
*,
kernel_size=4,
scale_factor=2):
super().__init__()
assert (kernel_size - scale_factor >= 0) and\
(kernel_size - scale_factor) % 2 == 0,\
f'kernel_size should be greater than or equal to scale_factor '\
f'and (kernel_size - scale_factor) should be even numbers, '\
f'while the kernel size is {kernel_size} and scale_factor is '\
f'{scale_factor}.'
stride = scale_factor
padding = (kernel_size - scale_factor) // 2
self.with_cp = with_cp
deconv = nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding)
norm_name, norm = build_norm_layer(norm_cfg, out_channels)
activate = build_activation_layer(act_cfg)
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
def forward(self, x):
"""Forward function."""
if self.with_cp and x.requires_grad:
out = cp.checkpoint(self.deconv_upsamping, x)
else:
out = self.deconv_upsamping(x)
return out
@MODELS.register_module()
class InterpConv(nn.Module):
"""Interpolation upsample module in decoder for UNet.
This module uses interpolation to upsample feature map in the decoder
of UNet. It consists of one interpolation upsample layer and one
convolutional layer. It can be one interpolation upsample layer followed
by one convolutional layer (conv_first=False) or one convolutional layer
followed by one interpolation upsample layer (conv_first=True).
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
conv_cfg (dict | None): Config dict for convolution layer.
Default: None.
conv_first (bool): Whether convolutional layer or interpolation
upsample layer first. Default: False. It means interpolation
upsample layer followed by one convolutional layer.
kernel_size (int): Kernel size of the convolutional layer. Default: 1.
stride (int): Stride of the convolutional layer. Default: 1.
padding (int): Padding of the convolutional layer. Default: 1.
upsample_cfg (dict): Interpolation config of the upsample layer.
Default: dict(
scale_factor=2, mode='bilinear', align_corners=False).
"""
def __init__(self,
in_channels,
out_channels,
with_cp=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
*,
conv_cfg=None,
conv_first=False,
kernel_size=1,
stride=1,
padding=0,
upsample_cfg=dict(
scale_factor=2, mode='bilinear', align_corners=False)):
super().__init__()
self.with_cp = with_cp
conv = ConvModule(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
upsample = Upsample(**upsample_cfg)
if conv_first:
self.interp_upsample = nn.Sequential(conv, upsample)
else:
self.interp_upsample = nn.Sequential(upsample, conv)
def forward(self, x):
"""Forward function."""
if self.with_cp and x.requires_grad:
out = cp.checkpoint(self.interp_upsample, x)
else:
out = self.interp_upsample(x)
return out
@MODELS.register_module()
class UNet(BaseModule):
"""UNet backbone.
This backbone is the implementation of `U-Net: Convolutional Networks
for Biomedical Image Segmentation <https://arxiv.org/abs/1505.04597>`_.
Args:
in_channels (int): Number of input image channels. Default" 3.
base_channels (int): Number of base channels of each stage.
The output channels of the first stage. Default: 64.
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
len(strides) is equal to num_stages. Normally the stride of the
first stage in encoder is 1. If strides[i]=2, it uses stride
convolution to downsample in the correspondence encoder stage.
Default: (1, 1, 1, 1, 1).
enc_num_convs (Sequence[int]): Number of convolutional layers in the
convolution block of the correspondence encoder stage.
Default: (2, 2, 2, 2, 2).
dec_num_convs (Sequence[int]): Number of convolutional layers in the
convolution block of the correspondence decoder stage.
Default: (2, 2, 2, 2).
downsamples (Sequence[int]): Whether use MaxPool to downsample the
feature map after the first stage of encoder
(stages: [1, num_stages)). If the correspondence encoder stage use
stride convolution (strides[i]=2), it will never use MaxPool to
downsample, even downsamples[i-1]=True.
Default: (True, True, True, True).
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
Default: (1, 1, 1, 1, 1).
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
Default: (1, 1, 1, 1).
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
conv_cfg (dict | None): Config dict for convolution layer.
Default: None.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
upsample_cfg (dict): The upsample config of the upsample module in
decoder. Default: dict(type='InterpConv').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
dcn (bool): Use deformable convolution in convolutional layer or not.
Default: None.
plugins (dict): plugins for convolutional layers. Default: None.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
Notice:
The input image size should be divisible by the whole downsample rate
of the encoder. More detail of the whole downsample rate can be found
in UNet._check_input_divisible.
"""
def __init__(self,
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1),
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
upsample_cfg=dict(type='InterpConv'),
norm_eval=False,
dcn=None,
plugins=None,
pretrained=None,
init_cfg=None):
super().__init__(init_cfg)
self.pretrained = pretrained
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
assert len(strides) == num_stages, \
'The length of strides should be equal to num_stages, '\
f'while the strides is {strides}, the length of '\
f'strides is {len(strides)}, and the num_stages is '\
f'{num_stages}.'
assert len(enc_num_convs) == num_stages, \
'The length of enc_num_convs should be equal to num_stages, '\
f'while the enc_num_convs is {enc_num_convs}, the length of '\
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
f'{num_stages}.'
assert len(dec_num_convs) == (num_stages-1), \
'The length of dec_num_convs should be equal to (num_stages-1), '\
f'while the dec_num_convs is {dec_num_convs}, the length of '\
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
f'{num_stages}.'
assert len(downsamples) == (num_stages-1), \
'The length of downsamples should be equal to (num_stages-1), '\
f'while the downsamples is {downsamples}, the length of '\
f'downsamples is {len(downsamples)}, and the num_stages is '\
f'{num_stages}.'
assert len(enc_dilations) == num_stages, \
'The length of enc_dilations should be equal to num_stages, '\
f'while the enc_dilations is {enc_dilations}, the length of '\
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
f'{num_stages}.'
assert len(dec_dilations) == (num_stages-1), \
'The length of dec_dilations should be equal to (num_stages-1), '\
f'while the dec_dilations is {dec_dilations}, the length of '\
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
f'{num_stages}.'
self.num_stages = num_stages
self.strides = strides
self.downsamples = downsamples
self.norm_eval = norm_eval
self.base_channels = base_channels
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
for i in range(num_stages):
enc_conv_block = []
if i != 0:
if strides[i] == 1 and downsamples[i - 1]:
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
upsample = (strides[i] != 1 or downsamples[i - 1])
self.decoder.append(
UpConvBlock(
conv_block=BasicConvBlock,
in_channels=base_channels * 2**i,
skip_channels=base_channels * 2**(i - 1),
out_channels=base_channels * 2**(i - 1),
num_convs=dec_num_convs[i - 1],
stride=1,
dilation=dec_dilations[i - 1],
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
upsample_cfg=upsample_cfg if upsample else None,
dcn=None,
plugins=None))
enc_conv_block.append(
BasicConvBlock(
in_channels=in_channels,
out_channels=base_channels * 2**i,
num_convs=enc_num_convs[i],
stride=strides[i],
dilation=enc_dilations[i],
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
dcn=None,
plugins=None))
self.encoder.append(nn.Sequential(*enc_conv_block))
in_channels = base_channels * 2**i
def forward(self, x):
self._check_input_divisible(x)
enc_outs = []
for enc in self.encoder:
x = enc(x)
enc_outs.append(x)
dec_outs = [x]
for i in reversed(range(len(self.decoder))):
x = self.decoder[i](enc_outs[i], x)
dec_outs.append(x)
return dec_outs
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super().train(mode)
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def _check_input_divisible(self, x):
h, w = x.shape[-2:]
whole_downsample_rate = 1
for i in range(1, self.num_stages):
if self.strides[i] == 2 or self.downsamples[i - 1]:
whole_downsample_rate *= 2
assert (h % whole_downsample_rate == 0) \
and (w % whole_downsample_rate == 0),\
f'The input image size {(h, w)} should be divisible by the whole '\
f'downsample rate {whole_downsample_rate}, when num_stages is '\
f'{self.num_stages}, strides is {self.strides}, and downsamples '\
f'is {self.downsamples}.'

View File

@@ -0,0 +1,501 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import warnings
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmengine.logging import print_log
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import (constant_init, kaiming_init,
trunc_normal_)
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.registry import MODELS
from ..utils import PatchEmbed, resize
class TransformerEncoderLayer(BaseModule):
"""Implements one encoder layer in Vision Transformer.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default: 0.0.
attn_drop_rate (float): The drop out rate for attention layer.
Default: 0.0.
drop_path_rate (float): stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): enable bias for qkv if True. Default: True
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default: True.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
batch_first=True,
attn_cfg=dict(),
ffn_cfg=dict(),
with_cp=False):
super().__init__()
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
attn_cfg.update(
dict(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
batch_first=batch_first,
bias=qkv_bias))
self.build_attn(attn_cfg)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
ffn_cfg.update(
dict(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate)
if drop_path_rate > 0 else None,
act_cfg=act_cfg))
self.build_ffn(ffn_cfg)
self.with_cp = with_cp
def build_attn(self, attn_cfg):
self.attn = MultiheadAttention(**attn_cfg)
def build_ffn(self, ffn_cfg):
self.ffn = FFN(**ffn_cfg)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def forward(self, x):
def _inner_forward(x):
x = self.attn(self.norm1(x), identity=x)
x = self.ffn(self.norm2(x), identity=x)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
@MODELS.register_module()
class VisionTransformer(BaseModule):
"""Vision Transformer.
This backbone is the implementation of `An Image is Worth 16x16 Words:
Transformers for Image Recognition at
Scale <https://arxiv.org/abs/2010.11929>`_.
Args:
img_size (int | tuple): Input image size. Default: 224.
patch_size (int): The patch size. Default: 16.
patch_pad (str | int | None): The padding method in patch embedding.
Default: 'corner'.
in_channels (int): Number of input channels. Default: 3.
embed_dims (int): embedding dimension. Default: 768.
num_layers (int): depth of transformer. Default: 12.
num_heads (int): number of attention heads. Default: 12.
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
Default: 4.
out_origin (bool): Whether to output the original input embedding.
Default: False
out_indices (list | tuple | int): Output from which stages.
Default: -1.
qkv_bias (bool): enable bias for qkv if True. Default: True.
drop_rate (float): Probability of an element to be zeroed.
Default 0.0
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
drop_path_rate (float): stochastic depth rate. Default 0.0
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Default: True.
output_cls_token (bool): Whether output the cls_token. If set True,
`with_cls_token` must be True. Default: False.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
patch_bias (dict): Whether use bias in convolution of PatchEmbed Block.
Default: True.
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
Default: False.
pre_norm (bool): Whether to add a norm before Transformer Layers.
Default: False.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Default: False.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Default: bicubic.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
frozen_exclude (List): List of parameters that are not to be frozen.
Default: ["all"], "all" means there are no frozen parameters.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
img_size=224,
patch_size=16,
patch_pad='corner',
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_origin=False,
out_indices=-1,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
with_cls_token=True,
output_cls_token=False,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
patch_norm=False,
patch_bias=False,
pre_norm=False,
final_norm=False,
interpolate_mode='bicubic',
num_fcs=2,
norm_eval=False,
with_cp=False,
frozen_exclude=['all'],
pretrained=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, tuple):
if len(img_size) == 1:
img_size = to_2tuple(img_size[0])
assert len(img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.img_size = img_size
self.patch_size = patch_size
self.interpolate_mode = interpolate_mode
self.norm_eval = norm_eval
self.with_cp = with_cp
self.pretrained = pretrained
self.out_origin = out_origin
self.frozen_exclude = frozen_exclude
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
padding=patch_pad,
bias=patch_bias,
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None,
)
num_patches = (img_size[0] // patch_size) * \
(img_size[1] // patch_size)
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dims))
self.drop_after_pos = nn.Dropout(p=drop_rate)
self.pre_norm = pre_norm
if self.pre_norm:
self.pre_ln_name, pre_ln = build_norm_layer(
norm_cfg, embed_dims, postfix='_pre')
self.add_module(self.pre_ln_name, pre_ln)
if isinstance(out_indices, int):
if out_indices == -1:
out_indices = num_layers - 1
self.out_indices = [out_indices]
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
self.out_indices = out_indices
else:
raise TypeError('out_indices must be type of int, list or tuple')
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
] # stochastic depth decay rule
self.layers = ModuleList()
for i in range(num_layers):
self.layers.append(
TransformerEncoderLayer(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=mlp_ratio * embed_dims,
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
num_fcs=num_fcs,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
batch_first=True))
self.final_norm = final_norm
if final_norm:
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self._freeze()
@property
def pre_ln(self):
return getattr(self, self.pre_ln_name)
@property
def norm1(self):
return getattr(self, self.norm1_name)
def init_weights(self):
if isinstance(self.init_cfg, dict) and \
self.init_cfg.get('type') in ['Pretrained', 'Pretrained_Part']:
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
if self.init_cfg.get('type') == 'Pretrained':
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
elif self.init_cfg.get('type') == 'Pretrained_Part':
state_dict = checkpoint.copy()
para_prefix = 'image_encoder'
prefix_len = len(para_prefix) + 1
for k, v in checkpoint.items():
state_dict.pop(k)
if para_prefix in k:
state_dict[k[prefix_len:]] = v
if 'pos_embed' in state_dict.keys():
if self.pos_embed.shape != state_dict['pos_embed'].shape:
print_log(msg=f'Resize the pos_embed shape from '
f'{state_dict["pos_embed"].shape} to '
f'{self.pos_embed.shape}')
h, w = self.img_size
pos_size = int(
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
state_dict['pos_embed'] = self.resize_pos_embed(
state_dict['pos_embed'],
(h // self.patch_size, w // self.patch_size),
(pos_size, pos_size), self.interpolate_mode)
load_state_dict(self, state_dict, strict=False, logger=None)
elif self.init_cfg is not None:
super().init_weights()
else:
# We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
if 'ffn' in n:
nn.init.normal_(m.bias, mean=0., std=1e-6)
else:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
kaiming_init(m, mode='fan_in', bias=0.)
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m, val=1.0, bias=0.)
def _freeze(self):
if 'all' in self.frozen_exclude:
return
for name, param in self.named_parameters():
if not any([exclude in name for exclude in self.frozen_exclude]):
param.requires_grad = False
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
"""Positioning embeding method.
Resize the pos_embed, if the input image size doesn't match
the training size.
Args:
patched_img (torch.Tensor): The patched image, it should be
shape of [B, L1, C].
hw_shape (tuple): The downsampled image resolution.
pos_embed (torch.Tensor): The pos_embed weighs, it should be
shape of [B, L2, c].
Return:
torch.Tensor: The pos encoded image feature.
"""
assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
'the shapes of patched_img and pos_embed must be [B, L, C]'
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
if x_len != pos_len:
if pos_len == (self.img_size[0] // self.patch_size) * (
self.img_size[1] // self.patch_size) + 1:
pos_h = self.img_size[0] // self.patch_size
pos_w = self.img_size[1] // self.patch_size
else:
raise ValueError(
'Unexpected shape of pos_embed, got {}.'.format(
pos_embed.shape))
pos_embed = self.resize_pos_embed(pos_embed, hw_shape,
(pos_h, pos_w),
self.interpolate_mode)
return self.drop_after_pos(patched_img + pos_embed)
@staticmethod
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
"""Resize pos_embed weights.
Resize pos_embed using bicubic interpolate method.
Args:
pos_embed (torch.Tensor): Position embedding weights.
input_shpae (tuple): Tuple for (downsampled input image height,
downsampled input image width).
pos_shape (tuple): The resolution of downsampled origin training
image.
mode (str): Algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'``. Default: ``'nearest'``
Return:
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
"""
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
pos_h, pos_w = pos_shape
cls_token_weight = pos_embed[:, 0]
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
pos_embed_weight = pos_embed_weight.reshape(
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
pos_embed_weight = resize(
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
cls_token_weight = cls_token_weight.unsqueeze(1)
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
return pos_embed
def forward(self, inputs):
B = inputs.shape[0]
x, hw_shape = self.patch_embed(inputs)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = self._pos_embeding(x, hw_shape, self.pos_embed)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
if self.pre_norm:
x = self.pre_ln(x)
outs = []
if self.out_origin:
if self.with_cls_token:
# Remove class token and reshape token for decoder head
out = x[:, 1:]
else:
out = x
B, _, C = out.shape
out = out.reshape(B, hw_shape[0], hw_shape[1],
C).permute(0, 3, 1, 2).contiguous()
if self.output_cls_token:
out = [out, x[:, 0]]
outs.append(out)
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1:
if self.final_norm:
x = self.norm1(x)
if i in self.out_indices:
if self.with_cls_token:
# Remove class token and reshape token for decoder head
out = x[:, 1:]
else:
out = x
B, _, C = out.shape
out = out.reshape(B, hw_shape[0], hw_shape[1],
C).permute(0, 3, 1, 2).contiguous()
if self.output_cls_token:
out = [out, x[:, 0]]
outs.append(out)
return tuple(outs)
def train(self, mode=True):
super().train(mode)
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, nn.LayerNorm):
m.eval()

View File

@@ -0,0 +1,395 @@
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------------
# Adapted from https://github.com/wl-zhao/VPD/blob/main/vpd/models.py
# Original licence: MIT License
# ------------------------------------------------------------------------------
import math
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmengine.runner import CheckpointLoader, load_checkpoint
from mmseg.registry import MODELS
from mmseg.utils import ConfigType, OptConfigType
try:
from ldm.modules.diffusionmodules.util import timestep_embedding
from ldm.util import instantiate_from_config
has_ldm = True
except ImportError:
has_ldm = False
def register_attention_control(model, controller):
"""Registers a control function to manage attention within a model.
Args:
model: The model to which attention is to be registered.
controller: The control function responsible for managing attention.
"""
def ca_forward(self, place_in_unet):
"""Custom forward method for attention.
Args:
self: Reference to the current object.
place_in_unet: The location in UNet (down/mid/up).
Returns:
The modified forward method.
"""
def forward(x, context=None, mask=None):
h = self.heads
is_cross = context is not None
context = context or x # if context is None, use x
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
q, k, v = (
tensor.view(tensor.shape[0] * h, tensor.shape[1],
tensor.shape[2] // h) for tensor in [q, k, v])
sim = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if mask is not None:
mask = mask.flatten(1).unsqueeze(1).repeat(h, 1, 1)
max_neg_value = -torch.finfo(sim.dtype).max
sim.masked_fill_(~mask, max_neg_value)
attn = sim.softmax(dim=-1)
attn_mean = attn.view(h, attn.shape[0] // h,
*attn.shape[1:]).mean(0)
controller(attn_mean, is_cross, place_in_unet)
out = torch.matmul(attn, v)
out = out.view(out.shape[0] // h, out.shape[1], out.shape[2] * h)
return self.to_out(out)
return forward
def register_recr(net_, count, place_in_unet):
"""Recursive function to register the custom forward method to all
CrossAttention layers.
Args:
net_: The network layer currently being processed.
count: The current count of layers processed.
place_in_unet: The location in UNet (down/mid/up).
Returns:
The updated count of layers processed.
"""
if net_.__class__.__name__ == 'CrossAttention':
net_.forward = ca_forward(net_, place_in_unet)
return count + 1
if hasattr(net_, 'children'):
return sum(
register_recr(child, 0, place_in_unet)
for child in net_.children())
return count
cross_att_count = sum(
register_recr(net[1], 0, place) for net, place in [
(child, 'down') if 'input_blocks' in name else (
child, 'up') if 'output_blocks' in name else
(child,
'mid') if 'middle_block' in name else (None, None) # Default case
for name, child in model.diffusion_model.named_children()
] if net is not None)
controller.num_att_layers = cross_att_count
class AttentionStore:
"""A class for storing attention information in the UNet model.
Attributes:
base_size (int): Base size for storing attention information.
max_size (int): Maximum size for storing attention information.
"""
def __init__(self, base_size=64, max_size=None):
"""Initialize AttentionStore with default or custom sizes."""
self.reset()
self.base_size = base_size
self.max_size = max_size or (base_size // 2)
self.num_att_layers = -1
@staticmethod
def get_empty_store():
"""Returns an empty store for holding attention values."""
return {
key: []
for key in [
'down_cross', 'mid_cross', 'up_cross', 'down_self', 'mid_self',
'up_self'
]
}
def reset(self):
"""Resets the step and attention stores to their initial states."""
self.cur_step = 0
self.cur_att_layer = 0
self.step_store = self.get_empty_store()
self.attention_store = {}
def forward(self, attn, is_cross: bool, place_in_unet: str):
"""Processes a single forward step, storing the attention.
Args:
attn: The attention tensor.
is_cross (bool): Whether it's cross attention.
place_in_unet (str): The location in UNet (down/mid/up).
Returns:
The unmodified attention tensor.
"""
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
if attn.shape[1] <= (self.max_size)**2:
self.step_store[key].append(attn)
return attn
def between_steps(self):
"""Processes and stores attention information between steps."""
if not self.attention_store:
self.attention_store = self.step_store
else:
for key in self.attention_store:
self.attention_store[key] = [
stored + step for stored, step in zip(
self.attention_store[key], self.step_store[key])
]
self.step_store = self.get_empty_store()
def get_average_attention(self):
"""Calculates and returns the average attention across all steps."""
return {
key: [item for item in self.step_store[key]]
for key in self.step_store
}
def __call__(self, attn, is_cross: bool, place_in_unet: str):
"""Allows the class instance to be callable."""
return self.forward(attn, is_cross, place_in_unet)
@property
def num_uncond_att_layers(self):
"""Returns the number of unconditional attention layers (default is
0)."""
return 0
def step_callback(self, x_t):
"""A placeholder for a step callback.
Returns the input unchanged.
"""
return x_t
class UNetWrapper(nn.Module):
"""A wrapper for UNet with optional attention mechanisms.
Args:
unet (nn.Module): The UNet model to wrap
use_attn (bool): Whether to use attention. Defaults to True
base_size (int): Base size for the attention store. Defaults to 512
max_attn_size (int, optional): Maximum size for the attention store.
Defaults to None
attn_selector (str): The types of attention to use.
Defaults to 'up_cross+down_cross'
"""
def __init__(self,
unet,
use_attn=True,
base_size=512,
max_attn_size=None,
attn_selector='up_cross+down_cross'):
super().__init__()
assert has_ldm, 'To use UNetWrapper, please install required ' \
'packages via `pip install -r requirements/optional.txt`.'
self.unet = unet
self.attention_store = AttentionStore(
base_size=base_size // 8, max_size=max_attn_size)
self.attn_selector = attn_selector.split('+')
self.use_attn = use_attn
self.init_sizes(base_size)
if self.use_attn:
register_attention_control(unet, self.attention_store)
def init_sizes(self, base_size):
"""Initialize sizes based on the base size."""
self.size16 = base_size // 32
self.size32 = base_size // 16
self.size64 = base_size // 8
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""Forward pass through the model."""
diffusion_model = self.unet.diffusion_model
if self.use_attn:
self.attention_store.reset()
hs, emb, out_list = self._unet_forward(x, timesteps, context, y,
diffusion_model)
if self.use_attn:
self._append_attn_to_output(out_list)
return out_list[::-1]
def _unet_forward(self, x, timesteps, context, y, diffusion_model):
hs = []
t_emb = timestep_embedding(
timesteps, diffusion_model.model_channels, repeat_only=False)
emb = diffusion_model.time_embed(t_emb)
h = x.type(diffusion_model.dtype)
for module in diffusion_model.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = diffusion_model.middle_block(h, emb, context)
out_list = []
for i_out, module in enumerate(diffusion_model.output_blocks):
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
if i_out in [1, 4, 7]:
out_list.append(h)
h = h.type(x.dtype)
out_list.append(h)
return hs, emb, out_list
def _append_attn_to_output(self, out_list):
avg_attn = self.attention_store.get_average_attention()
attns = {self.size16: [], self.size32: [], self.size64: []}
for k in self.attn_selector:
for up_attn in avg_attn[k]:
size = int(math.sqrt(up_attn.shape[1]))
up_attn = up_attn.transpose(-1, -2).reshape(
*up_attn.shape[:2], size, -1)
attns[size].append(up_attn)
attn16 = torch.stack(attns[self.size16]).mean(0)
attn32 = torch.stack(attns[self.size32]).mean(0)
attn64 = torch.stack(attns[self.size64]).mean(0) if len(
attns[self.size64]) > 0 else None
out_list[1] = torch.cat([out_list[1], attn16], dim=1)
out_list[2] = torch.cat([out_list[2], attn32], dim=1)
if attn64 is not None:
out_list[3] = torch.cat([out_list[3], attn64], dim=1)
class TextAdapter(nn.Module):
"""A PyTorch Module that serves as a text adapter.
This module takes text embeddings and adjusts them based on a scaling
factor gamma.
"""
def __init__(self, text_dim=768):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(text_dim, text_dim), nn.GELU(),
nn.Linear(text_dim, text_dim))
def forward(self, texts, gamma):
texts_after = self.fc(texts)
texts = texts + gamma * texts_after
return texts
@MODELS.register_module()
class VPD(BaseModule):
"""VPD (Visual Perception Diffusion) model.
.. _`VPD`: https://arxiv.org/abs/2303.02153
Args:
diffusion_cfg (dict): Configuration for diffusion model.
class_embed_path (str): Path for class embeddings.
unet_cfg (dict, optional): Configuration for U-Net.
gamma (float, optional): Gamma for text adaptation. Defaults to 1e-4.
class_embed_select (bool, optional): If True, enables class embedding
selection. Defaults to False.
pad_shape (Optional[Union[int, List[int]]], optional): Padding shape.
Defaults to None.
pad_val (Union[int, List[int]], optional): Padding value.
Defaults to 0.
init_cfg (dict, optional): Configuration for network initialization.
"""
def __init__(self,
diffusion_cfg: ConfigType,
class_embed_path: str,
unet_cfg: OptConfigType = dict(),
gamma: float = 1e-4,
class_embed_select=False,
pad_shape: Optional[Union[int, List[int]]] = None,
pad_val: Union[int, List[int]] = 0,
init_cfg: OptConfigType = None):
super().__init__(init_cfg=init_cfg)
assert has_ldm, 'To use VPD model, please install required packages' \
' via `pip install -r requirements/optional.txt`.'
if pad_shape is not None:
if not isinstance(pad_shape, (list, tuple)):
pad_shape = (pad_shape, pad_shape)
self.pad_shape = pad_shape
self.pad_val = pad_val
# diffusion model
diffusion_checkpoint = diffusion_cfg.pop('checkpoint', None)
sd_model = instantiate_from_config(diffusion_cfg)
if diffusion_checkpoint is not None:
load_checkpoint(sd_model, diffusion_checkpoint, strict=False)
self.encoder_vq = sd_model.first_stage_model
self.unet = UNetWrapper(sd_model.model, **unet_cfg)
# class embeddings & text adapter
class_embeddings = CheckpointLoader.load_checkpoint(class_embed_path)
text_dim = class_embeddings.size(-1)
self.text_adapter = TextAdapter(text_dim=text_dim)
self.class_embed_select = class_embed_select
if class_embed_select:
class_embeddings = torch.cat(
(class_embeddings, class_embeddings.mean(dim=0,
keepdims=True)),
dim=0)
self.register_buffer('class_embeddings', class_embeddings)
self.gamma = nn.Parameter(torch.ones(text_dim) * gamma)
def forward(self, x):
"""Extract features from images."""
# calculate cross-attn map
if self.class_embed_select:
if isinstance(x, (tuple, list)):
x, class_ids = x[:2]
class_ids = class_ids.tolist()
else:
class_ids = [-1] * x.size(0)
class_embeddings = self.class_embeddings[class_ids]
c_crossattn = self.text_adapter(class_embeddings, self.gamma)
c_crossattn = c_crossattn.unsqueeze(1)
else:
class_embeddings = self.class_embeddings
c_crossattn = self.text_adapter(class_embeddings, self.gamma)
c_crossattn = c_crossattn.unsqueeze(0).repeat(x.size(0), 1, 1)
# pad to required input shape for pretrained diffusion model
if self.pad_shape is not None:
pad_width = max(0, self.pad_shape[1] - x.shape[-1])
pad_height = max(0, self.pad_shape[0] - x.shape[-2])
x = F.pad(x, (0, pad_width, 0, pad_height), value=self.pad_val)
# forward the denoising model
with torch.no_grad():
latents = self.encoder_vq.encode(x).mode().detach()
t = torch.ones((x.shape[0], ), device=x.device).long()
outs = self.unet(latents, t, context=c_crossattn)
return outs