This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

View File

@@ -0,0 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .basic_block import BasicBlock, Bottleneck
from .embed import PatchEmbed
from .encoding import Encoding
from .inverted_residual import InvertedResidual, InvertedResidualV3
from .make_divisible import make_divisible
from .point_sample import get_uncertain_point_coords_with_randomness
from .ppm import DAPPM, PAPPM
from .res_layer import ResLayer
from .se_layer import SELayer
from .self_attention_block import SelfAttentionBlock
from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc,
nlc_to_nchw)
from .up_conv_block import UpConvBlock
# isort: off
from .wrappers import Upsample, resize
from .san_layers import MLP, LayerNorm2d, cross_attn_layer
__all__ = [
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed',
'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc', 'Encoding',
'Upsample', 'resize', 'DAPPM', 'PAPPM', 'BasicBlock', 'Bottleneck',
'cross_attn_layer', 'LayerNorm2d', 'MLP',
'get_uncertain_point_coords_with_randomness'
]

View File

@@ -0,0 +1,143 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from torch import Tensor
from mmseg.registry import MODELS
from mmseg.utils import OptConfigType
class BasicBlock(BaseModule):
"""Basic block from `ResNet <https://arxiv.org/abs/1512.03385>`_.
Args:
in_channels (int): Input channels.
channels (int): Output channels.
stride (int): Stride of the first block. Default: 1.
downsample (nn.Module, optional): Downsample operation on identity.
Default: None.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict, optional): Config dict for activation layer in
ConvModule. Default: dict(type='ReLU', inplace=True).
act_cfg_out (dict, optional): Config dict for activation layer at the
last of the block. Default: None.
init_cfg (dict, optional): Initialization config dict. Default: None.
"""
expansion = 1
def __init__(self,
in_channels: int,
channels: int,
stride: int = 1,
downsample: nn.Module = None,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
act_cfg_out: OptConfigType = dict(type='ReLU', inplace=True),
init_cfg: OptConfigType = None):
super().__init__(init_cfg)
self.conv1 = ConvModule(
in_channels,
channels,
kernel_size=3,
stride=stride,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv2 = ConvModule(
channels,
channels,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=None)
self.downsample = downsample
if act_cfg_out:
self.act = MODELS.build(act_cfg_out)
def forward(self, x: Tensor) -> Tensor:
residual = x
out = self.conv1(x)
out = self.conv2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
if hasattr(self, 'act'):
out = self.act(out)
return out
class Bottleneck(BaseModule):
"""Bottleneck block from `ResNet <https://arxiv.org/abs/1512.03385>`_.
Args:
in_channels (int): Input channels.
channels (int): Output channels.
stride (int): Stride of the first block. Default: 1.
downsample (nn.Module, optional): Downsample operation on identity.
Default: None.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict, optional): Config dict for activation layer in
ConvModule. Default: dict(type='ReLU', inplace=True).
act_cfg_out (dict, optional): Config dict for activation layer at
the last of the block. Default: None.
init_cfg (dict, optional): Initialization config dict. Default: None.
"""
expansion = 2
def __init__(self,
in_channels: int,
channels: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
act_cfg_out: OptConfigType = None,
init_cfg: OptConfigType = None):
super().__init__(init_cfg)
self.conv1 = ConvModule(
in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
self.conv2 = ConvModule(
channels,
channels,
3,
stride,
1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv3 = ConvModule(
channels,
channels * self.expansion,
1,
norm_cfg=norm_cfg,
act_cfg=None)
if act_cfg_out:
self.act = MODELS.build(act_cfg_out)
self.downsample = downsample
def forward(self, x: Tensor) -> Tensor:
residual = x
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
if self.downsample:
residual = self.downsample(x)
out += residual
if hasattr(self, 'act'):
out = self.act(out)
return out

View File

@@ -0,0 +1,330 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Sequence
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmengine.model import BaseModule
from mmengine.utils import to_2tuple
class AdaptivePadding(nn.Module):
"""Applies padding to input (if needed) so that input can get fully covered
by filter you specified. It support two modes "same" and "corner". The
"same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
input. The "corner" mode would pad zero to bottom right.
Args:
kernel_size (int | tuple): Size of the kernel:
stride (int | tuple): Stride of the filter. Default: 1:
dilation (int | tuple): Spacing between kernel elements.
Default: 1.
padding (str): Support "same" and "corner", "corner" mode
would pad zero to bottom right, and "same" mode would
pad zero around input. Default: "corner".
Example:
>>> kernel_size = 16
>>> stride = 16
>>> dilation = 1
>>> input = torch.rand(1, 1, 15, 17)
>>> adap_pad = AdaptivePadding(
>>> kernel_size=kernel_size,
>>> stride=stride,
>>> dilation=dilation,
>>> padding="corner")
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
>>> input = torch.rand(1, 1, 16, 17)
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
"""
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
super().__init__()
assert padding in ('same', 'corner')
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
self.padding = padding
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
def get_pad_shape(self, input_shape):
input_h, input_w = input_shape
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.stride
output_h = math.ceil(input_h / stride_h)
output_w = math.ceil(input_w / stride_w)
pad_h = max((output_h - 1) * stride_h +
(kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
pad_w = max((output_w - 1) * stride_w +
(kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
return pad_h, pad_w
def forward(self, x):
pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
if pad_h > 0 or pad_w > 0:
if self.padding == 'corner':
x = F.pad(x, [0, pad_w, 0, pad_h])
elif self.padding == 'same':
x = F.pad(x, [
pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
pad_h - pad_h // 2
])
return x
class PatchEmbed(BaseModule):
"""Image to Patch Embedding.
We use a conv layer to implement PatchEmbed.
Args:
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_type (str): The config dict for embedding
conv layer type selection. Default: "Conv2d".
kernel_size (int): The kernel_size of embedding conv. Default: 16.
stride (int, optional): The slide stride of embedding conv.
Default: None (Would be set as `kernel_size`).
padding (int | tuple | string ): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int): The dilation rate of embedding conv. Default: 1.
bias (bool): Bias of embed conv. Default: True.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
input_size (int | tuple | None): The size of input, which will be
used to calculate the out size. Only work when `dynamic_size`
is False. Default: None.
init_cfg (`mmengine.ConfigDict`, optional): The Config for
initialization. Default: None.
"""
def __init__(self,
in_channels=3,
embed_dims=768,
conv_type='Conv2d',
kernel_size=16,
stride=None,
padding='corner',
dilation=1,
bias=True,
norm_cfg=None,
input_size=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
if stride is None:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
if isinstance(padding, str):
self.adap_padding = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
# disable the padding of conv
padding = 0
else:
self.adap_padding = None
padding = to_2tuple(padding)
self.projection = build_conv_layer(
dict(type=conv_type),
in_channels=in_channels,
out_channels=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
if input_size:
input_size = to_2tuple(input_size)
# `init_out_size` would be used outside to
# calculate the num_patches
# when `use_abs_pos_embed` outside
self.init_input_size = input_size
if self.adap_padding:
pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
input_h, input_w = input_size
input_h = input_h + pad_h
input_w = input_w + pad_w
input_size = (input_h, input_w)
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
(kernel_size[0] - 1) - 1) // stride[0] + 1
w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
(kernel_size[1] - 1) - 1) // stride[1] + 1
self.init_out_size = (h_out, w_out)
else:
self.init_input_size = None
self.init_out_size = None
def forward(self, x):
"""
Args:
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, out_h * out_w, embed_dims)
- out_size (tuple[int]): Spatial shape of x, arrange as
(out_h, out_w).
"""
if self.adap_padding:
x = self.adap_padding(x)
x = self.projection(x)
out_size = (x.shape[2], x.shape[3])
x = x.flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x, out_size
class PatchMerging(BaseModule):
"""Merge patch feature map.
This layer groups feature map by kernel_size, and applies norm and linear
layers to the grouped feature map. Our implementation uses `nn.Unfold` to
merge patch, which is about 25% faster than original implementation.
Instead, we need to modify pretrained models for compatibility.
Args:
in_channels (int): The num of input channels.
out_channels (int): The num of output channels.
kernel_size (int | tuple, optional): the kernel size in the unfold
layer. Defaults to 2.
stride (int | tuple, optional): the stride of the sliding blocks in the
unfold layer. Default: None. (Would be set as `kernel_size`)
padding (int | tuple | string ): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int | tuple, optional): dilation parameter in the unfold
layer. Default: 1.
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults: False.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=2,
stride=None,
padding='corner',
dilation=1,
bias=False,
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
if stride:
stride = stride
else:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
if isinstance(padding, str):
self.adap_padding = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
# disable the padding of unfold
padding = 0
else:
self.adap_padding = None
padding = to_2tuple(padding)
self.sampler = nn.Unfold(
kernel_size=kernel_size,
dilation=dilation,
padding=padding,
stride=stride)
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
def forward(self, x, input_size):
"""
Args:
x (Tensor): Has shape (B, H*W, C_in).
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
Default: None.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
- out_size (tuple[int]): Spatial shape of x, arrange as
(Merged_H, Merged_W).
"""
B, L, C = x.shape
assert isinstance(input_size, Sequence), f'Expect ' \
f'input_size is ' \
f'`Sequence` ' \
f'but get {input_size}'
H, W = input_size
assert L == H * W, 'input feature has wrong size'
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
if self.adap_padding:
x = self.adap_padding(x)
H, W = x.shape[-2:]
x = self.sampler(x)
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
(self.sampler.kernel_size[0] - 1) -
1) // self.sampler.stride[0] + 1
out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
(self.sampler.kernel_size[1] - 1) -
1) // self.sampler.stride[1] + 1
output_size = (out_h, out_w)
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
x = self.norm(x) if self.norm else x
x = self.reduction(x)
return x, output_size

View File

@@ -0,0 +1,75 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn
from torch.nn import functional as F
class Encoding(nn.Module):
"""Encoding Layer: a learnable residual encoder.
Input is of shape (batch_size, channels, height, width).
Output is of shape (batch_size, num_codes, channels).
Args:
channels: dimension of the features or feature channels
num_codes: number of code words
"""
def __init__(self, channels, num_codes):
super().__init__()
# init codewords and smoothing factor
self.channels, self.num_codes = channels, num_codes
std = 1. / ((num_codes * channels)**0.5)
# [num_codes, channels]
self.codewords = nn.Parameter(
torch.empty(num_codes, channels,
dtype=torch.float).uniform_(-std, std),
requires_grad=True)
# [num_codes]
self.scale = nn.Parameter(
torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),
requires_grad=True)
@staticmethod
def scaled_l2(x, codewords, scale):
num_codes, channels = codewords.size()
batch_size = x.size(0)
reshaped_scale = scale.view((1, 1, num_codes))
expanded_x = x.unsqueeze(2).expand(
(batch_size, x.size(1), num_codes, channels))
reshaped_codewords = codewords.view((1, 1, num_codes, channels))
scaled_l2_norm = reshaped_scale * (
expanded_x - reshaped_codewords).pow(2).sum(dim=3)
return scaled_l2_norm
@staticmethod
def aggregate(assignment_weights, x, codewords):
num_codes, channels = codewords.size()
reshaped_codewords = codewords.view((1, 1, num_codes, channels))
batch_size = x.size(0)
expanded_x = x.unsqueeze(2).expand(
(batch_size, x.size(1), num_codes, channels))
encoded_feat = (assignment_weights.unsqueeze(3) *
(expanded_x - reshaped_codewords)).sum(dim=1)
return encoded_feat
def forward(self, x):
assert x.dim() == 4 and x.size(1) == self.channels
# [batch_size, channels, height, width]
batch_size = x.size(0)
# [batch_size, height x width, channels]
x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()
# assignment_weights: [batch_size, channels, num_codes]
assignment_weights = F.softmax(
self.scaled_l2(x, self.codewords, self.scale), dim=2)
# aggregate
encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
return encoded_feat
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \
f'x{self.channels})'
return repr_str

View File

@@ -0,0 +1,213 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import ConvModule
from torch import nn
from torch.utils import checkpoint as cp
from .se_layer import SELayer
class InvertedResidual(nn.Module):
"""InvertedResidual block for MobileNetV2.
Args:
in_channels (int): The input channels of the InvertedResidual block.
out_channels (int): The output channels of the InvertedResidual block.
stride (int): Stride of the middle (first) 3x3 convolution.
expand_ratio (int): Adjusts number of channels of the hidden layer
in InvertedResidual by this amount.
dilation (int): Dilation rate of depthwise conv. Default: 1
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU6').
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
Returns:
Tensor: The output tensor.
"""
def __init__(self,
in_channels,
out_channels,
stride,
expand_ratio,
dilation=1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'),
with_cp=False,
**kwargs):
super().__init__()
self.stride = stride
assert stride in [1, 2], f'stride must in [1, 2]. ' \
f'But received {stride}.'
self.with_cp = with_cp
self.use_res_connect = self.stride == 1 and in_channels == out_channels
hidden_dim = int(round(in_channels * expand_ratio))
layers = []
if expand_ratio != 1:
layers.append(
ConvModule(
in_channels=in_channels,
out_channels=hidden_dim,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**kwargs))
layers.extend([
ConvModule(
in_channels=hidden_dim,
out_channels=hidden_dim,
kernel_size=3,
stride=stride,
padding=dilation,
dilation=dilation,
groups=hidden_dim,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**kwargs),
ConvModule(
in_channels=hidden_dim,
out_channels=out_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None,
**kwargs)
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
def _inner_forward(x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
class InvertedResidualV3(nn.Module):
"""Inverted Residual Block for MobileNetV3.
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
mid_channels (int): The input channels of the depthwise convolution.
kernel_size (int): The kernel size of the depthwise convolution.
Default: 3.
stride (int): The stride of the depthwise convolution. Default: 1.
se_cfg (dict): Config dict for se layer. Default: None, which means no
se layer.
with_expand_conv (bool): Use expand conv or not. If set False,
mid_channels must be the same with in_channels. Default: True.
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
Returns:
Tensor: The output tensor.
"""
def __init__(self,
in_channels,
out_channels,
mid_channels,
kernel_size=3,
stride=1,
se_cfg=None,
with_expand_conv=True,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
with_cp=False):
super().__init__()
self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
assert stride in [1, 2]
self.with_cp = with_cp
self.with_se = se_cfg is not None
self.with_expand_conv = with_expand_conv
if self.with_se:
assert isinstance(se_cfg, dict)
if not self.with_expand_conv:
assert mid_channels == in_channels
if self.with_expand_conv:
self.expand_conv = ConvModule(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.depthwise_conv = ConvModule(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
groups=mid_channels,
conv_cfg=dict(
type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if self.with_se:
self.se = SELayer(**se_cfg)
self.linear_conv = ConvModule(
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
def forward(self, x):
def _inner_forward(x):
out = x
if self.with_expand_conv:
out = self.expand_conv(out)
out = self.depthwise_conv(out)
if self.with_se:
out = self.se(out)
out = self.linear_conv(out)
if self.with_res_shortcut:
return x + out
else:
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out

View File

@@ -0,0 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved.
def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
"""Make divisible function.
This function rounds the channel number to the nearest value that can be
divisible by the divisor. It is taken from the original tf repo. It ensures
that all layers have a channel number that is divisible by divisor. It can
be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa
Args:
value (int): The original channel number.
divisor (int): The divisor to fully divide the channel number.
min_value (int): The minimum value of the output channel.
Default: None, means that the minimum value equal to the divisor.
min_ratio (float): The minimum ratio of the rounded channel number to
the original channel number. Default: 0.9.
Returns:
int: The modified output channel number.
"""
if min_value is None:
min_value = divisor
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than (1-min_ratio).
if new_value < min_ratio * value:
new_value += divisor
return new_value

View File

@@ -0,0 +1,88 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import point_sample
from torch import Tensor
def get_uncertainty(mask_preds: Tensor, labels: Tensor) -> Tensor:
"""Estimate uncertainty based on pred logits.
We estimate uncertainty as L1 distance between 0.0 and the logits
prediction in 'mask_preds' for the foreground class in `classes`.
Args:
mask_preds (Tensor): mask predication logits, shape (num_rois,
num_classes, mask_height, mask_width).
labels (Tensor): Either predicted or ground truth label for
each predicted mask, of length num_rois.
Returns:
scores (Tensor): Uncertainty scores with the most uncertain
locations having the highest uncertainty score,
shape (num_rois, 1, mask_height, mask_width)
"""
if mask_preds.shape[1] == 1:
gt_class_logits = mask_preds.clone()
else:
inds = torch.arange(mask_preds.shape[0], device=mask_preds.device)
gt_class_logits = mask_preds[inds, labels].unsqueeze(1)
return -torch.abs(gt_class_logits)
def get_uncertain_point_coords_with_randomness(
mask_preds: Tensor, labels: Tensor, num_points: int,
oversample_ratio: float, importance_sample_ratio: float) -> Tensor:
"""Get ``num_points`` most uncertain points with random points during
train.
Sample points in [0, 1] x [0, 1] coordinate space based on their
uncertainty. The uncertainties are calculated for each point using
'get_uncertainty()' function that takes point's logit prediction as
input.
Args:
mask_preds (Tensor): A tensor of shape (num_rois, num_classes,
mask_height, mask_width) for class-specific or class-agnostic
prediction.
labels (Tensor): The ground truth class for each instance.
num_points (int): The number of points to sample.
oversample_ratio (float): Oversampling parameter.
importance_sample_ratio (float): Ratio of points that are sampled
via importnace sampling.
Returns:
point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
that contains the coordinates sampled points.
"""
assert oversample_ratio >= 1
assert 0 <= importance_sample_ratio <= 1
batch_size = mask_preds.shape[0]
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(
batch_size, num_sampled, 2, device=mask_preds.device)
point_logits = point_sample(mask_preds, point_coords)
# It is crucial to calculate uncertainty based on the sampled
# prediction value for the points. Calculating uncertainties of the
# coarse predictions first and sampling them for points leads to
# incorrect results. To illustrate this: assume uncertainty func(
# logits)=-abs(logits), a sampled point between two coarse
# predictions with -1 and 1 logits has 0 logits, and therefore 0
# uncertainty value. However, if we calculate uncertainties for the
# coarse predictions first, both will have -1 uncertainty,
# and sampled point will get -1 uncertainty.
point_uncertainties = get_uncertainty(point_logits, labels)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(
point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_sampled * torch.arange(
batch_size, dtype=torch.long, device=mask_preds.device)
idx += shift[:, None]
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
batch_size, num_uncertain_points, 2)
if num_random_points > 0:
rand_roi_coords = torch.rand(
batch_size, num_random_points, 2, device=mask_preds.device)
point_coords = torch.cat((point_coords, rand_roi_coords), dim=1)
return point_coords

View File

@@ -0,0 +1,193 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule, ModuleList, Sequential
from torch import Tensor
class DAPPM(BaseModule):
"""DAPPM module in `DDRNet <https://arxiv.org/abs/2101.06085>`_.
Args:
in_channels (int): Input channels.
branch_channels (int): Branch channels.
out_channels (int): Output channels.
num_scales (int): Number of scales.
kernel_sizes (list[int]): Kernel sizes of each scale.
strides (list[int]): Strides of each scale.
paddings (list[int]): Paddings of each scale.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU', inplace=True).
conv_cfg (dict): Config dict for convolution layer in ConvModule.
Default: dict(order=('norm', 'act', 'conv'), bias=False).
upsample_mode (str): Upsample mode. Default: 'bilinear'.
"""
def __init__(self,
in_channels: int,
branch_channels: int,
out_channels: int,
num_scales: int,
kernel_sizes: List[int] = [5, 9, 17],
strides: List[int] = [2, 4, 8],
paddings: List[int] = [2, 4, 8],
norm_cfg: Dict = dict(type='BN', momentum=0.1),
act_cfg: Dict = dict(type='ReLU', inplace=True),
conv_cfg: Dict = dict(
order=('norm', 'act', 'conv'), bias=False),
upsample_mode: str = 'bilinear'):
super().__init__()
self.num_scales = num_scales
self.unsample_mode = upsample_mode
self.in_channels = in_channels
self.branch_channels = branch_channels
self.out_channels = out_channels
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.conv_cfg = conv_cfg
self.scales = ModuleList([
ConvModule(
in_channels,
branch_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg)
])
for i in range(1, num_scales - 1):
self.scales.append(
Sequential(*[
nn.AvgPool2d(
kernel_size=kernel_sizes[i - 1],
stride=strides[i - 1],
padding=paddings[i - 1]),
ConvModule(
in_channels,
branch_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg)
]))
self.scales.append(
Sequential(*[
nn.AdaptiveAvgPool2d((1, 1)),
ConvModule(
in_channels,
branch_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg)
]))
self.processes = ModuleList()
for i in range(num_scales - 1):
self.processes.append(
ConvModule(
branch_channels,
branch_channels,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg))
self.compression = ConvModule(
branch_channels * num_scales,
out_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg)
self.shortcut = ConvModule(
in_channels,
out_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg)
def forward(self, inputs: Tensor):
feats = []
feats.append(self.scales[0](inputs))
for i in range(1, self.num_scales):
feat_up = F.interpolate(
self.scales[i](inputs),
size=inputs.shape[2:],
mode=self.unsample_mode)
feats.append(self.processes[i - 1](feat_up + feats[i - 1]))
return self.compression(torch.cat(feats,
dim=1)) + self.shortcut(inputs)
class PAPPM(DAPPM):
"""PAPPM module in `PIDNet <https://arxiv.org/abs/2206.02066>`_.
Args:
in_channels (int): Input channels.
branch_channels (int): Branch channels.
out_channels (int): Output channels.
num_scales (int): Number of scales.
kernel_sizes (list[int]): Kernel sizes of each scale.
strides (list[int]): Strides of each scale.
paddings (list[int]): Paddings of each scale.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN', momentum=0.1).
act_cfg (dict): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU', inplace=True).
conv_cfg (dict): Config dict for convolution layer in ConvModule.
Default: dict(order=('norm', 'act', 'conv'), bias=False).
upsample_mode (str): Upsample mode. Default: 'bilinear'.
"""
def __init__(self,
in_channels: int,
branch_channels: int,
out_channels: int,
num_scales: int,
kernel_sizes: List[int] = [5, 9, 17],
strides: List[int] = [2, 4, 8],
paddings: List[int] = [2, 4, 8],
norm_cfg: Dict = dict(type='BN', momentum=0.1),
act_cfg: Dict = dict(type='ReLU', inplace=True),
conv_cfg: Dict = dict(
order=('norm', 'act', 'conv'), bias=False),
upsample_mode: str = 'bilinear'):
super().__init__(in_channels, branch_channels, out_channels,
num_scales, kernel_sizes, strides, paddings, norm_cfg,
act_cfg, conv_cfg, upsample_mode)
self.processes = ConvModule(
self.branch_channels * (self.num_scales - 1),
self.branch_channels * (self.num_scales - 1),
kernel_size=3,
padding=1,
groups=self.num_scales - 1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
**self.conv_cfg)
def forward(self, inputs: Tensor):
x_ = self.scales[0](inputs)
feats = []
for i in range(1, self.num_scales):
feat_up = F.interpolate(
self.scales[i](inputs),
size=inputs.shape[2:],
mode=self.unsample_mode,
align_corners=False)
feats.append(feat_up + x_)
scale_out = self.processes(torch.cat(feats, dim=1))
return self.compression(torch.cat([x_, scale_out],
dim=1)) + self.shortcut(inputs)

View File

@@ -0,0 +1,96 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmengine.model import Sequential
from torch import nn as nn
class ResLayer(Sequential):
"""ResLayer to build ResNet style backbone.
Args:
block (nn.Module): block used to build ResLayer.
inplanes (int): inplanes of block.
planes (int): planes of block.
num_blocks (int): number of blocks.
stride (int): stride of the first block. Default: 1
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
multi_grid (int | None): Multi grid dilation rates of last
stage. Default: None
contract_dilation (bool): Whether contract first dilation of each layer
Default: False
"""
def __init__(self,
block,
inplanes,
planes,
num_blocks,
stride=1,
dilation=1,
avg_down=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
multi_grid=None,
contract_dilation=False,
**kwargs):
self.block = block
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = []
conv_stride = stride
if avg_down:
conv_stride = 1
downsample.append(
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
ceil_mode=True,
count_include_pad=False))
downsample.extend([
build_conv_layer(
conv_cfg,
inplanes,
planes * block.expansion,
kernel_size=1,
stride=conv_stride,
bias=False),
build_norm_layer(norm_cfg, planes * block.expansion)[1]
])
downsample = nn.Sequential(*downsample)
layers = []
if multi_grid is None:
if dilation > 1 and contract_dilation:
first_dilation = dilation // 2
else:
first_dilation = dilation
else:
first_dilation = multi_grid[0]
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=stride,
dilation=first_dilation,
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
**kwargs))
inplanes = planes * block.expansion
for i in range(1, num_blocks):
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=1,
dilation=dilation if multi_grid is None else multi_grid[i],
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
**kwargs))
super().__init__(*layers)

View File

@@ -0,0 +1,418 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/MendelXu/SAN/blob/main/san/model/attn_helper.py # noqa: E501
# Copyright (c) 2023 MendelXu.
# Licensed under the MIT License
import warnings
from typing import Optional
import torch
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from torch import Tensor, nn
from torch.nn import functional as F
def cross_attn_with_self_bias(
query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
):
"""Forward function of multi-head attention. Modified from
multi_head_attention_forward in
https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py.
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
embed_dim_to_check: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
Default: `True`
Note: `needs_weight` defaults to `True`, but should be set to `False`
For best performance when attention weights are not needed.
*Setting needs_weights to `True`
leads to a significant performance degradation.*
attn_mask: 2D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
use_separate_proj_weight: the function accept the proj. weights for query, key,
and value in different forms. If false, in_proj_weight will be used, which is
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
static_k, static_v: static key and value used for attention operators.
""" # noqa: E501
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
# allow MHA to have different sizes for the feature dimension
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, \
'embed_dim must be divisible by num_heads'
scaling = float(head_dim)**-0.5
if not use_separate_proj_weight:
if (query is key or torch.equal(
query, key)) and (key is value or torch.equal(key, value)):
# self-attention
raise NotImplementedError('self-attention is not implemented')
elif key is value or torch.equal(key, value):
# encoder-decoder attention
# This is inline in_proj function
# with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = F.linear(query, _w, _b)
if key is None:
assert value is None
k = None
v = None
q_k = None
q_v = None
else:
# This is inline in_proj function with
# in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
q_k, q_v = F.linear(query, _w, _b).chunk(2, dim=-1)
else:
# This is inline in_proj function with
# in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = F.linear(query, _w, _b)
# This is inline in_proj function with
# in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = embed_dim * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = F.linear(key, _w, _b)
q_k = F.linear(query, _w, _b)
# This is inline in_proj function with
# in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim * 2
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = F.linear(value, _w, _b)
q_v = F.linear(query, _w, _b)
else:
q_proj_weight_non_opt = \
torch.jit._unwrap_optional(q_proj_weight)
len1, len2 = q_proj_weight_non_opt.size()
assert len1 == embed_dim and len2 == query.size(-1)
k_proj_weight_non_opt = \
torch.jit._unwrap_optional(k_proj_weight)
len1, len2 = k_proj_weight_non_opt.size()
assert len1 == embed_dim and len2 == key.size(-1)
v_proj_weight_non_opt = \
torch.jit._unwrap_optional(v_proj_weight)
len1, len2 = v_proj_weight_non_opt.size()
assert len1 == embed_dim and len2 == value.size(-1)
if in_proj_bias is not None:
q = F.linear(query, q_proj_weight_non_opt,
in_proj_bias[0:embed_dim])
k = F.linear(key, k_proj_weight_non_opt,
in_proj_bias[embed_dim:(embed_dim * 2)])
v = F.linear(value, v_proj_weight_non_opt,
in_proj_bias[(embed_dim * 2):])
else:
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
q = q * scaling
if attn_mask is not None:
assert (
attn_mask.dtype == torch.float32
or attn_mask.dtype == torch.float64
or attn_mask.dtype == torch.float16
or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool
), 'Only float, byte, and bool types are supported for ' \
'attn_mask, not {}'.format(attn_mask.dtype)
if attn_mask.dtype == torch.uint8:
warnings.warn('Byte tensor for attn_mask in nn.MultiheadAttention '
'is deprecated. Use bool tensor instead.')
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError(
'The size of the 2D attn_mask is not correct.')
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0), key.size(0)
]:
raise RuntimeError(
'The size of the 3D attn_mask is not correct.')
else:
raise RuntimeError(
"attn_mask's dimension {} is not supported".format(
attn_mask.dim()))
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn(
'Byte tensor for key_padding_mask in nn.MultiheadAttention '
'is deprecated. Use bool tensor instead.')
key_padding_mask = key_padding_mask.to(torch.bool)
if bias_k is not None and bias_v is not None:
if static_k is None and static_v is None:
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = F.pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = F.pad(key_padding_mask, (0, 1))
else:
assert static_k is None, 'bias cannot be added to static key.'
assert static_v is None, 'bias cannot be added to static value.'
else:
assert bias_k is None
assert bias_v is None
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
q_k = q_k.contiguous().view(tgt_len, bsz * num_heads,
head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
q_v = q_v.contiguous().view(tgt_len, bsz * num_heads,
head_dim).transpose(0, 1)
if static_k is not None:
assert static_k.size(0) == bsz * num_heads
assert static_k.size(2) == head_dim
k = static_k
if static_v is not None:
assert static_v.size(0) == bsz * num_heads
assert static_v.size(2) == head_dim
v = static_v
src_len = k.size(1)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if add_zero_attn:
src_len += 1
k = torch.cat(
[
k,
torch.zeros(
(k.size(0), 1) + k.size()[2:],
dtype=k.dtype,
device=k.device),
],
dim=1,
)
v = torch.cat(
[
v,
torch.zeros(
(v.size(0), 1) + v.size()[2:],
dtype=v.dtype,
device=v.device),
],
dim=1,
)
if attn_mask is not None:
attn_mask = F.pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = F.pad(key_padding_mask, (0, 1))
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
assert list(
attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
else:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len,
src_len)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
)
attn_output_weights = attn_output_weights.view(bsz * num_heads,
tgt_len, src_len)
# attn_out_weights: [bsz * num_heads, tgt_len, src_len]
# ->[bsz * num_heads, tgt_len, src_len+1]
self_weight = (q * q_k).sum(
dim=-1, keepdim=True) # [bsz * num_heads, tgt_len, 1]
total_attn_output_weights = torch.cat([attn_output_weights, self_weight],
dim=-1)
total_attn_output_weights = F.softmax(total_attn_output_weights, dim=-1)
total_attn_output_weights = F.dropout(
total_attn_output_weights, p=dropout_p, training=training)
attn_output_weights = \
total_attn_output_weights[:, :, : -1]
# [bsz * num_heads, tgt_len, src_len]
self_weight = \
total_attn_output_weights[:, :, -1:] # [bsz * num_heads, tgt_len, 1]
attn_output = torch.bmm(attn_output_weights,
v) # [bsz * num_heads, tgt_len, head_dim]
attn_output = (attn_output + self_weight * q_v
) # [bsz * num_heads, tgt_len, head_dim]
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = attn_output.transpose(0, 1).contiguous().view(
tgt_len, bsz, embed_dim)
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len,
src_len)
return attn_output, attn_output_weights # .sum(dim=1) / num_heads
else:
return attn_output, None
def cross_attn_layer(tf_layer: BaseTransformerLayer, x, mem, attn_bias):
"""Implementation of transformer layer with cross attention. The cross
attention shares the embedding weights with self-attention of tf_layer.
Args:
tf_layer: (TransformerEncoderLayer): The Module of transformer layer.
x (Tensor): query [K,N,C]
mem (Tensor): key and value [L,N,C]
attn_bias (Tensor): attention bias [N*num_head,K,L]
Return:
x (Tensor): cross attention output [K,N,C]
"""
self_attn_layer = tf_layer.attentions[0].attn
attn_layer_paras = {
'embed_dim_to_check': self_attn_layer.embed_dim,
'num_heads': self_attn_layer.num_heads,
'in_proj_weight': self_attn_layer.in_proj_weight,
'in_proj_bias': self_attn_layer.in_proj_bias,
'bias_k': self_attn_layer.bias_k,
'bias_v': self_attn_layer.bias_v,
'add_zero_attn': self_attn_layer.add_zero_attn,
'dropout_p': self_attn_layer.dropout,
'out_proj_weight': self_attn_layer.out_proj.weight,
'out_proj_bias': self_attn_layer.out_proj.bias,
'training': self_attn_layer.training
}
q_x = tf_layer.norms[0](x)
k_x = v_x = tf_layer.norms[0](mem)
x = x + cross_attn_with_self_bias(
q_x,
k_x,
v_x,
attn_mask=attn_bias,
need_weights=False,
**attn_layer_paras)[0]
x = tf_layer.ffns[0](tf_layer.norms[1](x), identity=x)
return x
class LayerNorm2d(nn.Module):
"""A LayerNorm variant, popularized by Transformers, that performs point-
wise mean and variance normalization over the channel dimension for inputs
that have shape (batch_size, channels, height, width).
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
"""
def __init__(self, normalized_shape, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.normalized_shape = (normalized_shape, )
def forward(self, x: torch.Tensor):
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""
def __init__(self,
input_dim,
hidden_dim,
output_dim,
num_layers,
affine_func=nn.Linear):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
affine_func(n, k)
for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x: torch.Tensor):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x

View File

@@ -0,0 +1,58 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.utils import is_tuple_of
from .make_divisible import make_divisible
class SELayer(nn.Module):
"""Squeeze-and-Excitation Module.
Args:
channels (int): The input (and output) channels of the SE layer.
ratio (int): Squeeze ratio in SELayer, the intermediate channel will be
``int(channels/ratio)``. Default: 16.
conv_cfg (None or dict): Config dict for convolution layer.
Default: None, which means using conv2d.
act_cfg (dict or Sequence[dict]): Config dict for activation layer.
If act_cfg is a dict, two activation layers will be configured
by this dict. If act_cfg is a sequence of dicts, the first
activation layer will be configured by the first dict and the
second activation layer will be configured by the second dict.
Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0,
divisor=6.0)).
"""
def __init__(self,
channels,
ratio=16,
conv_cfg=None,
act_cfg=(dict(type='ReLU'),
dict(type='HSigmoid', bias=3.0, divisor=6.0))):
super().__init__()
if isinstance(act_cfg, dict):
act_cfg = (act_cfg, act_cfg)
assert len(act_cfg) == 2
assert is_tuple_of(act_cfg, dict)
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.conv1 = ConvModule(
in_channels=channels,
out_channels=make_divisible(channels // ratio, 8),
kernel_size=1,
stride=1,
conv_cfg=conv_cfg,
act_cfg=act_cfg[0])
self.conv2 = ConvModule(
in_channels=make_divisible(channels // ratio, 8),
out_channels=channels,
kernel_size=1,
stride=1,
conv_cfg=conv_cfg,
act_cfg=act_cfg[1])
def forward(self, x):
out = self.global_avgpool(x)
out = self.conv1(out)
out = self.conv2(out)
return x * out

View File

@@ -0,0 +1,161 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.cnn import ConvModule
from mmengine.model.weight_init import constant_init
from torch import nn as nn
from torch.nn import functional as F
class SelfAttentionBlock(nn.Module):
"""General self-attention block/non-local block.
Please refer to https://arxiv.org/abs/1706.03762 for details about key,
query and value.
Args:
key_in_channels (int): Input channels of key feature.
query_in_channels (int): Input channels of query feature.
channels (int): Output channels of key/query transform.
out_channels (int): Output channels.
share_key_query (bool): Whether share projection weight between key
and query projection.
query_downsample (nn.Module): Query downsample module.
key_downsample (nn.Module): Key downsample module.
key_query_num_convs (int): Number of convs for key/query projection.
value_num_convs (int): Number of convs for value projection.
matmul_norm (bool): Whether normalize attention map with sqrt of
channels
with_out (bool): Whether use out projection.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict|None): Config of activation layers.
"""
def __init__(self, key_in_channels, query_in_channels, channels,
out_channels, share_key_query, query_downsample,
key_downsample, key_query_num_convs, value_out_num_convs,
key_query_norm, value_out_norm, matmul_norm, with_out,
conv_cfg, norm_cfg, act_cfg):
super().__init__()
if share_key_query:
assert key_in_channels == query_in_channels
self.key_in_channels = key_in_channels
self.query_in_channels = query_in_channels
self.out_channels = out_channels
self.channels = channels
self.share_key_query = share_key_query
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.key_project = self.build_project(
key_in_channels,
channels,
num_convs=key_query_num_convs,
use_conv_module=key_query_norm,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if share_key_query:
self.query_project = self.key_project
else:
self.query_project = self.build_project(
query_in_channels,
channels,
num_convs=key_query_num_convs,
use_conv_module=key_query_norm,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.value_project = self.build_project(
key_in_channels,
channels if with_out else out_channels,
num_convs=value_out_num_convs,
use_conv_module=value_out_norm,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if with_out:
self.out_project = self.build_project(
channels,
out_channels,
num_convs=value_out_num_convs,
use_conv_module=value_out_norm,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
else:
self.out_project = None
self.query_downsample = query_downsample
self.key_downsample = key_downsample
self.matmul_norm = matmul_norm
self.init_weights()
def init_weights(self):
"""Initialize weight of later layer."""
if self.out_project is not None:
if not isinstance(self.out_project, ConvModule):
constant_init(self.out_project, 0)
def build_project(self, in_channels, channels, num_convs, use_conv_module,
conv_cfg, norm_cfg, act_cfg):
"""Build projection layer for key/query/value/out."""
if use_conv_module:
convs = [
ConvModule(
in_channels,
channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
]
for _ in range(num_convs - 1):
convs.append(
ConvModule(
channels,
channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
else:
convs = [nn.Conv2d(in_channels, channels, 1)]
for _ in range(num_convs - 1):
convs.append(nn.Conv2d(channels, channels, 1))
if len(convs) > 1:
convs = nn.Sequential(*convs)
else:
convs = convs[0]
return convs
def forward(self, query_feats, key_feats):
"""Forward function."""
batch_size = query_feats.size(0)
query = self.query_project(query_feats)
if self.query_downsample is not None:
query = self.query_downsample(query)
query = query.reshape(*query.shape[:2], -1)
query = query.permute(0, 2, 1).contiguous()
key = self.key_project(key_feats)
value = self.value_project(key_feats)
if self.key_downsample is not None:
key = self.key_downsample(key)
value = self.key_downsample(value)
key = key.reshape(*key.shape[:2], -1)
value = value.reshape(*value.shape[:2], -1)
value = value.permute(0, 2, 1).contiguous()
sim_map = torch.matmul(query, key)
if self.matmul_norm:
sim_map = (self.channels**-.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)
context = torch.matmul(sim_map, value)
context = context.permute(0, 2, 1).contiguous()
context = context.reshape(batch_size, -1, *query_feats.shape[2:])
if self.out_project is not None:
context = self.out_project(context)
return context

View File

@@ -0,0 +1,107 @@
# Copyright (c) OpenMMLab. All rights reserved.
def nlc_to_nchw(x, hw_shape):
"""Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, L, C] before conversion.
hw_shape (Sequence[int]): The height and width of output feature map.
Returns:
Tensor: The output tensor of shape [N, C, H, W] after conversion.
"""
H, W = hw_shape
assert len(x.shape) == 3
B, L, C = x.shape
assert L == H * W, 'The seq_len doesn\'t match H, W'
return x.transpose(1, 2).reshape(B, C, H, W)
def nchw_to_nlc(x):
"""Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
Returns:
Tensor: The output tensor of shape [N, L, C] after conversion.
"""
assert len(x.shape) == 4
return x.flatten(2).transpose(1, 2).contiguous()
def nchw2nlc2nchw(module, x, contiguous=False, **kwargs):
"""Flatten [N, C, H, W] shape tensor `x` to [N, L, C] shape tensor. Use the
reshaped tensor as the input of `module`, and the convert the output of
`module`, whose shape is.
[N, L, C], to [N, C, H, W].
Args:
module (Callable): A callable object the takes a tensor
with shape [N, L, C] as input.
x (Tensor): The input tensor of shape [N, C, H, W].
contiguous:
contiguous (Bool): Whether to make the tensor contiguous
after each shape transform.
Returns:
Tensor: The output tensor of shape [N, C, H, W].
Example:
>>> import torch
>>> import torch.nn as nn
>>> norm = nn.LayerNorm(4)
>>> feature_map = torch.rand(4, 4, 5, 5)
>>> output = nchw2nlc2nchw(norm, feature_map)
"""
B, C, H, W = x.shape
if not contiguous:
x = x.flatten(2).transpose(1, 2)
x = module(x, **kwargs)
x = x.transpose(1, 2).reshape(B, C, H, W)
else:
x = x.flatten(2).transpose(1, 2).contiguous()
x = module(x, **kwargs)
x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()
return x
def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs):
"""Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the
reshaped tensor as the input of `module`, and convert the output of
`module`, whose shape is.
[N, C, H, W], to [N, L, C].
Args:
module (Callable): A callable object the takes a tensor
with shape [N, C, H, W] as input.
x (Tensor): The input tensor of shape [N, L, C].
hw_shape: (Sequence[int]): The height and width of the
feature map with shape [N, C, H, W].
contiguous (Bool): Whether to make the tensor contiguous
after each shape transform.
Returns:
Tensor: The output tensor of shape [N, L, C].
Example:
>>> import torch
>>> import torch.nn as nn
>>> conv = nn.Conv2d(16, 16, 3, 1, 1)
>>> feature_map = torch.rand(4, 25, 16)
>>> output = nlc2nchw2nlc(conv, feature_map, (5, 5))
"""
H, W = hw_shape
assert len(x.shape) == 3
B, L, C = x.shape
assert L == H * W, 'The seq_len doesn\'t match H, W'
if not contiguous:
x = x.transpose(1, 2).reshape(B, C, H, W)
x = module(x, **kwargs)
x = x.flatten(2).transpose(1, 2)
else:
x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()
x = module(x, **kwargs)
x = x.flatten(2).transpose(1, 2).contiguous()
return x

View File

@@ -0,0 +1,102 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, build_upsample_layer
class UpConvBlock(nn.Module):
"""Upsample convolution block in decoder for UNet.
This upsample convolution block consists of one upsample module
followed by one convolution block. The upsample module expands the
high-level low-resolution feature map and the convolution block fuses
the upsampled high-level low-resolution feature map and the low-level
high-resolution feature map from encoder.
Args:
conv_block (nn.Sequential): Sequential of convolutional layers.
in_channels (int): Number of input channels of the high-level
skip_channels (int): Number of input channels of the low-level
high-resolution feature map from encoder.
out_channels (int): Number of output channels.
num_convs (int): Number of convolutional layers in the conv_block.
Default: 2.
stride (int): Stride of convolutional layer in conv_block. Default: 1.
dilation (int): Dilation rate of convolutional layer in conv_block.
Default: 1.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
conv_cfg (dict | None): Config dict for convolution layer.
Default: None.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
upsample_cfg (dict): The upsample config of the upsample module in
decoder. Default: dict(type='InterpConv'). If the size of
high-level feature map is the same as that of skip feature map
(low-level feature map from encoder), it does not need upsample the
high-level feature map and the upsample_cfg is None.
dcn (bool): Use deformable convolution in convolutional layer or not.
Default: None.
plugins (dict): plugins for convolutional layers. Default: None.
"""
def __init__(self,
conv_block,
in_channels,
skip_channels,
out_channels,
num_convs=2,
stride=1,
dilation=1,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
upsample_cfg=dict(type='InterpConv'),
dcn=None,
plugins=None):
super().__init__()
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
self.conv_block = conv_block(
in_channels=2 * skip_channels,
out_channels=out_channels,
num_convs=num_convs,
stride=stride,
dilation=dilation,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
dcn=None,
plugins=None)
if upsample_cfg is not None:
self.upsample = build_upsample_layer(
cfg=upsample_cfg,
in_channels=in_channels,
out_channels=skip_channels,
with_cp=with_cp,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
else:
self.upsample = ConvModule(
in_channels,
skip_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, skip, x):
"""Forward function."""
x = self.upsample(x)
out = torch.cat([skip, x], dim=1)
out = self.conv_block(out)
return out

View File

@@ -0,0 +1,51 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch.nn as nn
import torch.nn.functional as F
def resize(input,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None,
warning=True):
if warning:
if size is not None and align_corners:
input_h, input_w = tuple(int(x) for x in input.shape[2:])
output_h, output_w = tuple(int(x) for x in size)
if output_h > input_h or output_w > output_h:
if ((output_h > 1 and output_w > 1 and input_h > 1
and input_w > 1) and (output_h - 1) % (input_h - 1)
and (output_w - 1) % (input_w - 1)):
warnings.warn(
f'When align_corners={align_corners}, '
'the output would more aligned if '
f'input size {(input_h, input_w)} is `x+1` and '
f'out size {(output_h, output_w)} is `nx+1`')
return F.interpolate(input, size, scale_factor, mode, align_corners)
class Upsample(nn.Module):
def __init__(self,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None):
super().__init__()
self.size = size
if isinstance(scale_factor, tuple):
self.scale_factor = tuple(float(factor) for factor in scale_factor)
else:
self.scale_factor = float(scale_factor) if scale_factor else None
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
if not self.size:
size = [int(t * self.scale_factor) for t in x.shape[-2:]]
else:
size = self.size
return resize(x, size, None, self.mode, self.align_corners)