init
This commit is contained in:
27
finetune/mmseg/models/utils/__init__.py
Normal file
27
finetune/mmseg/models/utils/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .basic_block import BasicBlock, Bottleneck
|
||||
from .embed import PatchEmbed
|
||||
from .encoding import Encoding
|
||||
from .inverted_residual import InvertedResidual, InvertedResidualV3
|
||||
from .make_divisible import make_divisible
|
||||
from .point_sample import get_uncertain_point_coords_with_randomness
|
||||
from .ppm import DAPPM, PAPPM
|
||||
from .res_layer import ResLayer
|
||||
from .se_layer import SELayer
|
||||
from .self_attention_block import SelfAttentionBlock
|
||||
from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc,
|
||||
nlc_to_nchw)
|
||||
from .up_conv_block import UpConvBlock
|
||||
|
||||
# isort: off
|
||||
from .wrappers import Upsample, resize
|
||||
from .san_layers import MLP, LayerNorm2d, cross_attn_layer
|
||||
|
||||
__all__ = [
|
||||
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
|
||||
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed',
|
||||
'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc', 'Encoding',
|
||||
'Upsample', 'resize', 'DAPPM', 'PAPPM', 'BasicBlock', 'Bottleneck',
|
||||
'cross_attn_layer', 'LayerNorm2d', 'MLP',
|
||||
'get_uncertain_point_coords_with_randomness'
|
||||
]
|
||||
143
finetune/mmseg/models/utils/basic_block.py
Normal file
143
finetune/mmseg/models/utils/basic_block.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType
|
||||
|
||||
|
||||
class BasicBlock(BaseModule):
|
||||
"""Basic block from `ResNet <https://arxiv.org/abs/1512.03385>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Output channels.
|
||||
stride (int): Stride of the first block. Default: 1.
|
||||
downsample (nn.Module, optional): Downsample operation on identity.
|
||||
Default: None.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict, optional): Config dict for activation layer in
|
||||
ConvModule. Default: dict(type='ReLU', inplace=True).
|
||||
act_cfg_out (dict, optional): Config dict for activation layer at the
|
||||
last of the block. Default: None.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
expansion = 1
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
stride: int = 1,
|
||||
downsample: nn.Module = None,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
act_cfg_out: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
self.conv1 = ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.conv2 = ConvModule(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
self.downsample = downsample
|
||||
if act_cfg_out:
|
||||
self.act = MODELS.build(act_cfg_out)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
residual = x
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
|
||||
if self.downsample:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
|
||||
if hasattr(self, 'act'):
|
||||
out = self.act(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(BaseModule):
|
||||
"""Bottleneck block from `ResNet <https://arxiv.org/abs/1512.03385>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Output channels.
|
||||
stride (int): Stride of the first block. Default: 1.
|
||||
downsample (nn.Module, optional): Downsample operation on identity.
|
||||
Default: None.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict, optional): Config dict for activation layer in
|
||||
ConvModule. Default: dict(type='ReLU', inplace=True).
|
||||
act_cfg_out (dict, optional): Config dict for activation layer at
|
||||
the last of the block. Default: None.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
expansion = 2
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
act_cfg_out: OptConfigType = None,
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
self.conv1 = ConvModule(
|
||||
in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
self.conv2 = ConvModule(
|
||||
channels,
|
||||
channels,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.conv3 = ConvModule(
|
||||
channels,
|
||||
channels * self.expansion,
|
||||
1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
if act_cfg_out:
|
||||
self.act = MODELS.build(act_cfg_out)
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
out = self.conv3(out)
|
||||
|
||||
if self.downsample:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
|
||||
if hasattr(self, 'act'):
|
||||
out = self.act(out)
|
||||
|
||||
return out
|
||||
330
finetune/mmseg/models/utils/embed.py
Normal file
330
finetune/mmseg/models/utils/embed.py
Normal file
@@ -0,0 +1,330 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import Sequence
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils import to_2tuple
|
||||
|
||||
|
||||
class AdaptivePadding(nn.Module):
|
||||
"""Applies padding to input (if needed) so that input can get fully covered
|
||||
by filter you specified. It support two modes "same" and "corner". The
|
||||
"same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
|
||||
input. The "corner" mode would pad zero to bottom right.
|
||||
|
||||
Args:
|
||||
kernel_size (int | tuple): Size of the kernel:
|
||||
stride (int | tuple): Stride of the filter. Default: 1:
|
||||
dilation (int | tuple): Spacing between kernel elements.
|
||||
Default: 1.
|
||||
padding (str): Support "same" and "corner", "corner" mode
|
||||
would pad zero to bottom right, and "same" mode would
|
||||
pad zero around input. Default: "corner".
|
||||
Example:
|
||||
>>> kernel_size = 16
|
||||
>>> stride = 16
|
||||
>>> dilation = 1
|
||||
>>> input = torch.rand(1, 1, 15, 17)
|
||||
>>> adap_pad = AdaptivePadding(
|
||||
>>> kernel_size=kernel_size,
|
||||
>>> stride=stride,
|
||||
>>> dilation=dilation,
|
||||
>>> padding="corner")
|
||||
>>> out = adap_pad(input)
|
||||
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
|
||||
>>> input = torch.rand(1, 1, 16, 17)
|
||||
>>> out = adap_pad(input)
|
||||
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
|
||||
"""
|
||||
|
||||
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
|
||||
|
||||
super().__init__()
|
||||
|
||||
assert padding in ('same', 'corner')
|
||||
|
||||
kernel_size = to_2tuple(kernel_size)
|
||||
stride = to_2tuple(stride)
|
||||
dilation = to_2tuple(dilation)
|
||||
|
||||
self.padding = padding
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
|
||||
def get_pad_shape(self, input_shape):
|
||||
input_h, input_w = input_shape
|
||||
kernel_h, kernel_w = self.kernel_size
|
||||
stride_h, stride_w = self.stride
|
||||
output_h = math.ceil(input_h / stride_h)
|
||||
output_w = math.ceil(input_w / stride_w)
|
||||
pad_h = max((output_h - 1) * stride_h +
|
||||
(kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
|
||||
pad_w = max((output_w - 1) * stride_w +
|
||||
(kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
|
||||
return pad_h, pad_w
|
||||
|
||||
def forward(self, x):
|
||||
pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
if self.padding == 'corner':
|
||||
x = F.pad(x, [0, pad_w, 0, pad_h])
|
||||
elif self.padding == 'same':
|
||||
x = F.pad(x, [
|
||||
pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
|
||||
pad_h - pad_h // 2
|
||||
])
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(BaseModule):
|
||||
"""Image to Patch Embedding.
|
||||
|
||||
We use a conv layer to implement PatchEmbed.
|
||||
|
||||
Args:
|
||||
in_channels (int): The num of input channels. Default: 3
|
||||
embed_dims (int): The dimensions of embedding. Default: 768
|
||||
conv_type (str): The config dict for embedding
|
||||
conv layer type selection. Default: "Conv2d".
|
||||
kernel_size (int): The kernel_size of embedding conv. Default: 16.
|
||||
stride (int, optional): The slide stride of embedding conv.
|
||||
Default: None (Would be set as `kernel_size`).
|
||||
padding (int | tuple | string ): The padding length of
|
||||
embedding conv. When it is a string, it means the mode
|
||||
of adaptive padding, support "same" and "corner" now.
|
||||
Default: "corner".
|
||||
dilation (int): The dilation rate of embedding conv. Default: 1.
|
||||
bias (bool): Bias of embed conv. Default: True.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||
Default: None.
|
||||
input_size (int | tuple | None): The size of input, which will be
|
||||
used to calculate the out size. Only work when `dynamic_size`
|
||||
is False. Default: None.
|
||||
init_cfg (`mmengine.ConfigDict`, optional): The Config for
|
||||
initialization. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=768,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=16,
|
||||
stride=None,
|
||||
padding='corner',
|
||||
dilation=1,
|
||||
bias=True,
|
||||
norm_cfg=None,
|
||||
input_size=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
if stride is None:
|
||||
stride = kernel_size
|
||||
|
||||
kernel_size = to_2tuple(kernel_size)
|
||||
stride = to_2tuple(stride)
|
||||
dilation = to_2tuple(dilation)
|
||||
|
||||
if isinstance(padding, str):
|
||||
self.adap_padding = AdaptivePadding(
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding)
|
||||
# disable the padding of conv
|
||||
padding = 0
|
||||
else:
|
||||
self.adap_padding = None
|
||||
padding = to_2tuple(padding)
|
||||
|
||||
self.projection = build_conv_layer(
|
||||
dict(type=conv_type),
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dims,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
|
||||
if norm_cfg is not None:
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
if input_size:
|
||||
input_size = to_2tuple(input_size)
|
||||
# `init_out_size` would be used outside to
|
||||
# calculate the num_patches
|
||||
# when `use_abs_pos_embed` outside
|
||||
self.init_input_size = input_size
|
||||
if self.adap_padding:
|
||||
pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
|
||||
input_h, input_w = input_size
|
||||
input_h = input_h + pad_h
|
||||
input_w = input_w + pad_w
|
||||
input_size = (input_h, input_w)
|
||||
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
||||
h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
|
||||
(kernel_size[0] - 1) - 1) // stride[0] + 1
|
||||
w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
|
||||
(kernel_size[1] - 1) - 1) // stride[1] + 1
|
||||
self.init_out_size = (h_out, w_out)
|
||||
else:
|
||||
self.init_input_size = None
|
||||
self.init_out_size = None
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
|
||||
|
||||
Returns:
|
||||
tuple: Contains merged results and its spatial shape.
|
||||
|
||||
- x (Tensor): Has shape (B, out_h * out_w, embed_dims)
|
||||
- out_size (tuple[int]): Spatial shape of x, arrange as
|
||||
(out_h, out_w).
|
||||
"""
|
||||
|
||||
if self.adap_padding:
|
||||
x = self.adap_padding(x)
|
||||
|
||||
x = self.projection(x)
|
||||
out_size = (x.shape[2], x.shape[3])
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x, out_size
|
||||
|
||||
|
||||
class PatchMerging(BaseModule):
|
||||
"""Merge patch feature map.
|
||||
|
||||
This layer groups feature map by kernel_size, and applies norm and linear
|
||||
layers to the grouped feature map. Our implementation uses `nn.Unfold` to
|
||||
merge patch, which is about 25% faster than original implementation.
|
||||
Instead, we need to modify pretrained models for compatibility.
|
||||
|
||||
Args:
|
||||
in_channels (int): The num of input channels.
|
||||
out_channels (int): The num of output channels.
|
||||
kernel_size (int | tuple, optional): the kernel size in the unfold
|
||||
layer. Defaults to 2.
|
||||
stride (int | tuple, optional): the stride of the sliding blocks in the
|
||||
unfold layer. Default: None. (Would be set as `kernel_size`)
|
||||
padding (int | tuple | string ): The padding length of
|
||||
embedding conv. When it is a string, it means the mode
|
||||
of adaptive padding, support "same" and "corner" now.
|
||||
Default: "corner".
|
||||
dilation (int | tuple, optional): dilation parameter in the unfold
|
||||
layer. Default: 1.
|
||||
bias (bool, optional): Whether to add bias in linear layer or not.
|
||||
Defaults: False.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=2,
|
||||
stride=None,
|
||||
padding='corner',
|
||||
dilation=1,
|
||||
bias=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
if stride:
|
||||
stride = stride
|
||||
else:
|
||||
stride = kernel_size
|
||||
|
||||
kernel_size = to_2tuple(kernel_size)
|
||||
stride = to_2tuple(stride)
|
||||
dilation = to_2tuple(dilation)
|
||||
|
||||
if isinstance(padding, str):
|
||||
self.adap_padding = AdaptivePadding(
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding)
|
||||
# disable the padding of unfold
|
||||
padding = 0
|
||||
else:
|
||||
self.adap_padding = None
|
||||
|
||||
padding = to_2tuple(padding)
|
||||
self.sampler = nn.Unfold(
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
stride=stride)
|
||||
|
||||
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
|
||||
|
||||
if norm_cfg is not None:
|
||||
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
|
||||
|
||||
def forward(self, x, input_size):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): Has shape (B, H*W, C_in).
|
||||
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
tuple: Contains merged results and its spatial shape.
|
||||
|
||||
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
|
||||
- out_size (tuple[int]): Spatial shape of x, arrange as
|
||||
(Merged_H, Merged_W).
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
assert isinstance(input_size, Sequence), f'Expect ' \
|
||||
f'input_size is ' \
|
||||
f'`Sequence` ' \
|
||||
f'but get {input_size}'
|
||||
|
||||
H, W = input_size
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
|
||||
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
|
||||
# Use nn.Unfold to merge patch. About 25% faster than original method,
|
||||
# but need to modify pretrained model for compatibility
|
||||
|
||||
if self.adap_padding:
|
||||
x = self.adap_padding(x)
|
||||
H, W = x.shape[-2:]
|
||||
|
||||
x = self.sampler(x)
|
||||
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
|
||||
|
||||
out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
|
||||
(self.sampler.kernel_size[0] - 1) -
|
||||
1) // self.sampler.stride[0] + 1
|
||||
out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
|
||||
(self.sampler.kernel_size[1] - 1) -
|
||||
1) // self.sampler.stride[1] + 1
|
||||
|
||||
output_size = (out_h, out_w)
|
||||
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
|
||||
x = self.norm(x) if self.norm else x
|
||||
x = self.reduction(x)
|
||||
return x, output_size
|
||||
75
finetune/mmseg/models/utils/encoding.py
Normal file
75
finetune/mmseg/models/utils/encoding.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class Encoding(nn.Module):
|
||||
"""Encoding Layer: a learnable residual encoder.
|
||||
|
||||
Input is of shape (batch_size, channels, height, width).
|
||||
Output is of shape (batch_size, num_codes, channels).
|
||||
|
||||
Args:
|
||||
channels: dimension of the features or feature channels
|
||||
num_codes: number of code words
|
||||
"""
|
||||
|
||||
def __init__(self, channels, num_codes):
|
||||
super().__init__()
|
||||
# init codewords and smoothing factor
|
||||
self.channels, self.num_codes = channels, num_codes
|
||||
std = 1. / ((num_codes * channels)**0.5)
|
||||
# [num_codes, channels]
|
||||
self.codewords = nn.Parameter(
|
||||
torch.empty(num_codes, channels,
|
||||
dtype=torch.float).uniform_(-std, std),
|
||||
requires_grad=True)
|
||||
# [num_codes]
|
||||
self.scale = nn.Parameter(
|
||||
torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),
|
||||
requires_grad=True)
|
||||
|
||||
@staticmethod
|
||||
def scaled_l2(x, codewords, scale):
|
||||
num_codes, channels = codewords.size()
|
||||
batch_size = x.size(0)
|
||||
reshaped_scale = scale.view((1, 1, num_codes))
|
||||
expanded_x = x.unsqueeze(2).expand(
|
||||
(batch_size, x.size(1), num_codes, channels))
|
||||
reshaped_codewords = codewords.view((1, 1, num_codes, channels))
|
||||
|
||||
scaled_l2_norm = reshaped_scale * (
|
||||
expanded_x - reshaped_codewords).pow(2).sum(dim=3)
|
||||
return scaled_l2_norm
|
||||
|
||||
@staticmethod
|
||||
def aggregate(assignment_weights, x, codewords):
|
||||
num_codes, channels = codewords.size()
|
||||
reshaped_codewords = codewords.view((1, 1, num_codes, channels))
|
||||
batch_size = x.size(0)
|
||||
|
||||
expanded_x = x.unsqueeze(2).expand(
|
||||
(batch_size, x.size(1), num_codes, channels))
|
||||
encoded_feat = (assignment_weights.unsqueeze(3) *
|
||||
(expanded_x - reshaped_codewords)).sum(dim=1)
|
||||
return encoded_feat
|
||||
|
||||
def forward(self, x):
|
||||
assert x.dim() == 4 and x.size(1) == self.channels
|
||||
# [batch_size, channels, height, width]
|
||||
batch_size = x.size(0)
|
||||
# [batch_size, height x width, channels]
|
||||
x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()
|
||||
# assignment_weights: [batch_size, channels, num_codes]
|
||||
assignment_weights = F.softmax(
|
||||
self.scaled_l2(x, self.codewords, self.scale), dim=2)
|
||||
# aggregate
|
||||
encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
|
||||
return encoded_feat
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \
|
||||
f'x{self.channels})'
|
||||
return repr_str
|
||||
213
finetune/mmseg/models/utils/inverted_residual.py
Normal file
213
finetune/mmseg/models/utils/inverted_residual.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.cnn import ConvModule
|
||||
from torch import nn
|
||||
from torch.utils import checkpoint as cp
|
||||
|
||||
from .se_layer import SELayer
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
"""InvertedResidual block for MobileNetV2.
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channels of the InvertedResidual block.
|
||||
out_channels (int): The output channels of the InvertedResidual block.
|
||||
stride (int): Stride of the middle (first) 3x3 convolution.
|
||||
expand_ratio (int): Adjusts number of channels of the hidden layer
|
||||
in InvertedResidual by this amount.
|
||||
dilation (int): Dilation rate of depthwise conv. Default: 1
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU6').
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
expand_ratio,
|
||||
dilation=1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6'),
|
||||
with_cp=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2], f'stride must in [1, 2]. ' \
|
||||
f'But received {stride}.'
|
||||
self.with_cp = with_cp
|
||||
self.use_res_connect = self.stride == 1 and in_channels == out_channels
|
||||
hidden_dim = int(round(in_channels * expand_ratio))
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
layers.append(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=hidden_dim,
|
||||
kernel_size=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**kwargs))
|
||||
layers.extend([
|
||||
ConvModule(
|
||||
in_channels=hidden_dim,
|
||||
out_channels=hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
groups=hidden_dim,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**kwargs),
|
||||
ConvModule(
|
||||
in_channels=hidden_dim,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None,
|
||||
**kwargs)
|
||||
])
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class InvertedResidualV3(nn.Module):
|
||||
"""Inverted Residual Block for MobileNetV3.
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channels of this Module.
|
||||
out_channels (int): The output channels of this Module.
|
||||
mid_channels (int): The input channels of the depthwise convolution.
|
||||
kernel_size (int): The kernel size of the depthwise convolution.
|
||||
Default: 3.
|
||||
stride (int): The stride of the depthwise convolution. Default: 1.
|
||||
se_cfg (dict): Config dict for se layer. Default: None, which means no
|
||||
se layer.
|
||||
with_expand_conv (bool): Use expand conv or not. If set False,
|
||||
mid_channels must be the same with in_channels. Default: True.
|
||||
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
||||
which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU').
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
se_cfg=None,
|
||||
with_expand_conv=True,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
with_cp=False):
|
||||
super().__init__()
|
||||
self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
|
||||
assert stride in [1, 2]
|
||||
self.with_cp = with_cp
|
||||
self.with_se = se_cfg is not None
|
||||
self.with_expand_conv = with_expand_conv
|
||||
|
||||
if self.with_se:
|
||||
assert isinstance(se_cfg, dict)
|
||||
if not self.with_expand_conv:
|
||||
assert mid_channels == in_channels
|
||||
|
||||
if self.with_expand_conv:
|
||||
self.expand_conv = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.depthwise_conv = ConvModule(
|
||||
in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=kernel_size // 2,
|
||||
groups=mid_channels,
|
||||
conv_cfg=dict(
|
||||
type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
if self.with_se:
|
||||
self.se = SELayer(**se_cfg)
|
||||
|
||||
self.linear_conv = ConvModule(
|
||||
in_channels=mid_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
out = x
|
||||
|
||||
if self.with_expand_conv:
|
||||
out = self.expand_conv(out)
|
||||
|
||||
out = self.depthwise_conv(out)
|
||||
|
||||
if self.with_se:
|
||||
out = self.se(out)
|
||||
|
||||
out = self.linear_conv(out)
|
||||
|
||||
if self.with_res_shortcut:
|
||||
return x + out
|
||||
else:
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
return out
|
||||
28
finetune/mmseg/models/utils/make_divisible.py
Normal file
28
finetune/mmseg/models/utils/make_divisible.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
|
||||
"""Make divisible function.
|
||||
|
||||
This function rounds the channel number to the nearest value that can be
|
||||
divisible by the divisor. It is taken from the original tf repo. It ensures
|
||||
that all layers have a channel number that is divisible by divisor. It can
|
||||
be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa
|
||||
|
||||
Args:
|
||||
value (int): The original channel number.
|
||||
divisor (int): The divisor to fully divide the channel number.
|
||||
min_value (int): The minimum value of the output channel.
|
||||
Default: None, means that the minimum value equal to the divisor.
|
||||
min_ratio (float): The minimum ratio of the rounded channel number to
|
||||
the original channel number. Default: 0.9.
|
||||
|
||||
Returns:
|
||||
int: The modified output channel number.
|
||||
"""
|
||||
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than (1-min_ratio).
|
||||
if new_value < min_ratio * value:
|
||||
new_value += divisor
|
||||
return new_value
|
||||
88
finetune/mmseg/models/utils/point_sample.py
Normal file
88
finetune/mmseg/models/utils/point_sample.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.ops import point_sample
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def get_uncertainty(mask_preds: Tensor, labels: Tensor) -> Tensor:
|
||||
"""Estimate uncertainty based on pred logits.
|
||||
|
||||
We estimate uncertainty as L1 distance between 0.0 and the logits
|
||||
prediction in 'mask_preds' for the foreground class in `classes`.
|
||||
|
||||
Args:
|
||||
mask_preds (Tensor): mask predication logits, shape (num_rois,
|
||||
num_classes, mask_height, mask_width).
|
||||
|
||||
labels (Tensor): Either predicted or ground truth label for
|
||||
each predicted mask, of length num_rois.
|
||||
|
||||
Returns:
|
||||
scores (Tensor): Uncertainty scores with the most uncertain
|
||||
locations having the highest uncertainty score,
|
||||
shape (num_rois, 1, mask_height, mask_width)
|
||||
"""
|
||||
if mask_preds.shape[1] == 1:
|
||||
gt_class_logits = mask_preds.clone()
|
||||
else:
|
||||
inds = torch.arange(mask_preds.shape[0], device=mask_preds.device)
|
||||
gt_class_logits = mask_preds[inds, labels].unsqueeze(1)
|
||||
return -torch.abs(gt_class_logits)
|
||||
|
||||
|
||||
def get_uncertain_point_coords_with_randomness(
|
||||
mask_preds: Tensor, labels: Tensor, num_points: int,
|
||||
oversample_ratio: float, importance_sample_ratio: float) -> Tensor:
|
||||
"""Get ``num_points`` most uncertain points with random points during
|
||||
train.
|
||||
|
||||
Sample points in [0, 1] x [0, 1] coordinate space based on their
|
||||
uncertainty. The uncertainties are calculated for each point using
|
||||
'get_uncertainty()' function that takes point's logit prediction as
|
||||
input.
|
||||
|
||||
Args:
|
||||
mask_preds (Tensor): A tensor of shape (num_rois, num_classes,
|
||||
mask_height, mask_width) for class-specific or class-agnostic
|
||||
prediction.
|
||||
labels (Tensor): The ground truth class for each instance.
|
||||
num_points (int): The number of points to sample.
|
||||
oversample_ratio (float): Oversampling parameter.
|
||||
importance_sample_ratio (float): Ratio of points that are sampled
|
||||
via importnace sampling.
|
||||
|
||||
Returns:
|
||||
point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
|
||||
that contains the coordinates sampled points.
|
||||
"""
|
||||
assert oversample_ratio >= 1
|
||||
assert 0 <= importance_sample_ratio <= 1
|
||||
batch_size = mask_preds.shape[0]
|
||||
num_sampled = int(num_points * oversample_ratio)
|
||||
point_coords = torch.rand(
|
||||
batch_size, num_sampled, 2, device=mask_preds.device)
|
||||
point_logits = point_sample(mask_preds, point_coords)
|
||||
# It is crucial to calculate uncertainty based on the sampled
|
||||
# prediction value for the points. Calculating uncertainties of the
|
||||
# coarse predictions first and sampling them for points leads to
|
||||
# incorrect results. To illustrate this: assume uncertainty func(
|
||||
# logits)=-abs(logits), a sampled point between two coarse
|
||||
# predictions with -1 and 1 logits has 0 logits, and therefore 0
|
||||
# uncertainty value. However, if we calculate uncertainties for the
|
||||
# coarse predictions first, both will have -1 uncertainty,
|
||||
# and sampled point will get -1 uncertainty.
|
||||
point_uncertainties = get_uncertainty(point_logits, labels)
|
||||
num_uncertain_points = int(importance_sample_ratio * num_points)
|
||||
num_random_points = num_points - num_uncertain_points
|
||||
idx = torch.topk(
|
||||
point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
|
||||
shift = num_sampled * torch.arange(
|
||||
batch_size, dtype=torch.long, device=mask_preds.device)
|
||||
idx += shift[:, None]
|
||||
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
|
||||
batch_size, num_uncertain_points, 2)
|
||||
if num_random_points > 0:
|
||||
rand_roi_coords = torch.rand(
|
||||
batch_size, num_random_points, 2, device=mask_preds.device)
|
||||
point_coords = torch.cat((point_coords, rand_roi_coords), dim=1)
|
||||
return point_coords
|
||||
193
finetune/mmseg/models/utils/ppm.py
Normal file
193
finetune/mmseg/models/utils/ppm.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class DAPPM(BaseModule):
|
||||
"""DAPPM module in `DDRNet <https://arxiv.org/abs/2101.06085>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
branch_channels (int): Branch channels.
|
||||
out_channels (int): Output channels.
|
||||
num_scales (int): Number of scales.
|
||||
kernel_sizes (list[int]): Kernel sizes of each scale.
|
||||
strides (list[int]): Strides of each scale.
|
||||
paddings (list[int]): Paddings of each scale.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
conv_cfg (dict): Config dict for convolution layer in ConvModule.
|
||||
Default: dict(order=('norm', 'act', 'conv'), bias=False).
|
||||
upsample_mode (str): Upsample mode. Default: 'bilinear'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
branch_channels: int,
|
||||
out_channels: int,
|
||||
num_scales: int,
|
||||
kernel_sizes: List[int] = [5, 9, 17],
|
||||
strides: List[int] = [2, 4, 8],
|
||||
paddings: List[int] = [2, 4, 8],
|
||||
norm_cfg: Dict = dict(type='BN', momentum=0.1),
|
||||
act_cfg: Dict = dict(type='ReLU', inplace=True),
|
||||
conv_cfg: Dict = dict(
|
||||
order=('norm', 'act', 'conv'), bias=False),
|
||||
upsample_mode: str = 'bilinear'):
|
||||
super().__init__()
|
||||
|
||||
self.num_scales = num_scales
|
||||
self.unsample_mode = upsample_mode
|
||||
self.in_channels = in_channels
|
||||
self.branch_channels = branch_channels
|
||||
self.out_channels = out_channels
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.conv_cfg = conv_cfg
|
||||
|
||||
self.scales = ModuleList([
|
||||
ConvModule(
|
||||
in_channels,
|
||||
branch_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**conv_cfg)
|
||||
])
|
||||
for i in range(1, num_scales - 1):
|
||||
self.scales.append(
|
||||
Sequential(*[
|
||||
nn.AvgPool2d(
|
||||
kernel_size=kernel_sizes[i - 1],
|
||||
stride=strides[i - 1],
|
||||
padding=paddings[i - 1]),
|
||||
ConvModule(
|
||||
in_channels,
|
||||
branch_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**conv_cfg)
|
||||
]))
|
||||
self.scales.append(
|
||||
Sequential(*[
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
ConvModule(
|
||||
in_channels,
|
||||
branch_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**conv_cfg)
|
||||
]))
|
||||
self.processes = ModuleList()
|
||||
for i in range(num_scales - 1):
|
||||
self.processes.append(
|
||||
ConvModule(
|
||||
branch_channels,
|
||||
branch_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**conv_cfg))
|
||||
|
||||
self.compression = ConvModule(
|
||||
branch_channels * num_scales,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**conv_cfg)
|
||||
|
||||
self.shortcut = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**conv_cfg)
|
||||
|
||||
def forward(self, inputs: Tensor):
|
||||
feats = []
|
||||
feats.append(self.scales[0](inputs))
|
||||
|
||||
for i in range(1, self.num_scales):
|
||||
feat_up = F.interpolate(
|
||||
self.scales[i](inputs),
|
||||
size=inputs.shape[2:],
|
||||
mode=self.unsample_mode)
|
||||
feats.append(self.processes[i - 1](feat_up + feats[i - 1]))
|
||||
|
||||
return self.compression(torch.cat(feats,
|
||||
dim=1)) + self.shortcut(inputs)
|
||||
|
||||
|
||||
class PAPPM(DAPPM):
|
||||
"""PAPPM module in `PIDNet <https://arxiv.org/abs/2206.02066>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
branch_channels (int): Branch channels.
|
||||
out_channels (int): Output channels.
|
||||
num_scales (int): Number of scales.
|
||||
kernel_sizes (list[int]): Kernel sizes of each scale.
|
||||
strides (list[int]): Strides of each scale.
|
||||
paddings (list[int]): Paddings of each scale.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN', momentum=0.1).
|
||||
act_cfg (dict): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
conv_cfg (dict): Config dict for convolution layer in ConvModule.
|
||||
Default: dict(order=('norm', 'act', 'conv'), bias=False).
|
||||
upsample_mode (str): Upsample mode. Default: 'bilinear'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
branch_channels: int,
|
||||
out_channels: int,
|
||||
num_scales: int,
|
||||
kernel_sizes: List[int] = [5, 9, 17],
|
||||
strides: List[int] = [2, 4, 8],
|
||||
paddings: List[int] = [2, 4, 8],
|
||||
norm_cfg: Dict = dict(type='BN', momentum=0.1),
|
||||
act_cfg: Dict = dict(type='ReLU', inplace=True),
|
||||
conv_cfg: Dict = dict(
|
||||
order=('norm', 'act', 'conv'), bias=False),
|
||||
upsample_mode: str = 'bilinear'):
|
||||
super().__init__(in_channels, branch_channels, out_channels,
|
||||
num_scales, kernel_sizes, strides, paddings, norm_cfg,
|
||||
act_cfg, conv_cfg, upsample_mode)
|
||||
|
||||
self.processes = ConvModule(
|
||||
self.branch_channels * (self.num_scales - 1),
|
||||
self.branch_channels * (self.num_scales - 1),
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
groups=self.num_scales - 1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
**self.conv_cfg)
|
||||
|
||||
def forward(self, inputs: Tensor):
|
||||
x_ = self.scales[0](inputs)
|
||||
feats = []
|
||||
for i in range(1, self.num_scales):
|
||||
feat_up = F.interpolate(
|
||||
self.scales[i](inputs),
|
||||
size=inputs.shape[2:],
|
||||
mode=self.unsample_mode,
|
||||
align_corners=False)
|
||||
feats.append(feat_up + x_)
|
||||
scale_out = self.processes(torch.cat(feats, dim=1))
|
||||
return self.compression(torch.cat([x_, scale_out],
|
||||
dim=1)) + self.shortcut(inputs)
|
||||
96
finetune/mmseg/models/utils/res_layer.py
Normal file
96
finetune/mmseg/models/utils/res_layer.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
from mmengine.model import Sequential
|
||||
from torch import nn as nn
|
||||
|
||||
|
||||
class ResLayer(Sequential):
|
||||
"""ResLayer to build ResNet style backbone.
|
||||
|
||||
Args:
|
||||
block (nn.Module): block used to build ResLayer.
|
||||
inplanes (int): inplanes of block.
|
||||
planes (int): planes of block.
|
||||
num_blocks (int): number of blocks.
|
||||
stride (int): stride of the first block. Default: 1
|
||||
avg_down (bool): Use AvgPool instead of stride conv when
|
||||
downsampling in the bottleneck. Default: False
|
||||
conv_cfg (dict): dictionary to construct and config conv layer.
|
||||
Default: None
|
||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN')
|
||||
multi_grid (int | None): Multi grid dilation rates of last
|
||||
stage. Default: None
|
||||
contract_dilation (bool): Whether contract first dilation of each layer
|
||||
Default: False
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
block,
|
||||
inplanes,
|
||||
planes,
|
||||
num_blocks,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
avg_down=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
multi_grid=None,
|
||||
contract_dilation=False,
|
||||
**kwargs):
|
||||
self.block = block
|
||||
|
||||
downsample = None
|
||||
if stride != 1 or inplanes != planes * block.expansion:
|
||||
downsample = []
|
||||
conv_stride = stride
|
||||
if avg_down:
|
||||
conv_stride = 1
|
||||
downsample.append(
|
||||
nn.AvgPool2d(
|
||||
kernel_size=stride,
|
||||
stride=stride,
|
||||
ceil_mode=True,
|
||||
count_include_pad=False))
|
||||
downsample.extend([
|
||||
build_conv_layer(
|
||||
conv_cfg,
|
||||
inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=conv_stride,
|
||||
bias=False),
|
||||
build_norm_layer(norm_cfg, planes * block.expansion)[1]
|
||||
])
|
||||
downsample = nn.Sequential(*downsample)
|
||||
|
||||
layers = []
|
||||
if multi_grid is None:
|
||||
if dilation > 1 and contract_dilation:
|
||||
first_dilation = dilation // 2
|
||||
else:
|
||||
first_dilation = dilation
|
||||
else:
|
||||
first_dilation = multi_grid[0]
|
||||
layers.append(
|
||||
block(
|
||||
inplanes=inplanes,
|
||||
planes=planes,
|
||||
stride=stride,
|
||||
dilation=first_dilation,
|
||||
downsample=downsample,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
**kwargs))
|
||||
inplanes = planes * block.expansion
|
||||
for i in range(1, num_blocks):
|
||||
layers.append(
|
||||
block(
|
||||
inplanes=inplanes,
|
||||
planes=planes,
|
||||
stride=1,
|
||||
dilation=dilation if multi_grid is None else multi_grid[i],
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
**kwargs))
|
||||
super().__init__(*layers)
|
||||
418
finetune/mmseg/models/utils/san_layers.py
Normal file
418
finetune/mmseg/models/utils/san_layers.py
Normal file
@@ -0,0 +1,418 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from https://github.com/MendelXu/SAN/blob/main/san/model/attn_helper.py # noqa: E501
|
||||
# Copyright (c) 2023 MendelXu.
|
||||
# Licensed under the MIT License
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def cross_attn_with_self_bias(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
embed_dim_to_check: int,
|
||||
num_heads: int,
|
||||
in_proj_weight: Tensor,
|
||||
in_proj_bias: Tensor,
|
||||
bias_k: Optional[Tensor],
|
||||
bias_v: Optional[Tensor],
|
||||
add_zero_attn: bool,
|
||||
dropout_p: float,
|
||||
out_proj_weight: Tensor,
|
||||
out_proj_bias: Tensor,
|
||||
training: bool = True,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
use_separate_proj_weight: bool = False,
|
||||
q_proj_weight: Optional[Tensor] = None,
|
||||
k_proj_weight: Optional[Tensor] = None,
|
||||
v_proj_weight: Optional[Tensor] = None,
|
||||
static_k: Optional[Tensor] = None,
|
||||
static_v: Optional[Tensor] = None,
|
||||
):
|
||||
"""Forward function of multi-head attention. Modified from
|
||||
multi_head_attention_forward in
|
||||
https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py.
|
||||
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
See "Attention Is All You Need" for more details.
|
||||
embed_dim_to_check: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
in_proj_weight, in_proj_bias: input projection weight and bias.
|
||||
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
||||
add_zero_attn: add a new batch of zeros to the key and
|
||||
value sequences at dim=1.
|
||||
dropout_p: probability of an element to be zeroed.
|
||||
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
||||
training: apply dropout if is ``True``.
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. This is an binary mask. When the value is True,
|
||||
the corresponding value on the attention layer will be filled with -inf.
|
||||
need_weights: output attn_output_weights.
|
||||
Default: `True`
|
||||
Note: `needs_weight` defaults to `True`, but should be set to `False`
|
||||
For best performance when attention weights are not needed.
|
||||
*Setting needs_weights to `True`
|
||||
leads to a significant performance degradation.*
|
||||
attn_mask: 2D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
||||
and value in different forms. If false, in_proj_weight will be used, which is
|
||||
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
||||
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
||||
static_k, static_v: static key and value used for attention operators.
|
||||
""" # noqa: E501
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == embed_dim_to_check
|
||||
# allow MHA to have different sizes for the feature dimension
|
||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||
|
||||
head_dim = embed_dim // num_heads
|
||||
assert head_dim * num_heads == embed_dim, \
|
||||
'embed_dim must be divisible by num_heads'
|
||||
scaling = float(head_dim)**-0.5
|
||||
|
||||
if not use_separate_proj_weight:
|
||||
if (query is key or torch.equal(
|
||||
query, key)) and (key is value or torch.equal(key, value)):
|
||||
# self-attention
|
||||
raise NotImplementedError('self-attention is not implemented')
|
||||
|
||||
elif key is value or torch.equal(key, value):
|
||||
# encoder-decoder attention
|
||||
# This is inline in_proj function
|
||||
# with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = F.linear(query, _w, _b)
|
||||
|
||||
if key is None:
|
||||
assert value is None
|
||||
k = None
|
||||
v = None
|
||||
q_k = None
|
||||
q_v = None
|
||||
else:
|
||||
# This is inline in_proj function with
|
||||
# in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = None
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
|
||||
q_k, q_v = F.linear(query, _w, _b).chunk(2, dim=-1)
|
||||
else:
|
||||
# This is inline in_proj function with
|
||||
# in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = F.linear(query, _w, _b)
|
||||
|
||||
# This is inline in_proj function with
|
||||
# in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = embed_dim * 2
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
k = F.linear(key, _w, _b)
|
||||
q_k = F.linear(query, _w, _b)
|
||||
# This is inline in_proj function with
|
||||
# in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim * 2
|
||||
_end = None
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
v = F.linear(value, _w, _b)
|
||||
q_v = F.linear(query, _w, _b)
|
||||
else:
|
||||
q_proj_weight_non_opt = \
|
||||
torch.jit._unwrap_optional(q_proj_weight)
|
||||
len1, len2 = q_proj_weight_non_opt.size()
|
||||
assert len1 == embed_dim and len2 == query.size(-1)
|
||||
|
||||
k_proj_weight_non_opt = \
|
||||
torch.jit._unwrap_optional(k_proj_weight)
|
||||
len1, len2 = k_proj_weight_non_opt.size()
|
||||
assert len1 == embed_dim and len2 == key.size(-1)
|
||||
|
||||
v_proj_weight_non_opt = \
|
||||
torch.jit._unwrap_optional(v_proj_weight)
|
||||
len1, len2 = v_proj_weight_non_opt.size()
|
||||
assert len1 == embed_dim and len2 == value.size(-1)
|
||||
|
||||
if in_proj_bias is not None:
|
||||
q = F.linear(query, q_proj_weight_non_opt,
|
||||
in_proj_bias[0:embed_dim])
|
||||
k = F.linear(key, k_proj_weight_non_opt,
|
||||
in_proj_bias[embed_dim:(embed_dim * 2)])
|
||||
v = F.linear(value, v_proj_weight_non_opt,
|
||||
in_proj_bias[(embed_dim * 2):])
|
||||
else:
|
||||
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
|
||||
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
|
||||
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
|
||||
q = q * scaling
|
||||
|
||||
if attn_mask is not None:
|
||||
assert (
|
||||
attn_mask.dtype == torch.float32
|
||||
or attn_mask.dtype == torch.float64
|
||||
or attn_mask.dtype == torch.float16
|
||||
or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool
|
||||
), 'Only float, byte, and bool types are supported for ' \
|
||||
'attn_mask, not {}'.format(attn_mask.dtype)
|
||||
if attn_mask.dtype == torch.uint8:
|
||||
warnings.warn('Byte tensor for attn_mask in nn.MultiheadAttention '
|
||||
'is deprecated. Use bool tensor instead.')
|
||||
attn_mask = attn_mask.to(torch.bool)
|
||||
|
||||
if attn_mask.dim() == 2:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||
raise RuntimeError(
|
||||
'The size of the 2D attn_mask is not correct.')
|
||||
elif attn_mask.dim() == 3:
|
||||
if list(attn_mask.size()) != [
|
||||
bsz * num_heads,
|
||||
query.size(0), key.size(0)
|
||||
]:
|
||||
raise RuntimeError(
|
||||
'The size of the 3D attn_mask is not correct.')
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"attn_mask's dimension {} is not supported".format(
|
||||
attn_mask.dim()))
|
||||
# attn_mask's dim is 3 now.
|
||||
|
||||
# convert ByteTensor key_padding_mask to bool
|
||||
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
||||
warnings.warn(
|
||||
'Byte tensor for key_padding_mask in nn.MultiheadAttention '
|
||||
'is deprecated. Use bool tensor instead.')
|
||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||
|
||||
if bias_k is not None and bias_v is not None:
|
||||
if static_k is None and static_v is None:
|
||||
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
||||
if attn_mask is not None:
|
||||
attn_mask = F.pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = F.pad(key_padding_mask, (0, 1))
|
||||
else:
|
||||
assert static_k is None, 'bias cannot be added to static key.'
|
||||
assert static_v is None, 'bias cannot be added to static value.'
|
||||
else:
|
||||
assert bias_k is None
|
||||
assert bias_v is None
|
||||
|
||||
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
if k is not None:
|
||||
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
q_k = q_k.contiguous().view(tgt_len, bsz * num_heads,
|
||||
head_dim).transpose(0, 1)
|
||||
if v is not None:
|
||||
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
q_v = q_v.contiguous().view(tgt_len, bsz * num_heads,
|
||||
head_dim).transpose(0, 1)
|
||||
|
||||
if static_k is not None:
|
||||
assert static_k.size(0) == bsz * num_heads
|
||||
assert static_k.size(2) == head_dim
|
||||
k = static_k
|
||||
|
||||
if static_v is not None:
|
||||
assert static_v.size(0) == bsz * num_heads
|
||||
assert static_v.size(2) == head_dim
|
||||
v = static_v
|
||||
|
||||
src_len = k.size(1)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
|
||||
if add_zero_attn:
|
||||
src_len += 1
|
||||
k = torch.cat(
|
||||
[
|
||||
k,
|
||||
torch.zeros(
|
||||
(k.size(0), 1) + k.size()[2:],
|
||||
dtype=k.dtype,
|
||||
device=k.device),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
v = torch.cat(
|
||||
[
|
||||
v,
|
||||
torch.zeros(
|
||||
(v.size(0), 1) + v.size()[2:],
|
||||
dtype=v.dtype,
|
||||
device=v.device),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
if attn_mask is not None:
|
||||
attn_mask = F.pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = F.pad(key_padding_mask, (0, 1))
|
||||
|
||||
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
assert list(
|
||||
attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
|
||||
else:
|
||||
attn_output_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len,
|
||||
src_len)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float('-inf'),
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(bsz * num_heads,
|
||||
tgt_len, src_len)
|
||||
# attn_out_weights: [bsz * num_heads, tgt_len, src_len]
|
||||
# ->[bsz * num_heads, tgt_len, src_len+1]
|
||||
self_weight = (q * q_k).sum(
|
||||
dim=-1, keepdim=True) # [bsz * num_heads, tgt_len, 1]
|
||||
total_attn_output_weights = torch.cat([attn_output_weights, self_weight],
|
||||
dim=-1)
|
||||
total_attn_output_weights = F.softmax(total_attn_output_weights, dim=-1)
|
||||
total_attn_output_weights = F.dropout(
|
||||
total_attn_output_weights, p=dropout_p, training=training)
|
||||
attn_output_weights = \
|
||||
total_attn_output_weights[:, :, : -1]
|
||||
# [bsz * num_heads, tgt_len, src_len]
|
||||
self_weight = \
|
||||
total_attn_output_weights[:, :, -1:] # [bsz * num_heads, tgt_len, 1]
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights,
|
||||
v) # [bsz * num_heads, tgt_len, head_dim]
|
||||
attn_output = (attn_output + self_weight * q_v
|
||||
) # [bsz * num_heads, tgt_len, head_dim]
|
||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||
attn_output = attn_output.transpose(0, 1).contiguous().view(
|
||||
tgt_len, bsz, embed_dim)
|
||||
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len,
|
||||
src_len)
|
||||
return attn_output, attn_output_weights # .sum(dim=1) / num_heads
|
||||
else:
|
||||
return attn_output, None
|
||||
|
||||
|
||||
def cross_attn_layer(tf_layer: BaseTransformerLayer, x, mem, attn_bias):
|
||||
"""Implementation of transformer layer with cross attention. The cross
|
||||
attention shares the embedding weights with self-attention of tf_layer.
|
||||
Args:
|
||||
tf_layer: (TransformerEncoderLayer): The Module of transformer layer.
|
||||
x (Tensor): query [K,N,C]
|
||||
mem (Tensor): key and value [L,N,C]
|
||||
attn_bias (Tensor): attention bias [N*num_head,K,L]
|
||||
|
||||
Return:
|
||||
x (Tensor): cross attention output [K,N,C]
|
||||
"""
|
||||
self_attn_layer = tf_layer.attentions[0].attn
|
||||
attn_layer_paras = {
|
||||
'embed_dim_to_check': self_attn_layer.embed_dim,
|
||||
'num_heads': self_attn_layer.num_heads,
|
||||
'in_proj_weight': self_attn_layer.in_proj_weight,
|
||||
'in_proj_bias': self_attn_layer.in_proj_bias,
|
||||
'bias_k': self_attn_layer.bias_k,
|
||||
'bias_v': self_attn_layer.bias_v,
|
||||
'add_zero_attn': self_attn_layer.add_zero_attn,
|
||||
'dropout_p': self_attn_layer.dropout,
|
||||
'out_proj_weight': self_attn_layer.out_proj.weight,
|
||||
'out_proj_bias': self_attn_layer.out_proj.bias,
|
||||
'training': self_attn_layer.training
|
||||
}
|
||||
|
||||
q_x = tf_layer.norms[0](x)
|
||||
k_x = v_x = tf_layer.norms[0](mem)
|
||||
x = x + cross_attn_with_self_bias(
|
||||
q_x,
|
||||
k_x,
|
||||
v_x,
|
||||
attn_mask=attn_bias,
|
||||
need_weights=False,
|
||||
**attn_layer_paras)[0]
|
||||
x = tf_layer.ffns[0](tf_layer.norms[1](x), identity=x)
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
"""A LayerNorm variant, popularized by Transformers, that performs point-
|
||||
wise mean and variance normalization over the channel dimension for inputs
|
||||
that have shape (batch_size, channels, height, width).
|
||||
|
||||
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||
self.eps = eps
|
||||
self.normalized_shape = (normalized_shape, )
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
return x
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""Very simple multi-layer perceptron (also called FFN)"""
|
||||
|
||||
def __init__(self,
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
output_dim,
|
||||
num_layers,
|
||||
affine_func=nn.Linear):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(
|
||||
affine_func(n, k)
|
||||
for n, k in zip([input_dim] + h, h + [output_dim]))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
return x
|
||||
58
finetune/mmseg/models/utils/se_layer.py
Normal file
58
finetune/mmseg/models/utils/se_layer.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.utils import is_tuple_of
|
||||
|
||||
from .make_divisible import make_divisible
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
"""Squeeze-and-Excitation Module.
|
||||
|
||||
Args:
|
||||
channels (int): The input (and output) channels of the SE layer.
|
||||
ratio (int): Squeeze ratio in SELayer, the intermediate channel will be
|
||||
``int(channels/ratio)``. Default: 16.
|
||||
conv_cfg (None or dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
act_cfg (dict or Sequence[dict]): Config dict for activation layer.
|
||||
If act_cfg is a dict, two activation layers will be configured
|
||||
by this dict. If act_cfg is a sequence of dicts, the first
|
||||
activation layer will be configured by the first dict and the
|
||||
second activation layer will be configured by the second dict.
|
||||
Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0,
|
||||
divisor=6.0)).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
ratio=16,
|
||||
conv_cfg=None,
|
||||
act_cfg=(dict(type='ReLU'),
|
||||
dict(type='HSigmoid', bias=3.0, divisor=6.0))):
|
||||
super().__init__()
|
||||
if isinstance(act_cfg, dict):
|
||||
act_cfg = (act_cfg, act_cfg)
|
||||
assert len(act_cfg) == 2
|
||||
assert is_tuple_of(act_cfg, dict)
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=channels,
|
||||
out_channels=make_divisible(channels // ratio, 8),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
conv_cfg=conv_cfg,
|
||||
act_cfg=act_cfg[0])
|
||||
self.conv2 = ConvModule(
|
||||
in_channels=make_divisible(channels // ratio, 8),
|
||||
out_channels=channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
conv_cfg=conv_cfg,
|
||||
act_cfg=act_cfg[1])
|
||||
|
||||
def forward(self, x):
|
||||
out = self.global_avgpool(x)
|
||||
out = self.conv1(out)
|
||||
out = self.conv2(out)
|
||||
return x * out
|
||||
161
finetune/mmseg/models/utils/self_attention_block.py
Normal file
161
finetune/mmseg/models/utils/self_attention_block.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model.weight_init import constant_init
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class SelfAttentionBlock(nn.Module):
|
||||
"""General self-attention block/non-local block.
|
||||
|
||||
Please refer to https://arxiv.org/abs/1706.03762 for details about key,
|
||||
query and value.
|
||||
|
||||
Args:
|
||||
key_in_channels (int): Input channels of key feature.
|
||||
query_in_channels (int): Input channels of query feature.
|
||||
channels (int): Output channels of key/query transform.
|
||||
out_channels (int): Output channels.
|
||||
share_key_query (bool): Whether share projection weight between key
|
||||
and query projection.
|
||||
query_downsample (nn.Module): Query downsample module.
|
||||
key_downsample (nn.Module): Key downsample module.
|
||||
key_query_num_convs (int): Number of convs for key/query projection.
|
||||
value_num_convs (int): Number of convs for value projection.
|
||||
matmul_norm (bool): Whether normalize attention map with sqrt of
|
||||
channels
|
||||
with_out (bool): Whether use out projection.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict|None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, key_in_channels, query_in_channels, channels,
|
||||
out_channels, share_key_query, query_downsample,
|
||||
key_downsample, key_query_num_convs, value_out_num_convs,
|
||||
key_query_norm, value_out_norm, matmul_norm, with_out,
|
||||
conv_cfg, norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
if share_key_query:
|
||||
assert key_in_channels == query_in_channels
|
||||
self.key_in_channels = key_in_channels
|
||||
self.query_in_channels = query_in_channels
|
||||
self.out_channels = out_channels
|
||||
self.channels = channels
|
||||
self.share_key_query = share_key_query
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.key_project = self.build_project(
|
||||
key_in_channels,
|
||||
channels,
|
||||
num_convs=key_query_num_convs,
|
||||
use_conv_module=key_query_norm,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
if share_key_query:
|
||||
self.query_project = self.key_project
|
||||
else:
|
||||
self.query_project = self.build_project(
|
||||
query_in_channels,
|
||||
channels,
|
||||
num_convs=key_query_num_convs,
|
||||
use_conv_module=key_query_norm,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.value_project = self.build_project(
|
||||
key_in_channels,
|
||||
channels if with_out else out_channels,
|
||||
num_convs=value_out_num_convs,
|
||||
use_conv_module=value_out_norm,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
if with_out:
|
||||
self.out_project = self.build_project(
|
||||
channels,
|
||||
out_channels,
|
||||
num_convs=value_out_num_convs,
|
||||
use_conv_module=value_out_norm,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
else:
|
||||
self.out_project = None
|
||||
|
||||
self.query_downsample = query_downsample
|
||||
self.key_downsample = key_downsample
|
||||
self.matmul_norm = matmul_norm
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize weight of later layer."""
|
||||
if self.out_project is not None:
|
||||
if not isinstance(self.out_project, ConvModule):
|
||||
constant_init(self.out_project, 0)
|
||||
|
||||
def build_project(self, in_channels, channels, num_convs, use_conv_module,
|
||||
conv_cfg, norm_cfg, act_cfg):
|
||||
"""Build projection layer for key/query/value/out."""
|
||||
if use_conv_module:
|
||||
convs = [
|
||||
ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
]
|
||||
for _ in range(num_convs - 1):
|
||||
convs.append(
|
||||
ConvModule(
|
||||
channels,
|
||||
channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
else:
|
||||
convs = [nn.Conv2d(in_channels, channels, 1)]
|
||||
for _ in range(num_convs - 1):
|
||||
convs.append(nn.Conv2d(channels, channels, 1))
|
||||
if len(convs) > 1:
|
||||
convs = nn.Sequential(*convs)
|
||||
else:
|
||||
convs = convs[0]
|
||||
return convs
|
||||
|
||||
def forward(self, query_feats, key_feats):
|
||||
"""Forward function."""
|
||||
batch_size = query_feats.size(0)
|
||||
query = self.query_project(query_feats)
|
||||
if self.query_downsample is not None:
|
||||
query = self.query_downsample(query)
|
||||
query = query.reshape(*query.shape[:2], -1)
|
||||
query = query.permute(0, 2, 1).contiguous()
|
||||
|
||||
key = self.key_project(key_feats)
|
||||
value = self.value_project(key_feats)
|
||||
if self.key_downsample is not None:
|
||||
key = self.key_downsample(key)
|
||||
value = self.key_downsample(value)
|
||||
key = key.reshape(*key.shape[:2], -1)
|
||||
value = value.reshape(*value.shape[:2], -1)
|
||||
value = value.permute(0, 2, 1).contiguous()
|
||||
|
||||
sim_map = torch.matmul(query, key)
|
||||
if self.matmul_norm:
|
||||
sim_map = (self.channels**-.5) * sim_map
|
||||
sim_map = F.softmax(sim_map, dim=-1)
|
||||
|
||||
context = torch.matmul(sim_map, value)
|
||||
context = context.permute(0, 2, 1).contiguous()
|
||||
context = context.reshape(batch_size, -1, *query_feats.shape[2:])
|
||||
if self.out_project is not None:
|
||||
context = self.out_project(context)
|
||||
return context
|
||||
107
finetune/mmseg/models/utils/shape_convert.py
Normal file
107
finetune/mmseg/models/utils/shape_convert.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
def nlc_to_nchw(x, hw_shape):
|
||||
"""Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
|
||||
|
||||
Args:
|
||||
x (Tensor): The input tensor of shape [N, L, C] before conversion.
|
||||
hw_shape (Sequence[int]): The height and width of output feature map.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor of shape [N, C, H, W] after conversion.
|
||||
"""
|
||||
H, W = hw_shape
|
||||
assert len(x.shape) == 3
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, 'The seq_len doesn\'t match H, W'
|
||||
return x.transpose(1, 2).reshape(B, C, H, W)
|
||||
|
||||
|
||||
def nchw_to_nlc(x):
|
||||
"""Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
|
||||
|
||||
Args:
|
||||
x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor of shape [N, L, C] after conversion.
|
||||
"""
|
||||
assert len(x.shape) == 4
|
||||
return x.flatten(2).transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
def nchw2nlc2nchw(module, x, contiguous=False, **kwargs):
|
||||
"""Flatten [N, C, H, W] shape tensor `x` to [N, L, C] shape tensor. Use the
|
||||
reshaped tensor as the input of `module`, and the convert the output of
|
||||
`module`, whose shape is.
|
||||
|
||||
[N, L, C], to [N, C, H, W].
|
||||
|
||||
Args:
|
||||
module (Callable): A callable object the takes a tensor
|
||||
with shape [N, L, C] as input.
|
||||
x (Tensor): The input tensor of shape [N, C, H, W].
|
||||
contiguous:
|
||||
contiguous (Bool): Whether to make the tensor contiguous
|
||||
after each shape transform.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor of shape [N, C, H, W].
|
||||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> import torch.nn as nn
|
||||
>>> norm = nn.LayerNorm(4)
|
||||
>>> feature_map = torch.rand(4, 4, 5, 5)
|
||||
>>> output = nchw2nlc2nchw(norm, feature_map)
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
if not contiguous:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = module(x, **kwargs)
|
||||
x = x.transpose(1, 2).reshape(B, C, H, W)
|
||||
else:
|
||||
x = x.flatten(2).transpose(1, 2).contiguous()
|
||||
x = module(x, **kwargs)
|
||||
x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()
|
||||
return x
|
||||
|
||||
|
||||
def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs):
|
||||
"""Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the
|
||||
reshaped tensor as the input of `module`, and convert the output of
|
||||
`module`, whose shape is.
|
||||
|
||||
[N, C, H, W], to [N, L, C].
|
||||
|
||||
Args:
|
||||
module (Callable): A callable object the takes a tensor
|
||||
with shape [N, C, H, W] as input.
|
||||
x (Tensor): The input tensor of shape [N, L, C].
|
||||
hw_shape: (Sequence[int]): The height and width of the
|
||||
feature map with shape [N, C, H, W].
|
||||
contiguous (Bool): Whether to make the tensor contiguous
|
||||
after each shape transform.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor of shape [N, L, C].
|
||||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> import torch.nn as nn
|
||||
>>> conv = nn.Conv2d(16, 16, 3, 1, 1)
|
||||
>>> feature_map = torch.rand(4, 25, 16)
|
||||
>>> output = nlc2nchw2nlc(conv, feature_map, (5, 5))
|
||||
"""
|
||||
H, W = hw_shape
|
||||
assert len(x.shape) == 3
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, 'The seq_len doesn\'t match H, W'
|
||||
if not contiguous:
|
||||
x = x.transpose(1, 2).reshape(B, C, H, W)
|
||||
x = module(x, **kwargs)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
else:
|
||||
x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()
|
||||
x = module(x, **kwargs)
|
||||
x = x.flatten(2).transpose(1, 2).contiguous()
|
||||
return x
|
||||
102
finetune/mmseg/models/utils/up_conv_block.py
Normal file
102
finetune/mmseg/models/utils/up_conv_block.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_upsample_layer
|
||||
|
||||
|
||||
class UpConvBlock(nn.Module):
|
||||
"""Upsample convolution block in decoder for UNet.
|
||||
|
||||
This upsample convolution block consists of one upsample module
|
||||
followed by one convolution block. The upsample module expands the
|
||||
high-level low-resolution feature map and the convolution block fuses
|
||||
the upsampled high-level low-resolution feature map and the low-level
|
||||
high-resolution feature map from encoder.
|
||||
|
||||
Args:
|
||||
conv_block (nn.Sequential): Sequential of convolutional layers.
|
||||
in_channels (int): Number of input channels of the high-level
|
||||
skip_channels (int): Number of input channels of the low-level
|
||||
high-resolution feature map from encoder.
|
||||
out_channels (int): Number of output channels.
|
||||
num_convs (int): Number of convolutional layers in the conv_block.
|
||||
Default: 2.
|
||||
stride (int): Stride of convolutional layer in conv_block. Default: 1.
|
||||
dilation (int): Dilation rate of convolutional layer in conv_block.
|
||||
Default: 1.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
upsample_cfg (dict): The upsample config of the upsample module in
|
||||
decoder. Default: dict(type='InterpConv'). If the size of
|
||||
high-level feature map is the same as that of skip feature map
|
||||
(low-level feature map from encoder), it does not need upsample the
|
||||
high-level feature map and the upsample_cfg is None.
|
||||
dcn (bool): Use deformable convolution in convolutional layer or not.
|
||||
Default: None.
|
||||
plugins (dict): plugins for convolutional layers. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
conv_block,
|
||||
in_channels,
|
||||
skip_channels,
|
||||
out_channels,
|
||||
num_convs=2,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
upsample_cfg=dict(type='InterpConv'),
|
||||
dcn=None,
|
||||
plugins=None):
|
||||
super().__init__()
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
|
||||
self.conv_block = conv_block(
|
||||
in_channels=2 * skip_channels,
|
||||
out_channels=out_channels,
|
||||
num_convs=num_convs,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
dcn=None,
|
||||
plugins=None)
|
||||
if upsample_cfg is not None:
|
||||
self.upsample = build_upsample_layer(
|
||||
cfg=upsample_cfg,
|
||||
in_channels=in_channels,
|
||||
out_channels=skip_channels,
|
||||
with_cp=with_cp,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
else:
|
||||
self.upsample = ConvModule(
|
||||
in_channels,
|
||||
skip_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, skip, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.upsample(x)
|
||||
out = torch.cat([skip, x], dim=1)
|
||||
out = self.conv_block(out)
|
||||
|
||||
return out
|
||||
51
finetune/mmseg/models/utils/wrappers.py
Normal file
51
finetune/mmseg/models/utils/wrappers.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def resize(input,
|
||||
size=None,
|
||||
scale_factor=None,
|
||||
mode='nearest',
|
||||
align_corners=None,
|
||||
warning=True):
|
||||
if warning:
|
||||
if size is not None and align_corners:
|
||||
input_h, input_w = tuple(int(x) for x in input.shape[2:])
|
||||
output_h, output_w = tuple(int(x) for x in size)
|
||||
if output_h > input_h or output_w > output_h:
|
||||
if ((output_h > 1 and output_w > 1 and input_h > 1
|
||||
and input_w > 1) and (output_h - 1) % (input_h - 1)
|
||||
and (output_w - 1) % (input_w - 1)):
|
||||
warnings.warn(
|
||||
f'When align_corners={align_corners}, '
|
||||
'the output would more aligned if '
|
||||
f'input size {(input_h, input_w)} is `x+1` and '
|
||||
f'out size {(output_h, output_w)} is `nx+1`')
|
||||
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
size=None,
|
||||
scale_factor=None,
|
||||
mode='nearest',
|
||||
align_corners=None):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
if isinstance(scale_factor, tuple):
|
||||
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
||||
else:
|
||||
self.scale_factor = float(scale_factor) if scale_factor else None
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
if not self.size:
|
||||
size = [int(t * self.scale_factor) for t in x.shape[-2:]]
|
||||
else:
|
||||
size = self.size
|
||||
return resize(x, size, None, self.mode, self.align_corners)
|
||||
Reference in New Issue
Block a user