init
This commit is contained in:
16
finetune/mmseg/models/__init__.py
Normal file
16
finetune/mmseg/models/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .assigners import * # noqa: F401,F403
|
||||
from .backbones import * # noqa: F401,F403
|
||||
from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
|
||||
build_head, build_loss, build_segmentor)
|
||||
from .data_preprocessor import SegDataPreProcessor
|
||||
from .decode_heads import * # noqa: F401,F403
|
||||
from .losses import * # noqa: F401,F403
|
||||
from .necks import * # noqa: F401,F403
|
||||
from .segmentors import * # noqa: F401,F403
|
||||
from .text_encoder import * # noqa: F401,F403
|
||||
|
||||
__all__ = [
|
||||
'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
|
||||
'build_head', 'build_loss', 'build_segmentor', 'SegDataPreProcessor'
|
||||
]
|
||||
12
finetune/mmseg/models/assigners/__init__.py
Normal file
12
finetune/mmseg/models/assigners/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_assigner import BaseAssigner
|
||||
from .hungarian_assigner import HungarianAssigner
|
||||
from .match_cost import ClassificationCost, CrossEntropyLossCost, DiceCost
|
||||
|
||||
__all__ = [
|
||||
'BaseAssigner',
|
||||
'HungarianAssigner',
|
||||
'ClassificationCost',
|
||||
'CrossEntropyLossCost',
|
||||
'DiceCost',
|
||||
]
|
||||
18
finetune/mmseg/models/assigners/base_assigner.py
Normal file
18
finetune/mmseg/models/assigners/base_assigner.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from mmengine.structures import InstanceData
|
||||
|
||||
|
||||
class BaseAssigner(metaclass=ABCMeta):
|
||||
"""Base assigner that assigns masks to ground truth class labels."""
|
||||
|
||||
@abstractmethod
|
||||
def assign(self,
|
||||
pred_instances: InstanceData,
|
||||
gt_instances: InstanceData,
|
||||
gt_instances_ignore: Optional[InstanceData] = None,
|
||||
**kwargs):
|
||||
"""Assign masks to either a ground truth class label or a negative
|
||||
label."""
|
||||
86
finetune/mmseg/models/assigners/hungarian_assigner.py
Normal file
86
finetune/mmseg/models/assigners/hungarian_assigner.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from mmengine import ConfigDict
|
||||
from mmengine.structures import InstanceData
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from mmseg.registry import TASK_UTILS
|
||||
from .base_assigner import BaseAssigner
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class HungarianAssigner(BaseAssigner):
|
||||
"""Computes one-to-one matching between prediction masks and ground truth.
|
||||
|
||||
This class uses bipartite matching-based assignment to computes an
|
||||
assignment between the prediction masks and the ground truth. The
|
||||
assignment result is based on the weighted sum of match costs. The
|
||||
Hungarian algorithm is used to calculate the best matching with the
|
||||
minimum cost. The prediction masks that are not matched are classified
|
||||
as background.
|
||||
|
||||
Args:
|
||||
match_costs (ConfigDict|List[ConfigDict]): Match cost configs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, match_costs: Union[List[Union[dict, ConfigDict]], dict,
|
||||
ConfigDict]
|
||||
) -> None:
|
||||
|
||||
if isinstance(match_costs, dict):
|
||||
match_costs = [match_costs]
|
||||
elif isinstance(match_costs, list):
|
||||
assert len(match_costs) > 0, \
|
||||
'match_costs must not be a empty list.'
|
||||
|
||||
self.match_costs = [
|
||||
TASK_UTILS.build(match_cost) for match_cost in match_costs
|
||||
]
|
||||
|
||||
def assign(self, pred_instances: InstanceData, gt_instances: InstanceData,
|
||||
**kwargs):
|
||||
"""Computes one-to-one matching based on the weighted costs.
|
||||
|
||||
This method assign each query prediction to a ground truth or
|
||||
background. The assignment first calculates the cost for each
|
||||
category assigned to each query mask, and then uses the
|
||||
Hungarian algorithm to calculate the minimum cost as the best
|
||||
match.
|
||||
|
||||
Args:
|
||||
pred_instances (InstanceData): Instances of model
|
||||
predictions. It includes "masks", with shape
|
||||
(n, h, w) or (n, l), and "cls", with shape (n, num_classes+1)
|
||||
gt_instances (InstanceData): Ground truth of instance
|
||||
annotations. It includes "labels", with shape (k, ),
|
||||
and "masks", with shape (k, h, w) or (k, l).
|
||||
|
||||
Returns:
|
||||
matched_quiery_inds (Tensor): The indexes of matched quieres.
|
||||
matched_label_inds (Tensor): The indexes of matched labels.
|
||||
"""
|
||||
# compute weighted cost
|
||||
cost_list = []
|
||||
with autocast(enabled=False):
|
||||
for match_cost in self.match_costs:
|
||||
cost = match_cost(
|
||||
pred_instances=pred_instances, gt_instances=gt_instances)
|
||||
cost_list.append(cost)
|
||||
cost = torch.stack(cost_list).sum(dim=0)
|
||||
|
||||
device = cost.device
|
||||
# do Hungarian matching on CPU using linear_sum_assignment
|
||||
cost = cost.detach().cpu()
|
||||
if linear_sum_assignment is None:
|
||||
raise ImportError('Please run "pip install scipy" '
|
||||
'to install scipy first.')
|
||||
|
||||
matched_quiery_inds, matched_label_inds = linear_sum_assignment(cost)
|
||||
matched_quiery_inds = torch.from_numpy(matched_quiery_inds).to(device)
|
||||
matched_label_inds = torch.from_numpy(matched_label_inds).to(device)
|
||||
|
||||
return matched_quiery_inds, matched_label_inds
|
||||
231
finetune/mmseg/models/assigners/match_cost.py
Normal file
231
finetune/mmseg/models/assigners/match_cost.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import abstractmethod
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import TASK_UTILS
|
||||
|
||||
|
||||
class BaseMatchCost:
|
||||
"""Base match cost class.
|
||||
|
||||
Args:
|
||||
weight (Union[float, int]): Cost weight. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, weight: Union[float, int] = 1.) -> None:
|
||||
self.weight = weight
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, pred_instances: InstanceData,
|
||||
gt_instances: InstanceData, **kwargs) -> Tensor:
|
||||
"""Compute match cost.
|
||||
|
||||
Args:
|
||||
pred_instances (InstanceData): Instances of model predictions.
|
||||
It often includes "labels" and "scores".
|
||||
gt_instances (InstanceData): Ground truth of instance
|
||||
annotations. It usually includes "labels".
|
||||
|
||||
Returns:
|
||||
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class ClassificationCost(BaseMatchCost):
|
||||
"""ClsSoftmaxCost.
|
||||
|
||||
Args:
|
||||
weight (Union[float, int]): Cost weight. Defaults to 1.
|
||||
|
||||
Examples:
|
||||
>>> from mmseg.models.assigners import ClassificationCost
|
||||
>>> import torch
|
||||
>>> self = ClassificationCost()
|
||||
>>> cls_pred = torch.rand(4, 3)
|
||||
>>> gt_labels = torch.tensor([0, 1, 2])
|
||||
>>> factor = torch.tensor([10, 8, 10, 8])
|
||||
>>> self(cls_pred, gt_labels)
|
||||
tensor([[-0.3430, -0.3525, -0.3045],
|
||||
[-0.3077, -0.2931, -0.3992],
|
||||
[-0.3664, -0.3455, -0.2881],
|
||||
[-0.3343, -0.2701, -0.3956]])
|
||||
"""
|
||||
|
||||
def __init__(self, weight: Union[float, int] = 1) -> None:
|
||||
super().__init__(weight=weight)
|
||||
|
||||
def __call__(self, pred_instances: InstanceData,
|
||||
gt_instances: InstanceData, **kwargs) -> Tensor:
|
||||
"""Compute match cost.
|
||||
|
||||
Args:
|
||||
pred_instances (InstanceData): "scores" inside is
|
||||
predicted classification logits, of shape
|
||||
(num_queries, num_class).
|
||||
gt_instances (InstanceData): "labels" inside should have
|
||||
shape (num_gt, ).
|
||||
|
||||
Returns:
|
||||
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
||||
"""
|
||||
assert hasattr(pred_instances, 'scores'), \
|
||||
"pred_instances must contain 'scores'"
|
||||
assert hasattr(gt_instances, 'labels'), \
|
||||
"gt_instances must contain 'labels'"
|
||||
pred_scores = pred_instances.scores
|
||||
gt_labels = gt_instances.labels
|
||||
|
||||
pred_scores = pred_scores.softmax(-1)
|
||||
cls_cost = -pred_scores[:, gt_labels]
|
||||
|
||||
return cls_cost * self.weight
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class DiceCost(BaseMatchCost):
|
||||
"""Cost of mask assignments based on dice losses.
|
||||
|
||||
Args:
|
||||
pred_act (bool): Whether to apply sigmoid to mask_pred.
|
||||
Defaults to False.
|
||||
eps (float): Defaults to 1e-3.
|
||||
naive_dice (bool): If True, use the naive dice loss
|
||||
in which the power of the number in the denominator is
|
||||
the first power. If False, use the second power that
|
||||
is adopted by K-Net and SOLO. Defaults to True.
|
||||
weight (Union[float, int]): Cost weight. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pred_act: bool = False,
|
||||
eps: float = 1e-3,
|
||||
naive_dice: bool = True,
|
||||
weight: Union[float, int] = 1.) -> None:
|
||||
super().__init__(weight=weight)
|
||||
self.pred_act = pred_act
|
||||
self.eps = eps
|
||||
self.naive_dice = naive_dice
|
||||
|
||||
def _binary_mask_dice_loss(self, mask_preds: Tensor,
|
||||
gt_masks: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
mask_preds (Tensor): Mask prediction in shape (num_queries, *).
|
||||
gt_masks (Tensor): Ground truth in shape (num_gt, *)
|
||||
store 0 or 1, 0 for negative class and 1 for
|
||||
positive class.
|
||||
|
||||
Returns:
|
||||
Tensor: Dice cost matrix in shape (num_queries, num_gt).
|
||||
"""
|
||||
mask_preds = mask_preds.flatten(1)
|
||||
gt_masks = gt_masks.flatten(1).float()
|
||||
numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks)
|
||||
if self.naive_dice:
|
||||
denominator = mask_preds.sum(-1)[:, None] + \
|
||||
gt_masks.sum(-1)[None, :]
|
||||
else:
|
||||
denominator = mask_preds.pow(2).sum(1)[:, None] + \
|
||||
gt_masks.pow(2).sum(1)[None, :]
|
||||
loss = 1 - (numerator + self.eps) / (denominator + self.eps)
|
||||
return loss
|
||||
|
||||
def __call__(self, pred_instances: InstanceData,
|
||||
gt_instances: InstanceData, **kwargs) -> Tensor:
|
||||
"""Compute match cost.
|
||||
|
||||
Args:
|
||||
pred_instances (InstanceData): Predicted instances which
|
||||
must contain "masks".
|
||||
gt_instances (InstanceData): Ground truth which must contain
|
||||
"mask".
|
||||
|
||||
Returns:
|
||||
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
||||
"""
|
||||
assert hasattr(pred_instances, 'masks'), \
|
||||
"pred_instances must contain 'masks'"
|
||||
assert hasattr(gt_instances, 'masks'), \
|
||||
"gt_instances must contain 'masks'"
|
||||
pred_masks = pred_instances.masks
|
||||
gt_masks = gt_instances.masks
|
||||
|
||||
if self.pred_act:
|
||||
pred_masks = pred_masks.sigmoid()
|
||||
dice_cost = self._binary_mask_dice_loss(pred_masks, gt_masks)
|
||||
return dice_cost * self.weight
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class CrossEntropyLossCost(BaseMatchCost):
|
||||
"""CrossEntropyLossCost.
|
||||
|
||||
Args:
|
||||
use_sigmoid (bool): Whether the prediction uses sigmoid
|
||||
of softmax. Defaults to True.
|
||||
weight (Union[float, int]): Cost weight. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
use_sigmoid: bool = True,
|
||||
weight: Union[float, int] = 1.) -> None:
|
||||
super().__init__(weight=weight)
|
||||
self.use_sigmoid = use_sigmoid
|
||||
|
||||
def _binary_cross_entropy(self, cls_pred: Tensor,
|
||||
gt_labels: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or
|
||||
(num_queries, *).
|
||||
gt_labels (Tensor): The learning label of prediction with
|
||||
shape (num_gt, *).
|
||||
|
||||
Returns:
|
||||
Tensor: Cross entropy cost matrix in shape (num_queries, num_gt).
|
||||
"""
|
||||
cls_pred = cls_pred.flatten(1).float()
|
||||
gt_labels = gt_labels.flatten(1).float()
|
||||
n = cls_pred.shape[1]
|
||||
pos = F.binary_cross_entropy_with_logits(
|
||||
cls_pred, torch.ones_like(cls_pred), reduction='none')
|
||||
neg = F.binary_cross_entropy_with_logits(
|
||||
cls_pred, torch.zeros_like(cls_pred), reduction='none')
|
||||
cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \
|
||||
torch.einsum('nc,mc->nm', neg, 1 - gt_labels)
|
||||
cls_cost = cls_cost / n
|
||||
|
||||
return cls_cost
|
||||
|
||||
def __call__(self, pred_instances: InstanceData,
|
||||
gt_instances: InstanceData, **kwargs) -> Tensor:
|
||||
"""Compute match cost.
|
||||
|
||||
Args:
|
||||
pred_instances (:obj:`InstanceData`): Predicted instances which
|
||||
must contain ``masks``.
|
||||
gt_instances (:obj:`InstanceData`): Ground truth which must contain
|
||||
``masks``.
|
||||
|
||||
Returns:
|
||||
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
||||
"""
|
||||
assert hasattr(pred_instances, 'masks'), \
|
||||
"pred_instances must contain 'masks'"
|
||||
assert hasattr(gt_instances, 'masks'), \
|
||||
"gt_instances must contain 'masks'"
|
||||
pred_masks = pred_instances.masks
|
||||
gt_masks = gt_instances.masks
|
||||
if self.use_sigmoid:
|
||||
cls_cost = self._binary_cross_entropy(pred_masks, gt_masks)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return cls_cost * self.weight
|
||||
35
finetune/mmseg/models/backbones/__init__.py
Normal file
35
finetune/mmseg/models/backbones/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .beit import BEiT
|
||||
from .bisenetv1 import BiSeNetV1
|
||||
from .bisenetv2 import BiSeNetV2
|
||||
from .cgnet import CGNet
|
||||
from .ddrnet import DDRNet
|
||||
from .erfnet import ERFNet
|
||||
from .fast_scnn import FastSCNN
|
||||
from .hrnet import HRNet
|
||||
from .icnet import ICNet
|
||||
from .mae import MAE
|
||||
from .mit import MixVisionTransformer
|
||||
from .mobilenet_v2 import MobileNetV2
|
||||
from .mobilenet_v3 import MobileNetV3
|
||||
from .mscan import MSCAN
|
||||
from .pidnet import PIDNet
|
||||
from .resnest import ResNeSt
|
||||
from .resnet import ResNet, ResNetV1c, ResNetV1d
|
||||
from .resnext import ResNeXt
|
||||
from .stdc import STDCContextPathNet, STDCNet
|
||||
from .swin import SwinTransformer
|
||||
from .timm_backbone import TIMMBackbone
|
||||
from .twins import PCPVT, SVT
|
||||
from .unet import UNet
|
||||
from .vit import VisionTransformer
|
||||
from .vpd import VPD
|
||||
|
||||
__all__ = [
|
||||
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
|
||||
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
|
||||
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
|
||||
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
|
||||
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN',
|
||||
'DDRNet', 'VPD'
|
||||
]
|
||||
554
finetune/mmseg/models/backbones/beit.py
Normal file
554
finetune/mmseg/models/backbones/beit.py
Normal file
@@ -0,0 +1,554 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmengine.runner.checkpoint import _load_checkpoint
|
||||
from scipy import interpolate
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import PatchEmbed
|
||||
from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer
|
||||
|
||||
|
||||
class BEiTAttention(BaseModule):
|
||||
"""Window based multi-head self-attention (W-MSA) module with relative
|
||||
position bias.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
bias (bool): The option to add leanable bias for q, k, v. If bias is
|
||||
True, it will add leanable bias. If bias is 'qv_bias', it will only
|
||||
add leanable bias for q, v. If bias is False, it will not add bias
|
||||
for q, k, v. Default to 'qv_bias'.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
attn_drop_rate (float): Dropout ratio of attention weight.
|
||||
Default: 0.0
|
||||
proj_drop_rate (float): Dropout ratio of output. Default: 0.
|
||||
init_cfg (dict | None, optional): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
window_size,
|
||||
bias='qv_bias',
|
||||
qk_scale=None,
|
||||
attn_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
init_cfg=None,
|
||||
**kwargs):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
head_embed_dims = embed_dims // num_heads
|
||||
self.bias = bias
|
||||
self.scale = qk_scale or head_embed_dims**-0.5
|
||||
|
||||
qkv_bias = bias
|
||||
if bias == 'qv_bias':
|
||||
self._init_qv_bias()
|
||||
qkv_bias = False
|
||||
|
||||
self.window_size = window_size
|
||||
self._init_rel_pos_embedding()
|
||||
|
||||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop_rate)
|
||||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.proj_drop = nn.Dropout(proj_drop_rate)
|
||||
|
||||
def _init_qv_bias(self):
|
||||
self.q_bias = nn.Parameter(torch.zeros(self.embed_dims))
|
||||
self.v_bias = nn.Parameter(torch.zeros(self.embed_dims))
|
||||
|
||||
def _init_rel_pos_embedding(self):
|
||||
Wh, Ww = self.window_size
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3
|
||||
# relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH)
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, self.num_heads))
|
||||
|
||||
# get pair-wise relative position index for
|
||||
# each token inside the window
|
||||
coords_h = torch.arange(Wh)
|
||||
coords_w = torch.arange(Ww)
|
||||
# coords shape is (2, Wh, Ww)
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
||||
# coords_flatten shape is (2, Wh*Ww)
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :])
|
||||
# relative_coords shape is (Wh*Ww, Wh*Ww, 2)
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
# shift to start from 0
|
||||
relative_coords[:, :, 0] += Wh - 1
|
||||
relative_coords[:, :, 1] += Ww - 1
|
||||
relative_coords[:, :, 0] *= 2 * Ww - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype)
|
||||
# relative_position_index shape is (Wh*Ww, Wh*Ww)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1)
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer('relative_position_index',
|
||||
relative_position_index)
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (tensor): input features with shape of (num_windows*B, N, C).
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
|
||||
if self.bias == 'qv_bias':
|
||||
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
|
||||
qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias))
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
else:
|
||||
qkv = self.qkv(x)
|
||||
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
if self.relative_position_bias_table is not None:
|
||||
Wh = self.window_size[0]
|
||||
Ww = self.window_size[1]
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
Wh * Ww + 1, Wh * Ww + 1, -1)
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class BEiTTransformerEncoderLayer(VisionTransformerEncoderLayer):
|
||||
"""Implements one encoder layer in Vision Transformer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default: 0.0.
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
bias (bool): The option to add leanable bias for q, k, v. If bias is
|
||||
True, it will add leanable bias. If bias is 'qv_bias', it will only
|
||||
add leanable bias for q, v. If bias is False, it will not add bias
|
||||
for q, k, v. Default to 'qv_bias'.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
window_size (tuple[int], optional): The height and width of the window.
|
||||
Default: None.
|
||||
init_values (float, optional): Initialize the values of BEiTAttention
|
||||
and FFN with learnable scaling. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
bias='qv_bias',
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
window_size=None,
|
||||
attn_cfg=dict(),
|
||||
ffn_cfg=dict(add_identity=False),
|
||||
init_values=None):
|
||||
attn_cfg.update(dict(window_size=window_size, qk_scale=None))
|
||||
|
||||
super().__init__(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=feedforward_channels,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=0.,
|
||||
drop_rate=0.,
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
attn_cfg=attn_cfg,
|
||||
ffn_cfg=ffn_cfg)
|
||||
|
||||
# NOTE: drop path for stochastic depth, we shall see if
|
||||
# this is better than dropout here
|
||||
dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
self.drop_path = build_dropout(
|
||||
dropout_layer) if dropout_layer else nn.Identity()
|
||||
self.gamma_1 = nn.Parameter(
|
||||
init_values * torch.ones(embed_dims), requires_grad=True)
|
||||
self.gamma_2 = nn.Parameter(
|
||||
init_values * torch.ones(embed_dims), requires_grad=True)
|
||||
|
||||
def build_attn(self, attn_cfg):
|
||||
self.attn = BEiTAttention(**attn_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BEiT(BaseModule):
|
||||
"""BERT Pre-Training of Image Transformers.
|
||||
|
||||
Args:
|
||||
img_size (int | tuple): Input image size. Default: 224.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): Embedding dimension. Default: 768.
|
||||
num_layers (int): Depth of transformer. Default: 12.
|
||||
num_heads (int): Number of attention heads. Default: 12.
|
||||
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
out_indices (list | tuple | int): Output from which stages.
|
||||
Default: -1.
|
||||
qv_bias (bool): Enable bias for qv if True. Default: True.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
|
||||
Default: False.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Default: False.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
pretrained (str, optional): Model pretrained path. Default: None.
|
||||
init_values (float): Initialize the values of BEiTAttention and FFN
|
||||
with learnable scaling.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dims=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
out_indices=-1,
|
||||
qv_bias=True,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
patch_norm=False,
|
||||
final_norm=False,
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
pretrained=None,
|
||||
init_values=0.1,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
if isinstance(img_size, int):
|
||||
img_size = to_2tuple(img_size)
|
||||
elif isinstance(img_size, tuple):
|
||||
if len(img_size) == 1:
|
||||
img_size = to_2tuple(img_size[0])
|
||||
assert len(img_size) == 2, \
|
||||
f'The size of image should have length 1 or 2, ' \
|
||||
f'but got {len(img_size)}'
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.norm_eval = norm_eval
|
||||
self.pretrained = pretrained
|
||||
self.num_layers = num_layers
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.attn_drop_rate = attn_drop_rate
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.num_fcs = num_fcs
|
||||
self.qv_bias = qv_bias
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.patch_norm = patch_norm
|
||||
self.init_values = init_values
|
||||
self.window_size = (img_size[0] // patch_size,
|
||||
img_size[1] // patch_size)
|
||||
self.patch_shape = self.window_size
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
|
||||
self._build_patch_embedding()
|
||||
self._build_layers()
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
if out_indices == -1:
|
||||
out_indices = num_layers - 1
|
||||
self.out_indices = [out_indices]
|
||||
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
|
||||
self.out_indices = out_indices
|
||||
else:
|
||||
raise TypeError('out_indices must be type of int, list or tuple')
|
||||
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
def _build_patch_embedding(self):
|
||||
"""Build patch embedding layer."""
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_channels=self.in_channels,
|
||||
embed_dims=self.embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
padding=0,
|
||||
norm_cfg=self.norm_cfg if self.patch_norm else None,
|
||||
init_cfg=None)
|
||||
|
||||
def _build_layers(self):
|
||||
"""Build transformer encoding layers."""
|
||||
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, self.drop_path_rate, self.num_layers)
|
||||
]
|
||||
self.layers = ModuleList()
|
||||
for i in range(self.num_layers):
|
||||
self.layers.append(
|
||||
BEiTTransformerEncoderLayer(
|
||||
embed_dims=self.embed_dims,
|
||||
num_heads=self.num_heads,
|
||||
feedforward_channels=self.mlp_ratio * self.embed_dims,
|
||||
attn_drop_rate=self.attn_drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=self.num_fcs,
|
||||
bias='qv_bias' if self.qv_bias else False,
|
||||
act_cfg=self.act_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
window_size=self.window_size,
|
||||
init_values=self.init_values))
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def _geometric_sequence_interpolation(self, src_size, dst_size, sequence,
|
||||
num):
|
||||
"""Get new sequence via geometric sequence interpolation.
|
||||
|
||||
Args:
|
||||
src_size (int): Pos_embedding size in pre-trained model.
|
||||
dst_size (int): Pos_embedding size in the current model.
|
||||
sequence (tensor): The relative position bias of the pretrain
|
||||
model after removing the extra tokens.
|
||||
num (int): Number of attention heads.
|
||||
Returns:
|
||||
new_sequence (tensor): Geometric sequence interpolate the
|
||||
pre-trained relative position bias to the size of
|
||||
the current model.
|
||||
"""
|
||||
|
||||
def geometric_progression(a, r, n):
|
||||
return a * (1.0 - r**n) / (1.0 - r)
|
||||
|
||||
# Here is a binary function.
|
||||
left, right = 1.01, 1.5
|
||||
while right - left > 1e-6:
|
||||
q = (left + right) / 2.0
|
||||
gp = geometric_progression(1, q, src_size // 2)
|
||||
if gp > dst_size // 2:
|
||||
right = q
|
||||
else:
|
||||
left = q
|
||||
# The position of each interpolated point is determined
|
||||
# by the ratio obtained by dichotomy.
|
||||
dis = []
|
||||
cur = 1
|
||||
for i in range(src_size // 2):
|
||||
dis.append(cur)
|
||||
cur += q**(i + 1)
|
||||
r_ids = [-_ for _ in reversed(dis)]
|
||||
x = r_ids + [0] + dis
|
||||
y = r_ids + [0] + dis
|
||||
t = dst_size // 2.0
|
||||
dx = np.arange(-t, t + 0.1, 1.0)
|
||||
dy = np.arange(-t, t + 0.1, 1.0)
|
||||
# Interpolation functions are being executed and called.
|
||||
new_sequence = []
|
||||
for i in range(num):
|
||||
z = sequence[:, i].view(src_size, src_size).float().numpy()
|
||||
f = interpolate.interp2d(x, y, z, kind='cubic')
|
||||
new_sequence.append(
|
||||
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(sequence))
|
||||
new_sequence = torch.cat(new_sequence, dim=-1)
|
||||
return new_sequence
|
||||
|
||||
def resize_rel_pos_embed(self, checkpoint):
|
||||
"""Resize relative pos_embed weights.
|
||||
|
||||
This function is modified from
|
||||
https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/checkpoint.py. # noqa: E501
|
||||
Copyright (c) Microsoft Corporation
|
||||
Licensed under the MIT License
|
||||
Args:
|
||||
checkpoint (dict): Key and value of the pretrain model.
|
||||
Returns:
|
||||
state_dict (dict): Interpolate the relative pos_embed weights
|
||||
in the pre-train model to the current model size.
|
||||
"""
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
all_keys = list(state_dict.keys())
|
||||
for key in all_keys:
|
||||
if 'relative_position_index' in key:
|
||||
state_dict.pop(key)
|
||||
# In order to keep the center of pos_bias as consistent as
|
||||
# possible after interpolation, and vice versa in the edge
|
||||
# area, the geometric sequence interpolation method is adopted.
|
||||
if 'relative_position_bias_table' in key:
|
||||
rel_pos_bias = state_dict[key]
|
||||
src_num_pos, num_attn_heads = rel_pos_bias.size()
|
||||
dst_num_pos, _ = self.state_dict()[key].size()
|
||||
dst_patch_shape = self.patch_shape
|
||||
if dst_patch_shape[0] != dst_patch_shape[1]:
|
||||
raise NotImplementedError()
|
||||
# Count the number of extra tokens.
|
||||
num_extra_tokens = dst_num_pos - (
|
||||
dst_patch_shape[0] * 2 - 1) * (
|
||||
dst_patch_shape[1] * 2 - 1)
|
||||
src_size = int((src_num_pos - num_extra_tokens)**0.5)
|
||||
dst_size = int((dst_num_pos - num_extra_tokens)**0.5)
|
||||
if src_size != dst_size:
|
||||
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
||||
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
||||
new_rel_pos_bias = self._geometric_sequence_interpolation(
|
||||
src_size, dst_size, rel_pos_bias, num_attn_heads)
|
||||
new_rel_pos_bias = torch.cat(
|
||||
(new_rel_pos_bias, extra_tokens), dim=0)
|
||||
state_dict[key] = new_rel_pos_bias
|
||||
|
||||
return state_dict
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
def _init_weights(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
self.apply(_init_weights)
|
||||
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
checkpoint = _load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
state_dict = self.resize_rel_pos_embed(checkpoint)
|
||||
self.load_state_dict(state_dict, False)
|
||||
elif self.init_cfg is not None:
|
||||
super().init_weights()
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
# Copyright 2019 Ross Wightman
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def forward(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
|
||||
x, hw_shape = self.patch_embed(inputs)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i == len(self.layers) - 1:
|
||||
if self.final_norm:
|
||||
x = self.norm1(x)
|
||||
if i in self.out_indices:
|
||||
# Remove class token and reshape token for decoder head
|
||||
out = x[:, 1:]
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.LayerNorm):
|
||||
m.eval()
|
||||
332
finetune/mmseg/models/backbones/bisenetv1.py
Normal file
332
finetune/mmseg/models/backbones/bisenetv1.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class SpatialPath(BaseModule):
|
||||
"""Spatial Path to preserve the spatial size of the original input image
|
||||
and encode affluent spatial information.
|
||||
|
||||
Args:
|
||||
in_channels(int): The number of channels of input
|
||||
image. Default: 3.
|
||||
num_channels (Tuple[int]): The number of channels of
|
||||
each layers in Spatial Path.
|
||||
Default: (64, 64, 64, 128).
|
||||
Returns:
|
||||
x (torch.Tensor): Feature map for Feature Fusion Module.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
num_channels=(64, 64, 64, 128),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(num_channels) == 4, 'Length of input channels \
|
||||
of Spatial Path must be 4!'
|
||||
|
||||
self.layers = []
|
||||
for i in range(len(num_channels)):
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.layers.append(layer_name)
|
||||
if i == 0:
|
||||
self.add_module(
|
||||
layer_name,
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=num_channels[i],
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
elif i == len(num_channels) - 1:
|
||||
self.add_module(
|
||||
layer_name,
|
||||
ConvModule(
|
||||
in_channels=num_channels[i - 1],
|
||||
out_channels=num_channels[i],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
else:
|
||||
self.add_module(
|
||||
layer_name,
|
||||
ConvModule(
|
||||
in_channels=num_channels[i - 1],
|
||||
out_channels=num_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer_stage = getattr(self, layer_name)
|
||||
x = layer_stage(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionRefinementModule(BaseModule):
|
||||
"""Attention Refinement Module (ARM) to refine the features of each stage.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
Returns:
|
||||
x_out (torch.Tensor): Feature map of Attention Refinement Module.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channel,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv_layer = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.atten_conv_layer = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
ConvModule(
|
||||
in_channels=out_channel,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_layer(x)
|
||||
x_atten = self.atten_conv_layer(x)
|
||||
x_out = x * x_atten
|
||||
return x_out
|
||||
|
||||
|
||||
class ContextPath(BaseModule):
|
||||
"""Context Path to provide sufficient receptive field.
|
||||
|
||||
Args:
|
||||
backbone_cfg:(dict): Config of backbone of
|
||||
Context Path.
|
||||
context_channels (Tuple[int]): The number of channel numbers
|
||||
of various modules in Context Path.
|
||||
Default: (128, 256, 512).
|
||||
align_corners (bool, optional): The align_corners argument of
|
||||
resize operation. Default: False.
|
||||
Returns:
|
||||
x_16_up, x_32_up (torch.Tensor, torch.Tensor): Two feature maps
|
||||
undergoing upsampling from 1/16 and 1/32 downsampling
|
||||
feature maps. These two feature maps are used for Feature
|
||||
Fusion Module and Auxiliary Head.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_cfg,
|
||||
context_channels=(128, 256, 512),
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(context_channels) == 3, 'Length of input channels \
|
||||
of Context Path must be 3!'
|
||||
|
||||
self.backbone = MODELS.build(backbone_cfg)
|
||||
|
||||
self.align_corners = align_corners
|
||||
self.arm16 = AttentionRefinementModule(context_channels[1],
|
||||
context_channels[0])
|
||||
self.arm32 = AttentionRefinementModule(context_channels[2],
|
||||
context_channels[0])
|
||||
self.conv_head32 = ConvModule(
|
||||
in_channels=context_channels[0],
|
||||
out_channels=context_channels[0],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.conv_head16 = ConvModule(
|
||||
in_channels=context_channels[0],
|
||||
out_channels=context_channels[0],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.gap_conv = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
ConvModule(
|
||||
in_channels=context_channels[2],
|
||||
out_channels=context_channels[0],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def forward(self, x):
|
||||
x_4, x_8, x_16, x_32 = self.backbone(x)
|
||||
x_gap = self.gap_conv(x_32)
|
||||
|
||||
x_32_arm = self.arm32(x_32)
|
||||
x_32_sum = x_32_arm + x_gap
|
||||
x_32_up = resize(input=x_32_sum, size=x_16.shape[2:], mode='nearest')
|
||||
x_32_up = self.conv_head32(x_32_up)
|
||||
|
||||
x_16_arm = self.arm16(x_16)
|
||||
x_16_sum = x_16_arm + x_32_up
|
||||
x_16_up = resize(input=x_16_sum, size=x_8.shape[2:], mode='nearest')
|
||||
x_16_up = self.conv_head16(x_16_up)
|
||||
|
||||
return x_16_up, x_32_up
|
||||
|
||||
|
||||
class FeatureFusionModule(BaseModule):
|
||||
"""Feature Fusion Module to fuse low level output feature of Spatial Path
|
||||
and high level output feature of Context Path.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
Returns:
|
||||
x_out (torch.Tensor): Feature map of Feature Fusion Module.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.gap = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.conv_atten = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg), nn.Sigmoid())
|
||||
|
||||
def forward(self, x_sp, x_cp):
|
||||
x_concat = torch.cat([x_sp, x_cp], dim=1)
|
||||
x_fuse = self.conv1(x_concat)
|
||||
x_atten = self.gap(x_fuse)
|
||||
# Note: No BN and more 1x1 conv in paper.
|
||||
x_atten = self.conv_atten(x_atten)
|
||||
x_atten = x_fuse * x_atten
|
||||
x_out = x_atten + x_fuse
|
||||
return x_out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BiSeNetV1(BaseModule):
|
||||
"""BiSeNetV1 backbone.
|
||||
|
||||
This backbone is the implementation of `BiSeNet: Bilateral
|
||||
Segmentation Network for Real-time Semantic
|
||||
Segmentation <https://arxiv.org/abs/1808.00897>`_.
|
||||
|
||||
Args:
|
||||
backbone_cfg:(dict): Config of backbone of
|
||||
Context Path.
|
||||
in_channels (int): The number of channels of input
|
||||
image. Default: 3.
|
||||
spatial_channels (Tuple[int]): Size of channel numbers of
|
||||
various layers in Spatial Path.
|
||||
Default: (64, 64, 64, 128).
|
||||
context_channels (Tuple[int]): Size of channel numbers of
|
||||
various modules in Context Path.
|
||||
Default: (128, 256, 512).
|
||||
out_indices (Tuple[int] | int, optional): Output from which stages.
|
||||
Default: (0, 1, 2).
|
||||
align_corners (bool, optional): The align_corners argument of
|
||||
resize operation in Bilateral Guided Aggregation Layer.
|
||||
Default: False.
|
||||
out_channels(int): The number of channels of output.
|
||||
It must be the same with `in_channels` of decode_head.
|
||||
Default: 256.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_cfg,
|
||||
in_channels=3,
|
||||
spatial_channels=(64, 64, 64, 128),
|
||||
context_channels=(128, 256, 512),
|
||||
out_indices=(0, 1, 2),
|
||||
align_corners=False,
|
||||
out_channels=256,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(spatial_channels) == 4, 'Length of input channels \
|
||||
of Spatial Path must be 4!'
|
||||
|
||||
assert len(context_channels) == 3, 'Length of input channels \
|
||||
of Context Path must be 3!'
|
||||
|
||||
self.out_indices = out_indices
|
||||
self.align_corners = align_corners
|
||||
self.context_path = ContextPath(backbone_cfg, context_channels,
|
||||
self.align_corners)
|
||||
self.spatial_path = SpatialPath(in_channels, spatial_channels)
|
||||
self.ffm = FeatureFusionModule(context_channels[1], out_channels)
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
def forward(self, x):
|
||||
# stole refactoring code from Coin Cheung, thanks
|
||||
x_context8, x_context16 = self.context_path(x)
|
||||
x_spatial = self.spatial_path(x)
|
||||
x_fuse = self.ffm(x_spatial, x_context8)
|
||||
|
||||
outs = [x_fuse, x_context8, x_context16]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
||||
622
finetune/mmseg/models/backbones/bisenetv2.py
Normal file
622
finetune/mmseg/models/backbones/bisenetv2.py
Normal file
@@ -0,0 +1,622 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
|
||||
build_activation_layer, build_norm_layer)
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class DetailBranch(BaseModule):
|
||||
"""Detail Branch with wide channels and shallow layers to capture low-level
|
||||
details and generate high-resolution feature representation.
|
||||
|
||||
Args:
|
||||
detail_channels (Tuple[int]): Size of channel numbers of each stage
|
||||
in Detail Branch, in paper it has 3 stages.
|
||||
Default: (64, 64, 128).
|
||||
in_channels (int): Number of channels of input image. Default: 3.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Feature map of Detail Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
detail_channels=(64, 64, 128),
|
||||
in_channels=3,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
detail_branch = []
|
||||
for i in range(len(detail_channels)):
|
||||
if i == 0:
|
||||
detail_branch.append(
|
||||
nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)))
|
||||
else:
|
||||
detail_branch.append(
|
||||
nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i - 1],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)))
|
||||
self.detail_branch = nn.ModuleList(detail_branch)
|
||||
|
||||
def forward(self, x):
|
||||
for stage in self.detail_branch:
|
||||
x = stage(x)
|
||||
return x
|
||||
|
||||
|
||||
class StemBlock(BaseModule):
|
||||
"""Stem Block at the beginning of Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
Default: 3.
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 16.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): First feature map in Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.conv_first = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.convs = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels // 2,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=out_channels // 2,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.pool = nn.MaxPool2d(
|
||||
kernel_size=3, stride=2, padding=1, ceil_mode=False)
|
||||
self.fuse_last = ConvModule(
|
||||
in_channels=out_channels * 2,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_first(x)
|
||||
x_left = self.convs(x)
|
||||
x_right = self.pool(x)
|
||||
x = self.fuse_last(torch.cat([x_left, x_right], dim=1))
|
||||
return x
|
||||
|
||||
|
||||
class GELayer(BaseModule):
|
||||
"""Gather-and-Expansion Layer.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
exp_ratio (int): Expansion ratio for middle channels.
|
||||
Default: 6.
|
||||
stride (int): Stride of GELayer. Default: 1
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Intermediate feature map in
|
||||
Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
exp_ratio=6,
|
||||
stride=1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
mid_channel = in_channels * exp_ratio
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
if stride == 1:
|
||||
self.dwconv = nn.Sequential(
|
||||
# ReLU in ConvModule not shown in paper
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.shortcut = None
|
||||
else:
|
||||
self.dwconv = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
# ReLU in ConvModule not shown in paper
|
||||
ConvModule(
|
||||
in_channels=mid_channel,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=mid_channel,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
)
|
||||
self.shortcut = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=norm_cfg,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=mid_channel,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None,
|
||||
))
|
||||
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.conv1(x)
|
||||
x = self.dwconv(x)
|
||||
x = self.conv2(x)
|
||||
if self.shortcut is not None:
|
||||
shortcut = self.shortcut(identity)
|
||||
x = x + shortcut
|
||||
else:
|
||||
x = x + identity
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class CEBlock(BaseModule):
|
||||
"""Context Embedding Block for large receptive filed in Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
Default: 3.
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 16.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Last feature map in Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.gap = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
build_norm_layer(norm_cfg, self.in_channels)[1])
|
||||
self.conv_gap = ConvModule(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
# Note: in paper here is naive conv2d, no bn-relu
|
||||
self.conv_last = ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.gap(x)
|
||||
x = self.conv_gap(x)
|
||||
x = identity + x
|
||||
x = self.conv_last(x)
|
||||
return x
|
||||
|
||||
|
||||
class SemanticBranch(BaseModule):
|
||||
"""Semantic Branch which is lightweight with narrow channels and deep
|
||||
layers to obtain high-level semantic context.
|
||||
|
||||
Args:
|
||||
semantic_channels(Tuple[int]): Size of channel numbers of
|
||||
various stages in Semantic Branch.
|
||||
Default: (16, 32, 64, 128).
|
||||
in_channels (int): Number of channels of input image. Default: 3.
|
||||
exp_ratio (int): Expansion ratio for middle channels.
|
||||
Default: 6.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
semantic_outs (List[torch.Tensor]): List of several feature maps
|
||||
for auxiliary heads (Booster) and Bilateral
|
||||
Guided Aggregation Layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
semantic_channels=(16, 32, 64, 128),
|
||||
in_channels=3,
|
||||
exp_ratio=6,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.semantic_channels = semantic_channels
|
||||
self.semantic_stages = []
|
||||
for i in range(len(semantic_channels)):
|
||||
stage_name = f'stage{i + 1}'
|
||||
self.semantic_stages.append(stage_name)
|
||||
if i == 0:
|
||||
self.add_module(
|
||||
stage_name,
|
||||
StemBlock(self.in_channels, semantic_channels[i]))
|
||||
elif i == (len(semantic_channels) - 1):
|
||||
self.add_module(
|
||||
stage_name,
|
||||
nn.Sequential(
|
||||
GELayer(semantic_channels[i - 1], semantic_channels[i],
|
||||
exp_ratio, 2),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1)))
|
||||
else:
|
||||
self.add_module(
|
||||
stage_name,
|
||||
nn.Sequential(
|
||||
GELayer(semantic_channels[i - 1], semantic_channels[i],
|
||||
exp_ratio, 2),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1)))
|
||||
|
||||
self.add_module(f'stage{len(semantic_channels)}_CEBlock',
|
||||
CEBlock(semantic_channels[-1], semantic_channels[-1]))
|
||||
self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock')
|
||||
|
||||
def forward(self, x):
|
||||
semantic_outs = []
|
||||
for stage_name in self.semantic_stages:
|
||||
semantic_stage = getattr(self, stage_name)
|
||||
x = semantic_stage(x)
|
||||
semantic_outs.append(x)
|
||||
return semantic_outs
|
||||
|
||||
|
||||
class BGALayer(BaseModule):
|
||||
"""Bilateral Guided Aggregation Layer to fuse the complementary information
|
||||
from both Detail Branch and Semantic Branch.
|
||||
|
||||
Args:
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 128.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
output (torch.Tensor): Output feature map for Segment heads.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_channels=128,
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.out_channels = out_channels
|
||||
self.align_corners = align_corners
|
||||
self.detail_dwconv = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=None,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
self.detail_down = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False))
|
||||
self.semantic_conv = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None))
|
||||
self.semantic_dwconv = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=None,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
self.conv = ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
inplace=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
)
|
||||
|
||||
def forward(self, x_d, x_s):
|
||||
detail_dwconv = self.detail_dwconv(x_d)
|
||||
detail_down = self.detail_down(x_d)
|
||||
semantic_conv = self.semantic_conv(x_s)
|
||||
semantic_dwconv = self.semantic_dwconv(x_s)
|
||||
semantic_conv = resize(
|
||||
input=semantic_conv,
|
||||
size=detail_dwconv.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv)
|
||||
fuse_2 = detail_down * torch.sigmoid(semantic_dwconv)
|
||||
fuse_2 = resize(
|
||||
input=fuse_2,
|
||||
size=fuse_1.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
output = self.conv(fuse_1 + fuse_2)
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BiSeNetV2(BaseModule):
|
||||
"""BiSeNetV2: Bilateral Network with Guided Aggregation for
|
||||
Real-time Semantic Segmentation.
|
||||
|
||||
This backbone is the implementation of
|
||||
`BiSeNetV2 <https://arxiv.org/abs/2004.02147>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channel of input image. Default: 3.
|
||||
detail_channels (Tuple[int], optional): Channels of each stage
|
||||
in Detail Branch. Default: (64, 64, 128).
|
||||
semantic_channels (Tuple[int], optional): Channels of each stage
|
||||
in Semantic Branch. Default: (16, 32, 64, 128).
|
||||
See Table 1 and Figure 3 of paper for more details.
|
||||
semantic_expansion_ratio (int, optional): The expansion factor
|
||||
expanding channel number of middle channels in Semantic Branch.
|
||||
Default: 6.
|
||||
bga_channels (int, optional): Number of middle channels in
|
||||
Bilateral Guided Aggregation Layer. Default: 128.
|
||||
out_indices (Tuple[int] | int, optional): Output from which stages.
|
||||
Default: (0, 1, 2, 3, 4).
|
||||
align_corners (bool, optional): The align_corners argument of
|
||||
resize operation in Bilateral Guided Aggregation Layer.
|
||||
Default: False.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
detail_channels=(64, 64, 128),
|
||||
semantic_channels=(16, 32, 64, 128),
|
||||
semantic_expansion_ratio=6,
|
||||
bga_channels=128,
|
||||
out_indices=(0, 1, 2, 3, 4),
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
if init_cfg is None:
|
||||
init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_indices = out_indices
|
||||
self.detail_channels = detail_channels
|
||||
self.semantic_channels = semantic_channels
|
||||
self.semantic_expansion_ratio = semantic_expansion_ratio
|
||||
self.bga_channels = bga_channels
|
||||
self.align_corners = align_corners
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.detail = DetailBranch(self.detail_channels, self.in_channels)
|
||||
self.semantic = SemanticBranch(self.semantic_channels,
|
||||
self.in_channels,
|
||||
self.semantic_expansion_ratio)
|
||||
self.bga = BGALayer(self.bga_channels, self.align_corners)
|
||||
|
||||
def forward(self, x):
|
||||
# stole refactoring code from Coin Cheung, thanks
|
||||
x_detail = self.detail(x)
|
||||
x_semantic_lst = self.semantic(x)
|
||||
x_head = self.bga(x_detail, x_semantic_lst[-1])
|
||||
outs = [x_head] + x_semantic_lst[:-1]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
||||
372
finetune/mmseg/models/backbones/cgnet.py
Normal file
372
finetune/mmseg/models/backbones/cgnet.py
Normal file
@@ -0,0 +1,372 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class GlobalContextExtractor(nn.Module):
|
||||
"""Global Context Extractor for CGNet.
|
||||
|
||||
This class is employed to refine the joint feature of both local feature
|
||||
and surrounding context.
|
||||
|
||||
Args:
|
||||
channel (int): Number of input feature channels.
|
||||
reduction (int): Reductions for global context extractor. Default: 16.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, channel, reduction=16, with_cp=False):
|
||||
super().__init__()
|
||||
self.channel = channel
|
||||
self.reduction = reduction
|
||||
assert reduction >= 1 and channel >= reduction
|
||||
self.with_cp = with_cp
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
num_batch, num_channel = x.size()[:2]
|
||||
y = self.avg_pool(x).view(num_batch, num_channel)
|
||||
y = self.fc(y).view(num_batch, num_channel, 1, 1)
|
||||
return x * y
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ContextGuidedBlock(nn.Module):
|
||||
"""Context Guided Block for CGNet.
|
||||
|
||||
This class consists of four components: local feature extractor,
|
||||
surrounding feature extractor, joint feature extractor and global
|
||||
context extractor.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input feature channels.
|
||||
out_channels (int): Number of output feature channels.
|
||||
dilation (int): Dilation rate for surrounding context extractor.
|
||||
Default: 2.
|
||||
reduction (int): Reduction for global context extractor. Default: 16.
|
||||
skip_connect (bool): Add input to output or not. Default: True.
|
||||
downsample (bool): Downsample the input to 1/2 or not. Default: False.
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='PReLU').
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dilation=2,
|
||||
reduction=16,
|
||||
skip_connect=True,
|
||||
downsample=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='PReLU'),
|
||||
with_cp=False):
|
||||
super().__init__()
|
||||
self.with_cp = with_cp
|
||||
self.downsample = downsample
|
||||
|
||||
channels = out_channels if downsample else out_channels // 2
|
||||
if 'type' in act_cfg and act_cfg['type'] == 'PReLU':
|
||||
act_cfg['num_parameters'] = channels
|
||||
kernel_size = 3 if downsample else 1
|
||||
stride = 2 if downsample else 1
|
||||
padding = (kernel_size - 1) // 2
|
||||
|
||||
self.conv1x1 = ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.f_loc = build_conv_layer(
|
||||
conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
groups=channels,
|
||||
bias=False)
|
||||
self.f_sur = build_conv_layer(
|
||||
conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=dilation,
|
||||
groups=channels,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
|
||||
self.bn = build_norm_layer(norm_cfg, 2 * channels)[1]
|
||||
self.activate = nn.PReLU(2 * channels)
|
||||
|
||||
if downsample:
|
||||
self.bottleneck = build_conv_layer(
|
||||
conv_cfg,
|
||||
2 * channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
|
||||
self.skip_connect = skip_connect and not downsample
|
||||
self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
out = self.conv1x1(x)
|
||||
loc = self.f_loc(out)
|
||||
sur = self.f_sur(out)
|
||||
|
||||
joi_feat = torch.cat([loc, sur], 1) # the joint feature
|
||||
joi_feat = self.bn(joi_feat)
|
||||
joi_feat = self.activate(joi_feat)
|
||||
if self.downsample:
|
||||
joi_feat = self.bottleneck(joi_feat) # channel = out_channels
|
||||
# f_glo is employed to refine the joint feature
|
||||
out = self.f_glo(joi_feat)
|
||||
|
||||
if self.skip_connect:
|
||||
return x + out
|
||||
else:
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class InputInjection(nn.Module):
|
||||
"""Downsampling module for CGNet."""
|
||||
|
||||
def __init__(self, num_downsampling):
|
||||
super().__init__()
|
||||
self.pool = nn.ModuleList()
|
||||
for i in range(num_downsampling):
|
||||
self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
for pool in self.pool:
|
||||
x = pool(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CGNet(BaseModule):
|
||||
"""CGNet backbone.
|
||||
|
||||
This backbone is the implementation of `A Light-weight Context Guided
|
||||
Network for Semantic Segmentation <https://arxiv.org/abs/1811.08201>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Normally 3.
|
||||
num_channels (tuple[int]): Numbers of feature channels at each stages.
|
||||
Default: (32, 64, 128).
|
||||
num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2.
|
||||
Default: (3, 21).
|
||||
dilations (tuple[int]): Dilation rate for surrounding context
|
||||
extractors at stage 1 and stage 2. Default: (2, 4).
|
||||
reductions (tuple[int]): Reductions for global context extractors at
|
||||
stage 1 and stage 2. Default: (8, 16).
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='PReLU').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
num_channels=(32, 64, 128),
|
||||
num_blocks=(3, 21),
|
||||
dilations=(2, 4),
|
||||
reductions=(8, 16),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='PReLU'),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg)
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer=['Conv2d', 'Linear']),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm']),
|
||||
dict(type='Constant', val=0, layer='PReLU')
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.num_channels = num_channels
|
||||
assert isinstance(self.num_channels, tuple) and len(
|
||||
self.num_channels) == 3
|
||||
self.num_blocks = num_blocks
|
||||
assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2
|
||||
self.dilations = dilations
|
||||
assert isinstance(self.dilations, tuple) and len(self.dilations) == 2
|
||||
self.reductions = reductions
|
||||
assert isinstance(self.reductions, tuple) and len(self.reductions) == 2
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU':
|
||||
self.act_cfg['num_parameters'] = num_channels[0]
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
cur_channels = in_channels
|
||||
self.stem = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.stem.append(
|
||||
ConvModule(
|
||||
cur_channels,
|
||||
num_channels[0],
|
||||
3,
|
||||
2 if i == 0 else 1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
cur_channels = num_channels[0]
|
||||
|
||||
self.inject_2x = InputInjection(1) # down-sample for Input, factor=2
|
||||
self.inject_4x = InputInjection(2) # down-sample for Input, factor=4
|
||||
|
||||
cur_channels += in_channels
|
||||
self.norm_prelu_0 = nn.Sequential(
|
||||
build_norm_layer(norm_cfg, cur_channels)[1],
|
||||
nn.PReLU(cur_channels))
|
||||
|
||||
# stage 1
|
||||
self.level1 = nn.ModuleList()
|
||||
for i in range(num_blocks[0]):
|
||||
self.level1.append(
|
||||
ContextGuidedBlock(
|
||||
cur_channels if i == 0 else num_channels[1],
|
||||
num_channels[1],
|
||||
dilations[0],
|
||||
reductions[0],
|
||||
downsample=(i == 0),
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
with_cp=with_cp)) # CG block
|
||||
|
||||
cur_channels = 2 * num_channels[1] + in_channels
|
||||
self.norm_prelu_1 = nn.Sequential(
|
||||
build_norm_layer(norm_cfg, cur_channels)[1],
|
||||
nn.PReLU(cur_channels))
|
||||
|
||||
# stage 2
|
||||
self.level2 = nn.ModuleList()
|
||||
for i in range(num_blocks[1]):
|
||||
self.level2.append(
|
||||
ContextGuidedBlock(
|
||||
cur_channels if i == 0 else num_channels[2],
|
||||
num_channels[2],
|
||||
dilations[1],
|
||||
reductions[1],
|
||||
downsample=(i == 0),
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
with_cp=with_cp)) # CG block
|
||||
|
||||
cur_channels = 2 * num_channels[2]
|
||||
self.norm_prelu_2 = nn.Sequential(
|
||||
build_norm_layer(norm_cfg, cur_channels)[1],
|
||||
nn.PReLU(cur_channels))
|
||||
|
||||
def forward(self, x):
|
||||
output = []
|
||||
|
||||
# stage 0
|
||||
inp_2x = self.inject_2x(x)
|
||||
inp_4x = self.inject_4x(x)
|
||||
for layer in self.stem:
|
||||
x = layer(x)
|
||||
x = self.norm_prelu_0(torch.cat([x, inp_2x], 1))
|
||||
output.append(x)
|
||||
|
||||
# stage 1
|
||||
for i, layer in enumerate(self.level1):
|
||||
x = layer(x)
|
||||
if i == 0:
|
||||
down1 = x
|
||||
x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1))
|
||||
output.append(x)
|
||||
|
||||
# stage 2
|
||||
for i, layer in enumerate(self.level2):
|
||||
x = layer(x)
|
||||
if i == 0:
|
||||
down2 = x
|
||||
x = self.norm_prelu_2(torch.cat([down2, x], 1))
|
||||
output.append(x)
|
||||
|
||||
return output
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode will keeping the normalization
|
||||
layer freezed."""
|
||||
super().train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
222
finetune/mmseg/models/backbones/ddrnet.py
Normal file
222
finetune/mmseg/models/backbones/ddrnet.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.models.utils import DAPPM, BasicBlock, Bottleneck, resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DDRNet(BaseModule):
|
||||
"""DDRNet backbone.
|
||||
|
||||
This backbone is the implementation of `Deep Dual-resolution Networks for
|
||||
Real-time and Accurate Semantic Segmentation of Road Scenes
|
||||
<http://arxiv.org/abs/2101.06085>`_.
|
||||
Modified from https://github.com/ydhongHIT/DDRNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
channels: (int): The base channels of DDRNet. Default: 32.
|
||||
ppm_channels (int): The channels of PPM module. Default: 128.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
norm_cfg (dict): Config dict to build norm layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int = 3,
|
||||
channels: int = 32,
|
||||
ppm_channels: int = 128,
|
||||
align_corners: bool = False,
|
||||
norm_cfg: OptConfigType = dict(type='BN', requires_grad=True),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.ppm_channels = ppm_channels
|
||||
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.align_corners = align_corners
|
||||
|
||||
# stage 0-2
|
||||
self.stem = self._make_stem_layer(in_channels, channels, num_blocks=2)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
# low resolution(context) branch
|
||||
self.context_branch_layers = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.context_branch_layers.append(
|
||||
self._make_layer(
|
||||
block=BasicBlock if i < 2 else Bottleneck,
|
||||
inplanes=channels * 2**(i + 1),
|
||||
planes=channels * 8 if i > 0 else channels * 4,
|
||||
num_blocks=2 if i < 2 else 1,
|
||||
stride=2))
|
||||
|
||||
# bilateral fusion
|
||||
self.compression_1 = ConvModule(
|
||||
channels * 4,
|
||||
channels * 2,
|
||||
kernel_size=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
self.down_1 = ConvModule(
|
||||
channels * 2,
|
||||
channels * 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
self.compression_2 = ConvModule(
|
||||
channels * 8,
|
||||
channels * 2,
|
||||
kernel_size=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
self.down_2 = nn.Sequential(
|
||||
ConvModule(
|
||||
channels * 2,
|
||||
channels * 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
channels * 4,
|
||||
channels * 8,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None))
|
||||
|
||||
# high resolution(spatial) branch
|
||||
self.spatial_branch_layers = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.spatial_branch_layers.append(
|
||||
self._make_layer(
|
||||
block=BasicBlock if i < 2 else Bottleneck,
|
||||
inplanes=channels * 2,
|
||||
planes=channels * 2,
|
||||
num_blocks=2 if i < 2 else 1,
|
||||
))
|
||||
|
||||
self.spp = DAPPM(
|
||||
channels * 16, ppm_channels, channels * 4, num_scales=5)
|
||||
|
||||
def _make_stem_layer(self, in_channels, channels, num_blocks):
|
||||
layers = [
|
||||
ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
]
|
||||
|
||||
layers.extend([
|
||||
self._make_layer(BasicBlock, channels, channels, num_blocks),
|
||||
nn.ReLU(),
|
||||
self._make_layer(
|
||||
BasicBlock, channels, channels * 2, num_blocks, stride=2),
|
||||
nn.ReLU(),
|
||||
])
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_layer(self, block, inplanes, planes, num_blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
|
||||
|
||||
layers = [
|
||||
block(
|
||||
in_channels=inplanes,
|
||||
channels=planes,
|
||||
stride=stride,
|
||||
downsample=downsample)
|
||||
]
|
||||
inplanes = planes * block.expansion
|
||||
for i in range(1, num_blocks):
|
||||
layers.append(
|
||||
block(
|
||||
in_channels=inplanes,
|
||||
channels=planes,
|
||||
stride=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg_out=None if i == num_blocks - 1 else self.act_cfg))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
out_size = (x.shape[-2] // 8, x.shape[-1] // 8)
|
||||
|
||||
# stage 0-2
|
||||
x = self.stem(x)
|
||||
|
||||
# stage3
|
||||
x_c = self.context_branch_layers[0](x)
|
||||
x_s = self.spatial_branch_layers[0](x)
|
||||
comp_c = self.compression_1(self.relu(x_c))
|
||||
x_c += self.down_1(self.relu(x_s))
|
||||
x_s += resize(
|
||||
comp_c,
|
||||
size=out_size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.training:
|
||||
temp_context = x_s.clone()
|
||||
|
||||
# stage4
|
||||
x_c = self.context_branch_layers[1](self.relu(x_c))
|
||||
x_s = self.spatial_branch_layers[1](self.relu(x_s))
|
||||
comp_c = self.compression_2(self.relu(x_c))
|
||||
x_c += self.down_2(self.relu(x_s))
|
||||
x_s += resize(
|
||||
comp_c,
|
||||
size=out_size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
# stage5
|
||||
x_s = self.spatial_branch_layers[2](self.relu(x_s))
|
||||
x_c = self.context_branch_layers[2](self.relu(x_c))
|
||||
x_c = self.spp(x_c)
|
||||
x_c = resize(
|
||||
x_c,
|
||||
size=out_size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
return (temp_context, x_s + x_c) if self.training else x_s + x_c
|
||||
329
finetune/mmseg/models/backbones/erfnet.py
Normal file
329
finetune/mmseg/models/backbones/erfnet.py
Normal file
@@ -0,0 +1,329 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class DownsamplerBlock(BaseModule):
|
||||
"""Downsampler block of ERFNet.
|
||||
|
||||
This module is a little different from basical ConvModule.
|
||||
The features from Conv and MaxPool layers are
|
||||
concatenated before BatchNorm.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', eps=1e-3),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.conv = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
out_channels - in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1)
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.bn = build_norm_layer(self.norm_cfg, out_channels)[1]
|
||||
self.act = build_activation_layer(self.act_cfg)
|
||||
|
||||
def forward(self, input):
|
||||
conv_out = self.conv(input)
|
||||
pool_out = self.pool(input)
|
||||
pool_out = resize(
|
||||
input=pool_out,
|
||||
size=conv_out.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
output = torch.cat([conv_out, pool_out], 1)
|
||||
output = self.bn(output)
|
||||
output = self.act(output)
|
||||
return output
|
||||
|
||||
|
||||
class NonBottleneck1d(BaseModule):
|
||||
"""Non-bottleneck block of ERFNet.
|
||||
|
||||
Args:
|
||||
channels (int): Number of channels in Non-bottleneck block.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.
|
||||
dilation (int): Dilation rate for last two conv layers.
|
||||
Default 1.
|
||||
num_conv_layer (int): Number of 3x1 and 1x3 convolution layers.
|
||||
Default 2.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
drop_rate=0,
|
||||
dilation=1,
|
||||
num_conv_layer=2,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', eps=1e-3),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.act = build_activation_layer(self.act_cfg)
|
||||
|
||||
self.convs_layers = nn.ModuleList()
|
||||
for conv_layer in range(num_conv_layer):
|
||||
first_conv_padding = (1, 0) if conv_layer == 0 else (dilation, 0)
|
||||
first_conv_dilation = 1 if conv_layer == 0 else (dilation, 1)
|
||||
second_conv_padding = (0, 1) if conv_layer == 0 else (0, dilation)
|
||||
second_conv_dilation = 1 if conv_layer == 0 else (1, dilation)
|
||||
|
||||
self.convs_layers.append(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=(3, 1),
|
||||
stride=1,
|
||||
padding=first_conv_padding,
|
||||
bias=True,
|
||||
dilation=first_conv_dilation))
|
||||
self.convs_layers.append(self.act)
|
||||
self.convs_layers.append(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=(1, 3),
|
||||
stride=1,
|
||||
padding=second_conv_padding,
|
||||
bias=True,
|
||||
dilation=second_conv_dilation))
|
||||
self.convs_layers.append(
|
||||
build_norm_layer(self.norm_cfg, channels)[1])
|
||||
if conv_layer == 0:
|
||||
self.convs_layers.append(self.act)
|
||||
else:
|
||||
self.convs_layers.append(nn.Dropout(p=drop_rate))
|
||||
|
||||
def forward(self, input):
|
||||
output = input
|
||||
for conv in self.convs_layers:
|
||||
output = conv(output)
|
||||
output = self.act(output + input)
|
||||
return output
|
||||
|
||||
|
||||
class UpsamplerBlock(BaseModule):
|
||||
"""Upsampler block of ERFNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', eps=1e-3),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.conv = nn.ConvTranspose2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
bias=True)
|
||||
self.bn = build_norm_layer(self.norm_cfg, out_channels)[1]
|
||||
self.act = build_activation_layer(self.act_cfg)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.conv(input)
|
||||
output = self.bn(output)
|
||||
output = self.act(output)
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ERFNet(BaseModule):
|
||||
"""ERFNet backbone.
|
||||
|
||||
This backbone is the implementation of `ERFNet: Efficient Residual
|
||||
Factorized ConvNet for Real-time SemanticSegmentation
|
||||
<https://ieeexplore.ieee.org/document/8063438>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels of input
|
||||
image. Default: 3.
|
||||
enc_downsample_channels (Tuple[int]): Size of channel
|
||||
numbers of various Downsampler block in encoder.
|
||||
Default: (16, 64, 128).
|
||||
enc_stage_non_bottlenecks (Tuple[int]): Number of stages of
|
||||
Non-bottleneck block in encoder.
|
||||
Default: (5, 8).
|
||||
enc_non_bottleneck_dilations (Tuple[int]): Dilation rate of each
|
||||
stage of Non-bottleneck block of encoder.
|
||||
Default: (2, 4, 8, 16).
|
||||
enc_non_bottleneck_channels (Tuple[int]): Size of channel
|
||||
numbers of various Non-bottleneck block in encoder.
|
||||
Default: (64, 128).
|
||||
dec_upsample_channels (Tuple[int]): Size of channel numbers of
|
||||
various Deconvolution block in decoder.
|
||||
Default: (64, 16).
|
||||
dec_stages_non_bottleneck (Tuple[int]): Number of stages of
|
||||
Non-bottleneck block in decoder.
|
||||
Default: (2, 2).
|
||||
dec_non_bottleneck_channels (Tuple[int]): Size of channel
|
||||
numbers of various Non-bottleneck block in decoder.
|
||||
Default: (64, 16).
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(enc_downsample_channels) \
|
||||
== len(dec_upsample_channels)+1, 'Number of downsample\
|
||||
block of encoder does not \
|
||||
match number of upsample block of decoder!'
|
||||
assert len(enc_downsample_channels) \
|
||||
== len(enc_stage_non_bottlenecks)+1, 'Number of \
|
||||
downsample block of encoder does not match \
|
||||
number of Non-bottleneck block of encoder!'
|
||||
assert len(enc_downsample_channels) \
|
||||
== len(enc_non_bottleneck_channels)+1, 'Number of \
|
||||
downsample block of encoder does not match \
|
||||
number of channels of Non-bottleneck block of encoder!'
|
||||
assert enc_stage_non_bottlenecks[-1] \
|
||||
% len(enc_non_bottleneck_dilations) == 0, 'Number of \
|
||||
Non-bottleneck block of encoder does not match \
|
||||
number of Non-bottleneck block of encoder!'
|
||||
assert len(dec_upsample_channels) \
|
||||
== len(dec_stages_non_bottleneck), 'Number of \
|
||||
upsample block of decoder does not match \
|
||||
number of Non-bottleneck block of decoder!'
|
||||
assert len(dec_stages_non_bottleneck) \
|
||||
== len(dec_non_bottleneck_channels), 'Number of \
|
||||
Non-bottleneck block of decoder does not match \
|
||||
number of channels of Non-bottleneck block of decoder!'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.enc_downsample_channels = enc_downsample_channels
|
||||
self.enc_stage_non_bottlenecks = enc_stage_non_bottlenecks
|
||||
self.enc_non_bottleneck_dilations = enc_non_bottleneck_dilations
|
||||
self.enc_non_bottleneck_channels = enc_non_bottleneck_channels
|
||||
self.dec_upsample_channels = dec_upsample_channels
|
||||
self.dec_stages_non_bottleneck = dec_stages_non_bottleneck
|
||||
self.dec_non_bottleneck_channels = dec_non_bottleneck_channels
|
||||
self.dropout_ratio = dropout_ratio
|
||||
|
||||
self.encoder = nn.ModuleList()
|
||||
self.decoder = nn.ModuleList()
|
||||
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.encoder.append(
|
||||
DownsamplerBlock(self.in_channels, enc_downsample_channels[0]))
|
||||
|
||||
for i in range(len(enc_downsample_channels) - 1):
|
||||
self.encoder.append(
|
||||
DownsamplerBlock(enc_downsample_channels[i],
|
||||
enc_downsample_channels[i + 1]))
|
||||
# Last part of encoder is some dilated NonBottleneck1d blocks.
|
||||
if i == len(enc_downsample_channels) - 2:
|
||||
iteration_times = int(enc_stage_non_bottlenecks[-1] /
|
||||
len(enc_non_bottleneck_dilations))
|
||||
for j in range(iteration_times):
|
||||
for k in range(len(enc_non_bottleneck_dilations)):
|
||||
self.encoder.append(
|
||||
NonBottleneck1d(enc_downsample_channels[-1],
|
||||
self.dropout_ratio,
|
||||
enc_non_bottleneck_dilations[k]))
|
||||
else:
|
||||
for j in range(enc_stage_non_bottlenecks[i]):
|
||||
self.encoder.append(
|
||||
NonBottleneck1d(enc_downsample_channels[i + 1],
|
||||
self.dropout_ratio))
|
||||
|
||||
for i in range(len(dec_upsample_channels)):
|
||||
if i == 0:
|
||||
self.decoder.append(
|
||||
UpsamplerBlock(enc_downsample_channels[-1],
|
||||
dec_non_bottleneck_channels[i]))
|
||||
else:
|
||||
self.decoder.append(
|
||||
UpsamplerBlock(dec_non_bottleneck_channels[i - 1],
|
||||
dec_non_bottleneck_channels[i]))
|
||||
for j in range(dec_stages_non_bottleneck[i]):
|
||||
self.decoder.append(
|
||||
NonBottleneck1d(dec_non_bottleneck_channels[i]))
|
||||
|
||||
def forward(self, x):
|
||||
for enc in self.encoder:
|
||||
x = enc(x)
|
||||
for dec in self.decoder:
|
||||
x = dec(x)
|
||||
return [x]
|
||||
408
finetune/mmseg/models/backbones/fast_scnn.py
Normal file
408
finetune/mmseg/models/backbones/fast_scnn.py
Normal file
@@ -0,0 +1,408 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.models.decode_heads.psp_head import PPM
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import InvertedResidual, resize
|
||||
|
||||
|
||||
class LearningToDownsample(nn.Module):
|
||||
"""Learning to downsample module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
dw_channels (tuple[int]): Number of output channels of the first and
|
||||
the second depthwise conv (dwconv) layers.
|
||||
out_channels (int): Number of output channels of the whole
|
||||
'learning to downsample' module.
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
act_cfg (dict): Config of activation layers. Default:
|
||||
dict(type='ReLU')
|
||||
dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config
|
||||
of depthwise ConvModule. If it is 'default', it will be the same
|
||||
as `act_cfg`. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
dw_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
dw_act_cfg=None):
|
||||
super().__init__()
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.dw_act_cfg = dw_act_cfg
|
||||
dw_channels1 = dw_channels[0]
|
||||
dw_channels2 = dw_channels[1]
|
||||
|
||||
self.conv = ConvModule(
|
||||
in_channels,
|
||||
dw_channels1,
|
||||
3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.dsconv1 = DepthwiseSeparableConvModule(
|
||||
dw_channels1,
|
||||
dw_channels2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dw_act_cfg=self.dw_act_cfg)
|
||||
|
||||
self.dsconv2 = DepthwiseSeparableConvModule(
|
||||
dw_channels2,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dw_act_cfg=self.dw_act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.dsconv1(x)
|
||||
x = self.dsconv2(x)
|
||||
return x
|
||||
|
||||
|
||||
class GlobalFeatureExtractor(nn.Module):
|
||||
"""Global feature extractor module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels of the GFE module.
|
||||
Default: 64
|
||||
block_channels (tuple[int]): Tuple of ints. Each int specifies the
|
||||
number of output channels of each Inverted Residual module.
|
||||
Default: (64, 96, 128)
|
||||
out_channels(int): Number of output channels of the GFE module.
|
||||
Default: 128
|
||||
expand_ratio (int): Adjusts number of channels of the hidden layer
|
||||
in InvertedResidual by this amount.
|
||||
Default: 6
|
||||
num_blocks (tuple[int]): Tuple of ints. Each int specifies the
|
||||
number of times each Inverted Residual module is repeated.
|
||||
The repeated Inverted Residual modules are called a 'group'.
|
||||
Default: (3, 3, 3)
|
||||
strides (tuple[int]): Tuple of ints. Each int specifies
|
||||
the downsampling factor of each 'group'.
|
||||
Default: (2, 2, 1)
|
||||
pool_scales (tuple[int]): Tuple of ints. Each int specifies
|
||||
the parameter required in 'global average pooling' within PPM.
|
||||
Default: (1, 2, 3, 6)
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
act_cfg (dict): Config of activation layers. Default:
|
||||
dict(type='ReLU')
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=64,
|
||||
block_channels=(64, 96, 128),
|
||||
out_channels=128,
|
||||
expand_ratio=6,
|
||||
num_blocks=(3, 3, 3),
|
||||
strides=(2, 2, 1),
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False):
|
||||
super().__init__()
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
assert len(block_channels) == len(num_blocks) == 3
|
||||
self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
|
||||
num_blocks[0], strides[0],
|
||||
expand_ratio)
|
||||
self.bottleneck2 = self._make_layer(block_channels[0],
|
||||
block_channels[1], num_blocks[1],
|
||||
strides[1], expand_ratio)
|
||||
self.bottleneck3 = self._make_layer(block_channels[1],
|
||||
block_channels[2], num_blocks[2],
|
||||
strides[2], expand_ratio)
|
||||
self.ppm = PPM(
|
||||
pool_scales,
|
||||
block_channels[2],
|
||||
block_channels[2] // 4,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=align_corners)
|
||||
|
||||
self.out = ConvModule(
|
||||
block_channels[2] * 2,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def _make_layer(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
blocks,
|
||||
stride=1,
|
||||
expand_ratio=6):
|
||||
layers = [
|
||||
InvertedResidual(
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
expand_ratio,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
]
|
||||
for i in range(1, blocks):
|
||||
layers.append(
|
||||
InvertedResidual(
|
||||
out_channels,
|
||||
out_channels,
|
||||
1,
|
||||
expand_ratio,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.bottleneck1(x)
|
||||
x = self.bottleneck2(x)
|
||||
x = self.bottleneck3(x)
|
||||
x = torch.cat([x, *self.ppm(x)], dim=1)
|
||||
x = self.out(x)
|
||||
return x
|
||||
|
||||
|
||||
class FeatureFusionModule(nn.Module):
|
||||
"""Feature fusion module.
|
||||
|
||||
Args:
|
||||
higher_in_channels (int): Number of input channels of the
|
||||
higher-resolution branch.
|
||||
lower_in_channels (int): Number of input channels of the
|
||||
lower-resolution branch.
|
||||
out_channels (int): Number of output channels.
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
dwconv_act_cfg (dict): Config of activation layers in 3x3 conv.
|
||||
Default: dict(type='ReLU').
|
||||
conv_act_cfg (dict): Config of activation layers in the two 1x1 conv.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
higher_in_channels,
|
||||
lower_in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dwconv_act_cfg=dict(type='ReLU'),
|
||||
conv_act_cfg=None,
|
||||
align_corners=False):
|
||||
super().__init__()
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.dwconv_act_cfg = dwconv_act_cfg
|
||||
self.conv_act_cfg = conv_act_cfg
|
||||
self.align_corners = align_corners
|
||||
self.dwconv = ConvModule(
|
||||
lower_in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
groups=out_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.dwconv_act_cfg)
|
||||
self.conv_lower_res = ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.conv_act_cfg)
|
||||
|
||||
self.conv_higher_res = ConvModule(
|
||||
higher_in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.conv_act_cfg)
|
||||
|
||||
self.relu = nn.ReLU(True)
|
||||
|
||||
def forward(self, higher_res_feature, lower_res_feature):
|
||||
lower_res_feature = resize(
|
||||
lower_res_feature,
|
||||
size=higher_res_feature.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
lower_res_feature = self.dwconv(lower_res_feature)
|
||||
lower_res_feature = self.conv_lower_res(lower_res_feature)
|
||||
|
||||
higher_res_feature = self.conv_higher_res(higher_res_feature)
|
||||
out = higher_res_feature + lower_res_feature
|
||||
return self.relu(out)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FastSCNN(BaseModule):
|
||||
"""Fast-SCNN Backbone.
|
||||
|
||||
This backbone is the implementation of `Fast-SCNN: Fast Semantic
|
||||
Segmentation Network <https://arxiv.org/abs/1902.04502>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
downsample_dw_channels (tuple[int]): Number of output channels after
|
||||
the first conv layer & the second conv layer in
|
||||
Learning-To-Downsample (LTD) module.
|
||||
Default: (32, 48).
|
||||
global_in_channels (int): Number of input channels of
|
||||
Global Feature Extractor(GFE).
|
||||
Equal to number of output channels of LTD.
|
||||
Default: 64.
|
||||
global_block_channels (tuple[int]): Tuple of integers that describe
|
||||
the output channels for each of the MobileNet-v2 bottleneck
|
||||
residual blocks in GFE.
|
||||
Default: (64, 96, 128).
|
||||
global_block_strides (tuple[int]): Tuple of integers
|
||||
that describe the strides (downsampling factors) for each of the
|
||||
MobileNet-v2 bottleneck residual blocks in GFE.
|
||||
Default: (2, 2, 1).
|
||||
global_out_channels (int): Number of output channels of GFE.
|
||||
Default: 128.
|
||||
higher_in_channels (int): Number of input channels of the higher
|
||||
resolution branch in FFM.
|
||||
Equal to global_in_channels.
|
||||
Default: 64.
|
||||
lower_in_channels (int): Number of input channels of the lower
|
||||
resolution branch in FFM.
|
||||
Equal to global_out_channels.
|
||||
Default: 128.
|
||||
fusion_out_channels (int): Number of output channels of FFM.
|
||||
Default: 128.
|
||||
out_indices (tuple): Tuple of indices of list
|
||||
[higher_res_features, lower_res_features, fusion_output].
|
||||
Often set to (0,1,2) to enable aux. heads.
|
||||
Default: (0, 1, 2).
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
act_cfg (dict): Config of activation layers. Default:
|
||||
dict(type='ReLU')
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False
|
||||
dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config
|
||||
of depthwise ConvModule. If it is 'default', it will be the same
|
||||
as `act_cfg`. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
downsample_dw_channels=(32, 48),
|
||||
global_in_channels=64,
|
||||
global_block_channels=(64, 96, 128),
|
||||
global_block_strides=(2, 2, 1),
|
||||
global_out_channels=128,
|
||||
higher_in_channels=64,
|
||||
lower_in_channels=128,
|
||||
fusion_out_channels=128,
|
||||
out_indices=(0, 1, 2),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
dw_act_cfg=None,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg)
|
||||
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
|
||||
if global_in_channels != higher_in_channels:
|
||||
raise AssertionError('Global Input Channels must be the same \
|
||||
with Higher Input Channels!')
|
||||
elif global_out_channels != lower_in_channels:
|
||||
raise AssertionError('Global Output Channels must be the same \
|
||||
with Lower Input Channels!')
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.downsample_dw_channels1 = downsample_dw_channels[0]
|
||||
self.downsample_dw_channels2 = downsample_dw_channels[1]
|
||||
self.global_in_channels = global_in_channels
|
||||
self.global_block_channels = global_block_channels
|
||||
self.global_block_strides = global_block_strides
|
||||
self.global_out_channels = global_out_channels
|
||||
self.higher_in_channels = higher_in_channels
|
||||
self.lower_in_channels = lower_in_channels
|
||||
self.fusion_out_channels = fusion_out_channels
|
||||
self.out_indices = out_indices
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.align_corners = align_corners
|
||||
self.learning_to_downsample = LearningToDownsample(
|
||||
in_channels,
|
||||
downsample_dw_channels,
|
||||
global_in_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
dw_act_cfg=dw_act_cfg)
|
||||
self.global_feature_extractor = GlobalFeatureExtractor(
|
||||
global_in_channels,
|
||||
global_block_channels,
|
||||
global_out_channels,
|
||||
strides=self.global_block_strides,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
self.feature_fusion = FeatureFusionModule(
|
||||
higher_in_channels,
|
||||
lower_in_channels,
|
||||
fusion_out_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dwconv_act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
|
||||
def forward(self, x):
|
||||
higher_res_features = self.learning_to_downsample(x)
|
||||
lower_res_features = self.global_feature_extractor(higher_res_features)
|
||||
fusion_output = self.feature_fusion(higher_res_features,
|
||||
lower_res_features)
|
||||
|
||||
outs = [higher_res_features, lower_res_features, fusion_output]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
||||
642
finetune/mmseg/models/backbones/hrnet.py
Normal file
642
finetune/mmseg/models/backbones/hrnet.py
Normal file
@@ -0,0 +1,642 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample, resize
|
||||
from .resnet import BasicBlock, Bottleneck
|
||||
|
||||
|
||||
class HRModule(BaseModule):
|
||||
"""High-Resolution Module for HRNet.
|
||||
|
||||
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
|
||||
is in this module.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_branches,
|
||||
blocks,
|
||||
num_blocks,
|
||||
in_channels,
|
||||
num_channels,
|
||||
multiscale_output=True,
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
block_init_cfg=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.block_init_cfg = block_init_cfg
|
||||
self._check_branches(num_branches, num_blocks, in_channels,
|
||||
num_channels)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.num_branches = num_branches
|
||||
|
||||
self.multiscale_output = multiscale_output
|
||||
self.norm_cfg = norm_cfg
|
||||
self.conv_cfg = conv_cfg
|
||||
self.with_cp = with_cp
|
||||
self.branches = self._make_branches(num_branches, blocks, num_blocks,
|
||||
num_channels)
|
||||
self.fuse_layers = self._make_fuse_layers()
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
|
||||
def _check_branches(self, num_branches, num_blocks, in_channels,
|
||||
num_channels):
|
||||
"""Check branches configuration."""
|
||||
if num_branches != len(num_blocks):
|
||||
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \
|
||||
f'{len(num_blocks)})'
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if num_branches != len(num_channels):
|
||||
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \
|
||||
f'{len(num_channels)})'
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if num_branches != len(in_channels):
|
||||
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \
|
||||
f'{len(in_channels)})'
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def _make_one_branch(self,
|
||||
branch_index,
|
||||
block,
|
||||
num_blocks,
|
||||
num_channels,
|
||||
stride=1):
|
||||
"""Build one branch."""
|
||||
downsample = None
|
||||
if stride != 1 or \
|
||||
self.in_channels[branch_index] != \
|
||||
num_channels[branch_index] * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
self.in_channels[branch_index],
|
||||
num_channels[branch_index] * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, num_channels[branch_index] *
|
||||
block.expansion)[1])
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
block(
|
||||
self.in_channels[branch_index],
|
||||
num_channels[branch_index],
|
||||
stride,
|
||||
downsample=downsample,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=self.block_init_cfg))
|
||||
self.in_channels[branch_index] = \
|
||||
num_channels[branch_index] * block.expansion
|
||||
for i in range(1, num_blocks[branch_index]):
|
||||
layers.append(
|
||||
block(
|
||||
self.in_channels[branch_index],
|
||||
num_channels[branch_index],
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=self.block_init_cfg))
|
||||
|
||||
return Sequential(*layers)
|
||||
|
||||
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
||||
"""Build multiple branch."""
|
||||
branches = []
|
||||
|
||||
for i in range(num_branches):
|
||||
branches.append(
|
||||
self._make_one_branch(i, block, num_blocks, num_channels))
|
||||
|
||||
return ModuleList(branches)
|
||||
|
||||
def _make_fuse_layers(self):
|
||||
"""Build fuse layer."""
|
||||
if self.num_branches == 1:
|
||||
return None
|
||||
|
||||
num_branches = self.num_branches
|
||||
in_channels = self.in_channels
|
||||
fuse_layers = []
|
||||
num_out_branches = num_branches if self.multiscale_output else 1
|
||||
for i in range(num_out_branches):
|
||||
fuse_layer = []
|
||||
for j in range(num_branches):
|
||||
if j > i:
|
||||
fuse_layer.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels[j],
|
||||
in_channels[i],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, in_channels[i])[1],
|
||||
# we set align_corners=False for HRNet
|
||||
Upsample(
|
||||
scale_factor=2**(j - i),
|
||||
mode='bilinear',
|
||||
align_corners=False)))
|
||||
elif j == i:
|
||||
fuse_layer.append(None)
|
||||
else:
|
||||
conv_downsamples = []
|
||||
for k in range(i - j):
|
||||
if k == i - j - 1:
|
||||
conv_downsamples.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels[j],
|
||||
in_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg,
|
||||
in_channels[i])[1]))
|
||||
else:
|
||||
conv_downsamples.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels[j],
|
||||
in_channels[j],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg,
|
||||
in_channels[j])[1],
|
||||
nn.ReLU(inplace=False)))
|
||||
fuse_layer.append(nn.Sequential(*conv_downsamples))
|
||||
fuse_layers.append(nn.ModuleList(fuse_layer))
|
||||
|
||||
return nn.ModuleList(fuse_layers)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
if self.num_branches == 1:
|
||||
return [self.branches[0](x[0])]
|
||||
|
||||
for i in range(self.num_branches):
|
||||
x[i] = self.branches[i](x[i])
|
||||
|
||||
x_fuse = []
|
||||
for i in range(len(self.fuse_layers)):
|
||||
y = 0
|
||||
for j in range(self.num_branches):
|
||||
if i == j:
|
||||
y += x[j]
|
||||
elif j > i:
|
||||
y = y + resize(
|
||||
self.fuse_layers[i][j](x[j]),
|
||||
size=x[i].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
else:
|
||||
y += self.fuse_layers[i][j](x[j])
|
||||
x_fuse.append(self.relu(y))
|
||||
return x_fuse
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class HRNet(BaseModule):
|
||||
"""HRNet backbone.
|
||||
|
||||
This backbone is the implementation of `High-Resolution Representations
|
||||
for Labeling Pixels and Regions <https://arxiv.org/abs/1904.04514>`_.
|
||||
|
||||
Args:
|
||||
extra (dict): Detailed configuration for each stage of HRNet.
|
||||
There must be 4 stages, the configuration for each stage must have
|
||||
5 keys:
|
||||
|
||||
- num_modules (int): The number of HRModule in this stage.
|
||||
- num_branches (int): The number of branches in the HRModule.
|
||||
- block (str): The type of convolution block.
|
||||
- num_blocks (tuple): The number of blocks in each branch.
|
||||
The length must be equal to num_branches.
|
||||
- num_channels (tuple): The number of channels in each branch.
|
||||
The length must be equal to num_branches.
|
||||
in_channels (int): Number of input image channels. Normally 3.
|
||||
conv_cfg (dict): Dictionary to construct and config conv layer.
|
||||
Default: None.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Use `BN` by default.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Default: -1.
|
||||
zero_init_residual (bool): Whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity. Default: False.
|
||||
multiscale_output (bool): Whether to output multi-level features
|
||||
produced by multiple branches. If False, only the first level
|
||||
feature will be output. Default: True.
|
||||
pretrained (str, optional): Model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Example:
|
||||
>>> from mmseg.models import HRNet
|
||||
>>> import torch
|
||||
>>> extra = dict(
|
||||
>>> stage1=dict(
|
||||
>>> num_modules=1,
|
||||
>>> num_branches=1,
|
||||
>>> block='BOTTLENECK',
|
||||
>>> num_blocks=(4, ),
|
||||
>>> num_channels=(64, )),
|
||||
>>> stage2=dict(
|
||||
>>> num_modules=1,
|
||||
>>> num_branches=2,
|
||||
>>> block='BASIC',
|
||||
>>> num_blocks=(4, 4),
|
||||
>>> num_channels=(32, 64)),
|
||||
>>> stage3=dict(
|
||||
>>> num_modules=4,
|
||||
>>> num_branches=3,
|
||||
>>> block='BASIC',
|
||||
>>> num_blocks=(4, 4, 4),
|
||||
>>> num_channels=(32, 64, 128)),
|
||||
>>> stage4=dict(
|
||||
>>> num_modules=3,
|
||||
>>> num_branches=4,
|
||||
>>> block='BASIC',
|
||||
>>> num_blocks=(4, 4, 4, 4),
|
||||
>>> num_channels=(32, 64, 128, 256)))
|
||||
>>> self = HRNet(extra, in_channels=1)
|
||||
>>> self.eval()
|
||||
>>> inputs = torch.rand(1, 1, 32, 32)
|
||||
>>> level_outputs = self.forward(inputs)
|
||||
>>> for level_out in level_outputs:
|
||||
... print(tuple(level_out.shape))
|
||||
(1, 32, 8, 8)
|
||||
(1, 64, 4, 4)
|
||||
(1, 128, 2, 2)
|
||||
(1, 256, 1, 1)
|
||||
"""
|
||||
|
||||
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
|
||||
|
||||
def __init__(self,
|
||||
extra,
|
||||
in_channels=3,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
frozen_stages=-1,
|
||||
zero_init_residual=False,
|
||||
multiscale_output=True,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
self.zero_init_residual = zero_init_residual
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
# Assert configurations of 4 stages are in extra
|
||||
assert 'stage1' in extra and 'stage2' in extra \
|
||||
and 'stage3' in extra and 'stage4' in extra
|
||||
# Assert whether the length of `num_blocks` and `num_channels` are
|
||||
# equal to `num_branches`
|
||||
for i in range(4):
|
||||
cfg = extra[f'stage{i + 1}']
|
||||
assert len(cfg['num_blocks']) == cfg['num_branches'] and \
|
||||
len(cfg['num_channels']) == cfg['num_branches']
|
||||
|
||||
self.extra = extra
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
# stem net
|
||||
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
64,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.conv2 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
64,
|
||||
64,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
# stage 1
|
||||
self.stage1_cfg = self.extra['stage1']
|
||||
num_channels = self.stage1_cfg['num_channels'][0]
|
||||
block_type = self.stage1_cfg['block']
|
||||
num_blocks = self.stage1_cfg['num_blocks'][0]
|
||||
|
||||
block = self.blocks_dict[block_type]
|
||||
stage1_out_channels = num_channels * block.expansion
|
||||
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
|
||||
|
||||
# stage 2
|
||||
self.stage2_cfg = self.extra['stage2']
|
||||
num_channels = self.stage2_cfg['num_channels']
|
||||
block_type = self.stage2_cfg['block']
|
||||
|
||||
block = self.blocks_dict[block_type]
|
||||
num_channels = [channel * block.expansion for channel in num_channels]
|
||||
self.transition1 = self._make_transition_layer([stage1_out_channels],
|
||||
num_channels)
|
||||
self.stage2, pre_stage_channels = self._make_stage(
|
||||
self.stage2_cfg, num_channels)
|
||||
|
||||
# stage 3
|
||||
self.stage3_cfg = self.extra['stage3']
|
||||
num_channels = self.stage3_cfg['num_channels']
|
||||
block_type = self.stage3_cfg['block']
|
||||
|
||||
block = self.blocks_dict[block_type]
|
||||
num_channels = [channel * block.expansion for channel in num_channels]
|
||||
self.transition2 = self._make_transition_layer(pre_stage_channels,
|
||||
num_channels)
|
||||
self.stage3, pre_stage_channels = self._make_stage(
|
||||
self.stage3_cfg, num_channels)
|
||||
|
||||
# stage 4
|
||||
self.stage4_cfg = self.extra['stage4']
|
||||
num_channels = self.stage4_cfg['num_channels']
|
||||
block_type = self.stage4_cfg['block']
|
||||
|
||||
block = self.blocks_dict[block_type]
|
||||
num_channels = [channel * block.expansion for channel in num_channels]
|
||||
self.transition3 = self._make_transition_layer(pre_stage_channels,
|
||||
num_channels)
|
||||
self.stage4, pre_stage_channels = self._make_stage(
|
||||
self.stage4_cfg, num_channels, multiscale_output=multiscale_output)
|
||||
|
||||
self._freeze_stages()
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: the normalization layer named "norm1" """
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
"""nn.Module: the normalization layer named "norm2" """
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def _make_transition_layer(self, num_channels_pre_layer,
|
||||
num_channels_cur_layer):
|
||||
"""Make transition layer."""
|
||||
num_branches_cur = len(num_channels_cur_layer)
|
||||
num_branches_pre = len(num_channels_pre_layer)
|
||||
|
||||
transition_layers = []
|
||||
for i in range(num_branches_cur):
|
||||
if i < num_branches_pre:
|
||||
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
||||
transition_layers.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
num_channels_pre_layer[i],
|
||||
num_channels_cur_layer[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg,
|
||||
num_channels_cur_layer[i])[1],
|
||||
nn.ReLU(inplace=True)))
|
||||
else:
|
||||
transition_layers.append(None)
|
||||
else:
|
||||
conv_downsamples = []
|
||||
for j in range(i + 1 - num_branches_pre):
|
||||
in_channels = num_channels_pre_layer[-1]
|
||||
out_channels = num_channels_cur_layer[i] \
|
||||
if j == i - num_branches_pre else in_channels
|
||||
conv_downsamples.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, out_channels)[1],
|
||||
nn.ReLU(inplace=True)))
|
||||
transition_layers.append(nn.Sequential(*conv_downsamples))
|
||||
|
||||
return nn.ModuleList(transition_layers)
|
||||
|
||||
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
||||
"""Make each layer."""
|
||||
downsample = None
|
||||
if stride != 1 or inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
|
||||
|
||||
layers = []
|
||||
block_init_cfg = None
|
||||
if self.pretrained is None and not hasattr(
|
||||
self, 'init_cfg') and self.zero_init_residual:
|
||||
if block is BasicBlock:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm2'))
|
||||
elif block is Bottleneck:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm3'))
|
||||
|
||||
layers.append(
|
||||
block(
|
||||
inplanes,
|
||||
planes,
|
||||
stride,
|
||||
downsample=downsample,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=block_init_cfg))
|
||||
inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(
|
||||
block(
|
||||
inplanes,
|
||||
planes,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=block_init_cfg))
|
||||
|
||||
return Sequential(*layers)
|
||||
|
||||
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
|
||||
"""Make each stage."""
|
||||
num_modules = layer_config['num_modules']
|
||||
num_branches = layer_config['num_branches']
|
||||
num_blocks = layer_config['num_blocks']
|
||||
num_channels = layer_config['num_channels']
|
||||
block = self.blocks_dict[layer_config['block']]
|
||||
|
||||
hr_modules = []
|
||||
block_init_cfg = None
|
||||
if self.pretrained is None and not hasattr(
|
||||
self, 'init_cfg') and self.zero_init_residual:
|
||||
if block is BasicBlock:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm2'))
|
||||
elif block is Bottleneck:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm3'))
|
||||
|
||||
for i in range(num_modules):
|
||||
# multi_scale_output is only used for the last module
|
||||
if not multiscale_output and i == num_modules - 1:
|
||||
reset_multiscale_output = False
|
||||
else:
|
||||
reset_multiscale_output = True
|
||||
|
||||
hr_modules.append(
|
||||
HRModule(
|
||||
num_branches,
|
||||
block,
|
||||
num_blocks,
|
||||
in_channels,
|
||||
num_channels,
|
||||
reset_multiscale_output,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
block_init_cfg=block_init_cfg))
|
||||
|
||||
return Sequential(*hr_modules), in_channels
|
||||
|
||||
def _freeze_stages(self):
|
||||
"""Freeze stages param and norm stats."""
|
||||
if self.frozen_stages >= 0:
|
||||
|
||||
self.norm1.eval()
|
||||
self.norm2.eval()
|
||||
for m in [self.conv1, self.norm1, self.conv2, self.norm2]:
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
if i == 1:
|
||||
m = getattr(self, f'layer{i}')
|
||||
t = getattr(self, f'transition{i}')
|
||||
elif i == 4:
|
||||
m = getattr(self, f'stage{i}')
|
||||
else:
|
||||
m = getattr(self, f'stage{i}')
|
||||
t = getattr(self, f'transition{i}')
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
t.eval()
|
||||
for param in t.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
x = self.relu(x)
|
||||
x = self.layer1(x)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage2_cfg['num_branches']):
|
||||
if self.transition1[i] is not None:
|
||||
x_list.append(self.transition1[i](x))
|
||||
else:
|
||||
x_list.append(x)
|
||||
y_list = self.stage2(x_list)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage3_cfg['num_branches']):
|
||||
if self.transition2[i] is not None:
|
||||
x_list.append(self.transition2[i](y_list[-1]))
|
||||
else:
|
||||
x_list.append(y_list[i])
|
||||
y_list = self.stage3(x_list)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage4_cfg['num_branches']):
|
||||
if self.transition3[i] is not None:
|
||||
x_list.append(self.transition3[i](y_list[-1]))
|
||||
else:
|
||||
x_list.append(y_list[i])
|
||||
y_list = self.stage4(x_list)
|
||||
|
||||
return y_list
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode will keeping the normalization
|
||||
layer freezed."""
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
166
finetune/mmseg/models/backbones/icnet.py
Normal file
166
finetune/mmseg/models/backbones/icnet.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..decode_heads.psp_head import PPM
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ICNet(BaseModule):
|
||||
"""ICNet for Real-Time Semantic Segmentation on High-Resolution Images.
|
||||
|
||||
This backbone is the implementation of
|
||||
`ICNet <https://arxiv.org/abs/1704.08545>`_.
|
||||
|
||||
Args:
|
||||
backbone_cfg (dict): Config dict to build backbone. Usually it is
|
||||
ResNet but it can also be other backbones.
|
||||
in_channels (int): The number of input image channels. Default: 3.
|
||||
layer_channels (Sequence[int]): The numbers of feature channels at
|
||||
layer 2 and layer 4 in ResNet. It can also be other backbones.
|
||||
Default: (512, 2048).
|
||||
light_branch_middle_channels (int): The number of channels of the
|
||||
middle layer in light branch. Default: 32.
|
||||
psp_out_channels (int): The number of channels of the output of PSP
|
||||
module. Default: 512.
|
||||
out_channels (Sequence[int]): The numbers of output feature channels
|
||||
at each branches. Default: (64, 256, 256).
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module. Default: (1, 2, 3, 6).
|
||||
conv_cfg (dict): Dictionary to construct and config conv layer.
|
||||
Default: None.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Dictionary to construct and config act layer.
|
||||
Default: dict(type='ReLU').
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_cfg,
|
||||
in_channels=3,
|
||||
layer_channels=(512, 2048),
|
||||
light_branch_middle_channels=32,
|
||||
psp_out_channels=512,
|
||||
out_channels=(64, 256, 256),
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
init_cfg=None):
|
||||
if backbone_cfg is None:
|
||||
raise TypeError('backbone_cfg must be passed from config file!')
|
||||
if init_cfg is None:
|
||||
init_cfg = [
|
||||
dict(type='Kaiming', mode='fan_out', layer='Conv2d'),
|
||||
dict(type='Constant', val=1, layer='_BatchNorm'),
|
||||
dict(type='Normal', mean=0.01, layer='Linear')
|
||||
]
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.align_corners = align_corners
|
||||
self.backbone = MODELS.build(backbone_cfg)
|
||||
|
||||
# Note: Default `ceil_mode` is false in nn.MaxPool2d, set
|
||||
# `ceil_mode=True` to keep information in the corner of feature map.
|
||||
self.backbone.maxpool = nn.MaxPool2d(
|
||||
kernel_size=3, stride=2, padding=1, ceil_mode=True)
|
||||
|
||||
self.psp_modules = PPM(
|
||||
pool_scales=pool_scales,
|
||||
in_channels=layer_channels[1],
|
||||
channels=psp_out_channels,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
align_corners=align_corners)
|
||||
|
||||
self.psp_bottleneck = ConvModule(
|
||||
layer_channels[1] + len(pool_scales) * psp_out_channels,
|
||||
psp_out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.conv_sub1 = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=light_branch_middle_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg),
|
||||
ConvModule(
|
||||
in_channels=light_branch_middle_channels,
|
||||
out_channels=light_branch_middle_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg),
|
||||
ConvModule(
|
||||
in_channels=light_branch_middle_channels,
|
||||
out_channels=out_channels[0],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg))
|
||||
|
||||
self.conv_sub2 = ConvModule(
|
||||
layer_channels[0],
|
||||
out_channels[1],
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg)
|
||||
|
||||
self.conv_sub4 = ConvModule(
|
||||
psp_out_channels,
|
||||
out_channels[2],
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
output = []
|
||||
|
||||
# sub 1
|
||||
output.append(self.conv_sub1(x))
|
||||
|
||||
# sub 2
|
||||
x = resize(
|
||||
x,
|
||||
scale_factor=0.5,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.backbone.stem(x)
|
||||
x = self.backbone.maxpool(x)
|
||||
x = self.backbone.layer1(x)
|
||||
x = self.backbone.layer2(x)
|
||||
output.append(self.conv_sub2(x))
|
||||
|
||||
# sub 4
|
||||
x = resize(
|
||||
x,
|
||||
scale_factor=0.5,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.backbone.layer3(x)
|
||||
x = self.backbone.layer4(x)
|
||||
psp_outs = self.psp_modules(x) + [x]
|
||||
psp_outs = torch.cat(psp_outs, dim=1)
|
||||
x = self.psp_bottleneck(psp_outs)
|
||||
|
||||
output.append(self.conv_sub4(x))
|
||||
|
||||
return output
|
||||
260
finetune/mmseg/models/backbones/mae.py
Normal file
260
finetune/mmseg/models/backbones/mae.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.import math
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmengine.runner.checkpoint import _load_checkpoint
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer
|
||||
|
||||
|
||||
class MAEAttention(BEiTAttention):
|
||||
"""Multi-head self-attention with relative position bias used in MAE.
|
||||
|
||||
This module is different from ``BEiTAttention`` by initializing the
|
||||
relative bias table with zeros.
|
||||
"""
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize relative position bias with zeros."""
|
||||
|
||||
# As MAE initializes relative position bias as zeros and this class
|
||||
# inherited from BEiT which initializes relative position bias
|
||||
# with `trunc_normal`, `init_weights` here does
|
||||
# nothing and just passes directly
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer):
|
||||
"""Implements one encoder layer in Vision Transformer.
|
||||
|
||||
This module is different from ``BEiTTransformerEncoderLayer`` by replacing
|
||||
``BEiTAttention`` with ``MAEAttention``.
|
||||
"""
|
||||
|
||||
def build_attn(self, attn_cfg):
|
||||
self.attn = MAEAttention(**attn_cfg)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MAE(BEiT):
|
||||
"""VisionTransformer with support for patch.
|
||||
|
||||
Args:
|
||||
img_size (int | tuple): Input image size. Default: 224.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): embedding dimension. Default: 768.
|
||||
num_layers (int): depth of transformer. Default: 12.
|
||||
num_heads (int): number of attention heads. Default: 12.
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
out_indices (list | tuple | int): Output from which stages.
|
||||
Default: -1.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
|
||||
Default: False.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Default: False.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_values (float): Initialize the values of Attention and FFN
|
||||
with learnable scaling. Defaults to 0.1.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dims=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
out_indices=-1,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
patch_norm=False,
|
||||
final_norm=False,
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
pretrained=None,
|
||||
init_values=0.1,
|
||||
init_cfg=None):
|
||||
super().__init__(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
num_layers=num_layers,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
out_indices=out_indices,
|
||||
qv_bias=False,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
patch_norm=patch_norm,
|
||||
final_norm=final_norm,
|
||||
num_fcs=num_fcs,
|
||||
norm_eval=norm_eval,
|
||||
pretrained=pretrained,
|
||||
init_values=init_values,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
|
||||
self.num_patches = self.patch_shape[0] * self.patch_shape[1]
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, self.num_patches + 1, embed_dims))
|
||||
|
||||
def _build_layers(self):
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, self.drop_path_rate, self.num_layers)
|
||||
]
|
||||
self.layers = ModuleList()
|
||||
for i in range(self.num_layers):
|
||||
self.layers.append(
|
||||
MAETransformerEncoderLayer(
|
||||
embed_dims=self.embed_dims,
|
||||
num_heads=self.num_heads,
|
||||
feedforward_channels=self.mlp_ratio * self.embed_dims,
|
||||
attn_drop_rate=self.attn_drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=self.num_fcs,
|
||||
bias=True,
|
||||
act_cfg=self.act_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
window_size=self.patch_shape,
|
||||
init_values=self.init_values))
|
||||
|
||||
def fix_init_weight(self):
|
||||
"""Rescale the initialization according to layer id.
|
||||
|
||||
This function is copied from https://github.com/microsoft/unilm/blob/master/beit/modeling_pretrain.py. # noqa: E501
|
||||
Copyright (c) Microsoft Corporation
|
||||
Licensed under the MIT License
|
||||
"""
|
||||
|
||||
def rescale(param, layer_id):
|
||||
param.div_(math.sqrt(2.0 * layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.layers):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.ffn.layers[1].weight.data, layer_id + 1)
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
def _init_weights(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
self.apply(_init_weights)
|
||||
self.fix_init_weight()
|
||||
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
checkpoint = _load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
state_dict = self.resize_rel_pos_embed(checkpoint)
|
||||
state_dict = self.resize_abs_pos_embed(state_dict)
|
||||
self.load_state_dict(state_dict, False)
|
||||
elif self.init_cfg is not None:
|
||||
super().init_weights()
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
# Copyright 2019 Ross Wightman
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def resize_abs_pos_embed(self, state_dict):
|
||||
if 'pos_embed' in state_dict:
|
||||
pos_embed_checkpoint = state_dict['pos_embed']
|
||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||
num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
|
||||
# height (== width) for the checkpoint position embedding
|
||||
orig_size = int(
|
||||
(pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
|
||||
# height (== width) for the new position embedding
|
||||
new_size = int(self.num_patches**0.5)
|
||||
# class_token and dist_token are kept unchanged
|
||||
if orig_size != new_size:
|
||||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||
# only the position tokens are interpolated
|
||||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
|
||||
embedding_size).permute(
|
||||
0, 3, 1, 2)
|
||||
pos_tokens = torch.nn.functional.interpolate(
|
||||
pos_tokens,
|
||||
size=(new_size, new_size),
|
||||
mode='bicubic',
|
||||
align_corners=False)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
state_dict['pos_embed'] = new_pos_embed
|
||||
return state_dict
|
||||
|
||||
def forward(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
|
||||
x, hw_shape = self.patch_embed(inputs)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i == len(self.layers) - 1:
|
||||
if self.final_norm:
|
||||
x = self.norm1(x)
|
||||
if i in self.out_indices:
|
||||
out = x[:, 1:]
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
450
finetune/mmseg/models/backbones/mit.py
Normal file
450
finetune/mmseg/models/backbones/mit.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
from mmengine.model.weight_init import (constant_init, normal_init,
|
||||
trunc_normal_init)
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw
|
||||
|
||||
|
||||
class MixFFN(BaseModule):
|
||||
"""An implementation of MixFFN of Segformer.
|
||||
|
||||
The differences between MixFFN & FFN:
|
||||
1. Use 1X1 Conv to replace Linear layer.
|
||||
2. Introduce 3X3 Conv to encode positional information.
|
||||
Args:
|
||||
embed_dims (int): The feature dimension. Same as
|
||||
`MultiheadAttention`. Defaults: 256.
|
||||
feedforward_channels (int): The hidden dimension of FFNs.
|
||||
Defaults: 1024.
|
||||
act_cfg (dict, optional): The activation config for FFNs.
|
||||
Default: dict(type='ReLU')
|
||||
ffn_drop (float, optional): Probability of an element to be
|
||||
zeroed in FFN. Default 0.0.
|
||||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||||
when adding the shortcut.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
feedforward_channels,
|
||||
act_cfg=dict(type='GELU'),
|
||||
ffn_drop=0.,
|
||||
dropout_layer=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.feedforward_channels = feedforward_channels
|
||||
self.act_cfg = act_cfg
|
||||
self.activate = build_activation_layer(act_cfg)
|
||||
|
||||
in_channels = embed_dims
|
||||
fc1 = Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=feedforward_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=True)
|
||||
# 3x3 depth wise conv to provide positional encode information
|
||||
pe_conv = Conv2d(
|
||||
in_channels=feedforward_channels,
|
||||
out_channels=feedforward_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=(3 - 1) // 2,
|
||||
bias=True,
|
||||
groups=feedforward_channels)
|
||||
fc2 = Conv2d(
|
||||
in_channels=feedforward_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=True)
|
||||
drop = nn.Dropout(ffn_drop)
|
||||
layers = [fc1, pe_conv, self.activate, drop, fc2, drop]
|
||||
self.layers = Sequential(*layers)
|
||||
self.dropout_layer = build_dropout(
|
||||
dropout_layer) if dropout_layer else torch.nn.Identity()
|
||||
|
||||
def forward(self, x, hw_shape, identity=None):
|
||||
out = nlc_to_nchw(x, hw_shape)
|
||||
out = self.layers(out)
|
||||
out = nchw_to_nlc(out)
|
||||
if identity is None:
|
||||
identity = x
|
||||
return identity + self.dropout_layer(out)
|
||||
|
||||
|
||||
class EfficientMultiheadAttention(MultiheadAttention):
|
||||
"""An implementation of Efficient Multi-head Attention of Segformer.
|
||||
|
||||
This module is modified from MultiheadAttention which is a module from
|
||||
mmcv.cnn.bricks.transformer.
|
||||
Args:
|
||||
embed_dims (int): The embedding dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
attn_drop (float): A Dropout layer on attn_output_weights.
|
||||
Default: 0.0.
|
||||
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
||||
Default: 0.0.
|
||||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||||
when adding the shortcut. Default: None.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default: False.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default True.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
||||
Attention of Segformer. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
dropout_layer=None,
|
||||
init_cfg=None,
|
||||
batch_first=True,
|
||||
qkv_bias=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
sr_ratio=1):
|
||||
super().__init__(
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop,
|
||||
proj_drop,
|
||||
dropout_layer=dropout_layer,
|
||||
init_cfg=init_cfg,
|
||||
batch_first=batch_first,
|
||||
bias=qkv_bias)
|
||||
|
||||
self.sr_ratio = sr_ratio
|
||||
if sr_ratio > 1:
|
||||
self.sr = Conv2d(
|
||||
in_channels=embed_dims,
|
||||
out_channels=embed_dims,
|
||||
kernel_size=sr_ratio,
|
||||
stride=sr_ratio)
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
# handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
|
||||
from mmseg import digit_version, mmcv_version
|
||||
if mmcv_version < digit_version('1.3.17'):
|
||||
warnings.warn('The legacy version of forward function in'
|
||||
'EfficientMultiheadAttention is deprecated in'
|
||||
'mmcv>=1.3.17 and will no longer support in the'
|
||||
'future. Please upgrade your mmcv.')
|
||||
self.forward = self.legacy_forward
|
||||
|
||||
def forward(self, x, hw_shape, identity=None):
|
||||
|
||||
x_q = x
|
||||
if self.sr_ratio > 1:
|
||||
x_kv = nlc_to_nchw(x, hw_shape)
|
||||
x_kv = self.sr(x_kv)
|
||||
x_kv = nchw_to_nlc(x_kv)
|
||||
x_kv = self.norm(x_kv)
|
||||
else:
|
||||
x_kv = x
|
||||
|
||||
if identity is None:
|
||||
identity = x_q
|
||||
|
||||
# Because the dataflow('key', 'query', 'value') of
|
||||
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
|
||||
# embed_dims), We should adjust the shape of dataflow from
|
||||
# batch_first (batch, num_query, embed_dims) to num_query_first
|
||||
# (num_query ,batch, embed_dims), and recover ``attn_output``
|
||||
# from num_query_first to batch_first.
|
||||
if self.batch_first:
|
||||
x_q = x_q.transpose(0, 1)
|
||||
x_kv = x_kv.transpose(0, 1)
|
||||
|
||||
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
|
||||
|
||||
if self.batch_first:
|
||||
out = out.transpose(0, 1)
|
||||
|
||||
return identity + self.dropout_layer(self.proj_drop(out))
|
||||
|
||||
def legacy_forward(self, x, hw_shape, identity=None):
|
||||
"""multi head attention forward in mmcv version < 1.3.17."""
|
||||
|
||||
x_q = x
|
||||
if self.sr_ratio > 1:
|
||||
x_kv = nlc_to_nchw(x, hw_shape)
|
||||
x_kv = self.sr(x_kv)
|
||||
x_kv = nchw_to_nlc(x_kv)
|
||||
x_kv = self.norm(x_kv)
|
||||
else:
|
||||
x_kv = x
|
||||
|
||||
if identity is None:
|
||||
identity = x_q
|
||||
|
||||
# `need_weights=True` will let nn.MultiHeadAttention
|
||||
# `return attn_output, attn_output_weights.sum(dim=1) / num_heads`
|
||||
# The `attn_output_weights.sum(dim=1)` may cause cuda error. So, we set
|
||||
# `need_weights=False` to ignore `attn_output_weights.sum(dim=1)`.
|
||||
# This issue - `https://github.com/pytorch/pytorch/issues/37583` report
|
||||
# the error that large scale tensor sum operation may cause cuda error.
|
||||
out = self.attn(query=x_q, key=x_kv, value=x_kv, need_weights=False)[0]
|
||||
|
||||
return identity + self.dropout_layer(self.proj_drop(out))
|
||||
|
||||
|
||||
class TransformerEncoderLayer(BaseModule):
|
||||
"""Implements one encoder layer in Segformer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
after the feed forward layer. Default 0.0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0.
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||||
qkv_bias (bool): enable bias for qkv if True.
|
||||
Default: True.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default: False.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Default:None.
|
||||
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
||||
Attention of Segformer. Default: 1.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
qkv_bias=True,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
batch_first=True,
|
||||
sr_ratio=1,
|
||||
with_cp=False):
|
||||
super().__init__()
|
||||
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
self.attn = EfficientMultiheadAttention(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
batch_first=batch_first,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_cfg=norm_cfg,
|
||||
sr_ratio=sr_ratio)
|
||||
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
self.ffn = MixFFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.with_cp = with_cp
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
|
||||
def _inner_forward(x):
|
||||
x = self.attn(self.norm1(x), hw_shape, identity=x)
|
||||
x = self.ffn(self.norm2(x), hw_shape, identity=x)
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MixVisionTransformer(BaseModule):
|
||||
"""The backbone of Segformer.
|
||||
|
||||
This backbone is the implementation of `SegFormer: Simple and
|
||||
Efficient Design for Semantic Segmentation with
|
||||
Transformers <https://arxiv.org/abs/2105.15203>`_.
|
||||
Args:
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): Embedding dimension. Default: 768.
|
||||
num_stags (int): The num of stages. Default: 4.
|
||||
num_layers (Sequence[int]): The layer number of each transformer encode
|
||||
layer. Default: [3, 4, 6, 3].
|
||||
num_heads (Sequence[int]): The attention heads of each transformer
|
||||
encode layer. Default: [1, 2, 4, 8].
|
||||
patch_sizes (Sequence[int]): The patch_size of each overlapped patch
|
||||
embedding. Default: [7, 3, 3, 3].
|
||||
strides (Sequence[int]): The stride of each overlapped patch embedding.
|
||||
Default: [4, 2, 2, 2].
|
||||
sr_ratios (Sequence[int]): The spatial reduction rate of each
|
||||
transformer encode layer. Default: [8, 4, 2, 1].
|
||||
out_indices (Sequence[int] | int): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: True.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.0
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=64,
|
||||
num_stages=4,
|
||||
num_layers=[3, 4, 6, 3],
|
||||
num_heads=[1, 2, 4, 8],
|
||||
patch_sizes=[7, 3, 3, 3],
|
||||
strides=[4, 2, 2, 2],
|
||||
sr_ratios=[8, 4, 2, 1],
|
||||
out_indices=(0, 1, 2, 3),
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
pretrained=None,
|
||||
init_cfg=None,
|
||||
with_cp=False):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.num_stages = num_stages
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.patch_sizes = patch_sizes
|
||||
self.strides = strides
|
||||
self.sr_ratios = sr_ratios
|
||||
self.with_cp = with_cp
|
||||
assert num_stages == len(num_layers) == len(num_heads) \
|
||||
== len(patch_sizes) == len(strides) == len(sr_ratios)
|
||||
|
||||
self.out_indices = out_indices
|
||||
assert max(out_indices) < self.num_stages
|
||||
|
||||
# transformer encoder
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, drop_path_rate, sum(num_layers))
|
||||
] # stochastic num_layer decay rule
|
||||
|
||||
cur = 0
|
||||
self.layers = ModuleList()
|
||||
for i, num_layer in enumerate(num_layers):
|
||||
embed_dims_i = embed_dims * num_heads[i]
|
||||
patch_embed = PatchEmbed(
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims_i,
|
||||
kernel_size=patch_sizes[i],
|
||||
stride=strides[i],
|
||||
padding=patch_sizes[i] // 2,
|
||||
norm_cfg=norm_cfg)
|
||||
layer = ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
embed_dims=embed_dims_i,
|
||||
num_heads=num_heads[i],
|
||||
feedforward_channels=mlp_ratio * embed_dims_i,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=dpr[cur + idx],
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
|
||||
])
|
||||
in_channels = embed_dims_i
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
norm = build_norm_layer(norm_cfg, embed_dims_i)[1]
|
||||
self.layers.append(ModuleList([patch_embed, layer, norm]))
|
||||
cur += num_layer
|
||||
|
||||
def init_weights(self):
|
||||
if self.init_cfg is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[
|
||||
1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
normal_init(
|
||||
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
||||
else:
|
||||
super().init_weights()
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
x, hw_shape = layer[0](x)
|
||||
for block in layer[1]:
|
||||
x = block(x, hw_shape)
|
||||
x = layer[2](x)
|
||||
x = nlc_to_nchw(x, hw_shape)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
return outs
|
||||
197
finetune/mmseg/models/backbones/mobilenet_v2.py
Normal file
197
finetune/mmseg/models/backbones/mobilenet_v2.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import InvertedResidual, make_divisible
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MobileNetV2(BaseModule):
|
||||
"""MobileNetV2 backbone.
|
||||
|
||||
This backbone is the implementation of
|
||||
`MobileNetV2: Inverted Residuals and Linear Bottlenecks
|
||||
<https://arxiv.org/abs/1801.04381>`_.
|
||||
|
||||
Args:
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Default: 1.0.
|
||||
strides (Sequence[int], optional): Strides of the first block of each
|
||||
layer. If not specified, default config in ``arch_setting`` will
|
||||
be used.
|
||||
dilations (Sequence[int]): Dilation of each layer.
|
||||
out_indices (None or Sequence[int]): Output from which stages.
|
||||
Default: (7, ).
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Default: -1, which means not freezing any parameters.
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU6').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
# Parameters to build layers. 3 parameters are needed to construct a
|
||||
# layer, from left to right: expand_ratio, channel, num_blocks.
|
||||
arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4],
|
||||
[6, 96, 3], [6, 160, 3], [6, 320, 1]]
|
||||
|
||||
def __init__(self,
|
||||
widen_factor=1.,
|
||||
strides=(1, 2, 2, 2, 1, 2, 1),
|
||||
dilations=(1, 1, 1, 1, 1, 1, 1),
|
||||
out_indices=(1, 2, 4, 6),
|
||||
frozen_stages=-1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6'),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.widen_factor = widen_factor
|
||||
self.strides = strides
|
||||
self.dilations = dilations
|
||||
assert len(strides) == len(dilations) == len(self.arch_settings)
|
||||
self.out_indices = out_indices
|
||||
for index in out_indices:
|
||||
if index not in range(0, 7):
|
||||
raise ValueError('the item in out_indices must in '
|
||||
f'range(0, 7). But received {index}')
|
||||
|
||||
if frozen_stages not in range(-1, 7):
|
||||
raise ValueError('frozen_stages must be in range(-1, 7). '
|
||||
f'But received {frozen_stages}')
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
self.in_channels = make_divisible(32 * widen_factor, 8)
|
||||
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=3,
|
||||
out_channels=self.in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.layers = []
|
||||
|
||||
for i, layer_cfg in enumerate(self.arch_settings):
|
||||
expand_ratio, channel, num_blocks = layer_cfg
|
||||
stride = self.strides[i]
|
||||
dilation = self.dilations[i]
|
||||
out_channels = make_divisible(channel * widen_factor, 8)
|
||||
inverted_res_layer = self.make_layer(
|
||||
out_channels=out_channels,
|
||||
num_blocks=num_blocks,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
expand_ratio=expand_ratio)
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, inverted_res_layer)
|
||||
self.layers.append(layer_name)
|
||||
|
||||
def make_layer(self, out_channels, num_blocks, stride, dilation,
|
||||
expand_ratio):
|
||||
"""Stack InvertedResidual blocks to build a layer for MobileNetV2.
|
||||
|
||||
Args:
|
||||
out_channels (int): out_channels of block.
|
||||
num_blocks (int): Number of blocks.
|
||||
stride (int): Stride of the first block.
|
||||
dilation (int): Dilation of the first block.
|
||||
expand_ratio (int): Expand the number of channels of the
|
||||
hidden layer in InvertedResidual by this ratio.
|
||||
"""
|
||||
layers = []
|
||||
for i in range(num_blocks):
|
||||
layers.append(
|
||||
InvertedResidual(
|
||||
self.in_channels,
|
||||
out_channels,
|
||||
stride if i == 0 else 1,
|
||||
expand_ratio=expand_ratio,
|
||||
dilation=dilation if i == 0 else 1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
with_cp=self.with_cp))
|
||||
self.in_channels = out_channels
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer = getattr(self, layer_name)
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
if len(outs) == 1:
|
||||
return outs[0]
|
||||
else:
|
||||
return tuple(outs)
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
for param in self.conv1.parameters():
|
||||
param.requires_grad = False
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
layer = getattr(self, f'layer{i}')
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
267
finetune/mmseg/models/backbones/mobilenet_v3.py
Normal file
267
finetune/mmseg/models/backbones/mobilenet_v3.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.cnn.bricks import Conv2dAdaptivePadding
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils import is_tuple_of
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import InvertedResidualV3 as InvertedResidual
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MobileNetV3(BaseModule):
|
||||
"""MobileNetV3 backbone.
|
||||
|
||||
This backbone is the improved implementation of `Searching for MobileNetV3
|
||||
<https://ieeexplore.ieee.org/document/9008835>`_.
|
||||
|
||||
Args:
|
||||
arch (str): Architecture of mobilnetv3, from {'small', 'large'}.
|
||||
Default: 'small'.
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
out_indices (tuple[int]): Output from which layer.
|
||||
Default: (0, 1, 12).
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Default: -1, which means not freezing any parameters.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
# Parameters to build each block:
|
||||
# [kernel size, mid channels, out channels, with_se, act type, stride]
|
||||
arch_settings = {
|
||||
'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4
|
||||
[3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8
|
||||
[3, 88, 24, False, 'ReLU', 1],
|
||||
[5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16
|
||||
[5, 240, 40, True, 'HSwish', 1],
|
||||
[5, 240, 40, True, 'HSwish', 1],
|
||||
[5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16
|
||||
[5, 144, 48, True, 'HSwish', 1],
|
||||
[5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32
|
||||
[5, 576, 96, True, 'HSwish', 1],
|
||||
[5, 576, 96, True, 'HSwish', 1]],
|
||||
'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2
|
||||
[3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4
|
||||
[3, 72, 24, False, 'ReLU', 1],
|
||||
[5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8
|
||||
[5, 120, 40, True, 'ReLU', 1],
|
||||
[5, 120, 40, True, 'ReLU', 1],
|
||||
[3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16
|
||||
[3, 200, 80, False, 'HSwish', 1],
|
||||
[3, 184, 80, False, 'HSwish', 1],
|
||||
[3, 184, 80, False, 'HSwish', 1],
|
||||
[3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16
|
||||
[3, 672, 112, True, 'HSwish', 1],
|
||||
[5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32
|
||||
[5, 960, 160, True, 'HSwish', 1],
|
||||
[5, 960, 160, True, 'HSwish', 1]]
|
||||
} # yapf: disable
|
||||
|
||||
def __init__(self,
|
||||
arch='small',
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
out_indices=(0, 1, 12),
|
||||
frozen_stages=-1,
|
||||
reduction_factor=1,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
assert arch in self.arch_settings
|
||||
assert isinstance(reduction_factor, int) and reduction_factor > 0
|
||||
assert is_tuple_of(out_indices, int)
|
||||
for index in out_indices:
|
||||
if index not in range(0, len(self.arch_settings[arch]) + 2):
|
||||
raise ValueError(
|
||||
'the item in out_indices must in '
|
||||
f'range(0, {len(self.arch_settings[arch])+2}). '
|
||||
f'But received {index}')
|
||||
|
||||
if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
|
||||
raise ValueError('frozen_stages must be in range(-1, '
|
||||
f'{len(self.arch_settings[arch])+2}). '
|
||||
f'But received {frozen_stages}')
|
||||
self.arch = arch
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.reduction_factor = reduction_factor
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.layers = self._make_layer()
|
||||
|
||||
def _make_layer(self):
|
||||
layers = []
|
||||
|
||||
# build the first layer (layer0)
|
||||
in_channels = 16
|
||||
layer = ConvModule(
|
||||
in_channels=3,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=dict(type='Conv2dAdaptivePadding'),
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type='HSwish'))
|
||||
self.add_module('layer0', layer)
|
||||
layers.append('layer0')
|
||||
|
||||
layer_setting = self.arch_settings[self.arch]
|
||||
for i, params in enumerate(layer_setting):
|
||||
(kernel_size, mid_channels, out_channels, with_se, act,
|
||||
stride) = params
|
||||
|
||||
if self.arch == 'large' and i >= 12 or self.arch == 'small' and \
|
||||
i >= 8:
|
||||
mid_channels = mid_channels // self.reduction_factor
|
||||
out_channels = out_channels // self.reduction_factor
|
||||
|
||||
if with_se:
|
||||
se_cfg = dict(
|
||||
channels=mid_channels,
|
||||
ratio=4,
|
||||
act_cfg=(dict(type='ReLU'),
|
||||
dict(type='HSigmoid', bias=3.0, divisor=6.0)))
|
||||
else:
|
||||
se_cfg = None
|
||||
|
||||
layer = InvertedResidual(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
mid_channels=mid_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
se_cfg=se_cfg,
|
||||
with_expand_conv=(in_channels != mid_channels),
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type=act),
|
||||
with_cp=self.with_cp)
|
||||
in_channels = out_channels
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, layer)
|
||||
layers.append(layer_name)
|
||||
|
||||
# build the last layer
|
||||
# block5 layer12 os=32 for small model
|
||||
# block6 layer16 os=32 for large model
|
||||
layer = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=576 if self.arch == 'small' else 960,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
dilation=4,
|
||||
padding=0,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type='HSwish'))
|
||||
layer_name = f'layer{len(layer_setting) + 1}'
|
||||
self.add_module(layer_name, layer)
|
||||
layers.append(layer_name)
|
||||
|
||||
# next, convert backbone MobileNetV3 to a semantic segmentation version
|
||||
if self.arch == 'small':
|
||||
self.layer4.depthwise_conv.conv.stride = (1, 1)
|
||||
self.layer9.depthwise_conv.conv.stride = (1, 1)
|
||||
for i in range(4, len(layers)):
|
||||
layer = getattr(self, layers[i])
|
||||
if isinstance(layer, InvertedResidual):
|
||||
modified_module = layer.depthwise_conv.conv
|
||||
else:
|
||||
modified_module = layer.conv
|
||||
|
||||
if i < 9:
|
||||
modified_module.dilation = (2, 2)
|
||||
pad = 2
|
||||
else:
|
||||
modified_module.dilation = (4, 4)
|
||||
pad = 4
|
||||
|
||||
if not isinstance(modified_module, Conv2dAdaptivePadding):
|
||||
# Adjust padding
|
||||
pad *= (modified_module.kernel_size[0] - 1) // 2
|
||||
modified_module.padding = (pad, pad)
|
||||
else:
|
||||
self.layer7.depthwise_conv.conv.stride = (1, 1)
|
||||
self.layer13.depthwise_conv.conv.stride = (1, 1)
|
||||
for i in range(7, len(layers)):
|
||||
layer = getattr(self, layers[i])
|
||||
if isinstance(layer, InvertedResidual):
|
||||
modified_module = layer.depthwise_conv.conv
|
||||
else:
|
||||
modified_module = layer.conv
|
||||
|
||||
if i < 13:
|
||||
modified_module.dilation = (2, 2)
|
||||
pad = 2
|
||||
else:
|
||||
modified_module.dilation = (4, 4)
|
||||
pad = 4
|
||||
|
||||
if not isinstance(modified_module, Conv2dAdaptivePadding):
|
||||
# Adjust padding
|
||||
pad *= (modified_module.kernel_size[0] - 1) // 2
|
||||
modified_module.padding = (pad, pad)
|
||||
|
||||
return layers
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer = getattr(self, layer_name)
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
return outs
|
||||
|
||||
def _freeze_stages(self):
|
||||
for i in range(self.frozen_stages + 1):
|
||||
layer = getattr(self, f'layer{i}')
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
467
finetune/mmseg/models/backbones/mscan.py
Normal file
467
finetune/mmseg/models/backbones/mscan.py
Normal file
@@ -0,0 +1,467 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Originally from https://github.com/visual-attention-network/segnext
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks import DropPath
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.model.weight_init import (constant_init, normal_init,
|
||||
trunc_normal_init)
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class Mlp(BaseModule):
|
||||
"""Multi Layer Perceptron (MLP) Module.
|
||||
|
||||
Args:
|
||||
in_features (int): The dimension of input features.
|
||||
hidden_features (int): The dimension of hidden features.
|
||||
Defaults: None.
|
||||
out_features (int): The dimension of output features.
|
||||
Defaults: None.
|
||||
act_cfg (dict): Config dict for activation layer in block.
|
||||
Default: dict(type='GELU').
|
||||
drop (float): The number of dropout rate in MLP block.
|
||||
Defaults: 0.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_cfg=dict(type='GELU'),
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
|
||||
self.dwconv = nn.Conv2d(
|
||||
hidden_features,
|
||||
hidden_features,
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
bias=True,
|
||||
groups=hidden_features)
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.fc1(x)
|
||||
|
||||
x = self.dwconv(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class StemConv(BaseModule):
|
||||
"""Stem Block at the beginning of Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): The dimension of input channels.
|
||||
out_channels (int): The dimension of output channels.
|
||||
act_cfg (dict): Config dict for activation layer in block.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults: dict(type='SyncBN', requires_grad=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||||
super().__init__()
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels // 2,
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=(1, 1)),
|
||||
build_norm_layer(norm_cfg, out_channels // 2)[1],
|
||||
build_activation_layer(act_cfg),
|
||||
nn.Conv2d(
|
||||
out_channels // 2,
|
||||
out_channels,
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=(1, 1)),
|
||||
build_norm_layer(norm_cfg, out_channels)[1],
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.size()
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x, H, W
|
||||
|
||||
|
||||
class MSCAAttention(BaseModule):
|
||||
"""Attention Module in Multi-Scale Convolutional Attention Module (MSCA).
|
||||
|
||||
Args:
|
||||
channels (int): The dimension of channels.
|
||||
kernel_sizes (list): The size of attention
|
||||
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
|
||||
paddings (list): The number of
|
||||
corresponding padding value in attention module.
|
||||
Defaults: [2, [0, 3], [0, 5], [0, 10]].
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
|
||||
paddings=[2, [0, 3], [0, 5], [0, 10]]):
|
||||
super().__init__()
|
||||
self.conv0 = nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=kernel_sizes[0],
|
||||
padding=paddings[0],
|
||||
groups=channels)
|
||||
for i, (kernel_size,
|
||||
padding) in enumerate(zip(kernel_sizes[1:], paddings[1:])):
|
||||
kernel_size_ = [kernel_size, kernel_size[::-1]]
|
||||
padding_ = [padding, padding[::-1]]
|
||||
conv_name = [f'conv{i}_1', f'conv{i}_2']
|
||||
for i_kernel, i_pad, i_conv in zip(kernel_size_, padding_,
|
||||
conv_name):
|
||||
self.add_module(
|
||||
i_conv,
|
||||
nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
tuple(i_kernel),
|
||||
padding=i_pad,
|
||||
groups=channels))
|
||||
self.conv3 = nn.Conv2d(channels, channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
u = x.clone()
|
||||
|
||||
attn = self.conv0(x)
|
||||
|
||||
# Multi-Scale Feature extraction
|
||||
attn_0 = self.conv0_1(attn)
|
||||
attn_0 = self.conv0_2(attn_0)
|
||||
|
||||
attn_1 = self.conv1_1(attn)
|
||||
attn_1 = self.conv1_2(attn_1)
|
||||
|
||||
attn_2 = self.conv2_1(attn)
|
||||
attn_2 = self.conv2_2(attn_2)
|
||||
|
||||
attn = attn + attn_0 + attn_1 + attn_2
|
||||
# Channel Mixing
|
||||
attn = self.conv3(attn)
|
||||
|
||||
# Convolutional Attention
|
||||
x = attn * u
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MSCASpatialAttention(BaseModule):
|
||||
"""Spatial Attention Module in Multi-Scale Convolutional Attention Module
|
||||
(MSCA).
|
||||
|
||||
Args:
|
||||
in_channels (int): The dimension of channels.
|
||||
attention_kernel_sizes (list): The size of attention
|
||||
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
|
||||
attention_kernel_paddings (list): The number of
|
||||
corresponding padding value in attention module.
|
||||
Defaults: [2, [0, 3], [0, 5], [0, 10]].
|
||||
act_cfg (dict): Config dict for activation layer in block.
|
||||
Default: dict(type='GELU').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
|
||||
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
|
||||
act_cfg=dict(type='GELU')):
|
||||
super().__init__()
|
||||
self.proj_1 = nn.Conv2d(in_channels, in_channels, 1)
|
||||
self.activation = build_activation_layer(act_cfg)
|
||||
self.spatial_gating_unit = MSCAAttention(in_channels,
|
||||
attention_kernel_sizes,
|
||||
attention_kernel_paddings)
|
||||
self.proj_2 = nn.Conv2d(in_channels, in_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
shorcut = x.clone()
|
||||
x = self.proj_1(x)
|
||||
x = self.activation(x)
|
||||
x = self.spatial_gating_unit(x)
|
||||
x = self.proj_2(x)
|
||||
x = x + shorcut
|
||||
return x
|
||||
|
||||
|
||||
class MSCABlock(BaseModule):
|
||||
"""Basic Multi-Scale Convolutional Attention Block. It leverage the large-
|
||||
kernel attention (LKA) mechanism to build both channel and spatial
|
||||
attention. In each branch, it uses two depth-wise strip convolutions to
|
||||
approximate standard depth-wise convolutions with large kernels. The kernel
|
||||
size for each branch is set to 7, 11, and 21, respectively.
|
||||
|
||||
Args:
|
||||
channels (int): The dimension of channels.
|
||||
attention_kernel_sizes (list): The size of attention
|
||||
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
|
||||
attention_kernel_paddings (list): The number of
|
||||
corresponding padding value in attention module.
|
||||
Defaults: [2, [0, 3], [0, 5], [0, 10]].
|
||||
mlp_ratio (float): The ratio of multiple input dimension to
|
||||
calculate hidden feature in MLP layer. Defaults: 4.0.
|
||||
drop (float): The number of dropout rate in MLP block.
|
||||
Defaults: 0.0.
|
||||
drop_path (float): The ratio of drop paths.
|
||||
Defaults: 0.0.
|
||||
act_cfg (dict): Config dict for activation layer in block.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults: dict(type='SyncBN', requires_grad=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
|
||||
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
|
||||
mlp_ratio=4.,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||||
super().__init__()
|
||||
self.norm1 = build_norm_layer(norm_cfg, channels)[1]
|
||||
self.attn = MSCASpatialAttention(channels, attention_kernel_sizes,
|
||||
attention_kernel_paddings, act_cfg)
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = build_norm_layer(norm_cfg, channels)[1]
|
||||
mlp_hidden_channels = int(channels * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=channels,
|
||||
hidden_features=mlp_hidden_channels,
|
||||
act_cfg=act_cfg,
|
||||
drop=drop)
|
||||
layer_scale_init_value = 1e-2
|
||||
self.layer_scale_1 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones(channels), requires_grad=True)
|
||||
self.layer_scale_2 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones(channels), requires_grad=True)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
"""Forward function."""
|
||||
|
||||
B, N, C = x.shape
|
||||
x = x.permute(0, 2, 1).view(B, C, H, W)
|
||||
x = x + self.drop_path(
|
||||
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) *
|
||||
self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(
|
||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) *
|
||||
self.mlp(self.norm2(x)))
|
||||
x = x.view(B, C, N).permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
class OverlapPatchEmbed(BaseModule):
|
||||
"""Image to Patch Embedding.
|
||||
|
||||
Args:
|
||||
patch_size (int): The patch size.
|
||||
Defaults: 7.
|
||||
stride (int): Stride of the convolutional layer.
|
||||
Default: 4.
|
||||
in_channels (int): The number of input channels.
|
||||
Defaults: 3.
|
||||
embed_dims (int): The dimensions of embedding.
|
||||
Defaults: 768.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults: dict(type='SyncBN', requires_grad=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
patch_size=7,
|
||||
stride=4,
|
||||
in_channels=3,
|
||||
embed_dim=768,
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||||
super().__init__()
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=stride,
|
||||
padding=patch_size // 2)
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
x = self.norm(x)
|
||||
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
return x, H, W
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MSCAN(BaseModule):
|
||||
"""SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone.
|
||||
|
||||
This backbone is the implementation of `SegNeXt: Rethinking
|
||||
Convolutional Attention Design for Semantic
|
||||
Segmentation <https://arxiv.org/abs/2209.08575>`_.
|
||||
Inspiration from https://github.com/visual-attention-network/segnext.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels. Defaults: 3.
|
||||
embed_dims (list[int]): Embedding dimension.
|
||||
Defaults: [64, 128, 256, 512].
|
||||
mlp_ratios (list[int]): Ratio of mlp hidden dim to embedding dim.
|
||||
Defaults: [4, 4, 4, 4].
|
||||
drop_rate (float): Dropout rate. Defaults: 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults: 0.
|
||||
depths (list[int]): Depths of each Swin Transformer stage.
|
||||
Default: [3, 4, 6, 3].
|
||||
num_stages (int): MSCAN stages. Default: 4.
|
||||
attention_kernel_sizes (list): Size of attention kernel in
|
||||
Attention Module (Figure 2(b) of original paper).
|
||||
Defaults: [5, [1, 7], [1, 11], [1, 21]].
|
||||
attention_kernel_paddings (list): Size of attention paddings
|
||||
in Attention Module (Figure 2(b) of original paper).
|
||||
Defaults: [2, [0, 3], [0, 5], [0, 10]].
|
||||
norm_cfg (dict): Config of norm layers.
|
||||
Defaults: dict(type='SyncBN', requires_grad=True).
|
||||
pretrained (str, optional): model pretrained path.
|
||||
Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=[64, 128, 256, 512],
|
||||
mlp_ratios=[4, 4, 4, 4],
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
depths=[3, 4, 6, 3],
|
||||
num_stages=4,
|
||||
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
|
||||
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.depths = depths
|
||||
self.num_stages = num_stages
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
] # stochastic depth decay rule
|
||||
cur = 0
|
||||
|
||||
for i in range(num_stages):
|
||||
if i == 0:
|
||||
patch_embed = StemConv(3, embed_dims[0], norm_cfg=norm_cfg)
|
||||
else:
|
||||
patch_embed = OverlapPatchEmbed(
|
||||
patch_size=7 if i == 0 else 3,
|
||||
stride=4 if i == 0 else 2,
|
||||
in_channels=in_channels if i == 0 else embed_dims[i - 1],
|
||||
embed_dim=embed_dims[i],
|
||||
norm_cfg=norm_cfg)
|
||||
|
||||
block = nn.ModuleList([
|
||||
MSCABlock(
|
||||
channels=embed_dims[i],
|
||||
attention_kernel_sizes=attention_kernel_sizes,
|
||||
attention_kernel_paddings=attention_kernel_paddings,
|
||||
mlp_ratio=mlp_ratios[i],
|
||||
drop=drop_rate,
|
||||
drop_path=dpr[cur + j],
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg) for j in range(depths[i])
|
||||
])
|
||||
norm = nn.LayerNorm(embed_dims[i])
|
||||
cur += depths[i]
|
||||
|
||||
setattr(self, f'patch_embed{i + 1}', patch_embed)
|
||||
setattr(self, f'block{i + 1}', block)
|
||||
setattr(self, f'norm{i + 1}', norm)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize modules of MSCAN."""
|
||||
|
||||
print('init cfg', self.init_cfg)
|
||||
if self.init_cfg is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[
|
||||
1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
normal_init(
|
||||
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
||||
else:
|
||||
super().init_weights()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
B = x.shape[0]
|
||||
outs = []
|
||||
|
||||
for i in range(self.num_stages):
|
||||
patch_embed = getattr(self, f'patch_embed{i + 1}')
|
||||
block = getattr(self, f'block{i + 1}')
|
||||
norm = getattr(self, f'norm{i + 1}')
|
||||
x, H, W = patch_embed(x)
|
||||
for blk in block:
|
||||
x = blk(x, H, W)
|
||||
x = norm(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
return outs
|
||||
522
finetune/mmseg/models/backbones/pidnet.py
Normal file
522
finetune/mmseg/models/backbones/pidnet.py
Normal file
@@ -0,0 +1,522 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.runner import CheckpointLoader
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType
|
||||
from ..utils import DAPPM, PAPPM, BasicBlock, Bottleneck
|
||||
|
||||
|
||||
class PagFM(BaseModule):
|
||||
"""Pixel-attention-guided fusion module.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
channels (int): The number of channels.
|
||||
after_relu (bool): Whether to use ReLU before attention.
|
||||
Default: False.
|
||||
with_channel (bool): Whether to use channel attention.
|
||||
Default: False.
|
||||
upsample_mode (str): The mode of upsample. Default: 'bilinear'.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(typ='ReLU', inplace=True).
|
||||
init_cfg (dict): Config dict for initialization. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
after_relu: bool = False,
|
||||
with_channel: bool = False,
|
||||
upsample_mode: str = 'bilinear',
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(typ='ReLU', inplace=True),
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
self.after_relu = after_relu
|
||||
self.with_channel = with_channel
|
||||
self.upsample_mode = upsample_mode
|
||||
self.f_i = ConvModule(
|
||||
in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
self.f_p = ConvModule(
|
||||
in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
if with_channel:
|
||||
self.up = ConvModule(
|
||||
channels, in_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
if after_relu:
|
||||
self.relu = MODELS.build(act_cfg)
|
||||
|
||||
def forward(self, x_p: Tensor, x_i: Tensor) -> Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x_p (Tensor): The featrue map from P branch.
|
||||
x_i (Tensor): The featrue map from I branch.
|
||||
|
||||
Returns:
|
||||
Tensor: The feature map with pixel-attention-guided fusion.
|
||||
"""
|
||||
if self.after_relu:
|
||||
x_p = self.relu(x_p)
|
||||
x_i = self.relu(x_i)
|
||||
|
||||
f_i = self.f_i(x_i)
|
||||
f_i = F.interpolate(
|
||||
f_i,
|
||||
size=x_p.shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=False)
|
||||
|
||||
f_p = self.f_p(x_p)
|
||||
|
||||
if self.with_channel:
|
||||
sigma = torch.sigmoid(self.up(f_p * f_i))
|
||||
else:
|
||||
sigma = torch.sigmoid(torch.sum(f_p * f_i, dim=1).unsqueeze(1))
|
||||
|
||||
x_i = F.interpolate(
|
||||
x_i,
|
||||
size=x_p.shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=False)
|
||||
|
||||
out = sigma * x_i + (1 - sigma) * x_p
|
||||
return out
|
||||
|
||||
|
||||
class Bag(BaseModule):
|
||||
"""Boundary-attention-guided fusion module.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
kernel_size (int): The kernel size of the convolution. Default: 3.
|
||||
padding (int): The padding of the convolution. Default: 1.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: dict(order=('norm', 'act', 'conv')).
|
||||
init_cfg (dict): Config dict for initialization. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
padding: int = 1,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
conv_cfg: OptConfigType = dict(order=('norm', 'act', 'conv')),
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.conv = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
padding=padding,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**conv_cfg)
|
||||
|
||||
def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x_p (Tensor): The featrue map from P branch.
|
||||
x_i (Tensor): The featrue map from I branch.
|
||||
x_d (Tensor): The featrue map from D branch.
|
||||
|
||||
Returns:
|
||||
Tensor: The feature map with boundary-attention-guided fusion.
|
||||
"""
|
||||
sigma = torch.sigmoid(x_d)
|
||||
return self.conv(sigma * x_p + (1 - sigma) * x_i)
|
||||
|
||||
|
||||
class LightBag(BaseModule):
|
||||
"""Light Boundary-attention-guided fusion module.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer. Default: None.
|
||||
init_cfg (dict): Config dict for initialization. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = None,
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
self.f_p = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.f_i = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor:
|
||||
"""Forward function.
|
||||
Args:
|
||||
x_p (Tensor): The featrue map from P branch.
|
||||
x_i (Tensor): The featrue map from I branch.
|
||||
x_d (Tensor): The featrue map from D branch.
|
||||
|
||||
Returns:
|
||||
Tensor: The feature map with light boundary-attention-guided
|
||||
fusion.
|
||||
"""
|
||||
sigma = torch.sigmoid(x_d)
|
||||
|
||||
f_p = self.f_p((1 - sigma) * x_i + x_p)
|
||||
f_i = self.f_i(x_i + sigma * x_p)
|
||||
|
||||
return f_p + f_i
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PIDNet(BaseModule):
|
||||
"""PIDNet backbone.
|
||||
|
||||
This backbone is the implementation of `PIDNet: A Real-time Semantic
|
||||
Segmentation Network Inspired from PID Controller
|
||||
<https://arxiv.org/abs/2206.02066>`_.
|
||||
Modified from https://github.com/XuJiacong/PIDNet.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels. Default: 3.
|
||||
channels (int): The number of channels in the stem layer. Default: 64.
|
||||
ppm_channels (int): The number of channels in the PPM layer.
|
||||
Default: 96.
|
||||
num_stem_blocks (int): The number of blocks in the stem layer.
|
||||
Default: 2.
|
||||
num_branch_blocks (int): The number of blocks in the branch layer.
|
||||
Default: 3.
|
||||
align_corners (bool): The align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
init_cfg (dict): Config dict for initialization. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int = 3,
|
||||
channels: int = 64,
|
||||
ppm_channels: int = 96,
|
||||
num_stem_blocks: int = 2,
|
||||
num_branch_blocks: int = 3,
|
||||
align_corners: bool = False,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
init_cfg: OptConfigType = None,
|
||||
**kwargs):
|
||||
super().__init__(init_cfg)
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.align_corners = align_corners
|
||||
|
||||
# stem layer
|
||||
self.stem = self._make_stem_layer(in_channels, channels,
|
||||
num_stem_blocks)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
# I Branch
|
||||
self.i_branch_layers = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.i_branch_layers.append(
|
||||
self._make_layer(
|
||||
block=BasicBlock if i < 2 else Bottleneck,
|
||||
in_channels=channels * 2**(i + 1),
|
||||
channels=channels * 8 if i > 0 else channels * 4,
|
||||
num_blocks=num_branch_blocks if i < 2 else 2,
|
||||
stride=2))
|
||||
|
||||
# P Branch
|
||||
self.p_branch_layers = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.p_branch_layers.append(
|
||||
self._make_layer(
|
||||
block=BasicBlock if i < 2 else Bottleneck,
|
||||
in_channels=channels * 2,
|
||||
channels=channels * 2,
|
||||
num_blocks=num_stem_blocks if i < 2 else 1))
|
||||
self.compression_1 = ConvModule(
|
||||
channels * 4,
|
||||
channels * 2,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
self.compression_2 = ConvModule(
|
||||
channels * 8,
|
||||
channels * 2,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
self.pag_1 = PagFM(channels * 2, channels)
|
||||
self.pag_2 = PagFM(channels * 2, channels)
|
||||
|
||||
# D Branch
|
||||
if num_stem_blocks == 2:
|
||||
self.d_branch_layers = nn.ModuleList([
|
||||
self._make_single_layer(BasicBlock, channels * 2, channels),
|
||||
self._make_layer(Bottleneck, channels, channels, 1)
|
||||
])
|
||||
channel_expand = 1
|
||||
spp_module = PAPPM
|
||||
dfm_module = LightBag
|
||||
act_cfg_dfm = None
|
||||
else:
|
||||
self.d_branch_layers = nn.ModuleList([
|
||||
self._make_single_layer(BasicBlock, channels * 2,
|
||||
channels * 2),
|
||||
self._make_single_layer(BasicBlock, channels * 2, channels * 2)
|
||||
])
|
||||
channel_expand = 2
|
||||
spp_module = DAPPM
|
||||
dfm_module = Bag
|
||||
act_cfg_dfm = act_cfg
|
||||
|
||||
self.diff_1 = ConvModule(
|
||||
channels * 4,
|
||||
channels * channel_expand,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
self.diff_2 = ConvModule(
|
||||
channels * 8,
|
||||
channels * 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
self.spp = spp_module(
|
||||
channels * 16, ppm_channels, channels * 4, num_scales=5)
|
||||
self.dfm = dfm_module(
|
||||
channels * 4, channels * 4, norm_cfg=norm_cfg, act_cfg=act_cfg_dfm)
|
||||
|
||||
self.d_branch_layers.append(
|
||||
self._make_layer(Bottleneck, channels * 2, channels * 2, 1))
|
||||
|
||||
def _make_stem_layer(self, in_channels: int, channels: int,
|
||||
num_blocks: int) -> nn.Sequential:
|
||||
"""Make stem layer.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
num_blocks (int): Number of blocks.
|
||||
|
||||
Returns:
|
||||
nn.Sequential: The stem layer.
|
||||
"""
|
||||
|
||||
layers = [
|
||||
ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
]
|
||||
|
||||
layers.append(
|
||||
self._make_layer(BasicBlock, channels, channels, num_blocks))
|
||||
layers.append(nn.ReLU())
|
||||
layers.append(
|
||||
self._make_layer(
|
||||
BasicBlock, channels, channels * 2, num_blocks, stride=2))
|
||||
layers.append(nn.ReLU())
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_layer(self,
|
||||
block: BasicBlock,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
num_blocks: int,
|
||||
stride: int = 1) -> nn.Sequential:
|
||||
"""Make layer for PIDNet backbone.
|
||||
Args:
|
||||
block (BasicBlock): Basic block.
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
num_blocks (int): Number of blocks.
|
||||
stride (int): Stride of the first block. Default: 1.
|
||||
|
||||
Returns:
|
||||
nn.Sequential: The Branch Layer.
|
||||
"""
|
||||
downsample = None
|
||||
if stride != 1 or in_channels != channels * block.expansion:
|
||||
downsample = ConvModule(
|
||||
in_channels,
|
||||
channels * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
layers = [block(in_channels, channels, stride, downsample)]
|
||||
in_channels = channels * block.expansion
|
||||
for i in range(1, num_blocks):
|
||||
layers.append(
|
||||
block(
|
||||
in_channels,
|
||||
channels,
|
||||
stride=1,
|
||||
act_cfg_out=None if i == num_blocks - 1 else self.act_cfg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_single_layer(self,
|
||||
block: Union[BasicBlock, Bottleneck],
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
stride: int = 1) -> nn.Module:
|
||||
"""Make single layer for PIDNet backbone.
|
||||
Args:
|
||||
block (BasicBlock or Bottleneck): Basic block or Bottleneck.
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
stride (int): Stride of the first block. Default: 1.
|
||||
|
||||
Returns:
|
||||
nn.Module
|
||||
"""
|
||||
|
||||
downsample = None
|
||||
if stride != 1 or in_channels != channels * block.expansion:
|
||||
downsample = ConvModule(
|
||||
in_channels,
|
||||
channels * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
return block(
|
||||
in_channels, channels, stride, downsample, act_cfg_out=None)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize the weights in backbone.
|
||||
|
||||
Since the D branch is not initialized by the pre-trained model, we
|
||||
initialize it with the same method as the ResNet.
|
||||
"""
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
if self.init_cfg is not None:
|
||||
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
||||
f'specify `Pretrained` in ' \
|
||||
f'`init_cfg` in ' \
|
||||
f'{self.__class__.__name__} '
|
||||
ckpt = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], map_location='cpu')
|
||||
self.load_state_dict(ckpt, strict=False)
|
||||
|
||||
def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor with shape (B, C, H, W).
|
||||
|
||||
Returns:
|
||||
Tensor or tuple[Tensor]: If self.training is True, return
|
||||
tuple[Tensor], else return Tensor.
|
||||
"""
|
||||
w_out = x.shape[-1] // 8
|
||||
h_out = x.shape[-2] // 8
|
||||
|
||||
# stage 0-2
|
||||
x = self.stem(x)
|
||||
|
||||
# stage 3
|
||||
x_i = self.relu(self.i_branch_layers[0](x))
|
||||
x_p = self.p_branch_layers[0](x)
|
||||
x_d = self.d_branch_layers[0](x)
|
||||
|
||||
comp_i = self.compression_1(x_i)
|
||||
x_p = self.pag_1(x_p, comp_i)
|
||||
diff_i = self.diff_1(x_i)
|
||||
x_d += F.interpolate(
|
||||
diff_i,
|
||||
size=[h_out, w_out],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.training:
|
||||
temp_p = x_p.clone()
|
||||
|
||||
# stage 4
|
||||
x_i = self.relu(self.i_branch_layers[1](x_i))
|
||||
x_p = self.p_branch_layers[1](self.relu(x_p))
|
||||
x_d = self.d_branch_layers[1](self.relu(x_d))
|
||||
|
||||
comp_i = self.compression_2(x_i)
|
||||
x_p = self.pag_2(x_p, comp_i)
|
||||
diff_i = self.diff_2(x_i)
|
||||
x_d += F.interpolate(
|
||||
diff_i,
|
||||
size=[h_out, w_out],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.training:
|
||||
temp_d = x_d.clone()
|
||||
|
||||
# stage 5
|
||||
x_i = self.i_branch_layers[2](x_i)
|
||||
x_p = self.p_branch_layers[2](self.relu(x_p))
|
||||
x_d = self.d_branch_layers[2](self.relu(x_d))
|
||||
|
||||
x_i = self.spp(x_i)
|
||||
x_i = F.interpolate(
|
||||
x_i,
|
||||
size=[h_out, w_out],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
out = self.dfm(x_p, x_i, x_d)
|
||||
return (temp_p, out, temp_d) if self.training else out
|
||||
318
finetune/mmseg/models/backbones/resnest.py
Normal file
318
finetune/mmseg/models/backbones/resnest.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import ResLayer
|
||||
from .resnet import Bottleneck as _Bottleneck
|
||||
from .resnet import ResNetV1d
|
||||
|
||||
|
||||
class RSoftmax(nn.Module):
|
||||
"""Radix Softmax module in ``SplitAttentionConv2d``.
|
||||
|
||||
Args:
|
||||
radix (int): Radix of input.
|
||||
groups (int): Groups of input.
|
||||
"""
|
||||
|
||||
def __init__(self, radix, groups):
|
||||
super().__init__()
|
||||
self.radix = radix
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
batch = x.size(0)
|
||||
if self.radix > 1:
|
||||
x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
|
||||
x = F.softmax(x, dim=1)
|
||||
x = x.reshape(batch, -1)
|
||||
else:
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class SplitAttentionConv2d(nn.Module):
|
||||
"""Split-Attention Conv2d in ResNeSt.
|
||||
|
||||
Args:
|
||||
in_channels (int): Same as nn.Conv2d.
|
||||
out_channels (int): Same as nn.Conv2d.
|
||||
kernel_size (int | tuple[int]): Same as nn.Conv2d.
|
||||
stride (int | tuple[int]): Same as nn.Conv2d.
|
||||
padding (int | tuple[int]): Same as nn.Conv2d.
|
||||
dilation (int | tuple[int]): Same as nn.Conv2d.
|
||||
groups (int): Same as nn.Conv2d.
|
||||
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||
reduction_factor (int): Reduction factor of inter_channels. Default: 4.
|
||||
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
||||
which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
dcn (dict): Config dict for DCN. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dcn=None):
|
||||
super().__init__()
|
||||
inter_channels = max(in_channels * radix // reduction_factor, 32)
|
||||
self.radix = radix
|
||||
self.groups = groups
|
||||
self.channels = channels
|
||||
self.with_dcn = dcn is not None
|
||||
self.dcn = dcn
|
||||
fallback_on_stride = False
|
||||
if self.with_dcn:
|
||||
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
|
||||
if self.with_dcn and not fallback_on_stride:
|
||||
assert conv_cfg is None, 'conv_cfg must be None for DCN'
|
||||
conv_cfg = dcn
|
||||
self.conv = build_conv_layer(
|
||||
conv_cfg,
|
||||
in_channels,
|
||||
channels * radix,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups * radix,
|
||||
bias=False)
|
||||
self.norm0_name, norm0 = build_norm_layer(
|
||||
norm_cfg, channels * radix, postfix=0)
|
||||
self.add_module(self.norm0_name, norm0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc1 = build_conv_layer(
|
||||
None, channels, inter_channels, 1, groups=self.groups)
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, inter_channels, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.fc2 = build_conv_layer(
|
||||
None, inter_channels, channels * radix, 1, groups=self.groups)
|
||||
self.rsoftmax = RSoftmax(radix, groups)
|
||||
|
||||
@property
|
||||
def norm0(self):
|
||||
"""nn.Module: the normalization layer named "norm0" """
|
||||
return getattr(self, self.norm0_name)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: the normalization layer named "norm1" """
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.norm0(x)
|
||||
x = self.relu(x)
|
||||
|
||||
batch, rchannel = x.shape[:2]
|
||||
batch = x.size(0)
|
||||
if self.radix > 1:
|
||||
splits = x.view(batch, self.radix, -1, *x.shape[2:])
|
||||
gap = splits.sum(dim=1)
|
||||
else:
|
||||
gap = x
|
||||
gap = F.adaptive_avg_pool2d(gap, 1)
|
||||
gap = self.fc1(gap)
|
||||
|
||||
gap = self.norm1(gap)
|
||||
gap = self.relu(gap)
|
||||
|
||||
atten = self.fc2(gap)
|
||||
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
|
||||
|
||||
if self.radix > 1:
|
||||
attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
|
||||
out = torch.sum(attens * splits, dim=1)
|
||||
else:
|
||||
out = atten * x
|
||||
return out.contiguous()
|
||||
|
||||
|
||||
class Bottleneck(_Bottleneck):
|
||||
"""Bottleneck block for ResNeSt.
|
||||
|
||||
Args:
|
||||
inplane (int): Input planes of this block.
|
||||
planes (int): Middle planes of this block.
|
||||
groups (int): Groups of conv2.
|
||||
width_per_group (int): Width per group of conv2. 64x4d indicates
|
||||
``groups=64, width_per_group=4`` and 32x8d indicates
|
||||
``groups=32, width_per_group=8``.
|
||||
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||
reduction_factor (int): Reduction factor of inter_channels in
|
||||
SplitAttentionConv2d. Default: 4.
|
||||
avg_down_stride (bool): Whether to use average pool for stride in
|
||||
Bottleneck. Default: True.
|
||||
kwargs (dict): Key word arguments for base class.
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
groups=1,
|
||||
base_width=4,
|
||||
base_channels=64,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True,
|
||||
**kwargs):
|
||||
"""Bottleneck block for ResNeSt."""
|
||||
super().__init__(inplanes, planes, **kwargs)
|
||||
|
||||
if groups == 1:
|
||||
width = self.planes
|
||||
else:
|
||||
width = math.floor(self.planes *
|
||||
(base_width / base_channels)) * groups
|
||||
|
||||
self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
self.norm_cfg, width, postfix=1)
|
||||
self.norm3_name, norm3 = build_norm_layer(
|
||||
self.norm_cfg, self.planes * self.expansion, postfix=3)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
self.inplanes,
|
||||
width,
|
||||
kernel_size=1,
|
||||
stride=self.conv1_stride,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.with_modulated_dcn = False
|
||||
self.conv2 = SplitAttentionConv2d(
|
||||
width,
|
||||
width,
|
||||
kernel_size=3,
|
||||
stride=1 if self.avg_down_stride else self.conv2_stride,
|
||||
padding=self.dilation,
|
||||
dilation=self.dilation,
|
||||
groups=groups,
|
||||
radix=radix,
|
||||
reduction_factor=reduction_factor,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dcn=self.dcn)
|
||||
delattr(self, self.norm2_name)
|
||||
|
||||
if self.avg_down_stride:
|
||||
self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
|
||||
|
||||
self.conv3 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
width,
|
||||
self.planes * self.expansion,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
self.add_module(self.norm3_name, norm3)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.norm1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv1_plugin_names)
|
||||
|
||||
out = self.conv2(out)
|
||||
|
||||
if self.avg_down_stride:
|
||||
out = self.avd_layer(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv2_plugin_names)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.norm3(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv3_plugin_names)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNeSt(ResNetV1d):
|
||||
"""ResNeSt backbone.
|
||||
|
||||
This backbone is the implementation of `ResNeSt:
|
||||
Split-Attention Networks <https://arxiv.org/abs/2004.08955>`_.
|
||||
|
||||
Args:
|
||||
groups (int): Number of groups of Bottleneck. Default: 1
|
||||
base_width (int): Base width of Bottleneck. Default: 4
|
||||
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||
reduction_factor (int): Reduction factor of inter_channels in
|
||||
SplitAttentionConv2d. Default: 4.
|
||||
avg_down_stride (bool): Whether to use average pool for stride in
|
||||
Bottleneck. Default: True.
|
||||
kwargs (dict): Keyword arguments for ResNet.
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
50: (Bottleneck, (3, 4, 6, 3)),
|
||||
101: (Bottleneck, (3, 4, 23, 3)),
|
||||
152: (Bottleneck, (3, 8, 36, 3)),
|
||||
200: (Bottleneck, (3, 24, 36, 3))
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
groups=1,
|
||||
base_width=4,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True,
|
||||
**kwargs):
|
||||
self.groups = groups
|
||||
self.base_width = base_width
|
||||
self.radix = radix
|
||||
self.reduction_factor = reduction_factor
|
||||
self.avg_down_stride = avg_down_stride
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def make_res_layer(self, **kwargs):
|
||||
"""Pack all blocks in a stage into a ``ResLayer``."""
|
||||
return ResLayer(
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
base_channels=self.base_channels,
|
||||
radix=self.radix,
|
||||
reduction_factor=self.reduction_factor,
|
||||
avg_down_stride=self.avg_down_stride,
|
||||
**kwargs)
|
||||
712
finetune/mmseg/models/backbones/resnet.py
Normal file
712
finetune/mmseg/models/backbones/resnet.py
Normal file
@@ -0,0 +1,712 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import ResLayer
|
||||
|
||||
|
||||
class BasicBlock(BaseModule):
|
||||
"""Basic block for ResNet."""
|
||||
|
||||
expansion = 1
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
downsample=None,
|
||||
style='pytorch',
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dcn=None,
|
||||
plugins=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
conv_cfg,
|
||||
inplanes,
|
||||
planes,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.conv2 = build_conv_layer(
|
||||
conv_cfg, planes, planes, 3, padding=1, bias=False)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.with_cp = with_cp
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: normalization layer after the first convolution layer"""
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
"""nn.Module: normalization layer after the second convolution layer"""
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.norm1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.norm2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(BaseModule):
|
||||
"""Bottleneck block for ResNet.
|
||||
|
||||
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
|
||||
"caffe", the stride-two layer is the first 1x1 conv layer.
|
||||
"""
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
downsample=None,
|
||||
style='pytorch',
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dcn=None,
|
||||
plugins=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
assert style in ['pytorch', 'caffe']
|
||||
assert dcn is None or isinstance(dcn, dict)
|
||||
assert plugins is None or isinstance(plugins, list)
|
||||
if plugins is not None:
|
||||
allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
|
||||
assert all(p['position'] in allowed_position for p in plugins)
|
||||
|
||||
self.inplanes = inplanes
|
||||
self.planes = planes
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.style = style
|
||||
self.with_cp = with_cp
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.dcn = dcn
|
||||
self.with_dcn = dcn is not None
|
||||
self.plugins = plugins
|
||||
self.with_plugins = plugins is not None
|
||||
|
||||
if self.with_plugins:
|
||||
# collect plugins for conv1/conv2/conv3
|
||||
self.after_conv1_plugins = [
|
||||
plugin['cfg'] for plugin in plugins
|
||||
if plugin['position'] == 'after_conv1'
|
||||
]
|
||||
self.after_conv2_plugins = [
|
||||
plugin['cfg'] for plugin in plugins
|
||||
if plugin['position'] == 'after_conv2'
|
||||
]
|
||||
self.after_conv3_plugins = [
|
||||
plugin['cfg'] for plugin in plugins
|
||||
if plugin['position'] == 'after_conv3'
|
||||
]
|
||||
|
||||
if self.style == 'pytorch':
|
||||
self.conv1_stride = 1
|
||||
self.conv2_stride = stride
|
||||
else:
|
||||
self.conv1_stride = stride
|
||||
self.conv2_stride = 1
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
|
||||
self.norm3_name, norm3 = build_norm_layer(
|
||||
norm_cfg, planes * self.expansion, postfix=3)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
conv_cfg,
|
||||
inplanes,
|
||||
planes,
|
||||
kernel_size=1,
|
||||
stride=self.conv1_stride,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
fallback_on_stride = False
|
||||
if self.with_dcn:
|
||||
fallback_on_stride = dcn.pop('fallback_on_stride', False)
|
||||
if not self.with_dcn or fallback_on_stride:
|
||||
self.conv2 = build_conv_layer(
|
||||
conv_cfg,
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=self.conv2_stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
else:
|
||||
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
|
||||
self.conv2 = build_conv_layer(
|
||||
dcn,
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=self.conv2_stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
self.conv3 = build_conv_layer(
|
||||
conv_cfg,
|
||||
planes,
|
||||
planes * self.expansion,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
self.add_module(self.norm3_name, norm3)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
|
||||
if self.with_plugins:
|
||||
self.after_conv1_plugin_names = self.make_block_plugins(
|
||||
planes, self.after_conv1_plugins)
|
||||
self.after_conv2_plugin_names = self.make_block_plugins(
|
||||
planes, self.after_conv2_plugins)
|
||||
self.after_conv3_plugin_names = self.make_block_plugins(
|
||||
planes * self.expansion, self.after_conv3_plugins)
|
||||
|
||||
def make_block_plugins(self, in_channels, plugins):
|
||||
"""make plugins for block.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of plugin.
|
||||
plugins (list[dict]): List of plugins cfg to build.
|
||||
|
||||
Returns:
|
||||
list[str]: List of the names of plugin.
|
||||
"""
|
||||
assert isinstance(plugins, list)
|
||||
plugin_names = []
|
||||
for plugin in plugins:
|
||||
plugin = plugin.copy()
|
||||
name, layer = build_plugin_layer(
|
||||
plugin,
|
||||
in_channels=in_channels,
|
||||
postfix=plugin.pop('postfix', ''))
|
||||
assert not hasattr(self, name), f'duplicate plugin {name}'
|
||||
self.add_module(name, layer)
|
||||
plugin_names.append(name)
|
||||
return plugin_names
|
||||
|
||||
def forward_plugin(self, x, plugin_names):
|
||||
"""Forward function for plugins."""
|
||||
out = x
|
||||
for name in plugin_names:
|
||||
out = getattr(self, name)(x)
|
||||
return out
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: normalization layer after the first convolution layer"""
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
"""nn.Module: normalization layer after the second convolution layer"""
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
@property
|
||||
def norm3(self):
|
||||
"""nn.Module: normalization layer after the third convolution layer"""
|
||||
return getattr(self, self.norm3_name)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.norm1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv1_plugin_names)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.norm2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv2_plugin_names)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.norm3(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv3_plugin_names)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNet(BaseModule):
|
||||
"""ResNet backbone.
|
||||
|
||||
This backbone is the improved implementation of `Deep Residual Learning
|
||||
for Image Recognition <https://arxiv.org/abs/1512.03385>`_.
|
||||
|
||||
Args:
|
||||
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
stem_channels (int): Number of stem channels. Default: 64.
|
||||
base_channels (int): Number of base channels of res layer. Default: 64.
|
||||
num_stages (int): Resnet stages, normally 4. Default: 4.
|
||||
strides (Sequence[int]): Strides of the first block of each stage.
|
||||
Default: (1, 2, 2, 2).
|
||||
dilations (Sequence[int]): Dilation of each stage.
|
||||
Default: (1, 1, 1, 1).
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
||||
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
||||
the first 1x1 conv layer. Default: 'pytorch'.
|
||||
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
|
||||
Default: False.
|
||||
avg_down (bool): Use AvgPool instead of stride conv when
|
||||
downsampling in the bottleneck. Default: False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Default: -1.
|
||||
conv_cfg (dict | None): Dictionary to construct and config conv layer.
|
||||
When conv_cfg is None, cfg will be set to dict(type='Conv2d').
|
||||
Default: None.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
dcn (dict | None): Dictionary to construct and config DCN conv layer.
|
||||
When dcn is not None, conv_cfg must be None. Default: None.
|
||||
stage_with_dcn (Sequence[bool]): Whether to set DCN conv for each
|
||||
stage. The length of stage_with_dcn is equal to num_stages.
|
||||
Default: (False, False, False, False).
|
||||
plugins (list[dict]): List of plugins for stages, each dict contains:
|
||||
|
||||
- cfg (dict, required): Cfg dict to build plugin.
|
||||
|
||||
- position (str, required): Position inside block to insert plugin,
|
||||
options: 'after_conv1', 'after_conv2', 'after_conv3'.
|
||||
|
||||
- stages (tuple[bool], optional): Stages to apply plugin, length
|
||||
should be same as 'num_stages'.
|
||||
Default: None.
|
||||
multi_grid (Sequence[int]|None): Multi grid dilation rates of last
|
||||
stage. Default: None.
|
||||
contract_dilation (bool): Whether contract first dilation of each layer
|
||||
Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
zero_init_residual (bool): Whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity. Default: True.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Example:
|
||||
>>> from mmseg.models import ResNet
|
||||
>>> import torch
|
||||
>>> self = ResNet(depth=18)
|
||||
>>> self.eval()
|
||||
>>> inputs = torch.rand(1, 3, 32, 32)
|
||||
>>> level_outputs = self.forward(inputs)
|
||||
>>> for level_out in level_outputs:
|
||||
... print(tuple(level_out.shape))
|
||||
(1, 64, 8, 8)
|
||||
(1, 128, 4, 4)
|
||||
(1, 256, 2, 2)
|
||||
(1, 512, 1, 1)
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
18: (BasicBlock, (2, 2, 2, 2)),
|
||||
34: (BasicBlock, (3, 4, 6, 3)),
|
||||
50: (Bottleneck, (3, 4, 6, 3)),
|
||||
101: (Bottleneck, (3, 4, 23, 3)),
|
||||
152: (Bottleneck, (3, 8, 36, 3))
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
depth,
|
||||
in_channels=3,
|
||||
stem_channels=64,
|
||||
base_channels=64,
|
||||
num_stages=4,
|
||||
strides=(1, 2, 2, 2),
|
||||
dilations=(1, 1, 1, 1),
|
||||
out_indices=(0, 1, 2, 3),
|
||||
style='pytorch',
|
||||
deep_stem=False,
|
||||
avg_down=False,
|
||||
frozen_stages=-1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=False,
|
||||
dcn=None,
|
||||
stage_with_dcn=(False, False, False, False),
|
||||
plugins=None,
|
||||
multi_grid=None,
|
||||
contract_dilation=False,
|
||||
with_cp=False,
|
||||
zero_init_residual=True,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
if depth not in self.arch_settings:
|
||||
raise KeyError(f'invalid depth {depth} for resnet')
|
||||
|
||||
self.pretrained = pretrained
|
||||
self.zero_init_residual = zero_init_residual
|
||||
block_init_cfg = None
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
block = self.arch_settings[depth][0]
|
||||
if self.zero_init_residual:
|
||||
if block is BasicBlock:
|
||||
block_init_cfg = dict(
|
||||
type='Constant',
|
||||
val=0,
|
||||
override=dict(name='norm2'))
|
||||
elif block is Bottleneck:
|
||||
block_init_cfg = dict(
|
||||
type='Constant',
|
||||
val=0,
|
||||
override=dict(name='norm3'))
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.depth = depth
|
||||
self.stem_channels = stem_channels
|
||||
self.base_channels = base_channels
|
||||
self.num_stages = num_stages
|
||||
assert num_stages >= 1 and num_stages <= 4
|
||||
self.strides = strides
|
||||
self.dilations = dilations
|
||||
assert len(strides) == len(dilations) == num_stages
|
||||
self.out_indices = out_indices
|
||||
assert max(out_indices) < num_stages
|
||||
self.style = style
|
||||
self.deep_stem = deep_stem
|
||||
self.avg_down = avg_down
|
||||
self.frozen_stages = frozen_stages
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.with_cp = with_cp
|
||||
self.norm_eval = norm_eval
|
||||
self.dcn = dcn
|
||||
self.stage_with_dcn = stage_with_dcn
|
||||
if dcn is not None:
|
||||
assert len(stage_with_dcn) == num_stages
|
||||
self.plugins = plugins
|
||||
self.multi_grid = multi_grid
|
||||
self.contract_dilation = contract_dilation
|
||||
self.block, stage_blocks = self.arch_settings[depth]
|
||||
self.stage_blocks = stage_blocks[:num_stages]
|
||||
self.inplanes = stem_channels
|
||||
|
||||
self._make_stem_layer(in_channels, stem_channels)
|
||||
|
||||
self.res_layers = []
|
||||
for i, num_blocks in enumerate(self.stage_blocks):
|
||||
stride = strides[i]
|
||||
dilation = dilations[i]
|
||||
dcn = self.dcn if self.stage_with_dcn[i] else None
|
||||
if plugins is not None:
|
||||
stage_plugins = self.make_stage_plugins(plugins, i)
|
||||
else:
|
||||
stage_plugins = None
|
||||
# multi grid is applied to last layer only
|
||||
stage_multi_grid = multi_grid if i == len(
|
||||
self.stage_blocks) - 1 else None
|
||||
planes = base_channels * 2**i
|
||||
res_layer = self.make_res_layer(
|
||||
block=self.block,
|
||||
inplanes=self.inplanes,
|
||||
planes=planes,
|
||||
num_blocks=num_blocks,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
style=self.style,
|
||||
avg_down=self.avg_down,
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
dcn=dcn,
|
||||
plugins=stage_plugins,
|
||||
multi_grid=stage_multi_grid,
|
||||
contract_dilation=contract_dilation,
|
||||
init_cfg=block_init_cfg)
|
||||
self.inplanes = planes * self.block.expansion
|
||||
layer_name = f'layer{i+1}'
|
||||
self.add_module(layer_name, res_layer)
|
||||
self.res_layers.append(layer_name)
|
||||
|
||||
self._freeze_stages()
|
||||
|
||||
self.feat_dim = self.block.expansion * base_channels * 2**(
|
||||
len(self.stage_blocks) - 1)
|
||||
|
||||
def make_stage_plugins(self, plugins, stage_idx):
|
||||
"""make plugins for ResNet 'stage_idx'th stage .
|
||||
|
||||
Currently we support to insert 'context_block',
|
||||
'empirical_attention_block', 'nonlocal_block' into the backbone like
|
||||
ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
|
||||
Bottleneck.
|
||||
|
||||
An example of plugins format could be :
|
||||
>>> plugins=[
|
||||
... dict(cfg=dict(type='xxx', arg1='xxx'),
|
||||
... stages=(False, True, True, True),
|
||||
... position='after_conv2'),
|
||||
... dict(cfg=dict(type='yyy'),
|
||||
... stages=(True, True, True, True),
|
||||
... position='after_conv3'),
|
||||
... dict(cfg=dict(type='zzz', postfix='1'),
|
||||
... stages=(True, True, True, True),
|
||||
... position='after_conv3'),
|
||||
... dict(cfg=dict(type='zzz', postfix='2'),
|
||||
... stages=(True, True, True, True),
|
||||
... position='after_conv3')
|
||||
... ]
|
||||
>>> self = ResNet(depth=18)
|
||||
>>> stage_plugins = self.make_stage_plugins(plugins, 0)
|
||||
>>> assert len(stage_plugins) == 3
|
||||
|
||||
Suppose 'stage_idx=0', the structure of blocks in the stage would be:
|
||||
conv1-> conv2->conv3->yyy->zzz1->zzz2
|
||||
Suppose 'stage_idx=1', the structure of blocks in the stage would be:
|
||||
conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
|
||||
|
||||
If stages is missing, the plugin would be applied to all stages.
|
||||
|
||||
Args:
|
||||
plugins (list[dict]): List of plugins cfg to build. The postfix is
|
||||
required if multiple same type plugins are inserted.
|
||||
stage_idx (int): Index of stage to build
|
||||
|
||||
Returns:
|
||||
list[dict]: Plugins for current stage
|
||||
"""
|
||||
stage_plugins = []
|
||||
for plugin in plugins:
|
||||
plugin = plugin.copy()
|
||||
stages = plugin.pop('stages', None)
|
||||
assert stages is None or len(stages) == self.num_stages
|
||||
# whether to insert plugin into current stage
|
||||
if stages is None or stages[stage_idx]:
|
||||
stage_plugins.append(plugin)
|
||||
|
||||
return stage_plugins
|
||||
|
||||
def make_res_layer(self, **kwargs):
|
||||
"""Pack all blocks in a stage into a ``ResLayer``."""
|
||||
return ResLayer(**kwargs)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: the normalization layer named "norm1" """
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def _make_stem_layer(self, in_channels, stem_channels):
|
||||
"""Make stem layer for ResNet."""
|
||||
if self.deep_stem:
|
||||
self.stem = nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
stem_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
|
||||
nn.ReLU(inplace=True),
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
stem_channels // 2,
|
||||
stem_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
|
||||
nn.ReLU(inplace=True),
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
stem_channels // 2,
|
||||
stem_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, stem_channels)[1],
|
||||
nn.ReLU(inplace=True))
|
||||
else:
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
stem_channels,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
bias=False)
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
self.norm_cfg, stem_channels, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
def _freeze_stages(self):
|
||||
"""Freeze stages param and norm stats."""
|
||||
if self.frozen_stages >= 0:
|
||||
if self.deep_stem:
|
||||
self.stem.eval()
|
||||
for param in self.stem.parameters():
|
||||
param.requires_grad = False
|
||||
else:
|
||||
self.norm1.eval()
|
||||
for m in [self.conv1, self.norm1]:
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
m = getattr(self, f'layer{i}')
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
if self.deep_stem:
|
||||
x = self.stem(x)
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.res_layers):
|
||||
res_layer = getattr(self, layer_name)
|
||||
x = res_layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep normalization layer
|
||||
freezed."""
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNetV1c(ResNet):
|
||||
"""ResNetV1c variant described in [1]_.
|
||||
|
||||
Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv in
|
||||
the input stem with three 3x3 convs. For more details please refer to `Bag
|
||||
of Tricks for Image Classification with Convolutional Neural Networks
|
||||
<https://arxiv.org/abs/1812.01187>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(deep_stem=True, avg_down=False, **kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNetV1d(ResNet):
|
||||
"""ResNetV1d variant described in [1]_.
|
||||
|
||||
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
|
||||
the input stem with three 3x3 convs. And in the downsampling block, a 2x2
|
||||
avg_pool with stride 2 is added before conv, whose stride is changed to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(deep_stem=True, avg_down=True, **kwargs)
|
||||
150
finetune/mmseg/models/backbones/resnext.py
Normal file
150
finetune/mmseg/models/backbones/resnext.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import ResLayer
|
||||
from .resnet import Bottleneck as _Bottleneck
|
||||
from .resnet import ResNet
|
||||
|
||||
|
||||
class Bottleneck(_Bottleneck):
|
||||
"""Bottleneck block for ResNeXt.
|
||||
|
||||
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
|
||||
"caffe", the stride-two layer is the first 1x1 conv layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
groups=1,
|
||||
base_width=4,
|
||||
base_channels=64,
|
||||
**kwargs):
|
||||
super().__init__(inplanes, planes, **kwargs)
|
||||
|
||||
if groups == 1:
|
||||
width = self.planes
|
||||
else:
|
||||
width = math.floor(self.planes *
|
||||
(base_width / base_channels)) * groups
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
self.norm_cfg, width, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
self.norm_cfg, width, postfix=2)
|
||||
self.norm3_name, norm3 = build_norm_layer(
|
||||
self.norm_cfg, self.planes * self.expansion, postfix=3)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
self.inplanes,
|
||||
width,
|
||||
kernel_size=1,
|
||||
stride=self.conv1_stride,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
fallback_on_stride = False
|
||||
self.with_modulated_dcn = False
|
||||
if self.with_dcn:
|
||||
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
|
||||
if not self.with_dcn or fallback_on_stride:
|
||||
self.conv2 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
width,
|
||||
width,
|
||||
kernel_size=3,
|
||||
stride=self.conv2_stride,
|
||||
padding=self.dilation,
|
||||
dilation=self.dilation,
|
||||
groups=groups,
|
||||
bias=False)
|
||||
else:
|
||||
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
|
||||
self.conv2 = build_conv_layer(
|
||||
self.dcn,
|
||||
width,
|
||||
width,
|
||||
kernel_size=3,
|
||||
stride=self.conv2_stride,
|
||||
padding=self.dilation,
|
||||
dilation=self.dilation,
|
||||
groups=groups,
|
||||
bias=False)
|
||||
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
self.conv3 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
width,
|
||||
self.planes * self.expansion,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
self.add_module(self.norm3_name, norm3)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNeXt(ResNet):
|
||||
"""ResNeXt backbone.
|
||||
|
||||
This backbone is the implementation of `Aggregated
|
||||
Residual Transformations for Deep Neural
|
||||
Networks <https://arxiv.org/abs/1611.05431>`_.
|
||||
|
||||
Args:
|
||||
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
|
||||
in_channels (int): Number of input image channels. Normally 3.
|
||||
num_stages (int): Resnet stages, normally 4.
|
||||
groups (int): Group of resnext.
|
||||
base_width (int): Base width of resnext.
|
||||
strides (Sequence[int]): Strides of the first block of each stage.
|
||||
dilations (Sequence[int]): Dilation of each stage.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
||||
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
||||
the first 1x1 conv layer.
|
||||
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
|
||||
not freezing any parameters.
|
||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed.
|
||||
zero_init_residual (bool): whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity.
|
||||
|
||||
Example:
|
||||
>>> from mmseg.models import ResNeXt
|
||||
>>> import torch
|
||||
>>> self = ResNeXt(depth=50)
|
||||
>>> self.eval()
|
||||
>>> inputs = torch.rand(1, 3, 32, 32)
|
||||
>>> level_outputs = self.forward(inputs)
|
||||
>>> for level_out in level_outputs:
|
||||
... print(tuple(level_out.shape))
|
||||
(1, 256, 8, 8)
|
||||
(1, 512, 4, 4)
|
||||
(1, 1024, 2, 2)
|
||||
(1, 2048, 1, 1)
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
50: (Bottleneck, (3, 4, 6, 3)),
|
||||
101: (Bottleneck, (3, 4, 23, 3)),
|
||||
152: (Bottleneck, (3, 8, 36, 3))
|
||||
}
|
||||
|
||||
def __init__(self, groups=1, base_width=4, **kwargs):
|
||||
self.groups = groups
|
||||
self.base_width = base_width
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def make_res_layer(self, **kwargs):
|
||||
"""Pack all blocks in a stage into a ``ResLayer``"""
|
||||
return ResLayer(
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
base_channels=self.base_channels,
|
||||
**kwargs)
|
||||
422
finetune/mmseg/models/backbones/stdc.py
Normal file
422
finetune/mmseg/models/backbones/stdc.py
Normal file
@@ -0,0 +1,422 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Modified from https://github.com/MichaelFan01/STDC-Seg."""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .bisenetv1 import AttentionRefinementModule
|
||||
|
||||
|
||||
class STDCModule(BaseModule):
|
||||
"""STDCModule.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels before scaling.
|
||||
stride (int): The number of stride for the first conv layer.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
act_cfg (dict): The activation config for conv layers.
|
||||
num_convs (int): Numbers of conv layers.
|
||||
fusion_type (str): Type of fusion operation. Default: 'add'.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
norm_cfg=None,
|
||||
act_cfg=None,
|
||||
num_convs=4,
|
||||
fusion_type='add',
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert num_convs > 1
|
||||
assert fusion_type in ['add', 'cat']
|
||||
self.stride = stride
|
||||
self.with_downsample = True if self.stride == 2 else False
|
||||
self.fusion_type = fusion_type
|
||||
|
||||
self.layers = ModuleList()
|
||||
conv_0 = ConvModule(
|
||||
in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg)
|
||||
|
||||
if self.with_downsample:
|
||||
self.downsample = ConvModule(
|
||||
out_channels // 2,
|
||||
out_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=out_channels // 2,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
if self.fusion_type == 'add':
|
||||
self.layers.append(nn.Sequential(conv_0, self.downsample))
|
||||
self.skip = Sequential(
|
||||
ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None))
|
||||
else:
|
||||
self.layers.append(conv_0)
|
||||
self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
|
||||
else:
|
||||
self.layers.append(conv_0)
|
||||
|
||||
for i in range(1, num_convs):
|
||||
out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i
|
||||
self.layers.append(
|
||||
ConvModule(
|
||||
out_channels // 2**i,
|
||||
out_channels // out_factor,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.fusion_type == 'add':
|
||||
out = self.forward_add(inputs)
|
||||
else:
|
||||
out = self.forward_cat(inputs)
|
||||
return out
|
||||
|
||||
def forward_add(self, inputs):
|
||||
layer_outputs = []
|
||||
x = inputs.clone()
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
layer_outputs.append(x)
|
||||
if self.with_downsample:
|
||||
inputs = self.skip(inputs)
|
||||
|
||||
return torch.cat(layer_outputs, dim=1) + inputs
|
||||
|
||||
def forward_cat(self, inputs):
|
||||
x0 = self.layers[0](inputs)
|
||||
layer_outputs = [x0]
|
||||
for i, layer in enumerate(self.layers[1:]):
|
||||
if i == 0:
|
||||
if self.with_downsample:
|
||||
x = layer(self.downsample(x0))
|
||||
else:
|
||||
x = layer(x0)
|
||||
else:
|
||||
x = layer(x)
|
||||
layer_outputs.append(x)
|
||||
if self.with_downsample:
|
||||
layer_outputs[0] = self.skip(x0)
|
||||
return torch.cat(layer_outputs, dim=1)
|
||||
|
||||
|
||||
class FeatureFusionModule(BaseModule):
|
||||
"""Feature Fusion Module. This module is different from FeatureFusionModule
|
||||
in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter
|
||||
channel number is calculated by given `scale_factor`, while
|
||||
FeatureFusionModule in BiSeNetV1 only uses one ConvModule in
|
||||
`self.conv_atten`.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
scale_factor (int): The number of channel scale factor.
|
||||
Default: 4.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): The activation config for conv layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
scale_factor=4,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
channels = out_channels // scale_factor
|
||||
self.conv0 = ConvModule(
|
||||
in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
self.attention = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
ConvModule(
|
||||
out_channels,
|
||||
channels,
|
||||
1,
|
||||
norm_cfg=None,
|
||||
bias=False,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
channels,
|
||||
out_channels,
|
||||
1,
|
||||
norm_cfg=None,
|
||||
bias=False,
|
||||
act_cfg=None), nn.Sigmoid())
|
||||
|
||||
def forward(self, spatial_inputs, context_inputs):
|
||||
inputs = torch.cat([spatial_inputs, context_inputs], dim=1)
|
||||
x = self.conv0(inputs)
|
||||
attn = self.attention(x)
|
||||
x_attn = x * attn
|
||||
return x_attn + x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class STDCNet(BaseModule):
|
||||
"""This backbone is the implementation of `Rethinking BiSeNet For Real-time
|
||||
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
|
||||
|
||||
Args:
|
||||
stdc_type (int): The type of backbone structure,
|
||||
`STDCNet1` and`STDCNet2` denotes two main backbones in paper,
|
||||
whose FLOPs is 813M and 1446M, respectively.
|
||||
in_channels (int): The num of input_channels.
|
||||
channels (tuple[int]): The output channels for each stage.
|
||||
bottleneck_type (str): The type of STDC Module type, the value must
|
||||
be 'add' or 'cat'.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
act_cfg (dict): The activation config for conv layers.
|
||||
num_convs (int): Numbers of conv layer at each STDC Module.
|
||||
Default: 4.
|
||||
with_final_conv (bool): Whether add a conv layer at the Module output.
|
||||
Default: True.
|
||||
pretrained (str, optional): Model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> stdc_type = 'STDCNet1'
|
||||
>>> in_channels = 3
|
||||
>>> channels = (32, 64, 256, 512, 1024)
|
||||
>>> bottleneck_type = 'cat'
|
||||
>>> inputs = torch.rand(1, 3, 1024, 2048)
|
||||
>>> self = STDCNet(stdc_type, in_channels,
|
||||
... channels, bottleneck_type).eval()
|
||||
>>> outputs = self.forward(inputs)
|
||||
>>> for i in range(len(outputs)):
|
||||
... print(f'outputs[{i}].shape = {outputs[i].shape}')
|
||||
outputs[0].shape = torch.Size([1, 256, 128, 256])
|
||||
outputs[1].shape = torch.Size([1, 512, 64, 128])
|
||||
outputs[2].shape = torch.Size([1, 1024, 32, 64])
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
'STDCNet1': [(2, 1), (2, 1), (2, 1)],
|
||||
'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)]
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
stdc_type,
|
||||
in_channels,
|
||||
channels,
|
||||
bottleneck_type,
|
||||
norm_cfg,
|
||||
act_cfg,
|
||||
num_convs=4,
|
||||
with_final_conv=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert stdc_type in self.arch_settings, \
|
||||
f'invalid structure {stdc_type} for STDCNet.'
|
||||
assert bottleneck_type in ['add', 'cat'],\
|
||||
f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}'
|
||||
|
||||
assert len(channels) == 5,\
|
||||
f'invalid channels length {len(channels)} for STDCNet.'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.stage_strides = self.arch_settings[stdc_type]
|
||||
self.prtrained = pretrained
|
||||
self.num_convs = num_convs
|
||||
self.with_final_conv = with_final_conv
|
||||
|
||||
self.stages = ModuleList([
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels[0],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
self.channels[0],
|
||||
self.channels[1],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
])
|
||||
# `self.num_shallow_features` is the number of shallow modules in
|
||||
# `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper.
|
||||
# They are both not used for following modules like Attention
|
||||
# Refinement Module and Feature Fusion Module.
|
||||
# Thus they would be cut from `outs`. Please refer to Figure 4
|
||||
# of original paper for more details.
|
||||
self.num_shallow_features = len(self.stages)
|
||||
|
||||
for strides in self.stage_strides:
|
||||
idx = len(self.stages) - 1
|
||||
self.stages.append(
|
||||
self._make_stage(self.channels[idx], self.channels[idx + 1],
|
||||
strides, norm_cfg, act_cfg, bottleneck_type))
|
||||
# After appending, `self.stages` is a ModuleList including several
|
||||
# shallow modules and STDCModules.
|
||||
# (len(self.stages) ==
|
||||
# self.num_shallow_features + len(self.stage_strides))
|
||||
if self.with_final_conv:
|
||||
self.final_conv = ConvModule(
|
||||
self.channels[-1],
|
||||
max(1024, self.channels[-1]),
|
||||
1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def _make_stage(self, in_channels, out_channels, strides, norm_cfg,
|
||||
act_cfg, bottleneck_type):
|
||||
layers = []
|
||||
for i, stride in enumerate(strides):
|
||||
layers.append(
|
||||
STDCModule(
|
||||
in_channels if i == 0 else out_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
norm_cfg,
|
||||
act_cfg,
|
||||
num_convs=self.num_convs,
|
||||
fusion_type=bottleneck_type))
|
||||
return Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
for stage in self.stages:
|
||||
x = stage(x)
|
||||
outs.append(x)
|
||||
if self.with_final_conv:
|
||||
outs[-1] = self.final_conv(outs[-1])
|
||||
outs = outs[self.num_shallow_features:]
|
||||
return tuple(outs)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class STDCContextPathNet(BaseModule):
|
||||
"""STDCNet with Context Path. The `outs` below is a list of three feature
|
||||
maps from deep to shallow, whose height and width is from small to big,
|
||||
respectively. The biggest feature map of `outs` is outputted for
|
||||
`STDCHead`, where Detail Loss would be calculated by Detail Ground-truth.
|
||||
The other two feature maps are used for Attention Refinement Module,
|
||||
respectively. Besides, the biggest feature map of `outs` and the last
|
||||
output of Attention Refinement Module are concatenated for Feature Fusion
|
||||
Module. Then, this fusion feature map `feat_fuse` would be outputted for
|
||||
`decode_head`. More details please refer to Figure 4 of original paper.
|
||||
|
||||
Args:
|
||||
backbone_cfg (dict): Config dict for stdc backbone.
|
||||
last_in_channels (tuple(int)), The number of channels of last
|
||||
two feature maps from stdc backbone. Default: (1024, 512).
|
||||
out_channels (int): The channels of output feature maps.
|
||||
Default: 128.
|
||||
ffm_cfg (dict): Config dict for Feature Fusion Module. Default:
|
||||
`dict(in_channels=512, out_channels=256, scale_factor=4)`.
|
||||
upsample_mode (str): Algorithm used for upsampling:
|
||||
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||
``'trilinear'``. Default: ``'nearest'``.
|
||||
align_corners (str): align_corners argument of F.interpolate. It
|
||||
must be `None` if upsample_mode is ``'nearest'``. Default: None.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Return:
|
||||
outputs (tuple): The tuple of list of output feature map for
|
||||
auxiliary heads and decoder head.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_cfg,
|
||||
last_in_channels=(1024, 512),
|
||||
out_channels=128,
|
||||
ffm_cfg=dict(
|
||||
in_channels=512, out_channels=256, scale_factor=4),
|
||||
upsample_mode='nearest',
|
||||
align_corners=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.backbone = MODELS.build(backbone_cfg)
|
||||
self.arms = ModuleList()
|
||||
self.convs = ModuleList()
|
||||
for channels in last_in_channels:
|
||||
self.arms.append(AttentionRefinementModule(channels, out_channels))
|
||||
self.convs.append(
|
||||
ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg))
|
||||
self.conv_avg = ConvModule(
|
||||
last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg)
|
||||
|
||||
self.ffm = FeatureFusionModule(**ffm_cfg)
|
||||
|
||||
self.upsample_mode = upsample_mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
outs = list(self.backbone(x))
|
||||
avg = F.adaptive_avg_pool2d(outs[-1], 1)
|
||||
avg_feat = self.conv_avg(avg)
|
||||
|
||||
feature_up = resize(
|
||||
avg_feat,
|
||||
size=outs[-1].shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=self.align_corners)
|
||||
arms_out = []
|
||||
for i in range(len(self.arms)):
|
||||
x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up
|
||||
feature_up = resize(
|
||||
x_arm,
|
||||
size=outs[len(outs) - 1 - i - 1].shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=self.align_corners)
|
||||
feature_up = self.convs[i](feature_up)
|
||||
arms_out.append(feature_up)
|
||||
|
||||
feat_fuse = self.ffm(outs[0], arms_out[1])
|
||||
|
||||
# The `outputs` has four feature maps.
|
||||
# `outs[0]` is outputted for `STDCHead` auxiliary head.
|
||||
# Two feature maps of `arms_out` are outputted for auxiliary head.
|
||||
# `feat_fuse` is outputted for decoder head.
|
||||
outputs = [outs[0]] + list(arms_out) + [feat_fuse]
|
||||
return tuple(outputs)
|
||||
757
finetune/mmseg/models/backbones/swin.py
Normal file
757
finetune/mmseg/models/backbones/swin.py
Normal file
@@ -0,0 +1,757 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, build_dropout
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, trunc_normal_,
|
||||
trunc_normal_init)
|
||||
from mmengine.runner import CheckpointLoader
|
||||
from mmengine.utils import to_2tuple
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils.embed import PatchEmbed, PatchMerging
|
||||
|
||||
|
||||
class WindowMSA(BaseModule):
|
||||
"""Window based multi-head self-attention (W-MSA) module with relative
|
||||
position bias.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||||
Default: True.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
||||
Default: 0.0
|
||||
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
|
||||
init_cfg (dict | None, optional): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
window_size,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.embed_dims = embed_dims
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_embed_dims = embed_dims // num_heads
|
||||
self.scale = qk_scale or head_embed_dims**-0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
||||
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# About 2x faster than original impl
|
||||
Wh, Ww = self.window_size
|
||||
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
|
||||
rel_position_index = rel_index_coords + rel_index_coords.T
|
||||
rel_position_index = rel_position_index.flip(1).contiguous()
|
||||
self.register_buffer('relative_position_index', rel_position_index)
|
||||
|
||||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop_rate)
|
||||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.proj_drop = nn.Dropout(proj_drop_rate)
|
||||
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
Args:
|
||||
|
||||
x (tensor): input features with shape of (num_windows*B, N, C)
|
||||
mask (tensor | None, Optional): mask with shape of (num_windows,
|
||||
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
# make torchscript happy (cannot use tensor as tuple)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1],
|
||||
-1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B // nW, nW, self.num_heads, N,
|
||||
N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def double_step_seq(step1, len1, step2, len2):
|
||||
seq1 = torch.arange(0, step1 * len1, step1)
|
||||
seq2 = torch.arange(0, step2 * len2, step2)
|
||||
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
|
||||
|
||||
|
||||
class ShiftWindowMSA(BaseModule):
|
||||
"""Shifted Window Multihead Self-Attention Module.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): The height and width of the window.
|
||||
shift_size (int, optional): The shift step of each window towards
|
||||
right-bottom. If zero, act as regular window-msa. Defaults to 0.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||||
Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Defaults: None.
|
||||
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
||||
Defaults: 0.
|
||||
proj_drop_rate (float, optional): Dropout ratio of output.
|
||||
Defaults: 0.
|
||||
dropout_layer (dict, optional): The dropout_layer used before output.
|
||||
Defaults: dict(type='DropPath', drop_prob=0.).
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
window_size,
|
||||
shift_size=0,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop_rate=0,
|
||||
proj_drop_rate=0,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=0.),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
assert 0 <= self.shift_size < self.window_size
|
||||
|
||||
self.w_msa = WindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
window_size=to_2tuple(window_size),
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
proj_drop_rate=proj_drop_rate,
|
||||
init_cfg=None)
|
||||
|
||||
self.drop = build_dropout(dropout_layer)
|
||||
|
||||
def forward(self, query, hw_shape):
|
||||
B, L, C = query.shape
|
||||
H, W = hw_shape
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
query = query.view(B, H, W, C)
|
||||
|
||||
# pad feature maps to multiples of window size
|
||||
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||
query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
|
||||
H_pad, W_pad = query.shape[1], query.shape[2]
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_query = torch.roll(
|
||||
query,
|
||||
shifts=(-self.shift_size, -self.shift_size),
|
||||
dims=(1, 2))
|
||||
|
||||
# calculate attention mask for SW-MSA
|
||||
img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
# nW, window_size, window_size, 1
|
||||
mask_windows = self.window_partition(img_mask)
|
||||
mask_windows = mask_windows.view(
|
||||
-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
||||
float(-100.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0))
|
||||
else:
|
||||
shifted_query = query
|
||||
attn_mask = None
|
||||
|
||||
# nW*B, window_size, window_size, C
|
||||
query_windows = self.window_partition(shifted_query)
|
||||
# nW*B, window_size*window_size, C
|
||||
query_windows = query_windows.view(-1, self.window_size**2, C)
|
||||
|
||||
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
|
||||
attn_windows = self.w_msa(query_windows, mask=attn_mask)
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size,
|
||||
self.window_size, C)
|
||||
|
||||
# B H' W' C
|
||||
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(
|
||||
shifted_x,
|
||||
shifts=(self.shift_size, self.shift_size),
|
||||
dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
if pad_r > 0 or pad_b:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
def window_reverse(self, windows, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
window_size = self.window_size
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size,
|
||||
window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
def window_partition(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
window_size = self.window_size
|
||||
x = x.view(B, H // window_size, window_size, W // window_size,
|
||||
window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
||||
windows = windows.view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
class SwinBlock(BaseModule):
|
||||
""""
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
window_size (int, optional): The local window scale. Default: 7.
|
||||
shift (bool, optional): whether to shift window or not. Default False.
|
||||
qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
drop_rate (float, optional): Dropout rate. Default: 0.
|
||||
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
|
||||
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
|
||||
act_cfg (dict, optional): The config dict of activation function.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict, optional): The config dict of normalization.
|
||||
Default: dict(type='LN').
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
init_cfg (dict | list | None, optional): The init config.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
window_size=7,
|
||||
shift=False,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.with_cp = with_cp
|
||||
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.attn = ShiftWindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=window_size // 2 if shift else 0,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
proj_drop_rate=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
init_cfg=None)
|
||||
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=2,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg,
|
||||
add_identity=True,
|
||||
init_cfg=None)
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
x = self.norm1(x)
|
||||
x = self.attn(x, hw_shape)
|
||||
|
||||
x = x + identity
|
||||
|
||||
identity = x
|
||||
x = self.norm2(x)
|
||||
x = self.ffn(x, identity=identity)
|
||||
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SwinBlockSequence(BaseModule):
|
||||
"""Implements one stage in Swin Transformer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
depth (int): The number of blocks in this stage.
|
||||
window_size (int, optional): The local window scale. Default: 7.
|
||||
qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
drop_rate (float, optional): Dropout rate. Default: 0.
|
||||
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
|
||||
drop_path_rate (float | list[float], optional): Stochastic depth
|
||||
rate. Default: 0.
|
||||
downsample (BaseModule | None, optional): The downsample operation
|
||||
module. Default: None.
|
||||
act_cfg (dict, optional): The config dict of activation function.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict, optional): The config dict of normalization.
|
||||
Default: dict(type='LN').
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
init_cfg (dict | list | None, optional): The init config.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
depth,
|
||||
window_size=7,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
downsample=None,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(drop_path_rate, list):
|
||||
drop_path_rates = drop_path_rate
|
||||
assert len(drop_path_rates) == depth
|
||||
else:
|
||||
drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
|
||||
|
||||
self.blocks = ModuleList()
|
||||
for i in range(depth):
|
||||
block = SwinBlock(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=feedforward_channels,
|
||||
window_size=window_size,
|
||||
shift=False if i % 2 == 0 else True,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rates[i],
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
init_cfg=None)
|
||||
self.blocks.append(block)
|
||||
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
for block in self.blocks:
|
||||
x = block(x, hw_shape)
|
||||
|
||||
if self.downsample:
|
||||
x_down, down_hw_shape = self.downsample(x, hw_shape)
|
||||
return x_down, down_hw_shape, x, hw_shape
|
||||
else:
|
||||
return x, hw_shape, x, hw_shape
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SwinTransformer(BaseModule):
|
||||
"""Swin Transformer backbone.
|
||||
|
||||
This backbone is the implementation of `Swin Transformer:
|
||||
Hierarchical Vision Transformer using Shifted
|
||||
Windows <https://arxiv.org/abs/2103.14030>`_.
|
||||
Inspiration from https://github.com/microsoft/Swin-Transformer.
|
||||
|
||||
Args:
|
||||
pretrain_img_size (int | tuple[int]): The size of input image when
|
||||
pretrain. Defaults: 224.
|
||||
in_channels (int): The num of input channels.
|
||||
Defaults: 3.
|
||||
embed_dims (int): The feature dimension. Default: 96.
|
||||
patch_size (int | tuple[int]): Patch size. Default: 4.
|
||||
window_size (int): Window size. Default: 7.
|
||||
mlp_ratio (int | float): Ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
depths (tuple[int]): Depths of each Swin Transformer stage.
|
||||
Default: (2, 2, 6, 2).
|
||||
num_heads (tuple[int]): Parallel attention heads of each Swin
|
||||
Transformer stage. Default: (3, 6, 12, 24).
|
||||
strides (tuple[int]): The patch merging or patch embedding stride of
|
||||
each Swin Transformer stage. (In swin, we set kernel size equal to
|
||||
stride.) Default: (4, 2, 2, 2).
|
||||
out_indices (tuple[int]): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key,
|
||||
value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
patch_norm (bool): If add a norm layer for patch embed and patch
|
||||
merging. Default: True.
|
||||
drop_rate (float): Dropout rate. Defaults: 0.
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
|
||||
use_abs_pos_embed (bool): If True, add absolute position embedding to
|
||||
the patch embedding. Defaults: False.
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='LN').
|
||||
norm_cfg (dict): Config dict for normalization layer at
|
||||
output of backone. Defaults: dict(type='LN').
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrain_img_size=224,
|
||||
in_channels=3,
|
||||
embed_dims=96,
|
||||
patch_size=4,
|
||||
window_size=7,
|
||||
mlp_ratio=4,
|
||||
depths=(2, 2, 6, 2),
|
||||
num_heads=(3, 6, 12, 24),
|
||||
strides=(4, 2, 2, 2),
|
||||
out_indices=(0, 1, 2, 3),
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
patch_norm=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.1,
|
||||
use_abs_pos_embed=False,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
frozen_stages=-1,
|
||||
init_cfg=None):
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
if isinstance(pretrain_img_size, int):
|
||||
pretrain_img_size = to_2tuple(pretrain_img_size)
|
||||
elif isinstance(pretrain_img_size, tuple):
|
||||
if len(pretrain_img_size) == 1:
|
||||
pretrain_img_size = to_2tuple(pretrain_img_size[0])
|
||||
assert len(pretrain_img_size) == 2, \
|
||||
f'The size of image should have length 1 or 2, ' \
|
||||
f'but got {len(pretrain_img_size)}'
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be specified at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
init_cfg = init_cfg
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
num_layers = len(depths)
|
||||
self.out_indices = out_indices
|
||||
self.use_abs_pos_embed = use_abs_pos_embed
|
||||
|
||||
assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=strides[0],
|
||||
padding='corner',
|
||||
norm_cfg=norm_cfg if patch_norm else None,
|
||||
init_cfg=None)
|
||||
|
||||
if self.use_abs_pos_embed:
|
||||
patch_row = pretrain_img_size[0] // patch_size
|
||||
patch_col = pretrain_img_size[1] // patch_size
|
||||
num_patches = patch_row * patch_col
|
||||
self.absolute_pos_embed = nn.Parameter(
|
||||
torch.zeros((1, num_patches, embed_dims)))
|
||||
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
# set stochastic depth decay rule
|
||||
total_depth = sum(depths)
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
|
||||
]
|
||||
|
||||
self.stages = ModuleList()
|
||||
in_channels = embed_dims
|
||||
for i in range(num_layers):
|
||||
if i < num_layers - 1:
|
||||
downsample = PatchMerging(
|
||||
in_channels=in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
stride=strides[i + 1],
|
||||
norm_cfg=norm_cfg if patch_norm else None,
|
||||
init_cfg=None)
|
||||
else:
|
||||
downsample = None
|
||||
|
||||
stage = SwinBlockSequence(
|
||||
embed_dims=in_channels,
|
||||
num_heads=num_heads[i],
|
||||
feedforward_channels=int(mlp_ratio * in_channels),
|
||||
depth=depths[i],
|
||||
window_size=window_size,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],
|
||||
downsample=downsample,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
init_cfg=None)
|
||||
self.stages.append(stage)
|
||||
if downsample:
|
||||
in_channels = downsample.out_channels
|
||||
|
||||
self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)]
|
||||
# Add a norm layer for each output
|
||||
for i in out_indices:
|
||||
layer = build_norm_layer(norm_cfg, self.num_features[i])[1]
|
||||
layer_name = f'norm{i}'
|
||||
self.add_module(layer_name, layer)
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep layers freezed."""
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
if self.use_abs_pos_embed:
|
||||
self.absolute_pos_embed.requires_grad = False
|
||||
self.drop_after_pos.eval()
|
||||
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
|
||||
if (i - 1) in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i-1}')
|
||||
norm_layer.eval()
|
||||
for param in norm_layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
m = self.stages[i - 1]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def init_weights(self):
|
||||
if self.init_cfg is None:
|
||||
print_log(f'No pre-trained weights for '
|
||||
f'{self.__class__.__name__}, '
|
||||
f'training start from scratch')
|
||||
if self.use_abs_pos_embed:
|
||||
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
else:
|
||||
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
||||
f'specify `Pretrained` in ' \
|
||||
f'`init_cfg` in ' \
|
||||
f'{self.__class__.__name__} '
|
||||
ckpt = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
if 'state_dict' in ckpt:
|
||||
_state_dict = ckpt['state_dict']
|
||||
elif 'model' in ckpt:
|
||||
_state_dict = ckpt['model']
|
||||
else:
|
||||
_state_dict = ckpt
|
||||
|
||||
state_dict = OrderedDict()
|
||||
for k, v in _state_dict.items():
|
||||
if k.startswith('backbone.'):
|
||||
state_dict[k[9:]] = v
|
||||
else:
|
||||
state_dict[k] = v
|
||||
|
||||
# strip prefix of state_dict
|
||||
if list(state_dict.keys())[0].startswith('module.'):
|
||||
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||
|
||||
# reshape absolute position embedding
|
||||
if state_dict.get('absolute_pos_embed') is not None:
|
||||
absolute_pos_embed = state_dict['absolute_pos_embed']
|
||||
N1, L, C1 = absolute_pos_embed.size()
|
||||
N2, C2, H, W = self.absolute_pos_embed.size()
|
||||
if N1 != N2 or C1 != C2 or L != H * W:
|
||||
print_log('Error in loading absolute_pos_embed, pass')
|
||||
else:
|
||||
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
|
||||
N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
# interpolate position bias table if needed
|
||||
relative_position_bias_table_keys = [
|
||||
k for k in state_dict.keys()
|
||||
if 'relative_position_bias_table' in k
|
||||
]
|
||||
for table_key in relative_position_bias_table_keys:
|
||||
table_pretrained = state_dict[table_key]
|
||||
if table_key in self.state_dict():
|
||||
table_current = self.state_dict()[table_key]
|
||||
L1, nH1 = table_pretrained.size()
|
||||
L2, nH2 = table_current.size()
|
||||
if nH1 != nH2:
|
||||
print_log(f'Error in loading {table_key}, pass')
|
||||
elif L1 != L2:
|
||||
S1 = int(L1**0.5)
|
||||
S2 = int(L2**0.5)
|
||||
table_pretrained_resized = F.interpolate(
|
||||
table_pretrained.permute(1, 0).reshape(
|
||||
1, nH1, S1, S1),
|
||||
size=(S2, S2),
|
||||
mode='bicubic')
|
||||
state_dict[table_key] = table_pretrained_resized.view(
|
||||
nH2, L2).permute(1, 0).contiguous()
|
||||
|
||||
# load state_dict
|
||||
self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
def forward(self, x):
|
||||
x, hw_shape = self.patch_embed(x)
|
||||
|
||||
if self.use_abs_pos_embed:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
outs = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
out = norm_layer(out)
|
||||
out = out.view(-1, *out_hw_shape,
|
||||
self.num_features[i]).permute(0, 3, 1,
|
||||
2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return outs
|
||||
63
finetune/mmseg/models/backbones/timm_backbone.py
Normal file
63
finetune/mmseg/models/backbones/timm_backbone.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
try:
|
||||
import timm
|
||||
except ImportError:
|
||||
timm = None
|
||||
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.registry import MODELS as MMENGINE_MODELS
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class TIMMBackbone(BaseModule):
|
||||
"""Wrapper to use backbones from timm library. More details can be found in
|
||||
`timm <https://github.com/rwightman/pytorch-image-models>`_ .
|
||||
|
||||
Args:
|
||||
model_name (str): Name of timm model to instantiate.
|
||||
pretrained (bool): Load pretrained weights if True.
|
||||
checkpoint_path (str): Path of checkpoint to load after
|
||||
model is initialized.
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
init_cfg (dict, optional): Initialization config dict
|
||||
**kwargs: Other timm & model specific arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
features_only=True,
|
||||
pretrained=True,
|
||||
checkpoint_path='',
|
||||
in_channels=3,
|
||||
init_cfg=None,
|
||||
**kwargs,
|
||||
):
|
||||
if timm is None:
|
||||
raise RuntimeError('timm is not installed')
|
||||
super().__init__(init_cfg)
|
||||
if 'norm_layer' in kwargs:
|
||||
kwargs['norm_layer'] = MMENGINE_MODELS.get(kwargs['norm_layer'])
|
||||
self.timm_model = timm.create_model(
|
||||
model_name=model_name,
|
||||
features_only=features_only,
|
||||
pretrained=pretrained,
|
||||
in_chans=in_channels,
|
||||
checkpoint_path=checkpoint_path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Make unused parameters None
|
||||
self.timm_model.global_pool = None
|
||||
self.timm_model.fc = None
|
||||
self.timm_model.classifier = None
|
||||
|
||||
# Hack to use pretrained weights from timm
|
||||
if pretrained or checkpoint_path:
|
||||
self._is_init = True
|
||||
|
||||
def forward(self, x):
|
||||
features = self.timm_model(x)
|
||||
return features
|
||||
588
finetune/mmseg/models/backbones/twins.py
Normal file
588
finetune/mmseg/models/backbones/twins.py
Normal file
@@ -0,0 +1,588 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, normal_init,
|
||||
trunc_normal_init)
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.models.backbones.mit import EfficientMultiheadAttention
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils.embed import PatchEmbed
|
||||
|
||||
|
||||
class GlobalSubsampledAttention(EfficientMultiheadAttention):
|
||||
"""Global Sub-sampled Attention (Spatial Reduction Attention)
|
||||
|
||||
This module is modified from EfficientMultiheadAttention,
|
||||
which is a module from mmseg.models.backbones.mit.py.
|
||||
Specifically, there is no difference between
|
||||
`GlobalSubsampledAttention` and `EfficientMultiheadAttention`,
|
||||
`GlobalSubsampledAttention` is built as a brand new class
|
||||
because it is renamed as `Global sub-sampled attention (GSA)`
|
||||
in paper.
|
||||
|
||||
|
||||
Args:
|
||||
embed_dims (int): The embedding dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
attn_drop (float): A Dropout layer on attn_output_weights.
|
||||
Default: 0.0.
|
||||
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
||||
Default: 0.0.
|
||||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||||
when adding the shortcut. Default: None.
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dims)
|
||||
or (n, batch, embed_dims). Default: False.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default: True.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
sr_ratio (int): The ratio of spatial reduction of GSA of PCPVT.
|
||||
Default: 1.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
dropout_layer=None,
|
||||
batch_first=True,
|
||||
qkv_bias=True,
|
||||
norm_cfg=dict(type='LN'),
|
||||
sr_ratio=1,
|
||||
init_cfg=None):
|
||||
super().__init__(
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
dropout_layer=dropout_layer,
|
||||
batch_first=batch_first,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_cfg=norm_cfg,
|
||||
sr_ratio=sr_ratio,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
|
||||
class GSAEncoderLayer(BaseModule):
|
||||
"""Implements one encoder layer with GSA.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Default: 0.0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default: 0.0.
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: True
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
sr_ratio (float): Kernel_size of conv in Attention modules. Default: 1.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
qkv_bias=True,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
sr_ratio=1.,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
|
||||
self.attn = GlobalSubsampledAttention(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
qkv_bias=qkv_bias,
|
||||
norm_cfg=norm_cfg,
|
||||
sr_ratio=sr_ratio)
|
||||
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1]
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg,
|
||||
add_identity=False)
|
||||
|
||||
self.drop_path = build_dropout(
|
||||
dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), hw_shape, identity=0.))
|
||||
x = x + self.drop_path(self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class LocallyGroupedSelfAttention(BaseModule):
|
||||
"""Locally-grouped Self Attention (LSA) module.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads. Default: 8
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||||
Default: False.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
||||
Default: 0.0
|
||||
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
|
||||
window_size(int): Window size of LSA. Default: 1.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
window_size=1,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
assert embed_dims % num_heads == 0, f'dim {embed_dims} should be ' \
|
||||
f'divided by num_heads ' \
|
||||
f'{num_heads}.'
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
head_dim = embed_dims // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop_rate)
|
||||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.proj_drop = nn.Dropout(proj_drop_rate)
|
||||
self.window_size = window_size
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
b, n, c = x.shape
|
||||
h, w = hw_shape
|
||||
x = x.view(b, h, w, c)
|
||||
|
||||
# pad feature maps to multiples of Local-groups
|
||||
pad_l = pad_t = 0
|
||||
pad_r = (self.window_size - w % self.window_size) % self.window_size
|
||||
pad_b = (self.window_size - h % self.window_size) % self.window_size
|
||||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||
|
||||
# calculate attention mask for LSA
|
||||
Hp, Wp = x.shape[1:-1]
|
||||
_h, _w = Hp // self.window_size, Wp // self.window_size
|
||||
mask = torch.zeros((1, Hp, Wp), device=x.device)
|
||||
mask[:, -pad_b:, :].fill_(1)
|
||||
mask[:, :, -pad_r:].fill_(1)
|
||||
|
||||
# [B, _h, _w, window_size, window_size, C]
|
||||
x = x.reshape(b, _h, self.window_size, _w, self.window_size,
|
||||
c).transpose(2, 3)
|
||||
mask = mask.reshape(1, _h, self.window_size, _w,
|
||||
self.window_size).transpose(2, 3).reshape(
|
||||
1, _h * _w,
|
||||
self.window_size * self.window_size)
|
||||
# [1, _h*_w, window_size*window_size, window_size*window_size]
|
||||
attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
||||
float(-1000.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0))
|
||||
|
||||
# [3, B, _w*_h, nhead, window_size*window_size, dim]
|
||||
qkv = self.qkv(x).reshape(b, _h * _w,
|
||||
self.window_size * self.window_size, 3,
|
||||
self.num_heads, c // self.num_heads).permute(
|
||||
3, 0, 1, 4, 2, 5)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
# [B, _h*_w, n_head, window_size*window_size, window_size*window_size]
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn + attn_mask.unsqueeze(2)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
attn = (attn @ v).transpose(2, 3).reshape(b, _h, _w, self.window_size,
|
||||
self.window_size, c)
|
||||
x = attn.transpose(2, 3).reshape(b, _h * self.window_size,
|
||||
_w * self.window_size, c)
|
||||
if pad_r > 0 or pad_b > 0:
|
||||
x = x[:, :h, :w, :].contiguous()
|
||||
|
||||
x = x.reshape(b, n, c)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class LSAEncoderLayer(BaseModule):
|
||||
"""Implements one encoder layer in Twins-SVT.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Default: 0.0.
|
||||
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
||||
Default: 0.0
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
window_size (int): Window size of LSA. Default: 1.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
window_size=1,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
|
||||
self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads,
|
||||
qkv_bias, qk_scale,
|
||||
attn_drop_rate, drop_rate,
|
||||
window_size)
|
||||
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1]
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg,
|
||||
add_identity=False)
|
||||
|
||||
self.drop_path = build_dropout(
|
||||
dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), hw_shape))
|
||||
x = x + self.drop_path(self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class ConditionalPositionEncoding(BaseModule):
|
||||
"""The Conditional Position Encoding (CPE) module.
|
||||
|
||||
The CPE is the implementation of 'Conditional Positional Encodings
|
||||
for Vision Transformers <https://arxiv.org/abs/2102.10882>'_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
embed_dims (int): The feature dimension. Default: 768.
|
||||
stride (int): Stride of conv layer. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels,
|
||||
embed_dims,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=True,
|
||||
groups=embed_dims)
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
b, n, c = x.shape
|
||||
h, w = hw_shape
|
||||
feat_token = x
|
||||
cnn_feat = feat_token.transpose(1, 2).view(b, c, h, w)
|
||||
if self.stride == 1:
|
||||
x = self.proj(cnn_feat) + cnn_feat
|
||||
else:
|
||||
x = self.proj(cnn_feat)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PCPVT(BaseModule):
|
||||
"""The backbone of Twins-PCPVT.
|
||||
|
||||
This backbone is the implementation of `Twins: Revisiting the Design
|
||||
of Spatial Attention in Vision Transformers
|
||||
<https://arxiv.org/abs/1512.03385>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512].
|
||||
patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2].
|
||||
strides (list): The strides. Default: [4, 2, 2, 2].
|
||||
num_heads (int): Number of attention heads. Default: [1, 2, 4, 8].
|
||||
mlp_ratios (int): Ratio of mlp hidden dim to embedding dim.
|
||||
Default: [4, 4, 4, 4].
|
||||
out_indices (tuple[int]): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: False.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
depths (list): Depths of each stage. Default [3, 4, 6, 3]
|
||||
sr_ratios (list): Kernel_size of conv in each Attn module in
|
||||
Transformer encoder layer. Default: [8, 4, 2, 1].
|
||||
norm_after_stage(bool): Add extra norm. Default False.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=[64, 128, 256, 512],
|
||||
patch_sizes=[4, 2, 2, 2],
|
||||
strides=[4, 2, 2, 2],
|
||||
num_heads=[1, 2, 4, 8],
|
||||
mlp_ratios=[4, 4, 4, 4],
|
||||
out_indices=(0, 1, 2, 3),
|
||||
qkv_bias=False,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN'),
|
||||
depths=[3, 4, 6, 3],
|
||||
sr_ratios=[8, 4, 2, 1],
|
||||
norm_after_stage=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
self.depths = depths
|
||||
|
||||
# patch_embed
|
||||
self.patch_embeds = ModuleList()
|
||||
self.position_encoding_drops = ModuleList()
|
||||
self.layers = ModuleList()
|
||||
|
||||
for i in range(len(depths)):
|
||||
self.patch_embeds.append(
|
||||
PatchEmbed(
|
||||
in_channels=in_channels if i == 0 else embed_dims[i - 1],
|
||||
embed_dims=embed_dims[i],
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_sizes[i],
|
||||
stride=strides[i],
|
||||
padding='corner',
|
||||
norm_cfg=norm_cfg))
|
||||
|
||||
self.position_encoding_drops.append(nn.Dropout(p=drop_rate))
|
||||
|
||||
self.position_encodings = ModuleList([
|
||||
ConditionalPositionEncoding(embed_dim, embed_dim)
|
||||
for embed_dim in embed_dims
|
||||
])
|
||||
|
||||
# transformer encoder
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
] # stochastic depth decay rule
|
||||
cur = 0
|
||||
|
||||
for k in range(len(depths)):
|
||||
_block = ModuleList([
|
||||
GSAEncoderLayer(
|
||||
embed_dims=embed_dims[k],
|
||||
num_heads=num_heads[k],
|
||||
feedforward_channels=mlp_ratios[k] * embed_dims[k],
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[cur + i],
|
||||
num_fcs=2,
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
sr_ratio=sr_ratios[k]) for i in range(depths[k])
|
||||
])
|
||||
self.layers.append(_block)
|
||||
cur += depths[k]
|
||||
|
||||
self.norm_name, norm = build_norm_layer(
|
||||
norm_cfg, embed_dims[-1], postfix=1)
|
||||
|
||||
self.out_indices = out_indices
|
||||
self.norm_after_stage = norm_after_stage
|
||||
if self.norm_after_stage:
|
||||
self.norm_list = ModuleList()
|
||||
for dim in embed_dims:
|
||||
self.norm_list.append(build_norm_layer(norm_cfg, dim)[1])
|
||||
|
||||
def init_weights(self):
|
||||
if self.init_cfg is not None:
|
||||
super().init_weights()
|
||||
else:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[
|
||||
1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
normal_init(
|
||||
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
||||
|
||||
def forward(self, x):
|
||||
outputs = list()
|
||||
|
||||
b = x.shape[0]
|
||||
|
||||
for i in range(len(self.depths)):
|
||||
x, hw_shape = self.patch_embeds[i](x)
|
||||
h, w = hw_shape
|
||||
x = self.position_encoding_drops[i](x)
|
||||
for j, blk in enumerate(self.layers[i]):
|
||||
x = blk(x, hw_shape)
|
||||
if j == 0:
|
||||
x = self.position_encodings[i](x, hw_shape)
|
||||
if self.norm_after_stage:
|
||||
x = self.norm_list[i](x)
|
||||
x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
if i in self.out_indices:
|
||||
outputs.append(x)
|
||||
|
||||
return tuple(outputs)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SVT(PCPVT):
|
||||
"""The backbone of Twins-SVT.
|
||||
|
||||
This backbone is the implementation of `Twins: Revisiting the Design
|
||||
of Spatial Attention in Vision Transformers
|
||||
<https://arxiv.org/abs/1512.03385>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512].
|
||||
patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2].
|
||||
strides (list): The strides. Default: [4, 2, 2, 2].
|
||||
num_heads (int): Number of attention heads. Default: [1, 2, 4].
|
||||
mlp_ratios (int): Ratio of mlp hidden dim to embedding dim.
|
||||
Default: [4, 4, 4].
|
||||
out_indices (tuple[int]): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: False.
|
||||
drop_rate (float): Dropout rate. Default 0.
|
||||
attn_drop_rate (float): Dropout ratio of attention weight.
|
||||
Default 0.0
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.2.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
depths (list): Depths of each stage. Default [4, 4, 4].
|
||||
sr_ratios (list): Kernel_size of conv in each Attn module in
|
||||
Transformer encoder layer. Default: [4, 2, 1].
|
||||
windiow_sizes (list): Window size of LSA. Default: [7, 7, 7],
|
||||
input_features_slice(bool): Input features need slice. Default: False.
|
||||
norm_after_stage(bool): Add extra norm. Default False.
|
||||
strides (list): Strides in patch-Embedding modules. Default: (2, 2, 2)
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=[64, 128, 256],
|
||||
patch_sizes=[4, 2, 2, 2],
|
||||
strides=[4, 2, 2, 2],
|
||||
num_heads=[1, 2, 4],
|
||||
mlp_ratios=[4, 4, 4],
|
||||
out_indices=(0, 1, 2, 3),
|
||||
qkv_bias=False,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.2,
|
||||
norm_cfg=dict(type='LN'),
|
||||
depths=[4, 4, 4],
|
||||
sr_ratios=[4, 2, 1],
|
||||
windiow_sizes=[7, 7, 7],
|
||||
norm_after_stage=True,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(in_channels, embed_dims, patch_sizes, strides,
|
||||
num_heads, mlp_ratios, out_indices, qkv_bias,
|
||||
drop_rate, attn_drop_rate, drop_path_rate, norm_cfg,
|
||||
depths, sr_ratios, norm_after_stage, pretrained,
|
||||
init_cfg)
|
||||
# transformer encoder
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
] # stochastic depth decay rule
|
||||
|
||||
for k in range(len(depths)):
|
||||
for i in range(depths[k]):
|
||||
if i % 2 == 0:
|
||||
self.layers[k][i] = \
|
||||
LSAEncoderLayer(
|
||||
embed_dims=embed_dims[k],
|
||||
num_heads=num_heads[k],
|
||||
feedforward_channels=mlp_ratios[k] * embed_dims[k],
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=dpr[sum(depths[:k])+i],
|
||||
qkv_bias=qkv_bias,
|
||||
window_size=windiow_sizes[k])
|
||||
436
finetune/mmseg/models/backbones/unet.py
Normal file
436
finetune/mmseg/models/backbones/unet.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import UpConvBlock, Upsample
|
||||
|
||||
|
||||
class BasicConvBlock(nn.Module):
|
||||
"""Basic convolutional block for UNet.
|
||||
|
||||
This module consists of several plain convolutional layers.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
num_convs (int): Number of convolutional layers. Default: 2.
|
||||
stride (int): Whether use stride convolution to downsample
|
||||
the input feature map. If stride=2, it only uses stride convolution
|
||||
in the first convolutional layer to downsample the input feature
|
||||
map. Options are 1 or 2. Default: 1.
|
||||
dilation (int): Whether use dilated convolution to expand the
|
||||
receptive field. Set dilation rate of each convolutional layer and
|
||||
the dilation rate of the first convolutional layer is always 1.
|
||||
Default: 1.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
dcn (bool): Use deformable convolution in convolutional layer or not.
|
||||
Default: None.
|
||||
plugins (dict): plugins for convolutional layers. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_convs=2,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
dcn=None,
|
||||
plugins=None):
|
||||
super().__init__()
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
|
||||
self.with_cp = with_cp
|
||||
convs = []
|
||||
for i in range(num_convs):
|
||||
convs.append(
|
||||
ConvModule(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride if i == 0 else 1,
|
||||
dilation=1 if i == 0 else dilation,
|
||||
padding=1 if i == 0 else dilation,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
self.convs = nn.Sequential(*convs)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(self.convs, x)
|
||||
else:
|
||||
out = self.convs(x)
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DeconvModule(nn.Module):
|
||||
"""Deconvolution upsample module in decoder for UNet (2X upsample).
|
||||
|
||||
This module uses deconvolution to upsample feature map in the decoder
|
||||
of UNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
kernel_size (int): Kernel size of the convolutional layer. Default: 4.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
with_cp=False,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
*,
|
||||
kernel_size=4,
|
||||
scale_factor=2):
|
||||
super().__init__()
|
||||
|
||||
assert (kernel_size - scale_factor >= 0) and\
|
||||
(kernel_size - scale_factor) % 2 == 0,\
|
||||
f'kernel_size should be greater than or equal to scale_factor '\
|
||||
f'and (kernel_size - scale_factor) should be even numbers, '\
|
||||
f'while the kernel size is {kernel_size} and scale_factor is '\
|
||||
f'{scale_factor}.'
|
||||
|
||||
stride = scale_factor
|
||||
padding = (kernel_size - scale_factor) // 2
|
||||
self.with_cp = with_cp
|
||||
deconv = nn.ConvTranspose2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding)
|
||||
|
||||
norm_name, norm = build_norm_layer(norm_cfg, out_channels)
|
||||
activate = build_activation_layer(act_cfg)
|
||||
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(self.deconv_upsamping, x)
|
||||
else:
|
||||
out = self.deconv_upsamping(x)
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class InterpConv(nn.Module):
|
||||
"""Interpolation upsample module in decoder for UNet.
|
||||
|
||||
This module uses interpolation to upsample feature map in the decoder
|
||||
of UNet. It consists of one interpolation upsample layer and one
|
||||
convolutional layer. It can be one interpolation upsample layer followed
|
||||
by one convolutional layer (conv_first=False) or one convolutional layer
|
||||
followed by one interpolation upsample layer (conv_first=True).
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
conv_first (bool): Whether convolutional layer or interpolation
|
||||
upsample layer first. Default: False. It means interpolation
|
||||
upsample layer followed by one convolutional layer.
|
||||
kernel_size (int): Kernel size of the convolutional layer. Default: 1.
|
||||
stride (int): Stride of the convolutional layer. Default: 1.
|
||||
padding (int): Padding of the convolutional layer. Default: 1.
|
||||
upsample_cfg (dict): Interpolation config of the upsample layer.
|
||||
Default: dict(
|
||||
scale_factor=2, mode='bilinear', align_corners=False).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
with_cp=False,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
*,
|
||||
conv_cfg=None,
|
||||
conv_first=False,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
upsample_cfg=dict(
|
||||
scale_factor=2, mode='bilinear', align_corners=False)):
|
||||
super().__init__()
|
||||
|
||||
self.with_cp = with_cp
|
||||
conv = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
upsample = Upsample(**upsample_cfg)
|
||||
if conv_first:
|
||||
self.interp_upsample = nn.Sequential(conv, upsample)
|
||||
else:
|
||||
self.interp_upsample = nn.Sequential(upsample, conv)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(self.interp_upsample, x)
|
||||
else:
|
||||
out = self.interp_upsample(x)
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class UNet(BaseModule):
|
||||
"""UNet backbone.
|
||||
|
||||
This backbone is the implementation of `U-Net: Convolutional Networks
|
||||
for Biomedical Image Segmentation <https://arxiv.org/abs/1505.04597>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Default" 3.
|
||||
base_channels (int): Number of base channels of each stage.
|
||||
The output channels of the first stage. Default: 64.
|
||||
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
|
||||
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
|
||||
len(strides) is equal to num_stages. Normally the stride of the
|
||||
first stage in encoder is 1. If strides[i]=2, it uses stride
|
||||
convolution to downsample in the correspondence encoder stage.
|
||||
Default: (1, 1, 1, 1, 1).
|
||||
enc_num_convs (Sequence[int]): Number of convolutional layers in the
|
||||
convolution block of the correspondence encoder stage.
|
||||
Default: (2, 2, 2, 2, 2).
|
||||
dec_num_convs (Sequence[int]): Number of convolutional layers in the
|
||||
convolution block of the correspondence decoder stage.
|
||||
Default: (2, 2, 2, 2).
|
||||
downsamples (Sequence[int]): Whether use MaxPool to downsample the
|
||||
feature map after the first stage of encoder
|
||||
(stages: [1, num_stages)). If the correspondence encoder stage use
|
||||
stride convolution (strides[i]=2), it will never use MaxPool to
|
||||
downsample, even downsamples[i-1]=True.
|
||||
Default: (True, True, True, True).
|
||||
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
|
||||
Default: (1, 1, 1, 1, 1).
|
||||
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
|
||||
Default: (1, 1, 1, 1).
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
upsample_cfg (dict): The upsample config of the upsample module in
|
||||
decoder. Default: dict(type='InterpConv').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
dcn (bool): Use deformable convolution in convolutional layer or not.
|
||||
Default: None.
|
||||
plugins (dict): plugins for convolutional layers. Default: None.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
|
||||
Notice:
|
||||
The input image size should be divisible by the whole downsample rate
|
||||
of the encoder. More detail of the whole downsample rate can be found
|
||||
in UNet._check_input_divisible.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1),
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
upsample_cfg=dict(type='InterpConv'),
|
||||
norm_eval=False,
|
||||
dcn=None,
|
||||
plugins=None,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
assert len(strides) == num_stages, \
|
||||
'The length of strides should be equal to num_stages, '\
|
||||
f'while the strides is {strides}, the length of '\
|
||||
f'strides is {len(strides)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(enc_num_convs) == num_stages, \
|
||||
'The length of enc_num_convs should be equal to num_stages, '\
|
||||
f'while the enc_num_convs is {enc_num_convs}, the length of '\
|
||||
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(dec_num_convs) == (num_stages-1), \
|
||||
'The length of dec_num_convs should be equal to (num_stages-1), '\
|
||||
f'while the dec_num_convs is {dec_num_convs}, the length of '\
|
||||
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(downsamples) == (num_stages-1), \
|
||||
'The length of downsamples should be equal to (num_stages-1), '\
|
||||
f'while the downsamples is {downsamples}, the length of '\
|
||||
f'downsamples is {len(downsamples)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(enc_dilations) == num_stages, \
|
||||
'The length of enc_dilations should be equal to num_stages, '\
|
||||
f'while the enc_dilations is {enc_dilations}, the length of '\
|
||||
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(dec_dilations) == (num_stages-1), \
|
||||
'The length of dec_dilations should be equal to (num_stages-1), '\
|
||||
f'while the dec_dilations is {dec_dilations}, the length of '\
|
||||
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
self.num_stages = num_stages
|
||||
self.strides = strides
|
||||
self.downsamples = downsamples
|
||||
self.norm_eval = norm_eval
|
||||
self.base_channels = base_channels
|
||||
|
||||
self.encoder = nn.ModuleList()
|
||||
self.decoder = nn.ModuleList()
|
||||
|
||||
for i in range(num_stages):
|
||||
enc_conv_block = []
|
||||
if i != 0:
|
||||
if strides[i] == 1 and downsamples[i - 1]:
|
||||
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
|
||||
upsample = (strides[i] != 1 or downsamples[i - 1])
|
||||
self.decoder.append(
|
||||
UpConvBlock(
|
||||
conv_block=BasicConvBlock,
|
||||
in_channels=base_channels * 2**i,
|
||||
skip_channels=base_channels * 2**(i - 1),
|
||||
out_channels=base_channels * 2**(i - 1),
|
||||
num_convs=dec_num_convs[i - 1],
|
||||
stride=1,
|
||||
dilation=dec_dilations[i - 1],
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
upsample_cfg=upsample_cfg if upsample else None,
|
||||
dcn=None,
|
||||
plugins=None))
|
||||
|
||||
enc_conv_block.append(
|
||||
BasicConvBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=base_channels * 2**i,
|
||||
num_convs=enc_num_convs[i],
|
||||
stride=strides[i],
|
||||
dilation=enc_dilations[i],
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
dcn=None,
|
||||
plugins=None))
|
||||
self.encoder.append(nn.Sequential(*enc_conv_block))
|
||||
in_channels = base_channels * 2**i
|
||||
|
||||
def forward(self, x):
|
||||
self._check_input_divisible(x)
|
||||
enc_outs = []
|
||||
for enc in self.encoder:
|
||||
x = enc(x)
|
||||
enc_outs.append(x)
|
||||
dec_outs = [x]
|
||||
for i in reversed(range(len(self.decoder))):
|
||||
x = self.decoder[i](enc_outs[i], x)
|
||||
dec_outs.append(x)
|
||||
|
||||
return dec_outs
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep normalization layer
|
||||
freezed."""
|
||||
super().train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
def _check_input_divisible(self, x):
|
||||
h, w = x.shape[-2:]
|
||||
whole_downsample_rate = 1
|
||||
for i in range(1, self.num_stages):
|
||||
if self.strides[i] == 2 or self.downsamples[i - 1]:
|
||||
whole_downsample_rate *= 2
|
||||
assert (h % whole_downsample_rate == 0) \
|
||||
and (w % whole_downsample_rate == 0),\
|
||||
f'The input image size {(h, w)} should be divisible by the whole '\
|
||||
f'downsample rate {whole_downsample_rate}, when num_stages is '\
|
||||
f'{self.num_stages}, strides is {self.strides}, and downsamples '\
|
||||
f'is {self.downsamples}.'
|
||||
501
finetune/mmseg/models/backbones/vit.py
Normal file
501
finetune/mmseg/models/backbones/vit.py
Normal file
@@ -0,0 +1,501 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import PatchEmbed, resize
|
||||
|
||||
|
||||
class TransformerEncoderLayer(BaseModule):
|
||||
"""Implements one encoder layer in Vision Transformer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Default: 0.0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default: 0.0.
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default: True
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default: True.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
qkv_bias=True,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
batch_first=True,
|
||||
attn_cfg=dict(),
|
||||
ffn_cfg=dict(),
|
||||
with_cp=False):
|
||||
super().__init__()
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
attn_cfg.update(
|
||||
dict(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
batch_first=batch_first,
|
||||
bias=qkv_bias))
|
||||
|
||||
self.build_attn(attn_cfg)
|
||||
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=2)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
|
||||
ffn_cfg.update(
|
||||
dict(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
if drop_path_rate > 0 else None,
|
||||
act_cfg=act_cfg))
|
||||
self.build_ffn(ffn_cfg)
|
||||
self.with_cp = with_cp
|
||||
|
||||
def build_attn(self, attn_cfg):
|
||||
self.attn = MultiheadAttention(**attn_cfg)
|
||||
|
||||
def build_ffn(self, ffn_cfg):
|
||||
self.ffn = FFN(**ffn_cfg)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
x = self.attn(self.norm1(x), identity=x)
|
||||
x = self.ffn(self.norm2(x), identity=x)
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VisionTransformer(BaseModule):
|
||||
"""Vision Transformer.
|
||||
|
||||
This backbone is the implementation of `An Image is Worth 16x16 Words:
|
||||
Transformers for Image Recognition at
|
||||
Scale <https://arxiv.org/abs/2010.11929>`_.
|
||||
|
||||
Args:
|
||||
img_size (int | tuple): Input image size. Default: 224.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
patch_pad (str | int | None): The padding method in patch embedding.
|
||||
Default: 'corner'.
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): embedding dimension. Default: 768.
|
||||
num_layers (int): depth of transformer. Default: 12.
|
||||
num_heads (int): number of attention heads. Default: 12.
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
out_origin (bool): Whether to output the original input embedding.
|
||||
Default: False
|
||||
out_indices (list | tuple | int): Output from which stages.
|
||||
Default: -1.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default: True.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.0
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Default: True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
`with_cls_token` must be True. Default: False.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
patch_bias (dict): Whether use bias in convolution of PatchEmbed Block.
|
||||
Default: True.
|
||||
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
|
||||
Default: False.
|
||||
pre_norm (bool): Whether to add a norm before Transformer Layers.
|
||||
Default: False.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Default: False.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Default: bicubic.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
frozen_exclude (List): List of parameters that are not to be frozen.
|
||||
Default: ["all"], "all" means there are no frozen parameters.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
patch_pad='corner',
|
||||
in_channels=3,
|
||||
embed_dims=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
out_origin=False,
|
||||
out_indices=-1,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
with_cls_token=True,
|
||||
output_cls_token=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
patch_norm=False,
|
||||
patch_bias=False,
|
||||
pre_norm=False,
|
||||
final_norm=False,
|
||||
interpolate_mode='bicubic',
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
frozen_exclude=['all'],
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(img_size, int):
|
||||
img_size = to_2tuple(img_size)
|
||||
elif isinstance(img_size, tuple):
|
||||
if len(img_size) == 1:
|
||||
img_size = to_2tuple(img_size[0])
|
||||
assert len(img_size) == 2, \
|
||||
f'The size of image should have length 1 or 2, ' \
|
||||
f'but got {len(img_size)}'
|
||||
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.interpolate_mode = interpolate_mode
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.pretrained = pretrained
|
||||
self.out_origin = out_origin
|
||||
self.frozen_exclude = frozen_exclude
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
padding=patch_pad,
|
||||
bias=patch_bias,
|
||||
norm_cfg=norm_cfg if patch_norm else None,
|
||||
init_cfg=None,
|
||||
)
|
||||
|
||||
num_patches = (img_size[0] // patch_size) * \
|
||||
(img_size[1] // patch_size)
|
||||
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + 1, embed_dims))
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
self.pre_norm = pre_norm
|
||||
|
||||
if self.pre_norm:
|
||||
self.pre_ln_name, pre_ln = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix='_pre')
|
||||
self.add_module(self.pre_ln_name, pre_ln)
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
if out_indices == -1:
|
||||
out_indices = num_layers - 1
|
||||
self.out_indices = [out_indices]
|
||||
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
|
||||
self.out_indices = out_indices
|
||||
else:
|
||||
raise TypeError('out_indices must be type of int, list or tuple')
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
self.layers = ModuleList()
|
||||
for i in range(num_layers):
|
||||
self.layers.append(
|
||||
TransformerEncoderLayer(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=mlp_ratio * embed_dims,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
batch_first=True))
|
||||
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
self._freeze()
|
||||
|
||||
@property
|
||||
def pre_ln(self):
|
||||
return getattr(self, self.pre_ln_name)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def init_weights(self):
|
||||
if isinstance(self.init_cfg, dict) and \
|
||||
self.init_cfg.get('type') in ['Pretrained', 'Pretrained_Part']:
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
|
||||
if self.init_cfg.get('type') == 'Pretrained':
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
elif self.init_cfg.get('type') == 'Pretrained_Part':
|
||||
state_dict = checkpoint.copy()
|
||||
para_prefix = 'image_encoder'
|
||||
prefix_len = len(para_prefix) + 1
|
||||
for k, v in checkpoint.items():
|
||||
state_dict.pop(k)
|
||||
if para_prefix in k:
|
||||
state_dict[k[prefix_len:]] = v
|
||||
|
||||
if 'pos_embed' in state_dict.keys():
|
||||
if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
||||
print_log(msg=f'Resize the pos_embed shape from '
|
||||
f'{state_dict["pos_embed"].shape} to '
|
||||
f'{self.pos_embed.shape}')
|
||||
h, w = self.img_size
|
||||
pos_size = int(
|
||||
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
|
||||
state_dict['pos_embed'] = self.resize_pos_embed(
|
||||
state_dict['pos_embed'],
|
||||
(h // self.patch_size, w // self.patch_size),
|
||||
(pos_size, pos_size), self.interpolate_mode)
|
||||
|
||||
load_state_dict(self, state_dict, strict=False, logger=None)
|
||||
elif self.init_cfg is not None:
|
||||
super().init_weights()
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def _freeze(self):
|
||||
if 'all' in self.frozen_exclude:
|
||||
return
|
||||
for name, param in self.named_parameters():
|
||||
if not any([exclude in name for exclude in self.frozen_exclude]):
|
||||
param.requires_grad = False
|
||||
|
||||
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
|
||||
"""Positioning embeding method.
|
||||
|
||||
Resize the pos_embed, if the input image size doesn't match
|
||||
the training size.
|
||||
Args:
|
||||
patched_img (torch.Tensor): The patched image, it should be
|
||||
shape of [B, L1, C].
|
||||
hw_shape (tuple): The downsampled image resolution.
|
||||
pos_embed (torch.Tensor): The pos_embed weighs, it should be
|
||||
shape of [B, L2, c].
|
||||
Return:
|
||||
torch.Tensor: The pos encoded image feature.
|
||||
"""
|
||||
assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
|
||||
'the shapes of patched_img and pos_embed must be [B, L, C]'
|
||||
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
|
||||
if x_len != pos_len:
|
||||
if pos_len == (self.img_size[0] // self.patch_size) * (
|
||||
self.img_size[1] // self.patch_size) + 1:
|
||||
pos_h = self.img_size[0] // self.patch_size
|
||||
pos_w = self.img_size[1] // self.patch_size
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unexpected shape of pos_embed, got {}.'.format(
|
||||
pos_embed.shape))
|
||||
pos_embed = self.resize_pos_embed(pos_embed, hw_shape,
|
||||
(pos_h, pos_w),
|
||||
self.interpolate_mode)
|
||||
return self.drop_after_pos(patched_img + pos_embed)
|
||||
|
||||
@staticmethod
|
||||
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
|
||||
"""Resize pos_embed weights.
|
||||
|
||||
Resize pos_embed using bicubic interpolate method.
|
||||
Args:
|
||||
pos_embed (torch.Tensor): Position embedding weights.
|
||||
input_shpae (tuple): Tuple for (downsampled input image height,
|
||||
downsampled input image width).
|
||||
pos_shape (tuple): The resolution of downsampled origin training
|
||||
image.
|
||||
mode (str): Algorithm used for upsampling:
|
||||
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||
``'trilinear'``. Default: ``'nearest'``
|
||||
Return:
|
||||
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
|
||||
"""
|
||||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
||||
pos_h, pos_w = pos_shape
|
||||
cls_token_weight = pos_embed[:, 0]
|
||||
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
||||
pos_embed_weight = pos_embed_weight.reshape(
|
||||
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
||||
pos_embed_weight = resize(
|
||||
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
|
||||
cls_token_weight = cls_token_weight.unsqueeze(1)
|
||||
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
||||
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
||||
return pos_embed
|
||||
|
||||
def forward(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
|
||||
x, hw_shape = self.patch_embed(inputs)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = self._pos_embeding(x, hw_shape, self.pos_embed)
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
if self.pre_norm:
|
||||
x = self.pre_ln(x)
|
||||
|
||||
outs = []
|
||||
if self.out_origin:
|
||||
if self.with_cls_token:
|
||||
# Remove class token and reshape token for decoder head
|
||||
out = x[:, 1:]
|
||||
else:
|
||||
out = x
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
if self.output_cls_token:
|
||||
out = [out, x[:, 0]]
|
||||
outs.append(out)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i == len(self.layers) - 1:
|
||||
if self.final_norm:
|
||||
x = self.norm1(x)
|
||||
if i in self.out_indices:
|
||||
if self.with_cls_token:
|
||||
# Remove class token and reshape token for decoder head
|
||||
out = x[:, 1:]
|
||||
else:
|
||||
out = x
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
if self.output_cls_token:
|
||||
out = [out, x[:, 0]]
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.LayerNorm):
|
||||
m.eval()
|
||||
395
finetune/mmseg/models/backbones/vpd.py
Normal file
395
finetune/mmseg/models/backbones/vpd.py
Normal file
@@ -0,0 +1,395 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# ------------------------------------------------------------------------------
|
||||
# Adapted from https://github.com/wl-zhao/VPD/blob/main/vpd/models.py
|
||||
# Original licence: MIT License
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.runner import CheckpointLoader, load_checkpoint
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import ConfigType, OptConfigType
|
||||
|
||||
try:
|
||||
from ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from ldm.util import instantiate_from_config
|
||||
has_ldm = True
|
||||
except ImportError:
|
||||
has_ldm = False
|
||||
|
||||
|
||||
def register_attention_control(model, controller):
|
||||
"""Registers a control function to manage attention within a model.
|
||||
|
||||
Args:
|
||||
model: The model to which attention is to be registered.
|
||||
controller: The control function responsible for managing attention.
|
||||
"""
|
||||
|
||||
def ca_forward(self, place_in_unet):
|
||||
"""Custom forward method for attention.
|
||||
|
||||
Args:
|
||||
self: Reference to the current object.
|
||||
place_in_unet: The location in UNet (down/mid/up).
|
||||
|
||||
Returns:
|
||||
The modified forward method.
|
||||
"""
|
||||
|
||||
def forward(x, context=None, mask=None):
|
||||
h = self.heads
|
||||
is_cross = context is not None
|
||||
context = context or x # if context is None, use x
|
||||
|
||||
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
|
||||
q, k, v = (
|
||||
tensor.view(tensor.shape[0] * h, tensor.shape[1],
|
||||
tensor.shape[2] // h) for tensor in [q, k, v])
|
||||
|
||||
sim = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.flatten(1).unsqueeze(1).repeat(h, 1, 1)
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
attn = sim.softmax(dim=-1)
|
||||
attn_mean = attn.view(h, attn.shape[0] // h,
|
||||
*attn.shape[1:]).mean(0)
|
||||
controller(attn_mean, is_cross, place_in_unet)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = out.view(out.shape[0] // h, out.shape[1], out.shape[2] * h)
|
||||
return self.to_out(out)
|
||||
|
||||
return forward
|
||||
|
||||
def register_recr(net_, count, place_in_unet):
|
||||
"""Recursive function to register the custom forward method to all
|
||||
CrossAttention layers.
|
||||
|
||||
Args:
|
||||
net_: The network layer currently being processed.
|
||||
count: The current count of layers processed.
|
||||
place_in_unet: The location in UNet (down/mid/up).
|
||||
|
||||
Returns:
|
||||
The updated count of layers processed.
|
||||
"""
|
||||
if net_.__class__.__name__ == 'CrossAttention':
|
||||
net_.forward = ca_forward(net_, place_in_unet)
|
||||
return count + 1
|
||||
if hasattr(net_, 'children'):
|
||||
return sum(
|
||||
register_recr(child, 0, place_in_unet)
|
||||
for child in net_.children())
|
||||
return count
|
||||
|
||||
cross_att_count = sum(
|
||||
register_recr(net[1], 0, place) for net, place in [
|
||||
(child, 'down') if 'input_blocks' in name else (
|
||||
child, 'up') if 'output_blocks' in name else
|
||||
(child,
|
||||
'mid') if 'middle_block' in name else (None, None) # Default case
|
||||
for name, child in model.diffusion_model.named_children()
|
||||
] if net is not None)
|
||||
|
||||
controller.num_att_layers = cross_att_count
|
||||
|
||||
|
||||
class AttentionStore:
|
||||
"""A class for storing attention information in the UNet model.
|
||||
|
||||
Attributes:
|
||||
base_size (int): Base size for storing attention information.
|
||||
max_size (int): Maximum size for storing attention information.
|
||||
"""
|
||||
|
||||
def __init__(self, base_size=64, max_size=None):
|
||||
"""Initialize AttentionStore with default or custom sizes."""
|
||||
self.reset()
|
||||
self.base_size = base_size
|
||||
self.max_size = max_size or (base_size // 2)
|
||||
self.num_att_layers = -1
|
||||
|
||||
@staticmethod
|
||||
def get_empty_store():
|
||||
"""Returns an empty store for holding attention values."""
|
||||
return {
|
||||
key: []
|
||||
for key in [
|
||||
'down_cross', 'mid_cross', 'up_cross', 'down_self', 'mid_self',
|
||||
'up_self'
|
||||
]
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Resets the step and attention stores to their initial states."""
|
||||
self.cur_step = 0
|
||||
self.cur_att_layer = 0
|
||||
self.step_store = self.get_empty_store()
|
||||
self.attention_store = {}
|
||||
|
||||
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
||||
"""Processes a single forward step, storing the attention.
|
||||
|
||||
Args:
|
||||
attn: The attention tensor.
|
||||
is_cross (bool): Whether it's cross attention.
|
||||
place_in_unet (str): The location in UNet (down/mid/up).
|
||||
|
||||
Returns:
|
||||
The unmodified attention tensor.
|
||||
"""
|
||||
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
||||
if attn.shape[1] <= (self.max_size)**2:
|
||||
self.step_store[key].append(attn)
|
||||
return attn
|
||||
|
||||
def between_steps(self):
|
||||
"""Processes and stores attention information between steps."""
|
||||
if not self.attention_store:
|
||||
self.attention_store = self.step_store
|
||||
else:
|
||||
for key in self.attention_store:
|
||||
self.attention_store[key] = [
|
||||
stored + step for stored, step in zip(
|
||||
self.attention_store[key], self.step_store[key])
|
||||
]
|
||||
self.step_store = self.get_empty_store()
|
||||
|
||||
def get_average_attention(self):
|
||||
"""Calculates and returns the average attention across all steps."""
|
||||
return {
|
||||
key: [item for item in self.step_store[key]]
|
||||
for key in self.step_store
|
||||
}
|
||||
|
||||
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
||||
"""Allows the class instance to be callable."""
|
||||
return self.forward(attn, is_cross, place_in_unet)
|
||||
|
||||
@property
|
||||
def num_uncond_att_layers(self):
|
||||
"""Returns the number of unconditional attention layers (default is
|
||||
0)."""
|
||||
return 0
|
||||
|
||||
def step_callback(self, x_t):
|
||||
"""A placeholder for a step callback.
|
||||
|
||||
Returns the input unchanged.
|
||||
"""
|
||||
return x_t
|
||||
|
||||
|
||||
class UNetWrapper(nn.Module):
|
||||
"""A wrapper for UNet with optional attention mechanisms.
|
||||
|
||||
Args:
|
||||
unet (nn.Module): The UNet model to wrap
|
||||
use_attn (bool): Whether to use attention. Defaults to True
|
||||
base_size (int): Base size for the attention store. Defaults to 512
|
||||
max_attn_size (int, optional): Maximum size for the attention store.
|
||||
Defaults to None
|
||||
attn_selector (str): The types of attention to use.
|
||||
Defaults to 'up_cross+down_cross'
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
unet,
|
||||
use_attn=True,
|
||||
base_size=512,
|
||||
max_attn_size=None,
|
||||
attn_selector='up_cross+down_cross'):
|
||||
super().__init__()
|
||||
|
||||
assert has_ldm, 'To use UNetWrapper, please install required ' \
|
||||
'packages via `pip install -r requirements/optional.txt`.'
|
||||
|
||||
self.unet = unet
|
||||
self.attention_store = AttentionStore(
|
||||
base_size=base_size // 8, max_size=max_attn_size)
|
||||
self.attn_selector = attn_selector.split('+')
|
||||
self.use_attn = use_attn
|
||||
self.init_sizes(base_size)
|
||||
if self.use_attn:
|
||||
register_attention_control(unet, self.attention_store)
|
||||
|
||||
def init_sizes(self, base_size):
|
||||
"""Initialize sizes based on the base size."""
|
||||
self.size16 = base_size // 32
|
||||
self.size32 = base_size // 16
|
||||
self.size64 = base_size // 8
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
"""Forward pass through the model."""
|
||||
diffusion_model = self.unet.diffusion_model
|
||||
if self.use_attn:
|
||||
self.attention_store.reset()
|
||||
hs, emb, out_list = self._unet_forward(x, timesteps, context, y,
|
||||
diffusion_model)
|
||||
if self.use_attn:
|
||||
self._append_attn_to_output(out_list)
|
||||
return out_list[::-1]
|
||||
|
||||
def _unet_forward(self, x, timesteps, context, y, diffusion_model):
|
||||
hs = []
|
||||
t_emb = timestep_embedding(
|
||||
timesteps, diffusion_model.model_channels, repeat_only=False)
|
||||
emb = diffusion_model.time_embed(t_emb)
|
||||
h = x.type(diffusion_model.dtype)
|
||||
for module in diffusion_model.input_blocks:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = diffusion_model.middle_block(h, emb, context)
|
||||
out_list = []
|
||||
for i_out, module in enumerate(diffusion_model.output_blocks):
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context)
|
||||
if i_out in [1, 4, 7]:
|
||||
out_list.append(h)
|
||||
h = h.type(x.dtype)
|
||||
out_list.append(h)
|
||||
return hs, emb, out_list
|
||||
|
||||
def _append_attn_to_output(self, out_list):
|
||||
avg_attn = self.attention_store.get_average_attention()
|
||||
attns = {self.size16: [], self.size32: [], self.size64: []}
|
||||
for k in self.attn_selector:
|
||||
for up_attn in avg_attn[k]:
|
||||
size = int(math.sqrt(up_attn.shape[1]))
|
||||
up_attn = up_attn.transpose(-1, -2).reshape(
|
||||
*up_attn.shape[:2], size, -1)
|
||||
attns[size].append(up_attn)
|
||||
attn16 = torch.stack(attns[self.size16]).mean(0)
|
||||
attn32 = torch.stack(attns[self.size32]).mean(0)
|
||||
attn64 = torch.stack(attns[self.size64]).mean(0) if len(
|
||||
attns[self.size64]) > 0 else None
|
||||
out_list[1] = torch.cat([out_list[1], attn16], dim=1)
|
||||
out_list[2] = torch.cat([out_list[2], attn32], dim=1)
|
||||
if attn64 is not None:
|
||||
out_list[3] = torch.cat([out_list[3], attn64], dim=1)
|
||||
|
||||
|
||||
class TextAdapter(nn.Module):
|
||||
"""A PyTorch Module that serves as a text adapter.
|
||||
|
||||
This module takes text embeddings and adjusts them based on a scaling
|
||||
factor gamma.
|
||||
"""
|
||||
|
||||
def __init__(self, text_dim=768):
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(text_dim, text_dim), nn.GELU(),
|
||||
nn.Linear(text_dim, text_dim))
|
||||
|
||||
def forward(self, texts, gamma):
|
||||
texts_after = self.fc(texts)
|
||||
texts = texts + gamma * texts_after
|
||||
return texts
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VPD(BaseModule):
|
||||
"""VPD (Visual Perception Diffusion) model.
|
||||
|
||||
.. _`VPD`: https://arxiv.org/abs/2303.02153
|
||||
|
||||
Args:
|
||||
diffusion_cfg (dict): Configuration for diffusion model.
|
||||
class_embed_path (str): Path for class embeddings.
|
||||
unet_cfg (dict, optional): Configuration for U-Net.
|
||||
gamma (float, optional): Gamma for text adaptation. Defaults to 1e-4.
|
||||
class_embed_select (bool, optional): If True, enables class embedding
|
||||
selection. Defaults to False.
|
||||
pad_shape (Optional[Union[int, List[int]]], optional): Padding shape.
|
||||
Defaults to None.
|
||||
pad_val (Union[int, List[int]], optional): Padding value.
|
||||
Defaults to 0.
|
||||
init_cfg (dict, optional): Configuration for network initialization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
diffusion_cfg: ConfigType,
|
||||
class_embed_path: str,
|
||||
unet_cfg: OptConfigType = dict(),
|
||||
gamma: float = 1e-4,
|
||||
class_embed_select=False,
|
||||
pad_shape: Optional[Union[int, List[int]]] = None,
|
||||
pad_val: Union[int, List[int]] = 0,
|
||||
init_cfg: OptConfigType = None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
assert has_ldm, 'To use VPD model, please install required packages' \
|
||||
' via `pip install -r requirements/optional.txt`.'
|
||||
|
||||
if pad_shape is not None:
|
||||
if not isinstance(pad_shape, (list, tuple)):
|
||||
pad_shape = (pad_shape, pad_shape)
|
||||
|
||||
self.pad_shape = pad_shape
|
||||
self.pad_val = pad_val
|
||||
|
||||
# diffusion model
|
||||
diffusion_checkpoint = diffusion_cfg.pop('checkpoint', None)
|
||||
sd_model = instantiate_from_config(diffusion_cfg)
|
||||
if diffusion_checkpoint is not None:
|
||||
load_checkpoint(sd_model, diffusion_checkpoint, strict=False)
|
||||
|
||||
self.encoder_vq = sd_model.first_stage_model
|
||||
self.unet = UNetWrapper(sd_model.model, **unet_cfg)
|
||||
|
||||
# class embeddings & text adapter
|
||||
class_embeddings = CheckpointLoader.load_checkpoint(class_embed_path)
|
||||
text_dim = class_embeddings.size(-1)
|
||||
self.text_adapter = TextAdapter(text_dim=text_dim)
|
||||
self.class_embed_select = class_embed_select
|
||||
if class_embed_select:
|
||||
class_embeddings = torch.cat(
|
||||
(class_embeddings, class_embeddings.mean(dim=0,
|
||||
keepdims=True)),
|
||||
dim=0)
|
||||
self.register_buffer('class_embeddings', class_embeddings)
|
||||
self.gamma = nn.Parameter(torch.ones(text_dim) * gamma)
|
||||
|
||||
def forward(self, x):
|
||||
"""Extract features from images."""
|
||||
|
||||
# calculate cross-attn map
|
||||
if self.class_embed_select:
|
||||
if isinstance(x, (tuple, list)):
|
||||
x, class_ids = x[:2]
|
||||
class_ids = class_ids.tolist()
|
||||
else:
|
||||
class_ids = [-1] * x.size(0)
|
||||
class_embeddings = self.class_embeddings[class_ids]
|
||||
c_crossattn = self.text_adapter(class_embeddings, self.gamma)
|
||||
c_crossattn = c_crossattn.unsqueeze(1)
|
||||
else:
|
||||
class_embeddings = self.class_embeddings
|
||||
c_crossattn = self.text_adapter(class_embeddings, self.gamma)
|
||||
c_crossattn = c_crossattn.unsqueeze(0).repeat(x.size(0), 1, 1)
|
||||
|
||||
# pad to required input shape for pretrained diffusion model
|
||||
if self.pad_shape is not None:
|
||||
pad_width = max(0, self.pad_shape[1] - x.shape[-1])
|
||||
pad_height = max(0, self.pad_shape[0] - x.shape[-2])
|
||||
x = F.pad(x, (0, pad_width, 0, pad_height), value=self.pad_val)
|
||||
|
||||
# forward the denoising model
|
||||
with torch.no_grad():
|
||||
latents = self.encoder_vq.encode(x).mode().detach()
|
||||
t = torch.ones((x.shape[0], ), device=x.device).long()
|
||||
outs = self.unet(latents, t, context=c_crossattn)
|
||||
|
||||
return outs
|
||||
52
finetune/mmseg/models/builder.py
Normal file
52
finetune/mmseg/models/builder.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
BACKBONES = MODELS
|
||||
NECKS = MODELS
|
||||
HEADS = MODELS
|
||||
LOSSES = MODELS
|
||||
SEGMENTORS = MODELS
|
||||
|
||||
|
||||
def build_backbone(cfg):
|
||||
"""Build backbone."""
|
||||
warnings.warn('``build_backbone`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return BACKBONES.build(cfg)
|
||||
|
||||
|
||||
def build_neck(cfg):
|
||||
"""Build neck."""
|
||||
warnings.warn('``build_neck`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return NECKS.build(cfg)
|
||||
|
||||
|
||||
def build_head(cfg):
|
||||
"""Build head."""
|
||||
warnings.warn('``build_head`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return HEADS.build(cfg)
|
||||
|
||||
|
||||
def build_loss(cfg):
|
||||
"""Build loss."""
|
||||
warnings.warn('``build_loss`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return LOSSES.build(cfg)
|
||||
|
||||
|
||||
def build_segmentor(cfg, train_cfg=None, test_cfg=None):
|
||||
"""Build segmentor."""
|
||||
if train_cfg is not None or test_cfg is not None:
|
||||
warnings.warn(
|
||||
'train_cfg and test_cfg is deprecated, '
|
||||
'please specify them in model', UserWarning)
|
||||
assert cfg.get('train_cfg') is None or train_cfg is None, \
|
||||
'train_cfg specified in both outer field and model field '
|
||||
assert cfg.get('test_cfg') is None or test_cfg is None, \
|
||||
'test_cfg specified in both outer field and model field '
|
||||
return SEGMENTORS.build(
|
||||
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
|
||||
151
finetune/mmseg/models/data_preprocessor.py
Normal file
151
finetune/mmseg/models/data_preprocessor.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseDataPreprocessor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import stack_batch
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SegDataPreProcessor(BaseDataPreprocessor):
|
||||
"""Image pre-processor for segmentation tasks.
|
||||
|
||||
Comparing with the :class:`mmengine.ImgDataPreprocessor`,
|
||||
|
||||
1. It won't do normalization if ``mean`` is not specified.
|
||||
2. It does normalization and color space conversion after stacking batch.
|
||||
3. It supports batch augmentations like mixup and cutmix.
|
||||
|
||||
|
||||
It provides the data pre-processing as follows
|
||||
|
||||
- Collate and move data to the target device.
|
||||
- Pad inputs to the input size with defined ``pad_val``, and pad seg map
|
||||
with defined ``seg_pad_val``.
|
||||
- Stack inputs to batch_inputs.
|
||||
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
|
||||
- Normalize image with defined std and mean.
|
||||
- Do batch augmentations like Mixup and Cutmix during training.
|
||||
|
||||
Args:
|
||||
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
|
||||
Defaults to None.
|
||||
std (Sequence[Number], optional): The pixel standard deviation of
|
||||
R, G, B channels. Defaults to None.
|
||||
size (tuple, optional): Fixed padding size.
|
||||
size_divisor (int, optional): The divisor of padded size.
|
||||
pad_val (float, optional): Padding value. Default: 0.
|
||||
seg_pad_val (float, optional): Padding value of segmentation map.
|
||||
Default: 255.
|
||||
padding_mode (str): Type of padding. Default: constant.
|
||||
- constant: pads with a constant value, this value is specified
|
||||
with pad_val.
|
||||
bgr_to_rgb (bool): whether to convert image from BGR to RGB.
|
||||
Defaults to False.
|
||||
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
|
||||
Defaults to False.
|
||||
batch_augments (list[dict], optional): Batch-level augmentations
|
||||
test_cfg (dict, optional): The padding size config in testing, if not
|
||||
specify, will use `size` and `size_divisor` params as default.
|
||||
Defaults to None, only supports keys `size` or `size_divisor`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mean: Sequence[Number] = None,
|
||||
std: Sequence[Number] = None,
|
||||
size: Optional[tuple] = None,
|
||||
size_divisor: Optional[int] = None,
|
||||
pad_val: Number = 0,
|
||||
seg_pad_val: Number = 255,
|
||||
bgr_to_rgb: bool = False,
|
||||
rgb_to_bgr: bool = False,
|
||||
batch_augments: Optional[List[dict]] = None,
|
||||
test_cfg: dict = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.size_divisor = size_divisor
|
||||
self.pad_val = pad_val
|
||||
self.seg_pad_val = seg_pad_val
|
||||
|
||||
assert not (bgr_to_rgb and rgb_to_bgr), (
|
||||
'`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time')
|
||||
self.channel_conversion = rgb_to_bgr or bgr_to_rgb
|
||||
|
||||
if mean is not None:
|
||||
assert std is not None, 'To enable the normalization in ' \
|
||||
'preprocessing, please specify both ' \
|
||||
'`mean` and `std`.'
|
||||
# Enable the normalization in preprocessing.
|
||||
self._enable_normalize = True
|
||||
self.register_buffer('mean',
|
||||
torch.tensor(mean).view(-1, 1, 1), False)
|
||||
self.register_buffer('std',
|
||||
torch.tensor(std).view(-1, 1, 1), False)
|
||||
else:
|
||||
self._enable_normalize = False
|
||||
|
||||
# TODO: support batch augmentations.
|
||||
self.batch_augments = batch_augments
|
||||
|
||||
# Support different padding methods in testing
|
||||
self.test_cfg = test_cfg
|
||||
|
||||
def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
|
||||
"""Perform normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (dict): data sampled from dataloader.
|
||||
training (bool): Whether to enable training time augmentation.
|
||||
|
||||
Returns:
|
||||
Dict: Data in the same format as the model input.
|
||||
"""
|
||||
data = self.cast_data(data) # type: ignore
|
||||
inputs = data['inputs']
|
||||
data_samples = data.get('data_samples', None)
|
||||
# TODO: whether normalize should be after stack_batch
|
||||
if self.channel_conversion and inputs[0].size(0) == 3:
|
||||
inputs = [_input[[2, 1, 0], ...] for _input in inputs]
|
||||
|
||||
inputs = [_input.float() for _input in inputs]
|
||||
if self._enable_normalize:
|
||||
inputs = [(_input - self.mean) / self.std for _input in inputs]
|
||||
|
||||
if training:
|
||||
assert data_samples is not None, ('During training, ',
|
||||
'`data_samples` must be define.')
|
||||
inputs, data_samples = stack_batch(
|
||||
inputs=inputs,
|
||||
data_samples=data_samples,
|
||||
size=self.size,
|
||||
size_divisor=self.size_divisor,
|
||||
pad_val=self.pad_val,
|
||||
seg_pad_val=self.seg_pad_val)
|
||||
|
||||
if self.batch_augments is not None:
|
||||
inputs, data_samples = self.batch_augments(
|
||||
inputs, data_samples)
|
||||
else:
|
||||
img_size = inputs[0].shape[1:]
|
||||
assert all(input_.shape[1:] == img_size for input_ in inputs), \
|
||||
'The image size in a batch should be the same.'
|
||||
# pad images when testing
|
||||
if self.test_cfg:
|
||||
inputs, padded_samples = stack_batch(
|
||||
inputs=inputs,
|
||||
size=self.test_cfg.get('size', None),
|
||||
size_divisor=self.test_cfg.get('size_divisor', None),
|
||||
pad_val=self.pad_val,
|
||||
seg_pad_val=self.seg_pad_val)
|
||||
for data_sample, pad_info in zip(data_samples, padded_samples):
|
||||
data_sample.set_metainfo({**pad_info})
|
||||
else:
|
||||
inputs = torch.stack(inputs, dim=0)
|
||||
|
||||
return dict(inputs=inputs, data_samples=data_samples)
|
||||
48
finetune/mmseg/models/decode_heads/__init__.py
Normal file
48
finetune/mmseg/models/decode_heads/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .ann_head import ANNHead
|
||||
from .apc_head import APCHead
|
||||
from .aspp_head import ASPPHead
|
||||
from .cc_head import CCHead
|
||||
from .da_head import DAHead
|
||||
from .ddr_head import DDRHead
|
||||
from .dm_head import DMHead
|
||||
from .dnl_head import DNLHead
|
||||
from .dpt_head import DPTHead
|
||||
from .ema_head import EMAHead
|
||||
from .enc_head import EncHead
|
||||
from .fcn_head import FCNHead
|
||||
from .fpn_head import FPNHead
|
||||
from .gc_head import GCHead
|
||||
from .ham_head import LightHamHead
|
||||
from .isa_head import ISAHead
|
||||
from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator
|
||||
from .lraspp_head import LRASPPHead
|
||||
from .mask2former_head import Mask2FormerHead
|
||||
from .maskformer_head import MaskFormerHead
|
||||
from .nl_head import NLHead
|
||||
from .ocr_head import OCRHead
|
||||
from .pid_head import PIDHead
|
||||
from .point_head import PointHead
|
||||
from .psa_head import PSAHead
|
||||
from .psp_head import PSPHead
|
||||
from .san_head import SideAdapterCLIPHead
|
||||
from .segformer_head import SegformerHead
|
||||
from .segmenter_mask_head import SegmenterMaskTransformerHead
|
||||
from .sep_aspp_head import DepthwiseSeparableASPPHead
|
||||
from .sep_fcn_head import DepthwiseSeparableFCNHead
|
||||
from .setr_mla_head import SETRMLAHead
|
||||
from .setr_up_head import SETRUPHead
|
||||
from .stdc_head import STDCHead
|
||||
from .uper_head import UPerHead
|
||||
from .vpd_depth_head import VPDDepthHead
|
||||
|
||||
__all__ = [
|
||||
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
|
||||
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
||||
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
|
||||
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
|
||||
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
|
||||
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
|
||||
'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead',
|
||||
'LightHamHead', 'PIDHead', 'DDRHead', 'VPDDepthHead', 'SideAdapterCLIPHead'
|
||||
]
|
||||
245
finetune/mmseg/models/decode_heads/ann_head.py
Normal file
245
finetune/mmseg/models/decode_heads/ann_head.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class PPMConcat(nn.ModuleList):
|
||||
"""Pyramid Pooling Module that only concat the features of each layer.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 3, 6, 8)):
|
||||
super().__init__(
|
||||
[nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales])
|
||||
|
||||
def forward(self, feats):
|
||||
"""Forward function."""
|
||||
ppm_outs = []
|
||||
for ppm in self:
|
||||
ppm_out = ppm(feats)
|
||||
ppm_outs.append(ppm_out.view(*feats.shape[:2], -1))
|
||||
concat_outs = torch.cat(ppm_outs, dim=2)
|
||||
return concat_outs
|
||||
|
||||
|
||||
class SelfAttentionBlock(_SelfAttentionBlock):
|
||||
"""Make a ANN used SelfAttentionBlock.
|
||||
|
||||
Args:
|
||||
low_in_channels (int): Input channels of lower level feature,
|
||||
which is the key feature for self-attention.
|
||||
high_in_channels (int): Input channels of higher level feature,
|
||||
which is the query feature for self-attention.
|
||||
channels (int): Output channels of key/query transform.
|
||||
out_channels (int): Output channels.
|
||||
share_key_query (bool): Whether share projection weight between key
|
||||
and query projection.
|
||||
query_scale (int): The scale of query feature map.
|
||||
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module of key feature.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict|None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, low_in_channels, high_in_channels, channels,
|
||||
out_channels, share_key_query, query_scale, key_pool_scales,
|
||||
conv_cfg, norm_cfg, act_cfg):
|
||||
key_psp = PPMConcat(key_pool_scales)
|
||||
if query_scale > 1:
|
||||
query_downsample = nn.MaxPool2d(kernel_size=query_scale)
|
||||
else:
|
||||
query_downsample = None
|
||||
super().__init__(
|
||||
key_in_channels=low_in_channels,
|
||||
query_in_channels=high_in_channels,
|
||||
channels=channels,
|
||||
out_channels=out_channels,
|
||||
share_key_query=share_key_query,
|
||||
query_downsample=query_downsample,
|
||||
key_downsample=key_psp,
|
||||
key_query_num_convs=1,
|
||||
key_query_norm=True,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=False,
|
||||
matmul_norm=True,
|
||||
with_out=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
|
||||
class AFNB(nn.Module):
|
||||
"""Asymmetric Fusion Non-local Block(AFNB)
|
||||
|
||||
Args:
|
||||
low_in_channels (int): Input channels of lower level feature,
|
||||
which is the key feature for self-attention.
|
||||
high_in_channels (int): Input channels of higher level feature,
|
||||
which is the query feature for self-attention.
|
||||
channels (int): Output channels of key/query transform.
|
||||
out_channels (int): Output channels.
|
||||
and query projection.
|
||||
query_scales (tuple[int]): The scales of query feature map.
|
||||
Default: (1,)
|
||||
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module of key feature.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict|None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, low_in_channels, high_in_channels, channels,
|
||||
out_channels, query_scales, key_pool_scales, conv_cfg,
|
||||
norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.stages = nn.ModuleList()
|
||||
for query_scale in query_scales:
|
||||
self.stages.append(
|
||||
SelfAttentionBlock(
|
||||
low_in_channels=low_in_channels,
|
||||
high_in_channels=high_in_channels,
|
||||
channels=channels,
|
||||
out_channels=out_channels,
|
||||
share_key_query=False,
|
||||
query_scale=query_scale,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.bottleneck = ConvModule(
|
||||
out_channels + high_in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
def forward(self, low_feats, high_feats):
|
||||
"""Forward function."""
|
||||
priors = [stage(high_feats, low_feats) for stage in self.stages]
|
||||
context = torch.stack(priors, dim=0).sum(dim=0)
|
||||
output = self.bottleneck(torch.cat([context, high_feats], 1))
|
||||
return output
|
||||
|
||||
|
||||
class APNB(nn.Module):
|
||||
"""Asymmetric Pyramid Non-local Block (APNB)
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of key/query feature,
|
||||
which is the key feature for self-attention.
|
||||
channels (int): Output channels of key/query transform.
|
||||
out_channels (int): Output channels.
|
||||
query_scales (tuple[int]): The scales of query feature map.
|
||||
Default: (1,)
|
||||
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module of key feature.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict|None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, channels, out_channels, query_scales,
|
||||
key_pool_scales, conv_cfg, norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.stages = nn.ModuleList()
|
||||
for query_scale in query_scales:
|
||||
self.stages.append(
|
||||
SelfAttentionBlock(
|
||||
low_in_channels=in_channels,
|
||||
high_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=out_channels,
|
||||
share_key_query=True,
|
||||
query_scale=query_scale,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.bottleneck = ConvModule(
|
||||
2 * in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, feats):
|
||||
"""Forward function."""
|
||||
priors = [stage(feats, feats) for stage in self.stages]
|
||||
context = torch.stack(priors, dim=0).sum(dim=0)
|
||||
output = self.bottleneck(torch.cat([context, feats], 1))
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ANNHead(BaseDecodeHead):
|
||||
"""Asymmetric Non-local Neural Networks for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `ANNNet
|
||||
<https://arxiv.org/abs/1908.07678>`_.
|
||||
|
||||
Args:
|
||||
project_channels (int): Projection channels for Nonlocal.
|
||||
query_scales (tuple[int]): The scales of query feature map.
|
||||
Default: (1,)
|
||||
key_pool_scales (tuple[int]): The pooling scales of key feature map.
|
||||
Default: (1, 3, 6, 8).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
project_channels,
|
||||
query_scales=(1, ),
|
||||
key_pool_scales=(1, 3, 6, 8),
|
||||
**kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
assert len(self.in_channels) == 2
|
||||
low_in_channels, high_in_channels = self.in_channels
|
||||
self.project_channels = project_channels
|
||||
self.fusion = AFNB(
|
||||
low_in_channels=low_in_channels,
|
||||
high_in_channels=high_in_channels,
|
||||
out_channels=high_in_channels,
|
||||
channels=project_channels,
|
||||
query_scales=query_scales,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.bottleneck = ConvModule(
|
||||
high_in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.context = APNB(
|
||||
in_channels=self.channels,
|
||||
out_channels=self.channels,
|
||||
channels=project_channels,
|
||||
query_scales=query_scales,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
low_feats, high_feats = self._transform_inputs(inputs)
|
||||
output = self.fusion(low_feats, high_feats)
|
||||
output = self.dropout(output)
|
||||
output = self.bottleneck(output)
|
||||
output = self.context(output)
|
||||
output = self.cls_seg(output)
|
||||
|
||||
return output
|
||||
159
finetune/mmseg/models/decode_heads/apc_head.py
Normal file
159
finetune/mmseg/models/decode_heads/apc_head.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class ACM(nn.Module):
|
||||
"""Adaptive Context Module used in APCNet.
|
||||
|
||||
Args:
|
||||
pool_scale (int): Pooling scale used in Adaptive Context
|
||||
Module to extract region features.
|
||||
fusion (bool): Add one conv to fuse residual feature.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg,
|
||||
norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.pool_scale = pool_scale
|
||||
self.fusion = fusion
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.pooled_redu_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.input_redu_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.global_info = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)
|
||||
|
||||
self.residual_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
if self.fusion:
|
||||
self.fusion_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale)
|
||||
# [batch_size, channels, h, w]
|
||||
x = self.input_redu_conv(x)
|
||||
# [batch_size, channels, pool_scale, pool_scale]
|
||||
pooled_x = self.pooled_redu_conv(pooled_x)
|
||||
batch_size = x.size(0)
|
||||
# [batch_size, pool_scale * pool_scale, channels]
|
||||
pooled_x = pooled_x.view(batch_size, self.channels,
|
||||
-1).permute(0, 2, 1).contiguous()
|
||||
# [batch_size, h * w, pool_scale * pool_scale]
|
||||
affinity_matrix = self.gla(x + resize(
|
||||
self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:])
|
||||
).permute(0, 2, 3, 1).reshape(
|
||||
batch_size, -1, self.pool_scale**2)
|
||||
affinity_matrix = F.sigmoid(affinity_matrix)
|
||||
# [batch_size, h * w, channels]
|
||||
z_out = torch.matmul(affinity_matrix, pooled_x)
|
||||
# [batch_size, channels, h * w]
|
||||
z_out = z_out.permute(0, 2, 1).contiguous()
|
||||
# [batch_size, channels, h, w]
|
||||
z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3))
|
||||
z_out = self.residual_conv(z_out)
|
||||
z_out = F.relu(z_out + x)
|
||||
if self.fusion:
|
||||
z_out = self.fusion_conv(z_out)
|
||||
|
||||
return z_out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class APCHead(BaseDecodeHead):
|
||||
"""Adaptive Pyramid Context Network for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of
|
||||
`APCNet <https://openaccess.thecvf.com/content_CVPR_2019/papers/\
|
||||
He_Adaptive_Pyramid_Context_Network_for_Semantic_Segmentation_\
|
||||
CVPR_2019_paper.pdf>`_.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Adaptive Context
|
||||
Module. Default: (1, 2, 3, 6).
|
||||
fusion (bool): Add one conv to fuse residual feature.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(pool_scales, (list, tuple))
|
||||
self.pool_scales = pool_scales
|
||||
self.fusion = fusion
|
||||
acm_modules = []
|
||||
for pool_scale in self.pool_scales:
|
||||
acm_modules.append(
|
||||
ACM(pool_scale,
|
||||
self.fusion,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.acm_modules = nn.ModuleList(acm_modules)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels + len(pool_scales) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
acm_outs = [x]
|
||||
for acm_module in self.acm_modules:
|
||||
acm_outs.append(acm_module(x))
|
||||
acm_outs = torch.cat(acm_outs, dim=1)
|
||||
output = self.bottleneck(acm_outs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
122
finetune/mmseg/models/decode_heads/aspp_head.py
Normal file
122
finetune/mmseg/models/decode_heads/aspp_head.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class ASPPModule(nn.ModuleList):
|
||||
"""Atrous Spatial Pyramid Pooling (ASPP) Module.
|
||||
|
||||
Args:
|
||||
dilations (tuple[int]): Dilation rate of each layer.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,
|
||||
act_cfg):
|
||||
super().__init__()
|
||||
self.dilations = dilations
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
for dilation in dilations:
|
||||
self.append(
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1 if dilation == 1 else 3,
|
||||
dilation=dilation,
|
||||
padding=0 if dilation == 1 else dilation,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
aspp_outs = []
|
||||
for aspp_module in self:
|
||||
aspp_outs.append(aspp_module(x))
|
||||
|
||||
return aspp_outs
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ASPPHead(BaseDecodeHead):
|
||||
"""Rethinking Atrous Convolution for Semantic Image Segmentation.
|
||||
|
||||
This head is the implementation of `DeepLabV3
|
||||
<https://arxiv.org/abs/1706.05587>`_.
|
||||
|
||||
Args:
|
||||
dilations (tuple[int]): Dilation rates for ASPP module.
|
||||
Default: (1, 6, 12, 18).
|
||||
"""
|
||||
|
||||
def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(dilations, (list, tuple))
|
||||
self.dilations = dilations
|
||||
self.image_pool = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.aspp_modules = ASPPModule(
|
||||
dilations,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.bottleneck = ConvModule(
|
||||
(len(dilations) + 1) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def _forward_feature(self, inputs):
|
||||
"""Forward function for feature maps before classifying each pixel with
|
||||
``self.cls_seg`` fc.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||
H, W) which is feature map for last layer of decoder head.
|
||||
"""
|
||||
x = self._transform_inputs(inputs)
|
||||
aspp_outs = [
|
||||
resize(
|
||||
self.image_pool(x),
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
]
|
||||
aspp_outs.extend(self.aspp_modules(x))
|
||||
aspp_outs = torch.cat(aspp_outs, dim=1)
|
||||
feats = self.bottleneck(aspp_outs)
|
||||
return feats
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
output = self._forward_feature(inputs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
62
finetune/mmseg/models/decode_heads/cascade_decode_head.py
Normal file
62
finetune/mmseg/models/decode_heads/cascade_decode_head.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.utils import ConfigType
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
|
||||
"""Base class for cascade decode head used in
|
||||
:class:`CascadeEncoderDecoder."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, inputs, prev_output):
|
||||
"""Placeholder of forward function."""
|
||||
pass
|
||||
|
||||
def loss(self, inputs: List[Tensor], prev_output: Tensor,
|
||||
batch_data_samples: List[dict], train_cfg: ConfigType) -> Tensor:
|
||||
"""Forward function for training.
|
||||
|
||||
Args:
|
||||
inputs (List[Tensor]): List of multi-level img features.
|
||||
prev_output (Tensor): The output of previous decode head.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_sem_seg`.
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
seg_logits = self.forward(inputs, prev_output)
|
||||
losses = self.loss_by_feat(seg_logits, batch_data_samples)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, inputs: List[Tensor], prev_output: Tensor,
|
||||
batch_img_metas: List[dict], tese_cfg: ConfigType):
|
||||
"""Forward function for testing.
|
||||
|
||||
Args:
|
||||
inputs (List[Tensor]): List of multi-level img features.
|
||||
prev_output (Tensor): The output of previous decode head.
|
||||
batch_img_metas (dict): List Image info where each dict may also
|
||||
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Output segmentation map.
|
||||
"""
|
||||
seg_logits = self.forward(inputs, prev_output)
|
||||
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas)
|
||||
43
finetune/mmseg/models/decode_heads/cc_head.py
Normal file
43
finetune/mmseg/models/decode_heads/cc_head.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
try:
|
||||
from mmcv.ops import CrissCrossAttention
|
||||
except ModuleNotFoundError:
|
||||
CrissCrossAttention = None
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CCHead(FCNHead):
|
||||
"""CCNet: Criss-Cross Attention for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `CCNet
|
||||
<https://arxiv.org/abs/1811.11721>`_.
|
||||
|
||||
Args:
|
||||
recurrence (int): Number of recurrence of Criss Cross Attention
|
||||
module. Default: 2.
|
||||
"""
|
||||
|
||||
def __init__(self, recurrence=2, **kwargs):
|
||||
if CrissCrossAttention is None:
|
||||
raise RuntimeError('Please install mmcv-full for '
|
||||
'CrissCrossAttention ops')
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.recurrence = recurrence
|
||||
self.cca = CrissCrossAttention(self.channels)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
for _ in range(self.recurrence):
|
||||
output = self.cca(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
184
finetune/mmseg/models/decode_heads/da_head.py
Normal file
184
finetune/mmseg/models/decode_heads/da_head.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, Scale
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList, add_prefix
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class PAM(_SelfAttentionBlock):
|
||||
"""Position Attention Module (PAM)
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of key/query feature.
|
||||
channels (int): Output channels of key/query transform.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, channels):
|
||||
super().__init__(
|
||||
key_in_channels=in_channels,
|
||||
query_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=in_channels,
|
||||
share_key_query=False,
|
||||
query_downsample=None,
|
||||
key_downsample=None,
|
||||
key_query_num_convs=1,
|
||||
key_query_norm=False,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=False,
|
||||
matmul_norm=False,
|
||||
with_out=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=None)
|
||||
|
||||
self.gamma = Scale(0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
out = super().forward(x, x)
|
||||
|
||||
out = self.gamma(out) + x
|
||||
return out
|
||||
|
||||
|
||||
class CAM(nn.Module):
|
||||
"""Channel Attention Module (CAM)"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gamma = Scale(0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
batch_size, channels, height, width = x.size()
|
||||
proj_query = x.view(batch_size, channels, -1)
|
||||
proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1)
|
||||
energy = torch.bmm(proj_query, proj_key)
|
||||
energy_new = torch.max(
|
||||
energy, -1, keepdim=True)[0].expand_as(energy) - energy
|
||||
attention = F.softmax(energy_new, dim=-1)
|
||||
proj_value = x.view(batch_size, channels, -1)
|
||||
|
||||
out = torch.bmm(attention, proj_value)
|
||||
out = out.view(batch_size, channels, height, width)
|
||||
|
||||
out = self.gamma(out) + x
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DAHead(BaseDecodeHead):
|
||||
"""Dual Attention Network for Scene Segmentation.
|
||||
|
||||
This head is the implementation of `DANet
|
||||
<https://arxiv.org/abs/1809.02983>`_.
|
||||
|
||||
Args:
|
||||
pam_channels (int): The channels of Position Attention Module(PAM).
|
||||
"""
|
||||
|
||||
def __init__(self, pam_channels, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.pam_channels = pam_channels
|
||||
self.pam_in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.pam = PAM(self.channels, pam_channels)
|
||||
self.pam_out_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.pam_conv_seg = nn.Conv2d(
|
||||
self.channels, self.num_classes, kernel_size=1)
|
||||
|
||||
self.cam_in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.cam = CAM()
|
||||
self.cam_out_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.cam_conv_seg = nn.Conv2d(
|
||||
self.channels, self.num_classes, kernel_size=1)
|
||||
|
||||
def pam_cls_seg(self, feat):
|
||||
"""PAM feature classification."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.pam_conv_seg(feat)
|
||||
return output
|
||||
|
||||
def cam_cls_seg(self, feat):
|
||||
"""CAM feature classification."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.cam_conv_seg(feat)
|
||||
return output
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
pam_feat = self.pam_in_conv(x)
|
||||
pam_feat = self.pam(pam_feat)
|
||||
pam_feat = self.pam_out_conv(pam_feat)
|
||||
pam_out = self.pam_cls_seg(pam_feat)
|
||||
|
||||
cam_feat = self.cam_in_conv(x)
|
||||
cam_feat = self.cam(cam_feat)
|
||||
cam_feat = self.cam_out_conv(cam_feat)
|
||||
cam_out = self.cam_cls_seg(cam_feat)
|
||||
|
||||
feat_sum = pam_feat + cam_feat
|
||||
pam_cam_out = self.cls_seg(feat_sum)
|
||||
|
||||
return pam_cam_out, pam_out, cam_out
|
||||
|
||||
def predict(self, inputs, batch_img_metas: List[dict], test_cfg,
|
||||
**kwargs) -> List[Tensor]:
|
||||
"""Forward function for testing, only ``pam_cam`` is used."""
|
||||
seg_logits = self.forward(inputs)[0]
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
|
||||
|
||||
def loss_by_feat(self, seg_logit: Tuple[Tensor],
|
||||
batch_data_samples: SampleList, **kwargs) -> dict:
|
||||
"""Compute ``pam_cam``, ``pam``, ``cam`` loss."""
|
||||
pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit
|
||||
loss = dict()
|
||||
loss.update(
|
||||
add_prefix(
|
||||
super().loss_by_feat(pam_cam_seg_logit, batch_data_samples),
|
||||
'pam_cam'))
|
||||
loss.update(
|
||||
add_prefix(super().loss_by_feat(pam_seg_logit, batch_data_samples),
|
||||
'pam'))
|
||||
loss.update(
|
||||
add_prefix(super().loss_by_feat(cam_seg_logit, batch_data_samples),
|
||||
'cam'))
|
||||
return loss
|
||||
116
finetune/mmseg/models/decode_heads/ddr_head.py
Normal file
116
finetune/mmseg/models/decode_heads/ddr_head.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.models.losses import accuracy
|
||||
from mmseg.models.utils import resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType, SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DDRHead(BaseDecodeHead):
|
||||
"""Decode head for DDRNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
num_classes (int): Number of classes.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict, optional): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
num_classes: int,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
channels,
|
||||
num_classes=num_classes,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**kwargs)
|
||||
|
||||
self.head = self._make_base_head(self.in_channels, self.channels)
|
||||
self.aux_head = self._make_base_head(self.in_channels // 2,
|
||||
self.channels)
|
||||
self.aux_cls_seg = nn.Conv2d(
|
||||
self.channels, self.out_channels, kernel_size=1)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: Union[Tensor,
|
||||
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
|
||||
if self.training:
|
||||
c3_feat, c5_feat = inputs
|
||||
x_c = self.head(c5_feat)
|
||||
x_c = self.cls_seg(x_c)
|
||||
x_s = self.aux_head(c3_feat)
|
||||
x_s = self.aux_cls_seg(x_s)
|
||||
|
||||
return x_c, x_s
|
||||
else:
|
||||
x_c = self.head(inputs)
|
||||
x_c = self.cls_seg(x_c)
|
||||
return x_c
|
||||
|
||||
def _make_base_head(self, in_channels: int,
|
||||
channels: int) -> nn.Sequential:
|
||||
layers = [
|
||||
ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
order=('norm', 'act', 'conv')),
|
||||
build_norm_layer(self.norm_cfg, channels)[1],
|
||||
build_activation_layer(self.act_cfg),
|
||||
]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def loss_by_feat(self, seg_logits: Tuple[Tensor],
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
loss = dict()
|
||||
context_logit, spatial_logit = seg_logits
|
||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||
|
||||
context_logit = resize(
|
||||
context_logit,
|
||||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
spatial_logit = resize(
|
||||
spatial_logit,
|
||||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
seg_label = seg_label.squeeze(1)
|
||||
|
||||
loss['loss_context'] = self.loss_decode[0](context_logit, seg_label)
|
||||
loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label)
|
||||
loss['acc_seg'] = accuracy(
|
||||
context_logit, seg_label, ignore_index=self.ignore_index)
|
||||
|
||||
return loss
|
||||
366
finetune/mmseg/models/decode_heads/decode_head.py
Normal file
366
finetune/mmseg/models/decode_heads/decode_head.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import BaseModule
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import build_pixel_sampler
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
from ..losses import accuracy
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
"""Base class for BaseDecodeHead.
|
||||
|
||||
1. The ``init_weights`` method is used to initialize decode_head's
|
||||
model parameters. After segmentor initialization, ``init_weights``
|
||||
is triggered when ``segmentor.init_weights()`` is called externally.
|
||||
|
||||
2. The ``loss`` method is used to calculate the loss of decode_head,
|
||||
which includes two steps: (1) the decode_head model performs forward
|
||||
propagation to obtain the feature maps (2) The ``loss_by_feat`` method
|
||||
is called based on the feature maps to calculate the loss.
|
||||
|
||||
.. code:: text
|
||||
|
||||
loss(): forward() -> loss_by_feat()
|
||||
|
||||
3. The ``predict`` method is used to predict segmentation results,
|
||||
which includes two steps: (1) the decode_head model performs forward
|
||||
propagation to obtain the feature maps (2) The ``predict_by_feat`` method
|
||||
is called based on the feature maps to predict segmentation results
|
||||
including post-processing.
|
||||
|
||||
.. code:: text
|
||||
|
||||
predict(): forward() -> predict_by_feat()
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
num_classes (int): Number of classes.
|
||||
out_channels (int): Output channels of conv_seg. Default: None.
|
||||
threshold (float): Threshold for binary segmentation in the case of
|
||||
`num_classes==1`. Default: None.
|
||||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
||||
conv_cfg (dict|None): Config of conv layers. Default: None.
|
||||
norm_cfg (dict|None): Config of norm layers. Default: None.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU')
|
||||
in_index (int|Sequence[int]): Input feature index. Default: -1
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
Default: None.
|
||||
loss_decode (dict | Sequence[dict]): Config of decode loss.
|
||||
The `loss_name` is property of corresponding loss function which
|
||||
could be shown in training log. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_ce'.
|
||||
e.g. dict(type='CrossEntropyLoss'),
|
||||
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
|
||||
dict(type='DiceLoss', loss_name='loss_dice')]
|
||||
Default: dict(type='CrossEntropyLoss').
|
||||
ignore_index (int | None): The label index to be ignored. When using
|
||||
masked BCE loss, ignore_index should be set to None. Default: 255.
|
||||
sampler (dict|None): The config of segmentation map sampler.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channels,
|
||||
*,
|
||||
num_classes,
|
||||
out_channels=None,
|
||||
threshold=None,
|
||||
dropout_ratio=0.1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
in_index=-1,
|
||||
input_transform=None,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0),
|
||||
ignore_index=255,
|
||||
sampler=None,
|
||||
align_corners=False,
|
||||
init_cfg=dict(
|
||||
type='Normal', std=0.01, override=dict(name='conv_seg'))):
|
||||
super().__init__(init_cfg)
|
||||
self._init_inputs(in_channels, in_index, input_transform)
|
||||
self.channels = channels
|
||||
self.dropout_ratio = dropout_ratio
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.in_index = in_index
|
||||
|
||||
self.ignore_index = ignore_index
|
||||
self.align_corners = align_corners
|
||||
|
||||
if out_channels is None:
|
||||
if num_classes == 2:
|
||||
warnings.warn('For binary segmentation, we suggest using'
|
||||
'`out_channels = 1` to define the output'
|
||||
'channels of segmentor, and use `threshold`'
|
||||
'to convert `seg_logits` into a prediction'
|
||||
'applying a threshold')
|
||||
out_channels = num_classes
|
||||
|
||||
if out_channels != num_classes and out_channels != 1:
|
||||
raise ValueError(
|
||||
'out_channels should be equal to num_classes,'
|
||||
'except binary segmentation set out_channels == 1 and'
|
||||
f'num_classes == 2, but got out_channels={out_channels}'
|
||||
f'and num_classes={num_classes}')
|
||||
|
||||
if out_channels == 1 and threshold is None:
|
||||
threshold = 0.3
|
||||
warnings.warn('threshold is not defined for binary, and defaults'
|
||||
'to 0.3')
|
||||
self.num_classes = num_classes
|
||||
self.out_channels = out_channels
|
||||
self.threshold = threshold
|
||||
|
||||
if isinstance(loss_decode, dict):
|
||||
self.loss_decode = MODELS.build(loss_decode)
|
||||
elif isinstance(loss_decode, (list, tuple)):
|
||||
self.loss_decode = nn.ModuleList()
|
||||
for loss in loss_decode:
|
||||
self.loss_decode.append(MODELS.build(loss))
|
||||
else:
|
||||
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
|
||||
but got {type(loss_decode)}')
|
||||
|
||||
if sampler is not None:
|
||||
self.sampler = build_pixel_sampler(sampler, context=self)
|
||||
else:
|
||||
self.sampler = None
|
||||
|
||||
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
|
||||
if dropout_ratio > 0:
|
||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
def extra_repr(self):
|
||||
"""Extra repr."""
|
||||
s = f'input_transform={self.input_transform}, ' \
|
||||
f'ignore_index={self.ignore_index}, ' \
|
||||
f'align_corners={self.align_corners}'
|
||||
return s
|
||||
|
||||
def _init_inputs(self, in_channels, in_index, input_transform):
|
||||
"""Check and initialize input transforms.
|
||||
|
||||
The in_channels, in_index and input_transform must match.
|
||||
Specifically, when input_transform is None, only single feature map
|
||||
will be selected. So in_channels and in_index must be of type int.
|
||||
When input_transform
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
in_index (int|Sequence[int]): Input feature index.
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
"""
|
||||
|
||||
if input_transform is not None:
|
||||
assert input_transform in ['resize_concat', 'multiple_select']
|
||||
self.input_transform = input_transform
|
||||
self.in_index = in_index
|
||||
if input_transform is not None:
|
||||
assert isinstance(in_channels, (list, tuple))
|
||||
assert isinstance(in_index, (list, tuple))
|
||||
assert len(in_channels) == len(in_index)
|
||||
if input_transform == 'resize_concat':
|
||||
self.in_channels = sum(in_channels)
|
||||
else:
|
||||
self.in_channels = in_channels
|
||||
else:
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(in_index, int)
|
||||
self.in_channels = in_channels
|
||||
|
||||
def _transform_inputs(self, inputs):
|
||||
"""Transform inputs for decoder.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
Tensor: The transformed inputs
|
||||
"""
|
||||
|
||||
if self.input_transform == 'resize_concat':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
upsampled_inputs = [
|
||||
resize(
|
||||
input=x,
|
||||
size=inputs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners) for x in inputs
|
||||
]
|
||||
inputs = torch.cat(upsampled_inputs, dim=1)
|
||||
elif self.input_transform == 'multiple_select':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
else:
|
||||
inputs = inputs[self.in_index]
|
||||
|
||||
return inputs
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, inputs):
|
||||
"""Placeholder of forward function."""
|
||||
pass
|
||||
|
||||
def cls_seg(self, feat):
|
||||
"""Classify each pixel."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.conv_seg(feat)
|
||||
return output
|
||||
|
||||
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList,
|
||||
train_cfg: ConfigType) -> dict:
|
||||
"""Forward function for training.
|
||||
|
||||
Args:
|
||||
inputs (Tuple[Tensor]): List of multi-level img features.
|
||||
batch_data_samples (list[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `img_metas` or `gt_semantic_seg`.
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
seg_logits = self.forward(inputs)
|
||||
losses = self.loss_by_feat(seg_logits, batch_data_samples)
|
||||
return losses
|
||||
|
||||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType) -> Tensor:
|
||||
"""Forward function for prediction.
|
||||
|
||||
Args:
|
||||
inputs (Tuple[Tensor]): List of multi-level img features.
|
||||
batch_img_metas (dict): List Image info where each dict may also
|
||||
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Outputs segmentation logits map.
|
||||
"""
|
||||
seg_logits = self.forward(inputs)
|
||||
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas)
|
||||
|
||||
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
|
||||
gt_semantic_segs = [
|
||||
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
|
||||
]
|
||||
return torch.stack(gt_semantic_segs, dim=0)
|
||||
|
||||
def loss_by_feat(self, seg_logits: Tensor,
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
"""Compute segmentation loss.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): The output from decode head forward function.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
|
||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||
loss = dict()
|
||||
seg_logits = resize(
|
||||
input=seg_logits,
|
||||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.sampler is not None:
|
||||
seg_weight = self.sampler.sample(seg_logits, seg_label)
|
||||
else:
|
||||
seg_weight = None
|
||||
seg_label = seg_label.squeeze(1)
|
||||
|
||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.loss_decode]
|
||||
else:
|
||||
losses_decode = self.loss_decode
|
||||
for loss_decode in losses_decode:
|
||||
if loss_decode.loss_name not in loss:
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
seg_logits,
|
||||
seg_label,
|
||||
weight=seg_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
else:
|
||||
loss[loss_decode.loss_name] += loss_decode(
|
||||
seg_logits,
|
||||
seg_label,
|
||||
weight=seg_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
|
||||
loss['acc_seg'] = accuracy(
|
||||
seg_logits, seg_label, ignore_index=self.ignore_index)
|
||||
return loss
|
||||
|
||||
def predict_by_feat(self, seg_logits: Tensor,
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Transform a batch of output seg_logits to the input shape.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): The output from decode head forward function.
|
||||
batch_img_metas (list[dict]): Meta information of each image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
|
||||
Returns:
|
||||
Tensor: Outputs segmentation logits map.
|
||||
"""
|
||||
|
||||
if isinstance(batch_img_metas[0]['img_shape'], torch.Size):
|
||||
# slide inference
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
elif 'pad_shape' in batch_img_metas[0]:
|
||||
size = batch_img_metas[0]['pad_shape'][:2]
|
||||
else:
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
|
||||
seg_logits = resize(
|
||||
input=seg_logits,
|
||||
size=size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
return seg_logits
|
||||
141
finetune/mmseg/models/decode_heads/dm_head.py
Normal file
141
finetune/mmseg/models/decode_heads/dm_head.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class DCM(nn.Module):
|
||||
"""Dynamic Convolutional Module used in DMNet.
|
||||
|
||||
Args:
|
||||
filter_size (int): The filter size of generated convolution kernel
|
||||
used in Dynamic Convolutional Module.
|
||||
fusion (bool): Add one conv to fuse DCM output feature.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg,
|
||||
norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.filter_size = filter_size
|
||||
self.fusion = fusion
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1,
|
||||
0)
|
||||
|
||||
self.input_redu_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
if self.norm_cfg is not None:
|
||||
self.norm = build_norm_layer(self.norm_cfg, self.channels)[1]
|
||||
else:
|
||||
self.norm = None
|
||||
self.activate = build_activation_layer(self.act_cfg)
|
||||
|
||||
if self.fusion:
|
||||
self.fusion_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
generated_filter = self.filter_gen_conv(
|
||||
F.adaptive_avg_pool2d(x, self.filter_size))
|
||||
x = self.input_redu_conv(x)
|
||||
b, c, h, w = x.shape
|
||||
# [1, b * c, h, w], c = self.channels
|
||||
x = x.view(1, b * c, h, w)
|
||||
# [b * c, 1, filter_size, filter_size]
|
||||
generated_filter = generated_filter.view(b * c, 1, self.filter_size,
|
||||
self.filter_size)
|
||||
pad = (self.filter_size - 1) // 2
|
||||
if (self.filter_size - 1) % 2 == 0:
|
||||
p2d = (pad, pad, pad, pad)
|
||||
else:
|
||||
p2d = (pad + 1, pad, pad + 1, pad)
|
||||
x = F.pad(input=x, pad=p2d, mode='constant', value=0)
|
||||
# [1, b * c, h, w]
|
||||
output = F.conv2d(input=x, weight=generated_filter, groups=b * c)
|
||||
# [b, c, h, w]
|
||||
output = output.view(b, c, h, w)
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
output = self.activate(output)
|
||||
|
||||
if self.fusion:
|
||||
output = self.fusion_conv(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DMHead(BaseDecodeHead):
|
||||
"""Dynamic Multi-scale Filters for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of
|
||||
`DMNet <https://openaccess.thecvf.com/content_ICCV_2019/papers/\
|
||||
He_Dynamic_Multi-Scale_Filters_for_Semantic_Segmentation_\
|
||||
ICCV_2019_paper.pdf>`_.
|
||||
|
||||
Args:
|
||||
filter_sizes (tuple[int]): The size of generated convolutional filters
|
||||
used in Dynamic Convolutional Module. Default: (1, 3, 5, 7).
|
||||
fusion (bool): Add one conv to fuse DCM output feature.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(filter_sizes, (list, tuple))
|
||||
self.filter_sizes = filter_sizes
|
||||
self.fusion = fusion
|
||||
dcm_modules = []
|
||||
for filter_size in self.filter_sizes:
|
||||
dcm_modules.append(
|
||||
DCM(filter_size,
|
||||
self.fusion,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.dcm_modules = nn.ModuleList(dcm_modules)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels + len(filter_sizes) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
dcm_outs = [x]
|
||||
for dcm_module in self.dcm_modules:
|
||||
dcm_outs.append(dcm_module(x))
|
||||
dcm_outs = torch.cat(dcm_outs, dim=1)
|
||||
output = self.bottleneck(dcm_outs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
137
finetune/mmseg/models/decode_heads/dnl_head.py
Normal file
137
finetune/mmseg/models/decode_heads/dnl_head.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.cnn import NonLocal2d
|
||||
from torch import nn
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
class DisentangledNonLocal2d(NonLocal2d):
|
||||
"""Disentangled Non-Local Blocks.
|
||||
|
||||
Args:
|
||||
temperature (float): Temperature to adjust attention. Default: 0.05
|
||||
"""
|
||||
|
||||
def __init__(self, *arg, temperature, **kwargs):
|
||||
super().__init__(*arg, **kwargs)
|
||||
self.temperature = temperature
|
||||
self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)
|
||||
|
||||
def embedded_gaussian(self, theta_x, phi_x):
|
||||
"""Embedded gaussian with temperature."""
|
||||
|
||||
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
||||
pairwise_weight = torch.matmul(theta_x, phi_x)
|
||||
if self.use_scale:
|
||||
# theta_x.shape[-1] is `self.inter_channels`
|
||||
pairwise_weight /= torch.tensor(
|
||||
theta_x.shape[-1],
|
||||
dtype=torch.float,
|
||||
device=pairwise_weight.device)**torch.tensor(
|
||||
0.5, device=pairwise_weight.device)
|
||||
pairwise_weight /= torch.tensor(
|
||||
self.temperature, device=pairwise_weight.device)
|
||||
pairwise_weight = pairwise_weight.softmax(dim=-1)
|
||||
return pairwise_weight
|
||||
|
||||
def forward(self, x):
|
||||
# x: [N, C, H, W]
|
||||
n = x.size(0)
|
||||
|
||||
# g_x: [N, HxW, C]
|
||||
g_x = self.g(x).view(n, self.inter_channels, -1)
|
||||
g_x = g_x.permute(0, 2, 1)
|
||||
|
||||
# theta_x: [N, HxW, C], phi_x: [N, C, HxW]
|
||||
if self.mode == 'gaussian':
|
||||
theta_x = x.view(n, self.in_channels, -1)
|
||||
theta_x = theta_x.permute(0, 2, 1)
|
||||
if self.sub_sample:
|
||||
phi_x = self.phi(x).view(n, self.in_channels, -1)
|
||||
else:
|
||||
phi_x = x.view(n, self.in_channels, -1)
|
||||
elif self.mode == 'concatenation':
|
||||
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
|
||||
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
|
||||
else:
|
||||
theta_x = self.theta(x).view(n, self.inter_channels, -1)
|
||||
theta_x = theta_x.permute(0, 2, 1)
|
||||
phi_x = self.phi(x).view(n, self.inter_channels, -1)
|
||||
|
||||
# subtract mean
|
||||
theta_x -= theta_x.mean(dim=-2, keepdim=True)
|
||||
phi_x -= phi_x.mean(dim=-1, keepdim=True)
|
||||
|
||||
pairwise_func = getattr(self, self.mode)
|
||||
# pairwise_weight: [N, HxW, HxW]
|
||||
pairwise_weight = pairwise_func(theta_x, phi_x)
|
||||
|
||||
# y: [N, HxW, C]
|
||||
y = torch.matmul(pairwise_weight, g_x)
|
||||
# y: [N, C, H, W]
|
||||
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
|
||||
*x.size()[2:])
|
||||
|
||||
# unary_mask: [N, 1, HxW]
|
||||
unary_mask = self.conv_mask(x)
|
||||
unary_mask = unary_mask.view(n, 1, -1)
|
||||
unary_mask = unary_mask.softmax(dim=-1)
|
||||
# unary_x: [N, 1, C]
|
||||
unary_x = torch.matmul(unary_mask, g_x)
|
||||
# unary_x: [N, C, 1, 1]
|
||||
unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(
|
||||
n, self.inter_channels, 1, 1)
|
||||
|
||||
output = x + self.conv_out(y + unary_x)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DNLHead(FCNHead):
|
||||
"""Disentangled Non-Local Neural Networks.
|
||||
|
||||
This head is the implementation of `DNLNet
|
||||
<https://arxiv.org/abs/2006.06668>`_.
|
||||
|
||||
Args:
|
||||
reduction (int): Reduction factor of projection transform. Default: 2.
|
||||
use_scale (bool): Whether to scale pairwise_weight by
|
||||
sqrt(1/inter_channels). Default: False.
|
||||
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
|
||||
'dot_product'. Default: 'embedded_gaussian.'.
|
||||
temperature (float): Temperature to adjust attention. Default: 0.05
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reduction=2,
|
||||
use_scale=True,
|
||||
mode='embedded_gaussian',
|
||||
temperature=0.05,
|
||||
**kwargs):
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.reduction = reduction
|
||||
self.use_scale = use_scale
|
||||
self.mode = mode
|
||||
self.temperature = temperature
|
||||
self.dnl_block = DisentangledNonLocal2d(
|
||||
in_channels=self.channels,
|
||||
reduction=self.reduction,
|
||||
use_scale=self.use_scale,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
mode=self.mode,
|
||||
temperature=self.temperature)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
output = self.dnl_block(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
294
finetune/mmseg/models/decode_heads/dpt_head.py
Normal file
294
finetune/mmseg/models/decode_heads/dpt_head.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, Linear, build_activation_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class ReassembleBlocks(BaseModule):
|
||||
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
|
||||
rearrange the feature vector to feature map.
|
||||
|
||||
Args:
|
||||
in_channels (int): ViT feature channels. Default: 768.
|
||||
out_channels (List): output channels of each stage.
|
||||
Default: [96, 192, 384, 768].
|
||||
readout_type (str): Type of readout operation. Default: 'ignore'.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=768,
|
||||
out_channels=[96, 192, 384, 768],
|
||||
readout_type='ignore',
|
||||
patch_size=16,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
assert readout_type in ['ignore', 'add', 'project']
|
||||
self.readout_type = readout_type
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.projects = nn.ModuleList([
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
act_cfg=None,
|
||||
) for out_channel in out_channels
|
||||
])
|
||||
|
||||
self.resize_layers = nn.ModuleList([
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[0],
|
||||
out_channels=out_channels[0],
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[1],
|
||||
out_channels=out_channels[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0),
|
||||
nn.Identity(),
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels[3],
|
||||
out_channels=out_channels[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1)
|
||||
])
|
||||
if self.readout_type == 'project':
|
||||
self.readout_projects = nn.ModuleList()
|
||||
for _ in range(len(self.projects)):
|
||||
self.readout_projects.append(
|
||||
nn.Sequential(
|
||||
Linear(2 * in_channels, in_channels),
|
||||
build_activation_layer(dict(type='GELU'))))
|
||||
|
||||
def forward(self, inputs):
|
||||
assert isinstance(inputs, list)
|
||||
out = []
|
||||
for i, x in enumerate(inputs):
|
||||
assert len(x) == 2
|
||||
x, cls_token = x[0], x[1]
|
||||
feature_shape = x.shape
|
||||
if self.readout_type == 'project':
|
||||
x = x.flatten(2).permute((0, 2, 1))
|
||||
readout = cls_token.unsqueeze(1).expand_as(x)
|
||||
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
||||
x = x.permute(0, 2, 1).reshape(feature_shape)
|
||||
elif self.readout_type == 'add':
|
||||
x = x.flatten(2) + cls_token.unsqueeze(-1)
|
||||
x = x.reshape(feature_shape)
|
||||
else:
|
||||
pass
|
||||
x = self.projects[i](x)
|
||||
x = self.resize_layers[i](x)
|
||||
out.append(x)
|
||||
return out
|
||||
|
||||
|
||||
class PreActResidualConvUnit(BaseModule):
|
||||
"""ResidualConvUnit, pre-activate residual unit.
|
||||
|
||||
Args:
|
||||
in_channels (int): number of channels in the input feature map.
|
||||
act_cfg (dict): dictionary to construct and config activation layer.
|
||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||
stride (int): stride of the first block. Default: 1
|
||||
dilation (int): dilation rate for convs layers. Default: 1.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
act_cfg,
|
||||
norm_cfg,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.conv1 = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
bias=False,
|
||||
order=('act', 'conv', 'norm'))
|
||||
|
||||
self.conv2 = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
bias=False,
|
||||
order=('act', 'conv', 'norm'))
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs_ = inputs.clone()
|
||||
x = self.conv1(inputs)
|
||||
x = self.conv2(x)
|
||||
return x + inputs_
|
||||
|
||||
|
||||
class FeatureFusionBlock(BaseModule):
|
||||
"""FeatureFusionBlock, merge feature map from different stages.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
act_cfg (dict): The activation config for ResidualConvUnit.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
expand (bool): Whether expand the channels in post process block.
|
||||
Default: False.
|
||||
align_corners (bool): align_corner setting for bilinear upsample.
|
||||
Default: True.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
act_cfg,
|
||||
norm_cfg,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.expand = expand
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.out_channels = in_channels
|
||||
if self.expand:
|
||||
self.out_channels = in_channels // 2
|
||||
|
||||
self.project = ConvModule(
|
||||
self.in_channels,
|
||||
self.out_channels,
|
||||
kernel_size=1,
|
||||
act_cfg=None,
|
||||
bias=True)
|
||||
|
||||
self.res_conv_unit1 = PreActResidualConvUnit(
|
||||
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
||||
self.res_conv_unit2 = PreActResidualConvUnit(
|
||||
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
||||
|
||||
def forward(self, *inputs):
|
||||
x = inputs[0]
|
||||
if len(inputs) == 2:
|
||||
if x.shape != inputs[1].shape:
|
||||
res = resize(
|
||||
inputs[1],
|
||||
size=(x.shape[2], x.shape[3]),
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
else:
|
||||
res = inputs[1]
|
||||
x = x + self.res_conv_unit1(res)
|
||||
x = self.res_conv_unit2(x)
|
||||
x = resize(
|
||||
x,
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.project(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DPTHead(BaseDecodeHead):
|
||||
"""Vision Transformers for Dense Prediction.
|
||||
|
||||
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The embed dimension of the ViT backbone.
|
||||
Default: 768.
|
||||
post_process_channels (List): Out channels of post process conv
|
||||
layers. Default: [96, 192, 384, 768].
|
||||
readout_type (str): Type of readout operation. Default: 'ignore'.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
expand_channels (bool): Whether expand the channels in post process
|
||||
block. Default: False.
|
||||
act_cfg (dict): The activation config for residual conv unit.
|
||||
Default dict(type='ReLU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims=768,
|
||||
post_process_channels=[96, 192, 384, 768],
|
||||
readout_type='ignore',
|
||||
patch_size=16,
|
||||
expand_channels=False,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
norm_cfg=dict(type='BN'),
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.in_channels = self.in_channels
|
||||
self.expand_channels = expand_channels
|
||||
self.reassemble_blocks = ReassembleBlocks(embed_dims,
|
||||
post_process_channels,
|
||||
readout_type, patch_size)
|
||||
|
||||
self.post_process_channels = [
|
||||
channel * math.pow(2, i) if expand_channels else channel
|
||||
for i, channel in enumerate(post_process_channels)
|
||||
]
|
||||
self.convs = nn.ModuleList()
|
||||
for channel in self.post_process_channels:
|
||||
self.convs.append(
|
||||
ConvModule(
|
||||
channel,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
act_cfg=None,
|
||||
bias=False))
|
||||
self.fusion_blocks = nn.ModuleList()
|
||||
for _ in range(len(self.convs)):
|
||||
self.fusion_blocks.append(
|
||||
FeatureFusionBlock(self.channels, act_cfg, norm_cfg))
|
||||
self.fusion_blocks[0].res_conv_unit1 = None
|
||||
self.project = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg)
|
||||
self.num_fusion_blocks = len(self.fusion_blocks)
|
||||
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
|
||||
self.num_post_process_channels = len(self.post_process_channels)
|
||||
assert self.num_fusion_blocks == self.num_reassemble_blocks
|
||||
assert self.num_reassemble_blocks == self.num_post_process_channels
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == self.num_reassemble_blocks
|
||||
x = self._transform_inputs(inputs)
|
||||
x = self.reassemble_blocks(x)
|
||||
x = [self.convs[i](feature) for i, feature in enumerate(x)]
|
||||
out = self.fusion_blocks[0](x[-1])
|
||||
for i in range(1, len(self.fusion_blocks)):
|
||||
out = self.fusion_blocks[i](out, x[-(i + 1)])
|
||||
out = self.project(out)
|
||||
out = self.cls_seg(out)
|
||||
return out
|
||||
169
finetune/mmseg/models/decode_heads/ema_head.py
Normal file
169
finetune/mmseg/models/decode_heads/ema_head.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
def reduce_mean(tensor):
|
||||
"""Reduce mean when distributed training."""
|
||||
if not (dist.is_available() and dist.is_initialized()):
|
||||
return tensor
|
||||
tensor = tensor.clone()
|
||||
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
|
||||
return tensor
|
||||
|
||||
|
||||
class EMAModule(nn.Module):
|
||||
"""Expectation Maximization Attention Module used in EMANet.
|
||||
|
||||
Args:
|
||||
channels (int): Channels of the whole module.
|
||||
num_bases (int): Number of bases.
|
||||
num_stages (int): Number of the EM iterations.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, num_bases, num_stages, momentum):
|
||||
super().__init__()
|
||||
assert num_stages >= 1, 'num_stages must be at least 1!'
|
||||
self.num_bases = num_bases
|
||||
self.num_stages = num_stages
|
||||
self.momentum = momentum
|
||||
|
||||
bases = torch.zeros(1, channels, self.num_bases)
|
||||
bases.normal_(0, math.sqrt(2. / self.num_bases))
|
||||
# [1, channels, num_bases]
|
||||
bases = F.normalize(bases, dim=1, p=2)
|
||||
self.register_buffer('bases', bases)
|
||||
|
||||
def forward(self, feats):
|
||||
"""Forward function."""
|
||||
batch_size, channels, height, width = feats.size()
|
||||
# [batch_size, channels, height*width]
|
||||
feats = feats.view(batch_size, channels, height * width)
|
||||
# [batch_size, channels, num_bases]
|
||||
bases = self.bases.repeat(batch_size, 1, 1)
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(self.num_stages):
|
||||
# [batch_size, height*width, num_bases]
|
||||
attention = torch.einsum('bcn,bck->bnk', feats, bases)
|
||||
attention = F.softmax(attention, dim=2)
|
||||
# l1 norm
|
||||
attention_normed = F.normalize(attention, dim=1, p=1)
|
||||
# [batch_size, channels, num_bases]
|
||||
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
|
||||
# l2 norm
|
||||
bases = F.normalize(bases, dim=1, p=2)
|
||||
|
||||
feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
|
||||
feats_recon = feats_recon.view(batch_size, channels, height, width)
|
||||
|
||||
if self.training:
|
||||
bases = bases.mean(dim=0, keepdim=True)
|
||||
bases = reduce_mean(bases)
|
||||
# l2 norm
|
||||
bases = F.normalize(bases, dim=1, p=2)
|
||||
self.bases = (1 -
|
||||
self.momentum) * self.bases + self.momentum * bases
|
||||
|
||||
return feats_recon
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class EMAHead(BaseDecodeHead):
|
||||
"""Expectation Maximization Attention Networks for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `EMANet
|
||||
<https://arxiv.org/abs/1907.13426>`_.
|
||||
|
||||
Args:
|
||||
ema_channels (int): EMA module channels
|
||||
num_bases (int): Number of bases.
|
||||
num_stages (int): Number of the EM iterations.
|
||||
concat_input (bool): Whether concat the input and output of convs
|
||||
before classification layer. Default: True
|
||||
momentum (float): Momentum to update the base. Default: 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ema_channels,
|
||||
num_bases,
|
||||
num_stages,
|
||||
concat_input=True,
|
||||
momentum=0.1,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.ema_channels = ema_channels
|
||||
self.num_bases = num_bases
|
||||
self.num_stages = num_stages
|
||||
self.concat_input = concat_input
|
||||
self.momentum = momentum
|
||||
self.ema_module = EMAModule(self.ema_channels, self.num_bases,
|
||||
self.num_stages, self.momentum)
|
||||
|
||||
self.ema_in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.ema_channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
# project (0, inf) -> (-inf, inf)
|
||||
self.ema_mid_conv = ConvModule(
|
||||
self.ema_channels,
|
||||
self.ema_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=None,
|
||||
act_cfg=None)
|
||||
for param in self.ema_mid_conv.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.ema_out_conv = ConvModule(
|
||||
self.ema_channels,
|
||||
self.ema_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
self.bottleneck = ConvModule(
|
||||
self.ema_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
if self.concat_input:
|
||||
self.conv_cat = ConvModule(
|
||||
self.in_channels + self.channels,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
feats = self.ema_in_conv(x)
|
||||
identity = feats
|
||||
feats = self.ema_mid_conv(feats)
|
||||
recon = self.ema_module(feats)
|
||||
recon = F.relu(recon, inplace=True)
|
||||
recon = self.ema_out_conv(recon)
|
||||
output = F.relu(identity + recon, inplace=True)
|
||||
output = self.bottleneck(output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
196
finetune/mmseg/models/decode_heads/enc_head.py
Normal file
196
finetune/mmseg/models/decode_heads/enc_head.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
from ..utils import Encoding, resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class EncModule(nn.Module):
|
||||
"""Encoding Module used in EncNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
num_codes (int): Number of code words.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.encoding_project = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
# TODO: resolve this hack
|
||||
# change to 1d
|
||||
if norm_cfg is not None:
|
||||
encoding_norm_cfg = norm_cfg.copy()
|
||||
if encoding_norm_cfg['type'] in ['BN', 'IN']:
|
||||
encoding_norm_cfg['type'] += '1d'
|
||||
else:
|
||||
encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(
|
||||
'2d', '1d')
|
||||
else:
|
||||
# fallback to BN1d
|
||||
encoding_norm_cfg = dict(type='BN1d')
|
||||
self.encoding = nn.Sequential(
|
||||
Encoding(channels=in_channels, num_codes=num_codes),
|
||||
build_norm_layer(encoding_norm_cfg, num_codes)[1],
|
||||
nn.ReLU(inplace=True))
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_channels, in_channels), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
encoding_projection = self.encoding_project(x)
|
||||
encoding_feat = self.encoding(encoding_projection).mean(dim=1)
|
||||
batch_size, channels, _, _ = x.size()
|
||||
gamma = self.fc(encoding_feat)
|
||||
y = gamma.view(batch_size, channels, 1, 1)
|
||||
output = F.relu_(x + x * y)
|
||||
return encoding_feat, output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class EncHead(BaseDecodeHead):
|
||||
"""Context Encoding for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `EncNet
|
||||
<https://arxiv.org/abs/1803.08904>`_.
|
||||
|
||||
Args:
|
||||
num_codes (int): Number of code words. Default: 32.
|
||||
use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to
|
||||
regularize the training. Default: True.
|
||||
add_lateral (bool): Whether use lateral connection to fuse features.
|
||||
Default: False.
|
||||
loss_se_decode (dict): Config of decode loss.
|
||||
Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_codes=32,
|
||||
use_se_loss=True,
|
||||
add_lateral=False,
|
||||
loss_se_decode=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
loss_weight=0.2),
|
||||
**kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
self.use_se_loss = use_se_loss
|
||||
self.add_lateral = add_lateral
|
||||
self.num_codes = num_codes
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels[-1],
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
if add_lateral:
|
||||
self.lateral_convs = nn.ModuleList()
|
||||
for in_channels in self.in_channels[:-1]: # skip the last one
|
||||
self.lateral_convs.append(
|
||||
ConvModule(
|
||||
in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.fusion = ConvModule(
|
||||
len(self.in_channels) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.enc_module = EncModule(
|
||||
self.channels,
|
||||
num_codes=num_codes,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
if self.use_se_loss:
|
||||
self.loss_se_decode = MODELS.build(loss_se_decode)
|
||||
self.se_layer = nn.Linear(self.channels, self.num_classes)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
inputs = self._transform_inputs(inputs)
|
||||
feat = self.bottleneck(inputs[-1])
|
||||
if self.add_lateral:
|
||||
laterals = [
|
||||
resize(
|
||||
lateral_conv(inputs[i]),
|
||||
size=feat.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||
]
|
||||
feat = self.fusion(torch.cat([feat, *laterals], 1))
|
||||
encode_feat, output = self.enc_module(feat)
|
||||
output = self.cls_seg(output)
|
||||
if self.use_se_loss:
|
||||
se_output = self.se_layer(encode_feat)
|
||||
return output, se_output
|
||||
else:
|
||||
return output
|
||||
|
||||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType):
|
||||
"""Forward function for testing, ignore se_loss."""
|
||||
if self.use_se_loss:
|
||||
seg_logits = self.forward(inputs)[0]
|
||||
else:
|
||||
seg_logits = self.forward(inputs)
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas)
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_onehot_labels(seg_label, num_classes):
|
||||
"""Convert segmentation label to onehot.
|
||||
|
||||
Args:
|
||||
seg_label (Tensor): Segmentation label of shape (N, H, W).
|
||||
num_classes (int): Number of classes.
|
||||
|
||||
Returns:
|
||||
Tensor: Onehot labels of shape (N, num_classes).
|
||||
"""
|
||||
|
||||
batch_size = seg_label.size(0)
|
||||
onehot_labels = seg_label.new_zeros((batch_size, num_classes))
|
||||
for i in range(batch_size):
|
||||
hist = seg_label[i].float().histc(
|
||||
bins=num_classes, min=0, max=num_classes - 1)
|
||||
onehot_labels[i] = hist > 0
|
||||
return onehot_labels
|
||||
|
||||
def loss_by_feat(self, seg_logit: Tuple[Tensor],
|
||||
batch_data_samples: SampleList, **kwargs) -> dict:
|
||||
"""Compute segmentation and semantic encoding loss."""
|
||||
seg_logit, se_seg_logit = seg_logit
|
||||
loss = dict()
|
||||
loss.update(super().loss_by_feat(seg_logit, batch_data_samples))
|
||||
|
||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||
se_loss = self.loss_se_decode(
|
||||
se_seg_logit,
|
||||
self._convert_to_onehot_labels(seg_label, self.num_classes))
|
||||
loss['loss_se'] = se_loss
|
||||
return loss
|
||||
96
finetune/mmseg/models/decode_heads/fcn_head.py
Normal file
96
finetune/mmseg/models/decode_heads/fcn_head.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FCNHead(BaseDecodeHead):
|
||||
"""Fully Convolution Networks for Semantic Segmentation.
|
||||
|
||||
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
|
||||
|
||||
Args:
|
||||
num_convs (int): Number of convs in the head. Default: 2.
|
||||
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
||||
concat_input (bool): Whether concat the input and output of convs
|
||||
before classification layer.
|
||||
dilation (int): The dilation rate for convs in the head. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_convs=2,
|
||||
kernel_size=3,
|
||||
concat_input=True,
|
||||
dilation=1,
|
||||
**kwargs):
|
||||
assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int)
|
||||
self.num_convs = num_convs
|
||||
self.concat_input = concat_input
|
||||
self.kernel_size = kernel_size
|
||||
super().__init__(**kwargs)
|
||||
if num_convs == 0:
|
||||
assert self.in_channels == self.channels
|
||||
|
||||
conv_padding = (kernel_size // 2) * dilation
|
||||
convs = []
|
||||
convs.append(
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=conv_padding,
|
||||
dilation=dilation,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
for i in range(num_convs - 1):
|
||||
convs.append(
|
||||
ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=conv_padding,
|
||||
dilation=dilation,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
if num_convs == 0:
|
||||
self.convs = nn.Identity()
|
||||
else:
|
||||
self.convs = nn.Sequential(*convs)
|
||||
if self.concat_input:
|
||||
self.conv_cat = ConvModule(
|
||||
self.in_channels + self.channels,
|
||||
self.channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def _forward_feature(self, inputs):
|
||||
"""Forward function for feature maps before classifying each pixel with
|
||||
``self.cls_seg`` fc.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||
H, W) which is feature map for last layer of decoder head.
|
||||
"""
|
||||
x = self._transform_inputs(inputs)
|
||||
feats = self.convs(x)
|
||||
if self.concat_input:
|
||||
feats = self.conv_cat(torch.cat([x, feats], dim=1))
|
||||
return feats
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
output = self._forward_feature(inputs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
68
finetune/mmseg/models/decode_heads/fpn_head.py
Normal file
68
finetune/mmseg/models/decode_heads/fpn_head.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample, resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FPNHead(BaseDecodeHead):
|
||||
"""Panoptic Feature Pyramid Networks.
|
||||
|
||||
This head is the implementation of `Semantic FPN
|
||||
<https://arxiv.org/abs/1901.02446>`_.
|
||||
|
||||
Args:
|
||||
feature_strides (tuple[int]): The strides for input feature maps.
|
||||
stack_lateral. All strides suppose to be power of 2. The first
|
||||
one is of largest resolution.
|
||||
"""
|
||||
|
||||
def __init__(self, feature_strides, **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
assert len(feature_strides) == len(self.in_channels)
|
||||
assert min(feature_strides) == feature_strides[0]
|
||||
self.feature_strides = feature_strides
|
||||
|
||||
self.scale_heads = nn.ModuleList()
|
||||
for i in range(len(feature_strides)):
|
||||
head_length = max(
|
||||
1,
|
||||
int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
|
||||
scale_head = []
|
||||
for k in range(head_length):
|
||||
scale_head.append(
|
||||
ConvModule(
|
||||
self.in_channels[i] if k == 0 else self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
if feature_strides[i] != feature_strides[0]:
|
||||
scale_head.append(
|
||||
Upsample(
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners))
|
||||
self.scale_heads.append(nn.Sequential(*scale_head))
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
x = self._transform_inputs(inputs)
|
||||
|
||||
output = self.scale_heads[0](x[0])
|
||||
for i in range(1, len(self.feature_strides)):
|
||||
# non inplace
|
||||
output = output + resize(
|
||||
self.scale_heads[i](x[i]),
|
||||
size=output.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
48
finetune/mmseg/models/decode_heads/gc_head.py
Normal file
48
finetune/mmseg/models/decode_heads/gc_head.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.cnn import ContextBlock
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class GCHead(FCNHead):
|
||||
"""GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.
|
||||
|
||||
This head is the implementation of `GCNet
|
||||
<https://arxiv.org/abs/1904.11492>`_.
|
||||
|
||||
Args:
|
||||
ratio (float): Multiplier of channels ratio. Default: 1/4.
|
||||
pooling_type (str): The pooling type of context aggregation.
|
||||
Options are 'att', 'avg'. Default: 'avg'.
|
||||
fusion_types (tuple[str]): The fusion type for feature fusion.
|
||||
Options are 'channel_add', 'channel_mul'. Default: ('channel_add',)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ratio=1 / 4.,
|
||||
pooling_type='att',
|
||||
fusion_types=('channel_add', ),
|
||||
**kwargs):
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.ratio = ratio
|
||||
self.pooling_type = pooling_type
|
||||
self.fusion_types = fusion_types
|
||||
self.gc_block = ContextBlock(
|
||||
in_channels=self.channels,
|
||||
ratio=self.ratio,
|
||||
pooling_type=self.pooling_type,
|
||||
fusion_types=self.fusion_types)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
output = self.gc_block(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
255
finetune/mmseg/models/decode_heads/ham_head.py
Normal file
255
finetune/mmseg/models/decode_heads/ham_head.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Originally from https://github.com/visual-attention-network/segnext
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.device import get_device
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class Matrix_Decomposition_2D_Base(nn.Module):
|
||||
"""Base class of 2D Matrix Decomposition.
|
||||
|
||||
Args:
|
||||
MD_S (int): The number of spatial coefficient in
|
||||
Matrix Decomposition, it may be used for calculation
|
||||
of the number of latent dimension D in Matrix
|
||||
Decomposition. Defaults: 1.
|
||||
MD_R (int): The number of latent dimension R in
|
||||
Matrix Decomposition. Defaults: 64.
|
||||
train_steps (int): The number of iteration steps in
|
||||
Multiplicative Update (MU) rule to solve Non-negative
|
||||
Matrix Factorization (NMF) in training. Defaults: 6.
|
||||
eval_steps (int): The number of iteration steps in
|
||||
Multiplicative Update (MU) rule to solve Non-negative
|
||||
Matrix Factorization (NMF) in evaluation. Defaults: 7.
|
||||
inv_t (int): Inverted multiple number to make coefficient
|
||||
smaller in softmax. Defaults: 100.
|
||||
rand_init (bool): Whether to initialize randomly.
|
||||
Defaults: True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
MD_S=1,
|
||||
MD_R=64,
|
||||
train_steps=6,
|
||||
eval_steps=7,
|
||||
inv_t=100,
|
||||
rand_init=True):
|
||||
super().__init__()
|
||||
|
||||
self.S = MD_S
|
||||
self.R = MD_R
|
||||
|
||||
self.train_steps = train_steps
|
||||
self.eval_steps = eval_steps
|
||||
|
||||
self.inv_t = inv_t
|
||||
|
||||
self.rand_init = rand_init
|
||||
|
||||
def _build_bases(self, B, S, D, R, device=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def local_step(self, x, bases, coef):
|
||||
raise NotImplementedError
|
||||
|
||||
def local_inference(self, x, bases):
|
||||
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
coef = torch.bmm(x.transpose(1, 2), bases)
|
||||
coef = F.softmax(self.inv_t * coef, dim=-1)
|
||||
|
||||
steps = self.train_steps if self.training else self.eval_steps
|
||||
for _ in range(steps):
|
||||
bases, coef = self.local_step(x, bases, coef)
|
||||
|
||||
return bases, coef
|
||||
|
||||
def compute_coef(self, x, bases, coef):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x, return_bases=False):
|
||||
"""Forward Function."""
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# (B, C, H, W) -> (B * S, D, N)
|
||||
D = C // self.S
|
||||
N = H * W
|
||||
x = x.view(B * self.S, D, N)
|
||||
if not self.rand_init and not hasattr(self, 'bases'):
|
||||
bases = self._build_bases(1, self.S, D, self.R, device=x.device)
|
||||
self.register_buffer('bases', bases)
|
||||
|
||||
# (S, D, R) -> (B * S, D, R)
|
||||
if self.rand_init:
|
||||
bases = self._build_bases(B, self.S, D, self.R, device=x.device)
|
||||
else:
|
||||
bases = self.bases.repeat(B, 1, 1)
|
||||
|
||||
bases, coef = self.local_inference(x, bases)
|
||||
|
||||
# (B * S, N, R)
|
||||
coef = self.compute_coef(x, bases, coef)
|
||||
|
||||
# (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
|
||||
x = torch.bmm(bases, coef.transpose(1, 2))
|
||||
|
||||
# (B * S, D, N) -> (B, C, H, W)
|
||||
x = x.view(B, C, H, W)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class NMF2D(Matrix_Decomposition_2D_Base):
|
||||
"""Non-negative Matrix Factorization (NMF) module.
|
||||
|
||||
It is inherited from ``Matrix_Decomposition_2D_Base`` module.
|
||||
"""
|
||||
|
||||
def __init__(self, args=dict()):
|
||||
super().__init__(**args)
|
||||
|
||||
self.inv_t = 1
|
||||
|
||||
def _build_bases(self, B, S, D, R, device=None):
|
||||
"""Build bases in initialization."""
|
||||
if device is None:
|
||||
device = get_device()
|
||||
bases = torch.rand((B * S, D, R)).to(device)
|
||||
bases = F.normalize(bases, dim=1)
|
||||
|
||||
return bases
|
||||
|
||||
def local_step(self, x, bases, coef):
|
||||
"""Local step in iteration to renew bases and coefficient."""
|
||||
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
numerator = torch.bmm(x.transpose(1, 2), bases)
|
||||
# (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
|
||||
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
|
||||
# Multiplicative Update
|
||||
coef = coef * numerator / (denominator + 1e-6)
|
||||
|
||||
# (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
|
||||
numerator = torch.bmm(x, coef)
|
||||
# (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
|
||||
denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
|
||||
# Multiplicative Update
|
||||
bases = bases * numerator / (denominator + 1e-6)
|
||||
|
||||
return bases, coef
|
||||
|
||||
def compute_coef(self, x, bases, coef):
|
||||
"""Compute coefficient."""
|
||||
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
numerator = torch.bmm(x.transpose(1, 2), bases)
|
||||
# (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
|
||||
# multiplication update
|
||||
coef = coef * numerator / (denominator + 1e-6)
|
||||
|
||||
return coef
|
||||
|
||||
|
||||
class Hamburger(nn.Module):
|
||||
"""Hamburger Module. It consists of one slice of "ham" (matrix
|
||||
decomposition) and two slices of "bread" (linear transformation).
|
||||
|
||||
Args:
|
||||
ham_channels (int): Input and output channels of feature.
|
||||
ham_kwargs (dict): Config of matrix decomposition module.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ham_channels=512,
|
||||
ham_kwargs=dict(),
|
||||
norm_cfg=None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.ham_in = ConvModule(
|
||||
ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None)
|
||||
|
||||
self.ham = NMF2D(ham_kwargs)
|
||||
|
||||
self.ham_out = ConvModule(
|
||||
ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
|
||||
def forward(self, x):
|
||||
enjoy = self.ham_in(x)
|
||||
enjoy = F.relu(enjoy, inplace=True)
|
||||
enjoy = self.ham(enjoy)
|
||||
enjoy = self.ham_out(enjoy)
|
||||
ham = F.relu(x + enjoy, inplace=True)
|
||||
|
||||
return ham
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class LightHamHead(BaseDecodeHead):
|
||||
"""SegNeXt decode head.
|
||||
|
||||
This decode head is the implementation of `SegNeXt: Rethinking
|
||||
Convolutional Attention Design for Semantic
|
||||
Segmentation <https://arxiv.org/abs/2209.08575>`_.
|
||||
Inspiration from https://github.com/visual-attention-network/segnext.
|
||||
|
||||
Specifically, LightHamHead is inspired by HamNet from
|
||||
`Is Attention Better Than Matrix Decomposition?
|
||||
<https://arxiv.org/abs/2109.04553>`.
|
||||
|
||||
Args:
|
||||
ham_channels (int): input channels for Hamburger.
|
||||
Defaults: 512.
|
||||
ham_kwargs (int): kwagrs for Ham. Defaults: dict().
|
||||
"""
|
||||
|
||||
def __init__(self, ham_channels=512, ham_kwargs=dict(), **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
self.ham_channels = ham_channels
|
||||
|
||||
self.squeeze = ConvModule(
|
||||
sum(self.in_channels),
|
||||
self.ham_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.hamburger = Hamburger(ham_channels, ham_kwargs, **kwargs)
|
||||
|
||||
self.align = ConvModule(
|
||||
self.ham_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
inputs = self._transform_inputs(inputs)
|
||||
|
||||
inputs = [
|
||||
resize(
|
||||
level,
|
||||
size=inputs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners) for level in inputs
|
||||
]
|
||||
|
||||
inputs = torch.cat(inputs, dim=1)
|
||||
# apply a conv block to squeeze feature map
|
||||
x = self.squeeze(inputs)
|
||||
# apply hamburger module
|
||||
x = self.hamburger(x)
|
||||
|
||||
# apply a conv block to align feature map
|
||||
output = self.align(x)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
143
finetune/mmseg/models/decode_heads/isa_head.py
Normal file
143
finetune/mmseg/models/decode_heads/isa_head.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class SelfAttentionBlock(_SelfAttentionBlock):
|
||||
"""Self-Attention Module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of key/query feature.
|
||||
channels (int): Output channels of key/query transform.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
act_cfg (dict | None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, channels, conv_cfg, norm_cfg, act_cfg):
|
||||
super().__init__(
|
||||
key_in_channels=in_channels,
|
||||
query_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=in_channels,
|
||||
share_key_query=False,
|
||||
query_downsample=None,
|
||||
key_downsample=None,
|
||||
key_query_num_convs=2,
|
||||
key_query_norm=True,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=False,
|
||||
matmul_norm=True,
|
||||
with_out=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.output_project = self.build_project(
|
||||
in_channels,
|
||||
in_channels,
|
||||
num_convs=1,
|
||||
use_conv_module=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
context = super().forward(x, x)
|
||||
return self.output_project(context)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ISAHead(BaseDecodeHead):
|
||||
"""Interlaced Sparse Self-Attention for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `ISA
|
||||
<https://arxiv.org/abs/1907.12273>`_.
|
||||
|
||||
Args:
|
||||
isa_channels (int): The channels of ISA Module.
|
||||
down_factor (tuple[int]): The local group size of ISA.
|
||||
"""
|
||||
|
||||
def __init__(self, isa_channels, down_factor=(8, 8), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.down_factor = down_factor
|
||||
|
||||
self.in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.global_relation = SelfAttentionBlock(
|
||||
self.channels,
|
||||
isa_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.local_relation = SelfAttentionBlock(
|
||||
self.channels,
|
||||
isa_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.out_conv = ConvModule(
|
||||
self.channels * 2,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x_ = self._transform_inputs(inputs)
|
||||
x = self.in_conv(x_)
|
||||
residual = x
|
||||
|
||||
n, c, h, w = x.size()
|
||||
loc_h, loc_w = self.down_factor # size of local group in H- and W-axes
|
||||
glb_h, glb_w = math.ceil(h / loc_h), math.ceil(w / loc_w)
|
||||
pad_h, pad_w = glb_h * loc_h - h, glb_w * loc_w - w
|
||||
if pad_h > 0 or pad_w > 0: # pad if the size is not divisible
|
||||
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
|
||||
pad_h - pad_h // 2)
|
||||
x = F.pad(x, padding)
|
||||
|
||||
# global relation
|
||||
x = x.view(n, c, glb_h, loc_h, glb_w, loc_w)
|
||||
# do permutation to gather global group
|
||||
x = x.permute(0, 3, 5, 1, 2, 4) # (n, loc_h, loc_w, c, glb_h, glb_w)
|
||||
x = x.reshape(-1, c, glb_h, glb_w)
|
||||
# apply attention within each global group
|
||||
x = self.global_relation(x) # (n * loc_h * loc_w, c, glb_h, glb_w)
|
||||
|
||||
# local relation
|
||||
x = x.view(n, loc_h, loc_w, c, glb_h, glb_w)
|
||||
# do permutation to gather local group
|
||||
x = x.permute(0, 4, 5, 3, 1, 2) # (n, glb_h, glb_w, c, loc_h, loc_w)
|
||||
x = x.reshape(-1, c, loc_h, loc_w)
|
||||
# apply attention within each local group
|
||||
x = self.local_relation(x) # (n * glb_h * glb_w, c, loc_h, loc_w)
|
||||
|
||||
# permute each pixel back to its original position
|
||||
x = x.view(n, glb_h, glb_w, c, loc_h, loc_w)
|
||||
x = x.permute(0, 3, 1, 4, 2, 5) # (n, c, glb_h, loc_h, glb_w, loc_w)
|
||||
x = x.reshape(n, c, glb_h * loc_h, glb_w * loc_w)
|
||||
if pad_h > 0 or pad_w > 0: # remove padding
|
||||
x = x[:, :, pad_h // 2:pad_h // 2 + h, pad_w // 2:pad_w // 2 + w]
|
||||
|
||||
x = self.out_conv(torch.cat([x, residual], dim=1))
|
||||
out = self.cls_seg(x)
|
||||
|
||||
return out
|
||||
461
finetune/mmseg/models/decode_heads/knet_head.py
Normal file
461
finetune/mmseg/models/decode_heads/knet_head.py
Normal file
@@ -0,0 +1,461 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import (FFN, MultiheadAttention,
|
||||
build_transformer_layer)
|
||||
from mmengine.logging import print_log
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class KernelUpdator(nn.Module):
|
||||
"""Dynamic Kernel Updator in Kernel Update Head.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels of input feature map.
|
||||
Default: 256.
|
||||
feat_channels (int): The number of middle-stage channels in
|
||||
the kernel updator. Default: 64.
|
||||
out_channels (int): The number of output channels.
|
||||
gate_sigmoid (bool): Whether use sigmoid function in gate
|
||||
mechanism. Default: True.
|
||||
gate_norm_act (bool): Whether add normalization and activation
|
||||
layer in gate mechanism. Default: False.
|
||||
activate_out: Whether add activation after gate mechanism.
|
||||
Default: False.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='LN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=256,
|
||||
feat_channels=64,
|
||||
out_channels=None,
|
||||
gate_sigmoid=True,
|
||||
gate_norm_act=False,
|
||||
activate_out=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.feat_channels = feat_channels
|
||||
self.out_channels_raw = out_channels
|
||||
self.gate_sigmoid = gate_sigmoid
|
||||
self.gate_norm_act = gate_norm_act
|
||||
self.activate_out = activate_out
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.out_channels = out_channels if out_channels else in_channels
|
||||
|
||||
self.num_params_in = self.feat_channels
|
||||
self.num_params_out = self.feat_channels
|
||||
self.dynamic_layer = nn.Linear(
|
||||
self.in_channels, self.num_params_in + self.num_params_out)
|
||||
self.input_layer = nn.Linear(self.in_channels,
|
||||
self.num_params_in + self.num_params_out,
|
||||
1)
|
||||
self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
|
||||
self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
|
||||
if self.gate_norm_act:
|
||||
self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
|
||||
self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
|
||||
self.activation = build_activation_layer(act_cfg)
|
||||
|
||||
self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1)
|
||||
self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
|
||||
|
||||
def forward(self, update_feature, input_feature):
|
||||
"""Forward function of KernelUpdator.
|
||||
|
||||
Args:
|
||||
update_feature (torch.Tensor): Feature map assembled from
|
||||
each group. It would be reshaped with last dimension
|
||||
shape: `self.in_channels`.
|
||||
input_feature (torch.Tensor): Intermediate feature
|
||||
with shape: (N, num_classes, conv_kernel_size**2, channels).
|
||||
Returns:
|
||||
Tensor: The output tensor of shape (N*C1/C2, K*K, C2), where N is
|
||||
the number of classes, C1 and C2 are the feature map channels of
|
||||
KernelUpdateHead and KernelUpdator, respectively.
|
||||
"""
|
||||
|
||||
update_feature = update_feature.reshape(-1, self.in_channels)
|
||||
num_proposals = update_feature.size(0)
|
||||
# dynamic_layer works for
|
||||
# phi_1 and psi_3 in Eq.(4) and (5) of K-Net paper
|
||||
parameters = self.dynamic_layer(update_feature)
|
||||
param_in = parameters[:, :self.num_params_in].view(
|
||||
-1, self.feat_channels)
|
||||
param_out = parameters[:, -self.num_params_out:].view(
|
||||
-1, self.feat_channels)
|
||||
|
||||
# input_layer works for
|
||||
# phi_2 and psi_4 in Eq.(4) and (5) of K-Net paper
|
||||
input_feats = self.input_layer(
|
||||
input_feature.reshape(num_proposals, -1, self.feat_channels))
|
||||
input_in = input_feats[..., :self.num_params_in]
|
||||
input_out = input_feats[..., -self.num_params_out:]
|
||||
|
||||
# `gate_feats` is F^G in K-Net paper
|
||||
gate_feats = input_in * param_in.unsqueeze(-2)
|
||||
if self.gate_norm_act:
|
||||
gate_feats = self.activation(self.gate_norm(gate_feats))
|
||||
|
||||
input_gate = self.input_norm_in(self.input_gate(gate_feats))
|
||||
update_gate = self.norm_in(self.update_gate(gate_feats))
|
||||
if self.gate_sigmoid:
|
||||
input_gate = input_gate.sigmoid()
|
||||
update_gate = update_gate.sigmoid()
|
||||
param_out = self.norm_out(param_out)
|
||||
input_out = self.input_norm_out(input_out)
|
||||
|
||||
if self.activate_out:
|
||||
param_out = self.activation(param_out)
|
||||
input_out = self.activation(input_out)
|
||||
|
||||
# Gate mechanism. Eq.(5) in original paper.
|
||||
# param_out has shape (batch_size, feat_channels, out_channels)
|
||||
features = update_gate * param_out.unsqueeze(
|
||||
-2) + input_gate * input_out
|
||||
|
||||
features = self.fc_layer(features)
|
||||
features = self.fc_norm(features)
|
||||
features = self.activation(features)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class KernelUpdateHead(nn.Module):
|
||||
"""Kernel Update Head in K-Net.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes. Default: 150.
|
||||
num_ffn_fcs (int): The number of fully-connected layers in
|
||||
FFNs. Default: 2.
|
||||
num_heads (int): The number of parallel attention heads.
|
||||
Default: 8.
|
||||
num_mask_fcs (int): The number of fully connected layers for
|
||||
mask prediction. Default: 3.
|
||||
feedforward_channels (int): The hidden dimension of FFNs.
|
||||
Defaults: 2048.
|
||||
in_channels (int): The number of channels of input feature map.
|
||||
Default: 256.
|
||||
out_channels (int): The number of output channels.
|
||||
Default: 256.
|
||||
dropout (float): The Probability of an element to be
|
||||
zeroed in MultiheadAttention and FFN. Default 0.0.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
ffn_act_cfg (dict): Config of activation layers in FFN.
|
||||
Default: dict(type='ReLU').
|
||||
conv_kernel_size (int): The kernel size of convolution in
|
||||
Kernel Update Head for dynamic kernel updation.
|
||||
Default: 1.
|
||||
feat_transform_cfg (dict | None): Config of feature transform.
|
||||
Default: None.
|
||||
kernel_init (bool): Whether initiate mask kernel in mask head.
|
||||
Default: False.
|
||||
with_ffn (bool): Whether add FFN in kernel update head.
|
||||
Default: True.
|
||||
feat_gather_stride (int): Stride of convolution in feature transform.
|
||||
Default: 1.
|
||||
mask_transform_stride (int): Stride of mask transform.
|
||||
Default: 1.
|
||||
kernel_updator_cfg (dict): Config of kernel updator.
|
||||
Default: dict(
|
||||
type='DynamicConv',
|
||||
in_channels=256,
|
||||
feat_channels=64,
|
||||
out_channels=256,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
norm_cfg=dict(type='LN')).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes=150,
|
||||
num_ffn_fcs=2,
|
||||
num_heads=8,
|
||||
num_mask_fcs=3,
|
||||
feedforward_channels=2048,
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
dropout=0.0,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
||||
conv_kernel_size=1,
|
||||
feat_transform_cfg=None,
|
||||
kernel_init=False,
|
||||
with_ffn=True,
|
||||
feat_gather_stride=1,
|
||||
mask_transform_stride=1,
|
||||
kernel_updator_cfg=dict(
|
||||
type='DynamicConv',
|
||||
in_channels=256,
|
||||
feat_channels=64,
|
||||
out_channels=256,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
norm_cfg=dict(type='LN'))):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.fp16_enabled = False
|
||||
self.dropout = dropout
|
||||
self.num_heads = num_heads
|
||||
self.kernel_init = kernel_init
|
||||
self.with_ffn = with_ffn
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.feat_gather_stride = feat_gather_stride
|
||||
self.mask_transform_stride = mask_transform_stride
|
||||
|
||||
self.attention = MultiheadAttention(in_channels * conv_kernel_size**2,
|
||||
num_heads, dropout)
|
||||
self.attention_norm = build_norm_layer(
|
||||
dict(type='LN'), in_channels * conv_kernel_size**2)[1]
|
||||
self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg)
|
||||
|
||||
if feat_transform_cfg is not None:
|
||||
kernel_size = feat_transform_cfg.pop('kernel_size', 1)
|
||||
transform_channels = in_channels
|
||||
self.feat_transform = ConvModule(
|
||||
transform_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
stride=feat_gather_stride,
|
||||
padding=int(feat_gather_stride // 2),
|
||||
**feat_transform_cfg)
|
||||
else:
|
||||
self.feat_transform = None
|
||||
|
||||
if self.with_ffn:
|
||||
self.ffn = FFN(
|
||||
in_channels,
|
||||
feedforward_channels,
|
||||
num_ffn_fcs,
|
||||
act_cfg=ffn_act_cfg,
|
||||
dropout=dropout)
|
||||
self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
|
||||
|
||||
self.mask_fcs = nn.ModuleList()
|
||||
for _ in range(num_mask_fcs):
|
||||
self.mask_fcs.append(
|
||||
nn.Linear(in_channels, in_channels, bias=False))
|
||||
self.mask_fcs.append(
|
||||
build_norm_layer(dict(type='LN'), in_channels)[1])
|
||||
self.mask_fcs.append(build_activation_layer(act_cfg))
|
||||
|
||||
self.fc_mask = nn.Linear(in_channels, out_channels)
|
||||
|
||||
def init_weights(self):
|
||||
"""Use xavier initialization for all weight parameter and set
|
||||
classification head bias as a specific value when use focal loss."""
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
else:
|
||||
# adopt the default initialization for
|
||||
# the weight and bias of the layer norm
|
||||
pass
|
||||
if self.kernel_init:
|
||||
print_log(
|
||||
'mask kernel in mask head is normal initialized by std 0.01')
|
||||
nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01)
|
||||
|
||||
def forward(self, x, proposal_feat, mask_preds, mask_shape=None):
|
||||
"""Forward function of Dynamic Instance Interactive Head.
|
||||
|
||||
Args:
|
||||
x (Tensor): Feature map from FPN with shape
|
||||
(batch_size, feature_dimensions, H , W).
|
||||
proposal_feat (Tensor): Intermediate feature get from
|
||||
diihead in last stage, has shape
|
||||
(batch_size, num_proposals, feature_dimensions)
|
||||
mask_preds (Tensor): mask prediction from the former stage in shape
|
||||
(batch_size, num_proposals, H, W).
|
||||
|
||||
Returns:
|
||||
Tuple: The first tensor is predicted mask with shape
|
||||
(N, num_classes, H, W), the second tensor is dynamic kernel
|
||||
with shape (N, num_classes, channels, K, K).
|
||||
"""
|
||||
N, num_proposals = proposal_feat.shape[:2]
|
||||
if self.feat_transform is not None:
|
||||
x = self.feat_transform(x)
|
||||
|
||||
C, H, W = x.shape[-3:]
|
||||
|
||||
mask_h, mask_w = mask_preds.shape[-2:]
|
||||
if mask_h != H or mask_w != W:
|
||||
gather_mask = F.interpolate(
|
||||
mask_preds, (H, W), align_corners=False, mode='bilinear')
|
||||
else:
|
||||
gather_mask = mask_preds
|
||||
|
||||
sigmoid_masks = gather_mask.softmax(dim=1)
|
||||
|
||||
# Group Feature Assembling. Eq.(3) in original paper.
|
||||
# einsum is faster than bmm by 30%
|
||||
x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x)
|
||||
|
||||
# obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C]
|
||||
proposal_feat = proposal_feat.reshape(N, num_proposals,
|
||||
self.in_channels,
|
||||
-1).permute(0, 1, 3, 2)
|
||||
obj_feat = self.kernel_update_conv(x_feat, proposal_feat)
|
||||
|
||||
# [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C]
|
||||
obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2)
|
||||
obj_feat = self.attention_norm(self.attention(obj_feat))
|
||||
# [N, B, K*K*C] -> [B, N, K*K*C]
|
||||
obj_feat = obj_feat.permute(1, 0, 2)
|
||||
|
||||
# obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]
|
||||
obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels)
|
||||
|
||||
# FFN
|
||||
if self.with_ffn:
|
||||
obj_feat = self.ffn_norm(self.ffn(obj_feat))
|
||||
|
||||
mask_feat = obj_feat
|
||||
|
||||
for reg_layer in self.mask_fcs:
|
||||
mask_feat = reg_layer(mask_feat)
|
||||
|
||||
# [B, N, K*K, C] -> [B, N, C, K*K]
|
||||
mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2)
|
||||
|
||||
if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1):
|
||||
mask_x = F.interpolate(
|
||||
x, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||
H, W = mask_x.shape[-2:]
|
||||
else:
|
||||
mask_x = x
|
||||
# group conv is 5x faster than unfold and uses about 1/5 memory
|
||||
# Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms
|
||||
# Group conv vs. unfold vs. concat batch, 278 : 1420 : 369
|
||||
# but in real training group conv is slower than concat batch
|
||||
# so we keep using concat batch.
|
||||
# fold_x = F.unfold(
|
||||
# mask_x,
|
||||
# self.conv_kernel_size,
|
||||
# padding=int(self.conv_kernel_size // 2))
|
||||
# mask_feat = mask_feat.reshape(N, num_proposals, -1)
|
||||
# new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x)
|
||||
# [B, N, C, K*K] -> [B*N, C, K, K]
|
||||
mask_feat = mask_feat.reshape(N, num_proposals, C,
|
||||
self.conv_kernel_size,
|
||||
self.conv_kernel_size)
|
||||
# [B, C, H, W] -> [1, B*C, H, W]
|
||||
new_mask_preds = []
|
||||
for i in range(N):
|
||||
new_mask_preds.append(
|
||||
F.conv2d(
|
||||
mask_x[i:i + 1],
|
||||
mask_feat[i],
|
||||
padding=int(self.conv_kernel_size // 2)))
|
||||
|
||||
new_mask_preds = torch.cat(new_mask_preds, dim=0)
|
||||
new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W)
|
||||
if self.mask_transform_stride == 2:
|
||||
new_mask_preds = F.interpolate(
|
||||
new_mask_preds,
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
if mask_shape is not None and mask_shape[0] != H:
|
||||
new_mask_preds = F.interpolate(
|
||||
new_mask_preds,
|
||||
mask_shape,
|
||||
align_corners=False,
|
||||
mode='bilinear')
|
||||
|
||||
return new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape(
|
||||
N, num_proposals, self.in_channels, self.conv_kernel_size,
|
||||
self.conv_kernel_size)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class IterativeDecodeHead(BaseDecodeHead):
|
||||
"""K-Net: Towards Unified Image Segmentation.
|
||||
|
||||
This head is the implementation of
|
||||
`K-Net: <https://arxiv.org/abs/2106.14855>`_.
|
||||
|
||||
Args:
|
||||
num_stages (int): The number of stages (kernel update heads)
|
||||
in IterativeDecodeHead. Default: 3.
|
||||
kernel_generate_head:(dict): Config of kernel generate head which
|
||||
generate mask predictions, dynamic kernels and class predictions
|
||||
for next kernel update heads.
|
||||
kernel_update_head (dict): Config of kernel update head which refine
|
||||
dynamic kernels and class predictions iteratively.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_stages, kernel_generate_head, kernel_update_head,
|
||||
**kwargs):
|
||||
# ``IterativeDecodeHead`` would skip initialization of
|
||||
# ``BaseDecodeHead`` which would be called when building
|
||||
# ``self.kernel_generate_head``.
|
||||
super(BaseDecodeHead, self).__init__(**kwargs)
|
||||
assert num_stages == len(kernel_update_head)
|
||||
self.num_stages = num_stages
|
||||
self.kernel_generate_head = MODELS.build(kernel_generate_head)
|
||||
self.kernel_update_head = nn.ModuleList()
|
||||
self.align_corners = self.kernel_generate_head.align_corners
|
||||
self.num_classes = self.kernel_generate_head.num_classes
|
||||
self.input_transform = self.kernel_generate_head.input_transform
|
||||
self.ignore_index = self.kernel_generate_head.ignore_index
|
||||
self.out_channels = self.num_classes
|
||||
|
||||
for head_cfg in kernel_update_head:
|
||||
self.kernel_update_head.append(MODELS.build(head_cfg))
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
feats = self.kernel_generate_head._forward_feature(inputs)
|
||||
sem_seg = self.kernel_generate_head.cls_seg(feats)
|
||||
seg_kernels = self.kernel_generate_head.conv_seg.weight.clone()
|
||||
seg_kernels = seg_kernels[None].expand(
|
||||
feats.size(0), *seg_kernels.size())
|
||||
|
||||
stage_segs = [sem_seg]
|
||||
for i in range(self.num_stages):
|
||||
sem_seg, seg_kernels = self.kernel_update_head[i](feats,
|
||||
seg_kernels,
|
||||
sem_seg)
|
||||
stage_segs.append(sem_seg)
|
||||
if self.training:
|
||||
return stage_segs
|
||||
# only return the prediction of the last stage during testing
|
||||
return stage_segs[-1]
|
||||
|
||||
def loss_by_feat(self, seg_logits: List[Tensor],
|
||||
batch_data_samples: SampleList, **kwargs) -> dict:
|
||||
losses = dict()
|
||||
for i, logit in enumerate(seg_logits):
|
||||
loss = self.kernel_generate_head.loss_by_feat(
|
||||
logit, batch_data_samples)
|
||||
for k, v in loss.items():
|
||||
losses[f'{k}.s{i}'] = v
|
||||
|
||||
return losses
|
||||
91
finetune/mmseg/models/decode_heads/lraspp_head.py
Normal file
91
finetune/mmseg/models/decode_heads/lraspp_head.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.utils import is_tuple_of
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class LRASPPHead(BaseDecodeHead):
|
||||
"""Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.
|
||||
|
||||
This head is the improved implementation of `Searching for MobileNetV3
|
||||
<https://ieeexplore.ieee.org/document/9008835>`_.
|
||||
|
||||
Args:
|
||||
branch_channels (tuple[int]): The number of output channels in every
|
||||
each branch. Default: (32, 64).
|
||||
"""
|
||||
|
||||
def __init__(self, branch_channels=(32, 64), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if self.input_transform != 'multiple_select':
|
||||
raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
|
||||
f'must be \'multiple_select\'. But received '
|
||||
f'\'{self.input_transform}\'')
|
||||
assert is_tuple_of(branch_channels, int)
|
||||
assert len(branch_channels) == len(self.in_channels) - 1
|
||||
self.branch_channels = branch_channels
|
||||
|
||||
self.convs = nn.Sequential()
|
||||
self.conv_ups = nn.Sequential()
|
||||
for i in range(len(branch_channels)):
|
||||
self.convs.add_module(
|
||||
f'conv{i}',
|
||||
nn.Conv2d(
|
||||
self.in_channels[i], branch_channels[i], 1, bias=False))
|
||||
self.conv_ups.add_module(
|
||||
f'conv_up{i}',
|
||||
ConvModule(
|
||||
self.channels + branch_channels[i],
|
||||
self.channels,
|
||||
1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
bias=False))
|
||||
|
||||
self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)
|
||||
|
||||
self.aspp_conv = ConvModule(
|
||||
self.in_channels[-1],
|
||||
self.channels,
|
||||
1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
bias=False)
|
||||
self.image_pool = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=49, stride=(16, 20)),
|
||||
ConvModule(
|
||||
self.in_channels[2],
|
||||
self.channels,
|
||||
1,
|
||||
act_cfg=dict(type='Sigmoid'),
|
||||
bias=False))
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
inputs = self._transform_inputs(inputs)
|
||||
|
||||
x = inputs[-1]
|
||||
|
||||
x = self.aspp_conv(x) * resize(
|
||||
self.image_pool(x),
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.conv_up_input(x)
|
||||
|
||||
for i in range(len(self.branch_channels) - 1, -1, -1):
|
||||
x = resize(
|
||||
x,
|
||||
size=inputs[i].size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = torch.cat([x, self.convs[i](inputs[i])], 1)
|
||||
x = self.conv_ups[i](x)
|
||||
|
||||
return self.cls_seg(x)
|
||||
163
finetune/mmseg/models/decode_heads/mask2former_head.py
Normal file
163
finetune/mmseg/models/decode_heads/mask2former_head.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
try:
|
||||
from mmdet.models.dense_heads import \
|
||||
Mask2FormerHead as MMDET_Mask2FormerHead
|
||||
except ModuleNotFoundError:
|
||||
MMDET_Mask2FormerHead = BaseModule
|
||||
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures.seg_data_sample import SegDataSample
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class Mask2FormerHead(MMDET_Mask2FormerHead):
|
||||
"""Implements the Mask2Former head.
|
||||
|
||||
See `Mask2Former: Masked-attention Mask Transformer for Universal Image
|
||||
Segmentation <https://arxiv.org/abs/2112.01527>`_ for details.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes. Default: 150.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
ignore_index (int): The label index to be ignored. Default: 255.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
align_corners=False,
|
||||
ignore_index=255,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.align_corners = align_corners
|
||||
self.out_channels = num_classes
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
feat_channels = kwargs['feat_channels']
|
||||
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
|
||||
|
||||
def _seg_data_to_instance_data(self, batch_data_samples: SampleList):
|
||||
"""Perform forward propagation to convert paradigm from MMSegmentation
|
||||
to MMDetection to ensure ``MMDET_Mask2FormerHead`` could be called
|
||||
normally. Specifically, ``batch_gt_instances`` would be added.
|
||||
|
||||
Args:
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: A tuple contains two lists.
|
||||
|
||||
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
||||
gt_instance. It usually includes ``labels``, each is
|
||||
unique ground truth label id of images, with
|
||||
shape (num_gt, ) and ``masks``, each is ground truth
|
||||
masks of each instances of a image, shape (num_gt, h, w).
|
||||
- batch_img_metas (list[dict]): List of image meta information.
|
||||
"""
|
||||
batch_img_metas = []
|
||||
batch_gt_instances = []
|
||||
|
||||
for data_sample in batch_data_samples:
|
||||
batch_img_metas.append(data_sample.metainfo)
|
||||
gt_sem_seg = data_sample.gt_sem_seg.data
|
||||
classes = torch.unique(
|
||||
gt_sem_seg,
|
||||
sorted=False,
|
||||
return_inverse=False,
|
||||
return_counts=False)
|
||||
|
||||
# remove ignored region
|
||||
gt_labels = classes[classes != self.ignore_index]
|
||||
|
||||
masks = []
|
||||
for class_id in gt_labels:
|
||||
masks.append(gt_sem_seg == class_id)
|
||||
|
||||
if len(masks) == 0:
|
||||
gt_masks = torch.zeros(
|
||||
(0, gt_sem_seg.shape[-2],
|
||||
gt_sem_seg.shape[-1])).to(gt_sem_seg).long()
|
||||
else:
|
||||
gt_masks = torch.stack(masks).squeeze(1).long()
|
||||
|
||||
instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
|
||||
batch_gt_instances.append(instance_data)
|
||||
return batch_gt_instances, batch_img_metas
|
||||
|
||||
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
|
||||
train_cfg: ConfigType) -> dict:
|
||||
"""Perform forward propagation and loss calculation of the decoder head
|
||||
on the features of the upstream network.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the upstream
|
||||
network, each is a 4D-tensor.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
train_cfg (ConfigType): Training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components.
|
||||
"""
|
||||
# batch SegDataSample to InstanceDataSample
|
||||
batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data(
|
||||
batch_data_samples)
|
||||
|
||||
# forward
|
||||
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
|
||||
|
||||
# loss
|
||||
losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
|
||||
batch_gt_instances, batch_img_metas)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType) -> Tuple[Tensor]:
|
||||
"""Test without augmentaton.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the
|
||||
upstream network, each is a 4D-tensor.
|
||||
batch_img_metas (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
test_cfg (ConfigType): Test config.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor of segmentation mask.
|
||||
"""
|
||||
batch_data_samples = [
|
||||
SegDataSample(metainfo=metainfo) for metainfo in batch_img_metas
|
||||
]
|
||||
|
||||
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
|
||||
mask_cls_results = all_cls_scores[-1]
|
||||
mask_pred_results = all_mask_preds[-1]
|
||||
if 'pad_shape' in batch_img_metas[0]:
|
||||
size = batch_img_metas[0]['pad_shape']
|
||||
else:
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
# upsample mask
|
||||
mask_pred_results = F.interpolate(
|
||||
mask_pred_results, size=size, mode='bilinear', align_corners=False)
|
||||
cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1]
|
||||
mask_pred = mask_pred_results.sigmoid()
|
||||
seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred)
|
||||
return seg_logits
|
||||
174
finetune/mmseg/models/decode_heads/maskformer_head.py
Normal file
174
finetune/mmseg/models/decode_heads/maskformer_head.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
try:
|
||||
from mmdet.models.dense_heads import MaskFormerHead as MMDET_MaskFormerHead
|
||||
except ModuleNotFoundError:
|
||||
MMDET_MaskFormerHead = BaseModule
|
||||
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures.seg_data_sample import SegDataSample
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MaskFormerHead(MMDET_MaskFormerHead):
|
||||
"""Implements the MaskFormer head.
|
||||
|
||||
See `Per-Pixel Classification is Not All You Need for Semantic Segmentation
|
||||
<https://arxiv.org/pdf/2107.06278>`_ for details.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes. Default: 150.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
ignore_index (int): The label index to be ignored. Default: 255.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes: int = 150,
|
||||
align_corners: bool = False,
|
||||
ignore_index: int = 255,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.out_channels = kwargs['out_channels']
|
||||
self.align_corners = True
|
||||
self.num_classes = num_classes
|
||||
self.align_corners = align_corners
|
||||
self.out_channels = num_classes
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
feat_channels = kwargs['feat_channels']
|
||||
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
|
||||
|
||||
def _seg_data_to_instance_data(self, batch_data_samples: SampleList):
|
||||
"""Perform forward propagation to convert paradigm from MMSegmentation
|
||||
to MMDetection to ensure ``MMDET_MaskFormerHead`` could be called
|
||||
normally. Specifically, ``batch_gt_instances`` would be added.
|
||||
|
||||
Args:
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: A tuple contains two lists.
|
||||
|
||||
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
||||
gt_instance. It usually includes ``labels``, each is
|
||||
unique ground truth label id of images, with
|
||||
shape (num_gt, ) and ``masks``, each is ground truth
|
||||
masks of each instances of a image, shape (num_gt, h, w).
|
||||
- batch_img_metas (list[dict]): List of image meta information.
|
||||
"""
|
||||
batch_img_metas = []
|
||||
batch_gt_instances = []
|
||||
for data_sample in batch_data_samples:
|
||||
# Add `batch_input_shape` in metainfo of data_sample, which would
|
||||
# be used in MaskFormerHead of MMDetection.
|
||||
metainfo = data_sample.metainfo
|
||||
metainfo['batch_input_shape'] = metainfo['img_shape']
|
||||
data_sample.set_metainfo(metainfo)
|
||||
batch_img_metas.append(data_sample.metainfo)
|
||||
gt_sem_seg = data_sample.gt_sem_seg.data
|
||||
classes = torch.unique(
|
||||
gt_sem_seg,
|
||||
sorted=False,
|
||||
return_inverse=False,
|
||||
return_counts=False)
|
||||
|
||||
# remove ignored region
|
||||
gt_labels = classes[classes != self.ignore_index]
|
||||
|
||||
masks = []
|
||||
for class_id in gt_labels:
|
||||
masks.append(gt_sem_seg == class_id)
|
||||
|
||||
if len(masks) == 0:
|
||||
gt_masks = torch.zeros((0, gt_sem_seg.shape[-2],
|
||||
gt_sem_seg.shape[-1])).to(gt_sem_seg)
|
||||
else:
|
||||
gt_masks = torch.stack(masks).squeeze(1)
|
||||
|
||||
instance_data = InstanceData(
|
||||
labels=gt_labels, masks=gt_masks.long())
|
||||
batch_gt_instances.append(instance_data)
|
||||
return batch_gt_instances, batch_img_metas
|
||||
|
||||
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
|
||||
train_cfg: ConfigType) -> dict:
|
||||
"""Perform forward propagation and loss calculation of the decoder head
|
||||
on the features of the upstream network.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the upstream
|
||||
network, each is a 4D-tensor.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
train_cfg (ConfigType): Training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components.
|
||||
"""
|
||||
# batch SegDataSample to InstanceDataSample
|
||||
batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data(
|
||||
batch_data_samples)
|
||||
|
||||
# forward
|
||||
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
|
||||
|
||||
# loss
|
||||
losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
|
||||
batch_gt_instances, batch_img_metas)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType) -> Tuple[Tensor]:
|
||||
"""Test without augmentaton.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the
|
||||
upstream network, each is a 4D-tensor.
|
||||
batch_img_metas (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
test_cfg (ConfigType): Test config.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor of segmentation mask.
|
||||
"""
|
||||
|
||||
batch_data_samples = []
|
||||
for metainfo in batch_img_metas:
|
||||
metainfo['batch_input_shape'] = metainfo['img_shape']
|
||||
batch_data_samples.append(SegDataSample(metainfo=metainfo))
|
||||
# Forward function of MaskFormerHead from MMDetection needs
|
||||
# 'batch_data_samples' as inputs, which is image shape actually.
|
||||
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
|
||||
mask_cls_results = all_cls_scores[-1]
|
||||
mask_pred_results = all_mask_preds[-1]
|
||||
|
||||
# upsample masks
|
||||
img_shape = batch_img_metas[0]['batch_input_shape']
|
||||
mask_pred_results = F.interpolate(
|
||||
mask_pred_results,
|
||||
size=img_shape,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
# semantic inference
|
||||
cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1]
|
||||
mask_pred = mask_pred_results.sigmoid()
|
||||
seg_logits = torch.einsum('bqc,bqhw->bchw', cls_score, mask_pred)
|
||||
return seg_logits
|
||||
50
finetune/mmseg/models/decode_heads/nl_head.py
Normal file
50
finetune/mmseg/models/decode_heads/nl_head.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.cnn import NonLocal2d
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class NLHead(FCNHead):
|
||||
"""Non-local Neural Networks.
|
||||
|
||||
This head is the implementation of `NLNet
|
||||
<https://arxiv.org/abs/1711.07971>`_.
|
||||
|
||||
Args:
|
||||
reduction (int): Reduction factor of projection transform. Default: 2.
|
||||
use_scale (bool): Whether to scale pairwise_weight by
|
||||
sqrt(1/inter_channels). Default: True.
|
||||
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
|
||||
'dot_product'. Default: 'embedded_gaussian.'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reduction=2,
|
||||
use_scale=True,
|
||||
mode='embedded_gaussian',
|
||||
**kwargs):
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.reduction = reduction
|
||||
self.use_scale = use_scale
|
||||
self.mode = mode
|
||||
self.nl_block = NonLocal2d(
|
||||
in_channels=self.channels,
|
||||
reduction=self.reduction,
|
||||
use_scale=self.use_scale,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
mode=self.mode)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
output = self.nl_block(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
127
finetune/mmseg/models/decode_heads/ocr_head.py
Normal file
127
finetune/mmseg/models/decode_heads/ocr_head.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from ..utils import resize
|
||||
from .cascade_decode_head import BaseCascadeDecodeHead
|
||||
|
||||
|
||||
class SpatialGatherModule(nn.Module):
|
||||
"""Aggregate the context features according to the initial predicted
|
||||
probability distribution.
|
||||
|
||||
Employ the soft-weighted method to aggregate the context.
|
||||
"""
|
||||
|
||||
def __init__(self, scale):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, feats, probs):
|
||||
"""Forward function."""
|
||||
batch_size, num_classes, height, width = probs.size()
|
||||
channels = feats.size(1)
|
||||
probs = probs.view(batch_size, num_classes, -1)
|
||||
feats = feats.view(batch_size, channels, -1)
|
||||
# [batch_size, height*width, num_classes]
|
||||
feats = feats.permute(0, 2, 1)
|
||||
# [batch_size, channels, height*width]
|
||||
probs = F.softmax(self.scale * probs, dim=2)
|
||||
# [batch_size, channels, num_classes]
|
||||
ocr_context = torch.matmul(probs, feats)
|
||||
ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3)
|
||||
return ocr_context
|
||||
|
||||
|
||||
class ObjectAttentionBlock(_SelfAttentionBlock):
|
||||
"""Make a OCR used SelfAttentionBlock."""
|
||||
|
||||
def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg,
|
||||
act_cfg):
|
||||
if scale > 1:
|
||||
query_downsample = nn.MaxPool2d(kernel_size=scale)
|
||||
else:
|
||||
query_downsample = None
|
||||
super().__init__(
|
||||
key_in_channels=in_channels,
|
||||
query_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=in_channels,
|
||||
share_key_query=False,
|
||||
query_downsample=query_downsample,
|
||||
key_downsample=None,
|
||||
key_query_num_convs=2,
|
||||
key_query_norm=True,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=True,
|
||||
matmul_norm=True,
|
||||
with_out=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.bottleneck = ConvModule(
|
||||
in_channels * 2,
|
||||
in_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, query_feats, key_feats):
|
||||
"""Forward function."""
|
||||
context = super().forward(query_feats, key_feats)
|
||||
output = self.bottleneck(torch.cat([context, query_feats], dim=1))
|
||||
if self.query_downsample is not None:
|
||||
output = resize(query_feats)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class OCRHead(BaseCascadeDecodeHead):
|
||||
"""Object-Contextual Representations for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `OCRNet
|
||||
<https://arxiv.org/abs/1909.11065>`_.
|
||||
|
||||
Args:
|
||||
ocr_channels (int): The intermediate channels of OCR block.
|
||||
scale (int): The scale of probability map in SpatialGatherModule in
|
||||
Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, ocr_channels, scale=1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.ocr_channels = ocr_channels
|
||||
self.scale = scale
|
||||
self.object_context_block = ObjectAttentionBlock(
|
||||
self.channels,
|
||||
self.ocr_channels,
|
||||
self.scale,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.spatial_gather_module = SpatialGatherModule(self.scale)
|
||||
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs, prev_output):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
feats = self.bottleneck(x)
|
||||
context = self.spatial_gather_module(feats, prev_output)
|
||||
object_context = self.object_context_block(feats, context)
|
||||
output = self.cls_seg(object_context)
|
||||
|
||||
return output
|
||||
183
finetune/mmseg/models/decode_heads/pid_head.py
Normal file
183
finetune/mmseg/models/decode_heads/pid_head.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.models.losses import accuracy
|
||||
from mmseg.models.utils import resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType, SampleList
|
||||
|
||||
|
||||
class BasePIDHead(BaseModule):
|
||||
"""Base class for PID head.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
init_cfg (dict or list[dict], optional): Init config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
self.conv = ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
order=('norm', 'act', 'conv'))
|
||||
_, self.norm = build_norm_layer(norm_cfg, num_features=channels)
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
|
||||
def forward(self, x: Tensor, cls_seg: Optional[nn.Module]) -> Tensor:
|
||||
"""Forward function.
|
||||
Args:
|
||||
x (Tensor): Input tensor.
|
||||
cls_seg (nn.Module, optional): The classification head.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor.
|
||||
"""
|
||||
x = self.conv(x)
|
||||
x = self.norm(x)
|
||||
x = self.act(x)
|
||||
if cls_seg is not None:
|
||||
x = cls_seg(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PIDHead(BaseDecodeHead):
|
||||
"""Decode head for PIDNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
num_classes (int): Number of classes.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
num_classes: int,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
channels,
|
||||
num_classes=num_classes,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**kwargs)
|
||||
self.i_head = BasePIDHead(in_channels, channels, norm_cfg, act_cfg)
|
||||
self.p_head = BasePIDHead(in_channels // 2, channels, norm_cfg,
|
||||
act_cfg)
|
||||
self.d_head = BasePIDHead(
|
||||
in_channels // 2,
|
||||
in_channels // 4,
|
||||
norm_cfg,
|
||||
)
|
||||
self.p_cls_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
|
||||
self.d_cls_seg = nn.Conv2d(in_channels // 4, 1, kernel_size=1)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: Union[Tensor,
|
||||
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
|
||||
"""Forward function.
|
||||
Args:
|
||||
inputs (Tensor | tuple[Tensor]): Input tensor or tuple of
|
||||
Tensor. When training, the input is a tuple of three tensors,
|
||||
(p_feat, i_feat, d_feat), and the output is a tuple of three
|
||||
tensors, (p_seg_logit, i_seg_logit, d_seg_logit).
|
||||
When inference, only the head of integral branch is used, and
|
||||
input is a tensor of integral feature map, and the output is
|
||||
the segmentation logit.
|
||||
|
||||
Returns:
|
||||
Tensor | tuple[Tensor]: Output tensor or tuple of tensors.
|
||||
"""
|
||||
if self.training:
|
||||
x_p, x_i, x_d = inputs
|
||||
x_p = self.p_head(x_p, self.p_cls_seg)
|
||||
x_i = self.i_head(x_i, self.cls_seg)
|
||||
x_d = self.d_head(x_d, self.d_cls_seg)
|
||||
return x_p, x_i, x_d
|
||||
else:
|
||||
return self.i_head(inputs, self.cls_seg)
|
||||
|
||||
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tuple[Tensor]:
|
||||
gt_semantic_segs = [
|
||||
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
|
||||
]
|
||||
gt_edge_segs = [
|
||||
data_sample.gt_edge_map.data for data_sample in batch_data_samples
|
||||
]
|
||||
gt_sem_segs = torch.stack(gt_semantic_segs, dim=0)
|
||||
gt_edge_segs = torch.stack(gt_edge_segs, dim=0)
|
||||
return gt_sem_segs, gt_edge_segs
|
||||
|
||||
def loss_by_feat(self, seg_logits: Tuple[Tensor],
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
loss = dict()
|
||||
p_logit, i_logit, d_logit = seg_logits
|
||||
sem_label, bd_label = self._stack_batch_gt(batch_data_samples)
|
||||
p_logit = resize(
|
||||
input=p_logit,
|
||||
size=sem_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
i_logit = resize(
|
||||
input=i_logit,
|
||||
size=sem_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
d_logit = resize(
|
||||
input=d_logit,
|
||||
size=bd_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
sem_label = sem_label.squeeze(1)
|
||||
bd_label = bd_label.squeeze(1)
|
||||
loss['loss_sem_p'] = self.loss_decode[0](
|
||||
p_logit, sem_label, ignore_index=self.ignore_index)
|
||||
loss['loss_sem_i'] = self.loss_decode[1](i_logit, sem_label)
|
||||
loss['loss_bd'] = self.loss_decode[2](d_logit, bd_label)
|
||||
filler = torch.ones_like(sem_label) * self.ignore_index
|
||||
sem_bd_label = torch.where(
|
||||
torch.sigmoid(d_logit[:, 0, :, :]) > 0.8, sem_label, filler)
|
||||
loss['loss_sem_bd'] = self.loss_decode[3](i_logit, sem_bd_label)
|
||||
loss['acc_seg'] = accuracy(
|
||||
i_logit, sem_label, ignore_index=self.ignore_index)
|
||||
return loss
|
||||
367
finetune/mmseg/models/decode_heads/point_head.py
Normal file
367
finetune/mmseg/models/decode_heads/point_head.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
try:
|
||||
from mmcv.ops import point_sample
|
||||
except ModuleNotFoundError:
|
||||
point_sample = None
|
||||
|
||||
from typing import List
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList
|
||||
from ..losses import accuracy
|
||||
from ..utils import resize
|
||||
from .cascade_decode_head import BaseCascadeDecodeHead
|
||||
|
||||
|
||||
def calculate_uncertainty(seg_logits):
|
||||
"""Estimate uncertainty based on seg logits.
|
||||
|
||||
For each location of the prediction ``seg_logits`` we estimate
|
||||
uncertainty as the difference between top first and top second
|
||||
predicted logits.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): Semantic segmentation logits,
|
||||
shape (batch_size, num_classes, height, width).
|
||||
|
||||
Returns:
|
||||
scores (Tensor): T uncertainty scores with the most uncertain
|
||||
locations having the highest uncertainty score, shape (
|
||||
batch_size, 1, height, width)
|
||||
"""
|
||||
top2_scores = torch.topk(seg_logits, k=2, dim=1)[0]
|
||||
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PointHead(BaseCascadeDecodeHead):
|
||||
"""A mask point head use in PointRend.
|
||||
|
||||
This head is implemented of `PointRend: Image Segmentation as
|
||||
Rendering <https://arxiv.org/abs/1912.08193>`_.
|
||||
``PointHead`` use shared multi-layer perceptron (equivalent to
|
||||
nn.Conv1d) to predict the logit of input points. The fine-grained feature
|
||||
and coarse feature will be concatenate together for predication.
|
||||
|
||||
Args:
|
||||
num_fcs (int): Number of fc layers in the head. Default: 3.
|
||||
in_channels (int): Number of input channels. Default: 256.
|
||||
fc_channels (int): Number of fc channels. Default: 256.
|
||||
num_classes (int): Number of classes for logits. Default: 80.
|
||||
class_agnostic (bool): Whether use class agnostic classification.
|
||||
If so, the output channels of logits will be 1. Default: False.
|
||||
coarse_pred_each_layer (bool): Whether concatenate coarse feature with
|
||||
the output of each fc layer. Default: True.
|
||||
conv_cfg (dict|None): Dictionary to construct and config conv layer.
|
||||
Default: dict(type='Conv1d'))
|
||||
norm_cfg (dict|None): Dictionary to construct and config norm layer.
|
||||
Default: None.
|
||||
loss_point (dict): Dictionary to construct and config loss layer of
|
||||
point head. Default: dict(type='CrossEntropyLoss', use_mask=True,
|
||||
loss_weight=1.0).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_fcs=3,
|
||||
coarse_pred_each_layer=True,
|
||||
conv_cfg=dict(type='Conv1d'),
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU', inplace=False),
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
input_transform='multiple_select',
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
init_cfg=dict(
|
||||
type='Normal', std=0.01, override=dict(name='fc_seg')),
|
||||
**kwargs)
|
||||
if point_sample is None:
|
||||
raise RuntimeError('Please install mmcv-full for '
|
||||
'point_sample ops')
|
||||
|
||||
self.num_fcs = num_fcs
|
||||
self.coarse_pred_each_layer = coarse_pred_each_layer
|
||||
|
||||
fc_in_channels = sum(self.in_channels) + self.num_classes
|
||||
fc_channels = self.channels
|
||||
self.fcs = nn.ModuleList()
|
||||
for k in range(num_fcs):
|
||||
fc = ConvModule(
|
||||
fc_in_channels,
|
||||
fc_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.fcs.append(fc)
|
||||
fc_in_channels = fc_channels
|
||||
fc_in_channels += self.num_classes if self.coarse_pred_each_layer \
|
||||
else 0
|
||||
self.fc_seg = nn.Conv1d(
|
||||
fc_in_channels,
|
||||
self.num_classes,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
if self.dropout_ratio > 0:
|
||||
self.dropout = nn.Dropout(self.dropout_ratio)
|
||||
delattr(self, 'conv_seg')
|
||||
|
||||
def cls_seg(self, feat):
|
||||
"""Classify each pixel with fc."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.fc_seg(feat)
|
||||
return output
|
||||
|
||||
def forward(self, fine_grained_point_feats, coarse_point_feats):
|
||||
x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1)
|
||||
for fc in self.fcs:
|
||||
x = fc(x)
|
||||
if self.coarse_pred_each_layer:
|
||||
x = torch.cat((x, coarse_point_feats), dim=1)
|
||||
return self.cls_seg(x)
|
||||
|
||||
def _get_fine_grained_point_feats(self, x, points):
|
||||
"""Sample from fine grained features.
|
||||
|
||||
Args:
|
||||
x (list[Tensor]): Feature pyramid from by neck or backbone.
|
||||
points (Tensor): Point coordinates, shape (batch_size,
|
||||
num_points, 2).
|
||||
|
||||
Returns:
|
||||
fine_grained_feats (Tensor): Sampled fine grained feature,
|
||||
shape (batch_size, sum(channels of x), num_points).
|
||||
"""
|
||||
|
||||
fine_grained_feats_list = [
|
||||
point_sample(_, points, align_corners=self.align_corners)
|
||||
for _ in x
|
||||
]
|
||||
if len(fine_grained_feats_list) > 1:
|
||||
fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1)
|
||||
else:
|
||||
fine_grained_feats = fine_grained_feats_list[0]
|
||||
|
||||
return fine_grained_feats
|
||||
|
||||
def _get_coarse_point_feats(self, prev_output, points):
|
||||
"""Sample from fine grained features.
|
||||
|
||||
Args:
|
||||
prev_output (list[Tensor]): Prediction of previous decode head.
|
||||
points (Tensor): Point coordinates, shape (batch_size,
|
||||
num_points, 2).
|
||||
|
||||
Returns:
|
||||
coarse_feats (Tensor): Sampled coarse feature, shape (batch_size,
|
||||
num_classes, num_points).
|
||||
"""
|
||||
|
||||
coarse_feats = point_sample(
|
||||
prev_output, points, align_corners=self.align_corners)
|
||||
|
||||
return coarse_feats
|
||||
|
||||
def loss(self, inputs, prev_output, batch_data_samples: SampleList,
|
||||
train_cfg, **kwargs):
|
||||
"""Forward function for training.
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
prev_output (Tensor): The output of previous decode head.
|
||||
batch_data_samples (list[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `img_metas` or `gt_semantic_seg`.
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
x = self._transform_inputs(inputs)
|
||||
with torch.no_grad():
|
||||
points = self.get_points_train(
|
||||
prev_output, calculate_uncertainty, cfg=train_cfg)
|
||||
fine_grained_point_feats = self._get_fine_grained_point_feats(
|
||||
x, points)
|
||||
coarse_point_feats = self._get_coarse_point_feats(prev_output, points)
|
||||
point_logits = self.forward(fine_grained_point_feats,
|
||||
coarse_point_feats)
|
||||
|
||||
losses = self.loss_by_feat(point_logits, points, batch_data_samples)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, inputs, prev_output, batch_img_metas: List[dict],
|
||||
test_cfg, **kwargs):
|
||||
"""Forward function for testing.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
prev_output (Tensor): The output of previous decode head.
|
||||
img_metas (list[dict]): List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Output segmentation map.
|
||||
"""
|
||||
|
||||
x = self._transform_inputs(inputs)
|
||||
refined_seg_logits = prev_output.clone()
|
||||
for _ in range(test_cfg.subdivision_steps):
|
||||
refined_seg_logits = resize(
|
||||
refined_seg_logits,
|
||||
scale_factor=test_cfg.scale_factor,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
batch_size, channels, height, width = refined_seg_logits.shape
|
||||
point_indices, points = self.get_points_test(
|
||||
refined_seg_logits, calculate_uncertainty, cfg=test_cfg)
|
||||
fine_grained_point_feats = self._get_fine_grained_point_feats(
|
||||
x, points)
|
||||
coarse_point_feats = self._get_coarse_point_feats(
|
||||
prev_output, points)
|
||||
point_logits = self.forward(fine_grained_point_feats,
|
||||
coarse_point_feats)
|
||||
|
||||
point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
|
||||
refined_seg_logits = refined_seg_logits.reshape(
|
||||
batch_size, channels, height * width)
|
||||
refined_seg_logits = refined_seg_logits.scatter_(
|
||||
2, point_indices, point_logits)
|
||||
refined_seg_logits = refined_seg_logits.view(
|
||||
batch_size, channels, height, width)
|
||||
|
||||
return self.predict_by_feat(refined_seg_logits, batch_img_metas,
|
||||
**kwargs)
|
||||
|
||||
def loss_by_feat(self, point_logits, points, batch_data_samples, **kwargs):
|
||||
"""Compute segmentation loss."""
|
||||
gt_semantic_seg = self._stack_batch_gt(batch_data_samples)
|
||||
point_label = point_sample(
|
||||
gt_semantic_seg.float(),
|
||||
points,
|
||||
mode='nearest',
|
||||
align_corners=self.align_corners)
|
||||
point_label = point_label.squeeze(1).long()
|
||||
|
||||
loss = dict()
|
||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.loss_decode]
|
||||
else:
|
||||
losses_decode = self.loss_decode
|
||||
for loss_module in losses_decode:
|
||||
loss['point' + loss_module.loss_name] = loss_module(
|
||||
point_logits, point_label, ignore_index=self.ignore_index)
|
||||
|
||||
loss['acc_point'] = accuracy(
|
||||
point_logits, point_label, ignore_index=self.ignore_index)
|
||||
return loss
|
||||
|
||||
def get_points_train(self, seg_logits, uncertainty_func, cfg):
|
||||
"""Sample points for training.
|
||||
|
||||
Sample points in [0, 1] x [0, 1] coordinate space based on their
|
||||
uncertainty. The uncertainties are calculated for each point using
|
||||
'uncertainty_func' function that takes point's logit prediction as
|
||||
input.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): Semantic segmentation logits, shape (
|
||||
batch_size, num_classes, height, width).
|
||||
uncertainty_func (func): uncertainty calculation function.
|
||||
cfg (dict): Training config of point head.
|
||||
|
||||
Returns:
|
||||
point_coords (Tensor): A tensor of shape (batch_size, num_points,
|
||||
2) that contains the coordinates of ``num_points`` sampled
|
||||
points.
|
||||
"""
|
||||
num_points = cfg.num_points
|
||||
oversample_ratio = cfg.oversample_ratio
|
||||
importance_sample_ratio = cfg.importance_sample_ratio
|
||||
assert oversample_ratio >= 1
|
||||
assert 0 <= importance_sample_ratio <= 1
|
||||
batch_size = seg_logits.shape[0]
|
||||
num_sampled = int(num_points * oversample_ratio)
|
||||
point_coords = torch.rand(
|
||||
batch_size, num_sampled, 2, device=seg_logits.device)
|
||||
point_logits = point_sample(seg_logits, point_coords)
|
||||
# It is crucial to calculate uncertainty based on the sampled
|
||||
# prediction value for the points. Calculating uncertainties of the
|
||||
# coarse predictions first and sampling them for points leads to
|
||||
# incorrect results. To illustrate this: assume uncertainty func(
|
||||
# logits)=-abs(logits), a sampled point between two coarse
|
||||
# predictions with -1 and 1 logits has 0 logits, and therefore 0
|
||||
# uncertainty value. However, if we calculate uncertainties for the
|
||||
# coarse predictions first, both will have -1 uncertainty,
|
||||
# and sampled point will get -1 uncertainty.
|
||||
point_uncertainties = uncertainty_func(point_logits)
|
||||
num_uncertain_points = int(importance_sample_ratio * num_points)
|
||||
num_random_points = num_points - num_uncertain_points
|
||||
idx = torch.topk(
|
||||
point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
|
||||
shift = num_sampled * torch.arange(
|
||||
batch_size, dtype=torch.long, device=seg_logits.device)
|
||||
idx += shift[:, None]
|
||||
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
|
||||
batch_size, num_uncertain_points, 2)
|
||||
if num_random_points > 0:
|
||||
rand_point_coords = torch.rand(
|
||||
batch_size, num_random_points, 2, device=seg_logits.device)
|
||||
point_coords = torch.cat((point_coords, rand_point_coords), dim=1)
|
||||
return point_coords
|
||||
|
||||
def get_points_test(self, seg_logits, uncertainty_func, cfg):
|
||||
"""Sample points for testing.
|
||||
|
||||
Find ``num_points`` most uncertain points from ``uncertainty_map``.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): A tensor of shape (batch_size, num_classes,
|
||||
height, width) for class-specific or class-agnostic prediction.
|
||||
uncertainty_func (func): uncertainty calculation function.
|
||||
cfg (dict): Testing config of point head.
|
||||
|
||||
Returns:
|
||||
point_indices (Tensor): A tensor of shape (batch_size, num_points)
|
||||
that contains indices from [0, height x width) of the most
|
||||
uncertain points.
|
||||
point_coords (Tensor): A tensor of shape (batch_size, num_points,
|
||||
2) that contains [0, 1] x [0, 1] normalized coordinates of the
|
||||
most uncertain points from the ``height x width`` grid .
|
||||
"""
|
||||
|
||||
num_points = cfg.subdivision_num_points
|
||||
uncertainty_map = uncertainty_func(seg_logits)
|
||||
batch_size, _, height, width = uncertainty_map.shape
|
||||
h_step = 1.0 / height
|
||||
w_step = 1.0 / width
|
||||
|
||||
uncertainty_map = uncertainty_map.view(batch_size, height * width)
|
||||
num_points = min(height * width, num_points)
|
||||
point_indices = uncertainty_map.topk(num_points, dim=1)[1]
|
||||
point_coords = torch.zeros(
|
||||
batch_size,
|
||||
num_points,
|
||||
2,
|
||||
dtype=torch.float,
|
||||
device=seg_logits.device)
|
||||
point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
|
||||
width).float() * w_step
|
||||
point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
|
||||
width).float() * h_step
|
||||
return point_indices, point_coords
|
||||
197
finetune/mmseg/models/decode_heads/psa_head.py
Normal file
197
finetune/mmseg/models/decode_heads/psa_head.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
try:
|
||||
from mmcv.ops import PSAMask
|
||||
except ModuleNotFoundError:
|
||||
PSAMask = None
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PSAHead(BaseDecodeHead):
|
||||
"""Point-wise Spatial Attention Network for Scene Parsing.
|
||||
|
||||
This head is the implementation of `PSANet
|
||||
<https://hszhao.github.io/papers/eccv18_psanet.pdf>`_.
|
||||
|
||||
Args:
|
||||
mask_size (tuple[int]): The PSA mask size. It usually equals input
|
||||
size.
|
||||
psa_type (str): The type of psa module. Options are 'collect',
|
||||
'distribute', 'bi-direction'. Default: 'bi-direction'
|
||||
compact (bool): Whether use compact map for 'collect' mode.
|
||||
Default: True.
|
||||
shrink_factor (int): The downsample factors of psa mask. Default: 2.
|
||||
normalization_factor (float): The normalize factor of attention.
|
||||
psa_softmax (bool): Whether use softmax for attention.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mask_size,
|
||||
psa_type='bi-direction',
|
||||
compact=False,
|
||||
shrink_factor=2,
|
||||
normalization_factor=1.0,
|
||||
psa_softmax=True,
|
||||
**kwargs):
|
||||
if PSAMask is None:
|
||||
raise RuntimeError('Please install mmcv-full for PSAMask ops')
|
||||
super().__init__(**kwargs)
|
||||
assert psa_type in ['collect', 'distribute', 'bi-direction']
|
||||
self.psa_type = psa_type
|
||||
self.compact = compact
|
||||
self.shrink_factor = shrink_factor
|
||||
self.mask_size = mask_size
|
||||
mask_h, mask_w = mask_size
|
||||
self.psa_softmax = psa_softmax
|
||||
if normalization_factor is None:
|
||||
normalization_factor = mask_h * mask_w
|
||||
self.normalization_factor = normalization_factor
|
||||
|
||||
self.reduce = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.attention = nn.Sequential(
|
||||
ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
nn.Conv2d(
|
||||
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
|
||||
if psa_type == 'bi-direction':
|
||||
self.reduce_p = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.attention_p = nn.Sequential(
|
||||
ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
nn.Conv2d(
|
||||
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
|
||||
self.psamask_collect = PSAMask('collect', mask_size)
|
||||
self.psamask_distribute = PSAMask('distribute', mask_size)
|
||||
else:
|
||||
self.psamask = PSAMask(psa_type, mask_size)
|
||||
self.proj = ConvModule(
|
||||
self.channels * (2 if psa_type == 'bi-direction' else 1),
|
||||
self.in_channels,
|
||||
kernel_size=1,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels * 2,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
identity = x
|
||||
align_corners = self.align_corners
|
||||
if self.psa_type in ['collect', 'distribute']:
|
||||
out = self.reduce(x)
|
||||
n, c, h, w = out.size()
|
||||
if self.shrink_factor != 1:
|
||||
if h % self.shrink_factor and w % self.shrink_factor:
|
||||
h = (h - 1) // self.shrink_factor + 1
|
||||
w = (w - 1) // self.shrink_factor + 1
|
||||
align_corners = True
|
||||
else:
|
||||
h = h // self.shrink_factor
|
||||
w = w // self.shrink_factor
|
||||
align_corners = False
|
||||
out = resize(
|
||||
out,
|
||||
size=(h, w),
|
||||
mode='bilinear',
|
||||
align_corners=align_corners)
|
||||
y = self.attention(out)
|
||||
if self.compact:
|
||||
if self.psa_type == 'collect':
|
||||
y = y.view(n, h * w,
|
||||
h * w).transpose(1, 2).view(n, h * w, h, w)
|
||||
else:
|
||||
y = self.psamask(y)
|
||||
if self.psa_softmax:
|
||||
y = F.softmax(y, dim=1)
|
||||
out = torch.bmm(
|
||||
out.view(n, c, h * w), y.view(n, h * w, h * w)).view(
|
||||
n, c, h, w) * (1.0 / self.normalization_factor)
|
||||
else:
|
||||
x_col = self.reduce(x)
|
||||
x_dis = self.reduce_p(x)
|
||||
n, c, h, w = x_col.size()
|
||||
if self.shrink_factor != 1:
|
||||
if h % self.shrink_factor and w % self.shrink_factor:
|
||||
h = (h - 1) // self.shrink_factor + 1
|
||||
w = (w - 1) // self.shrink_factor + 1
|
||||
align_corners = True
|
||||
else:
|
||||
h = h // self.shrink_factor
|
||||
w = w // self.shrink_factor
|
||||
align_corners = False
|
||||
x_col = resize(
|
||||
x_col,
|
||||
size=(h, w),
|
||||
mode='bilinear',
|
||||
align_corners=align_corners)
|
||||
x_dis = resize(
|
||||
x_dis,
|
||||
size=(h, w),
|
||||
mode='bilinear',
|
||||
align_corners=align_corners)
|
||||
y_col = self.attention(x_col)
|
||||
y_dis = self.attention_p(x_dis)
|
||||
if self.compact:
|
||||
y_dis = y_dis.view(n, h * w,
|
||||
h * w).transpose(1, 2).view(n, h * w, h, w)
|
||||
else:
|
||||
y_col = self.psamask_collect(y_col)
|
||||
y_dis = self.psamask_distribute(y_dis)
|
||||
if self.psa_softmax:
|
||||
y_col = F.softmax(y_col, dim=1)
|
||||
y_dis = F.softmax(y_dis, dim=1)
|
||||
x_col = torch.bmm(
|
||||
x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(
|
||||
n, c, h, w) * (1.0 / self.normalization_factor)
|
||||
x_dis = torch.bmm(
|
||||
x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(
|
||||
n, c, h, w) * (1.0 / self.normalization_factor)
|
||||
out = torch.cat([x_col, x_dis], 1)
|
||||
out = self.proj(out)
|
||||
out = resize(
|
||||
out,
|
||||
size=identity.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=align_corners)
|
||||
out = self.bottleneck(torch.cat((identity, out), dim=1))
|
||||
out = self.cls_seg(out)
|
||||
return out
|
||||
117
finetune/mmseg/models/decode_heads/psp_head.py
Normal file
117
finetune/mmseg/models/decode_heads/psp_head.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class PPM(nn.ModuleList):
|
||||
"""Pooling Pyramid Module used in PSPNet.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
|
||||
act_cfg, align_corners, **kwargs):
|
||||
super().__init__()
|
||||
self.pool_scales = pool_scales
|
||||
self.align_corners = align_corners
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
for pool_scale in pool_scales:
|
||||
self.append(
|
||||
nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(pool_scale),
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
**kwargs)))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
ppm_outs = []
|
||||
for ppm in self:
|
||||
ppm_out = ppm(x)
|
||||
upsampled_ppm_out = resize(
|
||||
ppm_out,
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
ppm_outs.append(upsampled_ppm_out)
|
||||
return ppm_outs
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PSPHead(BaseDecodeHead):
|
||||
"""Pyramid Scene Parsing Network.
|
||||
|
||||
This head is the implementation of
|
||||
`PSPNet <https://arxiv.org/abs/1612.01105>`_.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module. Default: (1, 2, 3, 6).
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(pool_scales, (list, tuple))
|
||||
self.pool_scales = pool_scales
|
||||
self.psp_modules = PPM(
|
||||
self.pool_scales,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels + len(pool_scales) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def _forward_feature(self, inputs):
|
||||
"""Forward function for feature maps before classifying each pixel with
|
||||
``self.cls_seg`` fc.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||
H, W) which is feature map for last layer of decoder head.
|
||||
"""
|
||||
x = self._transform_inputs(inputs)
|
||||
psp_outs = [x]
|
||||
psp_outs.extend(self.psp_modules(x))
|
||||
psp_outs = torch.cat(psp_outs, dim=1)
|
||||
feats = self.bottleneck(psp_outs)
|
||||
return feats
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
output = self._forward_feature(inputs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
736
finetune/mmseg/models/decode_heads/san_head.py
Normal file
736
finetune/mmseg/models/decode_heads/san_head.py
Normal file
@@ -0,0 +1,736 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from functools import partial
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
||||
from mmcv.ops import point_sample
|
||||
from mmengine.dist import all_reduce
|
||||
from mmengine.model.weight_init import (caffe2_xavier_init, normal_init,
|
||||
trunc_normal_)
|
||||
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmseg.models.backbones.vit import TransformerEncoderLayer
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import (ConfigType, MatchMasks, SampleList,
|
||||
seg_data_to_instance_data)
|
||||
from ..utils import (MLP, LayerNorm2d, PatchEmbed, cross_attn_layer,
|
||||
get_uncertain_point_coords_with_randomness, resize)
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class MLPMaskDecoder(nn.Module):
|
||||
"""Module for decoding query and visual features with MLP layers to
|
||||
generate the attention biases and the mask proposals."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels: int,
|
||||
total_heads: int = 1,
|
||||
total_layers: int = 1,
|
||||
embed_channels: int = 256,
|
||||
mlp_channels: int = 256,
|
||||
mlp_num_layers: int = 3,
|
||||
rescale_attn_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.total_heads = total_heads
|
||||
self.total_layers = total_layers
|
||||
|
||||
dense_affine_func = partial(nn.Conv2d, kernel_size=1)
|
||||
# Query Branch
|
||||
self.query_mlp = MLP(in_channels, mlp_channels, embed_channels,
|
||||
mlp_num_layers)
|
||||
# Pixel Branch
|
||||
self.pix_mlp = MLP(
|
||||
in_channels,
|
||||
mlp_channels,
|
||||
embed_channels,
|
||||
mlp_num_layers,
|
||||
affine_func=dense_affine_func,
|
||||
)
|
||||
# Attention Bias Branch
|
||||
self.attn_mlp = MLP(
|
||||
in_channels,
|
||||
mlp_channels,
|
||||
embed_channels * self.total_heads * self.total_layers,
|
||||
mlp_num_layers,
|
||||
affine_func=dense_affine_func,
|
||||
)
|
||||
if rescale_attn_bias:
|
||||
self.bias_scaling = nn.Linear(1, 1)
|
||||
else:
|
||||
self.bias_scaling = nn.Identity()
|
||||
|
||||
def forward(self, query: torch.Tensor,
|
||||
x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Forward function.
|
||||
Args:
|
||||
query (Tensor): Query Tokens [B,N,C].
|
||||
x (Tensor): Visual features [B,C,H,W]
|
||||
|
||||
Return:
|
||||
mask_preds (Tensor): Mask proposals.
|
||||
attn_bias (List[Tensor]): List of attention bias.
|
||||
"""
|
||||
query = self.query_mlp(query)
|
||||
pix = self.pix_mlp(x)
|
||||
b, c, h, w = pix.shape
|
||||
# preidict mask
|
||||
mask_preds = torch.einsum('bqc,bchw->bqhw', query, pix)
|
||||
# generate attn bias
|
||||
attn = self.attn_mlp(x)
|
||||
attn = attn.reshape(b, self.total_layers, self.total_heads, c, h, w)
|
||||
attn_bias = torch.einsum('bqc,blnchw->blnqhw', query, attn)
|
||||
attn_bias = self.bias_scaling(attn_bias[..., None]).squeeze(-1)
|
||||
attn_bias = attn_bias.chunk(self.total_layers, dim=1)
|
||||
attn_bias = [attn.squeeze(1) for attn in attn_bias]
|
||||
return mask_preds, attn_bias
|
||||
|
||||
|
||||
class SideAdapterNetwork(nn.Module):
|
||||
"""Side Adapter Network for predicting mask proposals and attention bias.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
clip_channels (int): Number of channels of visual features.
|
||||
Default: 768.
|
||||
embed_dims (int): embedding dimension. Default: 240.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
patch_bias (bool): Whether use bias in patch embedding.
|
||||
Default: True.
|
||||
num_queries (int): Number of queries for mask proposals.
|
||||
Default: 100.
|
||||
fusion_index (List[int]): The layer number of the encode
|
||||
transformer to fuse with the CLIP feature.
|
||||
Default: [0, 1, 2, 3].
|
||||
cfg_encoder (ConfigType): Configs for the encode layers.
|
||||
cfg_decoder (ConfigType): Configs for the decode layers.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
clip_channels: int = 768,
|
||||
embed_dims: int = 240,
|
||||
patch_size: int = 16,
|
||||
patch_bias: bool = True,
|
||||
num_queries: int = 100,
|
||||
fusion_index: list = [0, 1, 2, 3],
|
||||
cfg_encoder: ConfigType = ...,
|
||||
cfg_decoder: ConfigType = ...,
|
||||
norm_cfg: dict = dict(type='LN'),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
padding=0,
|
||||
input_size=(640, 640),
|
||||
bias=patch_bias,
|
||||
norm_cfg=None,
|
||||
init_cfg=None,
|
||||
)
|
||||
ori_h, ori_w = self.patch_embed.init_out_size
|
||||
num_patches = ori_h * ori_w
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.randn(1, num_patches, embed_dims) * .02)
|
||||
self.query_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_queries, embed_dims))
|
||||
self.query_embed = nn.Parameter(
|
||||
torch.zeros(1, num_queries, embed_dims))
|
||||
encode_layers = []
|
||||
for i in range(cfg_encoder.num_encode_layer):
|
||||
encode_layers.append(
|
||||
TransformerEncoderLayer(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=cfg_encoder.num_heads,
|
||||
feedforward_channels=cfg_encoder.mlp_ratio * embed_dims,
|
||||
norm_cfg=norm_cfg))
|
||||
self.encode_layers = nn.ModuleList(encode_layers)
|
||||
conv_clips = []
|
||||
for i in range(len(fusion_index)):
|
||||
conv_clips.append(
|
||||
nn.Sequential(
|
||||
LayerNorm2d(clip_channels),
|
||||
ConvModule(
|
||||
clip_channels,
|
||||
embed_dims,
|
||||
kernel_size=1,
|
||||
norm_cfg=None,
|
||||
act_cfg=None)))
|
||||
self.conv_clips = nn.ModuleList(conv_clips)
|
||||
self.fusion_index = fusion_index
|
||||
self.mask_decoder = MLPMaskDecoder(
|
||||
in_channels=embed_dims,
|
||||
total_heads=cfg_decoder.num_heads,
|
||||
total_layers=cfg_decoder.num_layers,
|
||||
embed_channels=cfg_decoder.embed_channels,
|
||||
mlp_channels=cfg_decoder.mlp_channels,
|
||||
mlp_num_layers=cfg_decoder.num_mlp,
|
||||
rescale_attn_bias=cfg_decoder.rescale)
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
nn.init.normal_(self.query_embed, std=0.02)
|
||||
nn.init.normal_(self.query_pos_embed, std=0.02)
|
||||
for i in range(len(self.conv_clips)):
|
||||
caffe2_xavier_init(self.conv_clips[i][1].conv)
|
||||
|
||||
def fuse_clip(self, fused_index: int, x: torch.Tensor,
|
||||
clip_feature: torch.Tensor, hwshape: Tuple[int,
|
||||
int], L: int):
|
||||
"""Fuse CLIP feature and visual tokens."""
|
||||
fused_clip = (resize(
|
||||
self.conv_clips[fused_index](clip_feature.contiguous()),
|
||||
size=hwshape,
|
||||
mode='bilinear',
|
||||
align_corners=False)).permute(0, 2, 3, 1).reshape(x[:, -L:,
|
||||
...].shape)
|
||||
x = torch.cat([x[:, :-L, ...], x[:, -L:, ...] + fused_clip], dim=1)
|
||||
return x
|
||||
|
||||
def encode_feature(self, image: torch.Tensor,
|
||||
clip_features: List[torch.Tensor],
|
||||
deep_supervision_idxs: List[int]) -> List[List]:
|
||||
"""Encode images by a lightweight vision transformer."""
|
||||
assert len(self.fusion_index) == len(clip_features)
|
||||
x, hwshape = self.patch_embed(image)
|
||||
ori_h, ori_w = self.patch_embed.init_out_size
|
||||
pos_embed = self.pos_embed
|
||||
if self.pos_embed.shape[1] != x.shape[1]:
|
||||
# resize the position embedding
|
||||
pos_embed = (
|
||||
resize(
|
||||
self.pos_embed.reshape(1, ori_h, ori_w,
|
||||
-1).permute(0, 3, 1, 2),
|
||||
size=hwshape,
|
||||
mode='bicubic',
|
||||
align_corners=False,
|
||||
).flatten(2).permute(0, 2, 1))
|
||||
pos_embed = torch.cat([
|
||||
self.query_pos_embed.expand(pos_embed.shape[0], -1, -1), pos_embed
|
||||
],
|
||||
dim=1)
|
||||
x = torch.cat([self.query_embed.expand(x.shape[0], -1, -1), x], dim=1)
|
||||
x = x + pos_embed
|
||||
L = hwshape[0] * hwshape[1]
|
||||
fused_index = 0
|
||||
if self.fusion_index[fused_index] == 0:
|
||||
x = self.fuse_clip(fused_index, x, clip_features[0][0], hwshape, L)
|
||||
fused_index += 1
|
||||
outs = []
|
||||
for index, block in enumerate(self.encode_layers, start=1):
|
||||
x = block(x)
|
||||
if index < len(self.fusion_index
|
||||
) and index == self.fusion_index[fused_index]:
|
||||
x = self.fuse_clip(fused_index, x,
|
||||
clip_features[fused_index][0], hwshape, L)
|
||||
fused_index += 1
|
||||
x_query = x[:, :-L, ...]
|
||||
x_feat = x[:, -L:, ...].permute(0, 2, 1)\
|
||||
.reshape(x.shape[0], x.shape[-1], hwshape[0], hwshape[1])
|
||||
|
||||
if index in deep_supervision_idxs or index == len(
|
||||
self.encode_layers):
|
||||
outs.append({'query': x_query, 'x': x_feat})
|
||||
|
||||
if index < len(self.encode_layers):
|
||||
x = x + pos_embed
|
||||
return outs
|
||||
|
||||
def decode_feature(self, features):
|
||||
mask_embeds = []
|
||||
attn_biases = []
|
||||
for feature in features:
|
||||
mask_embed, attn_bias = self.mask_decoder(**feature)
|
||||
mask_embeds.append(mask_embed)
|
||||
attn_biases.append(attn_bias)
|
||||
return mask_embeds, attn_biases
|
||||
|
||||
def forward(
|
||||
self, image: torch.Tensor, clip_features: List[torch.Tensor],
|
||||
deep_supervision_idxs: List[int]
|
||||
) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
|
||||
"""Forward function."""
|
||||
features = self.encode_feature(image, clip_features,
|
||||
deep_supervision_idxs)
|
||||
mask_embeds, attn_biases = self.decode_feature(features)
|
||||
return mask_embeds, attn_biases
|
||||
|
||||
|
||||
class RecWithAttnbias(nn.Module):
|
||||
"""Mask recognition module by applying the attention biases to rest deeper
|
||||
CLIP layers.
|
||||
|
||||
Args:
|
||||
sos_token_format (str): The format of sos token. It should be
|
||||
chosen from ["cls_token", "learnable_token", "pos_embedding"].
|
||||
Default: 'cls_token'.
|
||||
sos_token_num (int): Number of sos token. It should be equal to
|
||||
the number of quries. Default: 100.
|
||||
num_layers (int): Number of rest CLIP layers for mask recognition.
|
||||
Default: 3.
|
||||
cross_attn (bool): Whether use cross attention to update sos token.
|
||||
Default: False.
|
||||
embed_dims (int): The feature dimension of CLIP layers.
|
||||
Default: 768.
|
||||
num_heads (int): Parallel attention heads of CLIP layers.
|
||||
Default: 768.
|
||||
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
qkv_bias (bool): Whether to use bias in multihead-attention.
|
||||
Default: True.
|
||||
out_dims (int): Number of channels of the output mask proposals.
|
||||
It should be equal to the out_dims of text_encoder.
|
||||
Default: 512.
|
||||
final_norm (True): Whether use norm layer for sos token.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
frozen_exclude (List): List of parameters that are not to be frozen.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
sos_token_format: str = 'cls_token',
|
||||
sos_token_num: int = 100,
|
||||
num_layers: int = 3,
|
||||
cross_attn: bool = False,
|
||||
embed_dims: int = 768,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: int = 4,
|
||||
num_fcs: int = 2,
|
||||
qkv_bias: bool = True,
|
||||
out_dims: int = 512,
|
||||
final_norm: bool = True,
|
||||
act_cfg: dict = dict(type='GELU'),
|
||||
norm_cfg: dict = dict(type='LN'),
|
||||
frozen_exclude: List = []):
|
||||
super().__init__()
|
||||
|
||||
assert sos_token_format in [
|
||||
'cls_token', 'learnable_token', 'pos_embedding'
|
||||
]
|
||||
self.sos_token_format = sos_token_format
|
||||
self.sos_token_num = sos_token_num
|
||||
self.frozen_exclude = frozen_exclude
|
||||
self.cross_attn = cross_attn
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
if sos_token_format in ['learnable_token', 'pos_embedding']:
|
||||
self.sos_token = nn.Parameter(
|
||||
torch.randn(sos_token_num, 1, self.proj.shape[0]))
|
||||
self.frozen.append('sos_token')
|
||||
|
||||
layers = []
|
||||
for i in range(num_layers):
|
||||
layers.append(
|
||||
BaseTransformerLayer(
|
||||
attn_cfgs=dict(
|
||||
type='MultiheadAttention',
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
batch_first=False,
|
||||
bias=qkv_bias),
|
||||
ffn_cfgs=dict(
|
||||
type='FFN',
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=mlp_ratio * embed_dims,
|
||||
act_cfg=act_cfg),
|
||||
operation_order=('norm', 'self_attn', 'norm', 'ffn')))
|
||||
self.layers = nn.ModuleList(layers)
|
||||
|
||||
self.ln_post = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.proj = nn.Linear(embed_dims, out_dims, bias=False)
|
||||
|
||||
self.final_norm = final_norm
|
||||
self._freeze()
|
||||
|
||||
def init_weights(self, rec_state_dict):
|
||||
if hasattr(self, 'sos_token'):
|
||||
normal_init(self.sos_token, std=0.02)
|
||||
if rec_state_dict is not None:
|
||||
load_state_dict(self, rec_state_dict, strict=False, logger=None)
|
||||
else:
|
||||
super().init_weights()
|
||||
|
||||
def _freeze(self):
|
||||
if 'all' in self.frozen_exclude:
|
||||
return
|
||||
for name, param in self.named_parameters():
|
||||
if not any([exclude in name for exclude in self.frozen_exclude]):
|
||||
param.requires_grad = False
|
||||
|
||||
def _build_attn_biases(self, attn_biases, target_shape):
|
||||
formatted_attn_biases = []
|
||||
for attn_bias in attn_biases:
|
||||
# convert it to proper format: N*num_head,L,L
|
||||
# attn_bias: [N, num_head/1, num_sos,H,W]
|
||||
n, num_head, num_sos, h, w = attn_bias.shape
|
||||
# reshape and downsample
|
||||
attn_bias = F.adaptive_max_pool2d(
|
||||
attn_bias.reshape(n, num_head * num_sos, h, w),
|
||||
output_size=target_shape)
|
||||
attn_bias = attn_bias.reshape(n, num_head, num_sos, *target_shape)
|
||||
|
||||
true_num_head = self.num_heads
|
||||
assert (num_head == 1 or num_head
|
||||
== true_num_head), f'num_head={num_head} is not supported.'
|
||||
if num_head == 1:
|
||||
attn_bias = attn_bias.repeat(1, true_num_head, 1, 1, 1)
|
||||
attn_bias = attn_bias.reshape(n * true_num_head, num_sos, -1)
|
||||
L = attn_bias.shape[-1]
|
||||
if self.cross_attn:
|
||||
# [n*num_head, num_sos, L]
|
||||
formatted_attn_biases.append(attn_bias)
|
||||
else:
|
||||
# [n*num_head, num_sos+1+L, num_sos+1+L]
|
||||
new_attn_bias = attn_bias.new_zeros(num_sos + 1 + L,
|
||||
num_sos + 1 + L)
|
||||
new_attn_bias[:, :num_sos] = -100
|
||||
new_attn_bias[torch.arange(num_sos), torch.arange(num_sos)] = 0
|
||||
new_attn_bias[:num_sos, num_sos] = -100
|
||||
new_attn_bias = (
|
||||
new_attn_bias[None, ...].expand(n * true_num_head, -1,
|
||||
-1).clone())
|
||||
new_attn_bias[..., :num_sos, -L:] = attn_bias
|
||||
formatted_attn_biases.append(new_attn_bias)
|
||||
|
||||
if len(formatted_attn_biases) == 1:
|
||||
formatted_attn_biases = [
|
||||
formatted_attn_biases[0] for _ in range(self.num_layers)
|
||||
]
|
||||
return formatted_attn_biases
|
||||
|
||||
def forward(self, bias: List[Tensor], feature: List[Tensor]):
|
||||
"""Forward function to recognize the category of masks
|
||||
Args:
|
||||
bias (List[Tensor]): Attention bias for transformer layers
|
||||
feature (List[Tensor]): Output of the image encoder,
|
||||
including cls_token and img_feature.
|
||||
"""
|
||||
cls_token = feature[1].unsqueeze(0)
|
||||
img_feature = feature[0]
|
||||
b, c, h, w = img_feature.shape
|
||||
# construct clip shadow features
|
||||
x = torch.cat(
|
||||
[cls_token,
|
||||
img_feature.reshape(b, c, -1).permute(2, 0, 1)])
|
||||
|
||||
# construct sos token
|
||||
if self.sos_token_format == 'cls_token':
|
||||
sos_token = cls_token.repeat(self.sos_token_num, 1, 1)
|
||||
elif self.sos_token_format == 'learnable_token':
|
||||
sos_token = self.sos_token.expand(-1, b, -1)
|
||||
elif self.sos_token_format == 'pos_embedding':
|
||||
sos_token = self.sos_token.expand(-1, b, -1) + cls_token
|
||||
|
||||
# construct attn bias
|
||||
attn_biases = self._build_attn_biases(bias, target_shape=(h, w))
|
||||
|
||||
if self.cross_attn:
|
||||
for i, block in enumerate(self.layers):
|
||||
if self.cross_attn:
|
||||
sos_token = cross_attn_layer(
|
||||
block,
|
||||
sos_token,
|
||||
x[1:, ],
|
||||
attn_biases[i],
|
||||
)
|
||||
if i < len(self.layers) - 1:
|
||||
x = block(x)
|
||||
else:
|
||||
x = torch.cat([sos_token, x], dim=0)
|
||||
for i, block in enumerate(self.layers):
|
||||
x = block(x, attn_masks=[attn_biases[i]])
|
||||
sos_token = x[:self.sos_token_num]
|
||||
|
||||
sos_token = sos_token.permute(1, 0, 2) # LND -> NLD
|
||||
sos_token = self.ln_post(sos_token)
|
||||
sos_token = self.proj(sos_token)
|
||||
if self.final_norm:
|
||||
sos_token = F.normalize(sos_token, dim=-1)
|
||||
return sos_token
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SideAdapterCLIPHead(BaseDecodeHead):
|
||||
"""Side Adapter Network (SAN) for open-vocabulary semantic segmentation
|
||||
with pre-trained vision-language model.
|
||||
|
||||
This decode head is the implementation of `Side Adapter Network
|
||||
for Open-Vocabulary Semantic Segmentation`
|
||||
<https://arxiv.org/abs/2302.12242>.
|
||||
Modified from https://github.com/MendelXu/SAN/blob/main/san/model/side_adapter/side_adapter.py # noqa:E501
|
||||
Copyright (c) 2023 MendelXu.
|
||||
Licensed under the MIT License
|
||||
|
||||
Args:
|
||||
num_classes (int): the number of classes.
|
||||
san_cfg (ConfigType): Configs for SideAdapterNetwork module
|
||||
maskgen_cfg (ConfigType): Configs for RecWithAttnbias module
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes: int, san_cfg: ConfigType,
|
||||
maskgen_cfg: ConfigType, deep_supervision_idxs: List[int],
|
||||
train_cfg: ConfigType, **kwargs):
|
||||
super().__init__(
|
||||
in_channels=san_cfg.in_channels,
|
||||
channels=san_cfg.embed_dims,
|
||||
num_classes=num_classes,
|
||||
**kwargs)
|
||||
assert san_cfg.num_queries == maskgen_cfg.sos_token_num, \
|
||||
'num_queries in san_cfg should be equal to sos_token_num ' \
|
||||
'in maskgen_cfg'
|
||||
del self.conv_seg
|
||||
self.side_adapter_network = SideAdapterNetwork(**san_cfg)
|
||||
self.rec_with_attnbias = RecWithAttnbias(**maskgen_cfg)
|
||||
self.deep_supervision_idxs = deep_supervision_idxs
|
||||
self.train_cfg = train_cfg
|
||||
if train_cfg:
|
||||
self.match_masks = MatchMasks(
|
||||
num_points=train_cfg.num_points,
|
||||
num_queries=san_cfg.num_queries,
|
||||
num_classes=num_classes,
|
||||
assigner=train_cfg.assigner)
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
rec_state_dict = None
|
||||
if isinstance(self.init_cfg, dict) and \
|
||||
self.init_cfg.get('type') == 'Pretrained_Part':
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
|
||||
rec_state_dict = checkpoint.copy()
|
||||
para_prefix = 'decode_head.rec_with_attnbias'
|
||||
prefix_len = len(para_prefix) + 1
|
||||
for k, v in checkpoint.items():
|
||||
rec_state_dict.pop(k)
|
||||
if para_prefix in k:
|
||||
rec_state_dict[k[prefix_len:]] = v
|
||||
|
||||
self.side_adapter_network.init_weights()
|
||||
self.rec_with_attnbias.init_weights(rec_state_dict)
|
||||
|
||||
def forward(self, inputs: Tuple[Tensor],
|
||||
deep_supervision_idxs) -> Tuple[List]:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
inputs (Tuple[Tensor]): A triplet including images,
|
||||
list of multi-level visual features from image encoder and
|
||||
class embeddings from text_encoder.
|
||||
|
||||
Returns:
|
||||
mask_props (List[Tensor]): Mask proposals predicted by SAN.
|
||||
mask_logits (List[Tensor]): Class logits of mask proposals.
|
||||
"""
|
||||
imgs, clip_feature, class_embeds = inputs
|
||||
# predict mask proposals and attention bias
|
||||
mask_props, attn_biases = self.side_adapter_network(
|
||||
imgs, clip_feature, deep_supervision_idxs)
|
||||
|
||||
# mask recognition with attention bias
|
||||
mask_embeds = [
|
||||
self.rec_with_attnbias(att_bias, clip_feature[-1])
|
||||
for att_bias in attn_biases
|
||||
]
|
||||
# Obtain class prediction of masks by comparing the similarity
|
||||
# between the image token and the text embedding of class names.
|
||||
mask_logits = [
|
||||
torch.einsum('bqc,nc->bqn', mask_embed, class_embeds)
|
||||
for mask_embed in mask_embeds
|
||||
]
|
||||
return mask_props, mask_logits
|
||||
|
||||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType) -> Tensor:
|
||||
"""Forward function for prediction.
|
||||
|
||||
Args:
|
||||
inputs (Tuple[Tensor]): Images, visual features from image encoder
|
||||
and class embedding from text encoder.
|
||||
batch_img_metas (dict): List Image info where each dict may also
|
||||
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Outputs segmentation logits map.
|
||||
"""
|
||||
mask_props, mask_logits = self.forward(inputs, [])
|
||||
|
||||
return self.predict_by_feat([mask_props[-1], mask_logits[-1]],
|
||||
batch_img_metas)
|
||||
|
||||
def predict_by_feat(self, seg_logits: List[Tensor],
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""1. Transform a batch of mask proposals to the input shape.
|
||||
2. Generate segmentation map with mask proposals and class logits.
|
||||
"""
|
||||
mask_pred = seg_logits[0]
|
||||
cls_score = seg_logits[1]
|
||||
if isinstance(batch_img_metas[0]['img_shape'], torch.Size):
|
||||
# slide inference
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
elif 'pad_shape' in batch_img_metas[0]:
|
||||
size = batch_img_metas[0]['pad_shape'][:2]
|
||||
else:
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
# upsample mask
|
||||
mask_pred = F.interpolate(
|
||||
mask_pred, size=size, mode='bilinear', align_corners=False)
|
||||
|
||||
mask_cls = F.softmax(cls_score, dim=-1)[..., :-1]
|
||||
mask_pred = mask_pred.sigmoid()
|
||||
seg_logits = torch.einsum('bqc,bqhw->bchw', mask_cls, mask_pred)
|
||||
return seg_logits
|
||||
|
||||
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
|
||||
train_cfg: ConfigType) -> dict:
|
||||
"""Perform forward propagation and loss calculation of the decoder head
|
||||
on the features of the upstream network.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the upstream
|
||||
network, each is a 4D-tensor.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
train_cfg (ConfigType): Training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components.
|
||||
"""
|
||||
# batch SegDataSample to InstanceDataSample
|
||||
batch_gt_instances = seg_data_to_instance_data(self.ignore_index,
|
||||
batch_data_samples)
|
||||
|
||||
# forward
|
||||
all_mask_props, all_mask_logits = self.forward(
|
||||
x, self.deep_supervision_idxs)
|
||||
|
||||
# loss
|
||||
losses = self.loss_by_feat(all_mask_logits, all_mask_props,
|
||||
batch_gt_instances)
|
||||
|
||||
return losses
|
||||
|
||||
def loss_by_feat(
|
||||
self, all_cls_scores: Tensor, all_mask_preds: Tensor,
|
||||
batch_gt_instances: List[InstanceData]) -> Dict[str, Tensor]:
|
||||
"""Loss function.
|
||||
|
||||
Args:
|
||||
all_cls_scores (Tensor): Classification scores for all decoder
|
||||
layers with shape (num_decoder, batch_size, num_queries,
|
||||
cls_out_channels). Note `cls_out_channels` should includes
|
||||
background.
|
||||
all_mask_preds (Tensor): Mask scores for all decoder layers with
|
||||
shape (num_decoder, batch_size, num_queries, h, w).
|
||||
batch_gt_instances (list[obj:`InstanceData`]): each contains
|
||||
``labels`` and ``masks``.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
num_dec_layers = len(all_cls_scores)
|
||||
batch_gt_instances_list = [
|
||||
batch_gt_instances for _ in range(num_dec_layers)
|
||||
]
|
||||
|
||||
losses = []
|
||||
for i in range(num_dec_layers):
|
||||
cls_scores = all_cls_scores[i]
|
||||
mask_preds = all_mask_preds[i]
|
||||
# matching N mask predictions to K category labels
|
||||
(labels, mask_targets, mask_weights,
|
||||
avg_factor) = self.match_masks.get_targets(
|
||||
cls_scores, mask_preds, batch_gt_instances_list[i])
|
||||
cls_scores = cls_scores.flatten(0, 1)
|
||||
labels = labels.flatten(0, 1)
|
||||
num_total_masks = cls_scores.new_tensor([avg_factor],
|
||||
dtype=torch.float)
|
||||
all_reduce(num_total_masks, op='mean')
|
||||
num_total_masks = max(num_total_masks, 1)
|
||||
|
||||
# extract positive ones
|
||||
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
|
||||
mask_preds = mask_preds[mask_weights > 0]
|
||||
|
||||
if mask_targets.shape[0] != 0:
|
||||
with torch.no_grad():
|
||||
points_coords = get_uncertain_point_coords_with_randomness(
|
||||
mask_preds.unsqueeze(1), None,
|
||||
self.train_cfg.num_points,
|
||||
self.train_cfg.oversample_ratio,
|
||||
self.train_cfg.importance_sample_ratio)
|
||||
# shape (num_total_gts, h, w)
|
||||
# -> (num_total_gts, num_points)
|
||||
mask_point_targets = point_sample(
|
||||
mask_targets.unsqueeze(1).float(),
|
||||
points_coords).squeeze(1)
|
||||
# shape (num_queries, h, w) -> (num_queries, num_points)
|
||||
mask_point_preds = point_sample(
|
||||
mask_preds.unsqueeze(1), points_coords).squeeze(1)
|
||||
|
||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.loss_decode]
|
||||
else:
|
||||
losses_decode = self.loss_decode
|
||||
loss = dict()
|
||||
for loss_decode in losses_decode:
|
||||
if 'loss_cls' in loss_decode.loss_name:
|
||||
if loss_decode.loss_name == 'loss_cls_ce':
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
cls_scores, labels)
|
||||
else:
|
||||
assert False, "Only support 'CrossEntropyLoss' in" \
|
||||
' classification loss'
|
||||
|
||||
elif 'loss_mask' in loss_decode.loss_name:
|
||||
if mask_targets.shape[0] == 0:
|
||||
loss[loss_decode.loss_name] = mask_preds.sum()
|
||||
elif loss_decode.loss_name == 'loss_mask_ce':
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
mask_point_preds,
|
||||
mask_point_targets,
|
||||
avg_factor=num_total_masks *
|
||||
self.train_cfg.num_points)
|
||||
elif loss_decode.loss_name == 'loss_mask_dice':
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
mask_point_preds,
|
||||
mask_point_targets,
|
||||
avg_factor=num_total_masks)
|
||||
else:
|
||||
assert False, "Only support 'CrossEntropyLoss' and" \
|
||||
" 'DiceLoss' in mask loss"
|
||||
else:
|
||||
assert False, "Only support for 'loss_cls' and 'loss_mask'"
|
||||
|
||||
losses.append(loss)
|
||||
|
||||
loss_dict = dict()
|
||||
# loss from the last decoder layer
|
||||
loss_dict.update(losses[-1])
|
||||
# loss from other decoder layers
|
||||
for i, loss in enumerate(losses[:-1]):
|
||||
for k, v in loss.items():
|
||||
loss_dict[f'd{self.deep_supervision_idxs[i]}.{k}'] = v
|
||||
return loss_dict
|
||||
66
finetune/mmseg/models/decode_heads/segformer_head.py
Normal file
66
finetune/mmseg/models/decode_heads/segformer_head.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SegformerHead(BaseDecodeHead):
|
||||
"""The all mlp Head of segformer.
|
||||
|
||||
This head is the implementation of
|
||||
`Segformer <https://arxiv.org/abs/2105.15203>` _.
|
||||
|
||||
Args:
|
||||
interpolate_mode: The interpolate mode of MLP head upsample operation.
|
||||
Default: 'bilinear'.
|
||||
"""
|
||||
|
||||
def __init__(self, interpolate_mode='bilinear', **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
|
||||
self.interpolate_mode = interpolate_mode
|
||||
num_inputs = len(self.in_channels)
|
||||
|
||||
assert num_inputs == len(self.in_index)
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(num_inputs):
|
||||
self.convs.append(
|
||||
ConvModule(
|
||||
in_channels=self.in_channels[i],
|
||||
out_channels=self.channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
|
||||
self.fusion_conv = ConvModule(
|
||||
in_channels=self.channels * num_inputs,
|
||||
out_channels=self.channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=self.norm_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
# Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
|
||||
inputs = self._transform_inputs(inputs)
|
||||
outs = []
|
||||
for idx in range(len(inputs)):
|
||||
x = inputs[idx]
|
||||
conv = self.convs[idx]
|
||||
outs.append(
|
||||
resize(
|
||||
input=conv(x),
|
||||
size=inputs[0].shape[2:],
|
||||
mode=self.interpolate_mode,
|
||||
align_corners=self.align_corners))
|
||||
|
||||
out = self.fusion_conv(torch.cat(outs, dim=1))
|
||||
|
||||
out = self.cls_seg(out)
|
||||
|
||||
return out
|
||||
132
finetune/mmseg/models/decode_heads/segmenter_mask_head.py
Normal file
132
finetune/mmseg/models/decode_heads/segmenter_mask_head.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmengine.model import ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, trunc_normal_,
|
||||
trunc_normal_init)
|
||||
|
||||
from mmseg.models.backbones.vit import TransformerEncoderLayer
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SegmenterMaskTransformerHead(BaseDecodeHead):
|
||||
"""Segmenter: Transformer for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of
|
||||
`Segmenter: <https://arxiv.org/abs/2105.05633>`_.
|
||||
|
||||
Args:
|
||||
backbone_cfg:(dict): Config of backbone of
|
||||
Context Path.
|
||||
in_channels (int): The number of channels of input image.
|
||||
num_layers (int): The depth of transformer.
|
||||
num_heads (int): The number of attention heads.
|
||||
embed_dims (int): The number of embedding dimension.
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.1.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.0
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: True.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
init_std (float): The value of std in weight initialization.
|
||||
Default: 0.02.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
num_layers,
|
||||
num_heads,
|
||||
embed_dims,
|
||||
mlp_ratio=4,
|
||||
drop_path_rate=0.1,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
num_fcs=2,
|
||||
qkv_bias=True,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
init_std=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(in_channels=in_channels, **kwargs)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
|
||||
self.layers = ModuleList()
|
||||
for i in range(num_layers):
|
||||
self.layers.append(
|
||||
TransformerEncoderLayer(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=mlp_ratio * embed_dims,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
batch_first=True,
|
||||
))
|
||||
|
||||
self.dec_proj = nn.Linear(in_channels, embed_dims)
|
||||
|
||||
self.cls_emb = nn.Parameter(
|
||||
torch.randn(1, self.num_classes, embed_dims))
|
||||
self.patch_proj = nn.Linear(embed_dims, embed_dims, bias=False)
|
||||
self.classes_proj = nn.Linear(embed_dims, embed_dims, bias=False)
|
||||
|
||||
self.decoder_norm = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=1)[1]
|
||||
self.mask_norm = build_norm_layer(
|
||||
norm_cfg, self.num_classes, postfix=2)[1]
|
||||
|
||||
self.init_std = init_std
|
||||
|
||||
delattr(self, 'conv_seg')
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.cls_emb, std=self.init_std)
|
||||
trunc_normal_init(self.patch_proj, std=self.init_std)
|
||||
trunc_normal_init(self.classes_proj, std=self.init_std)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=self.init_std, bias=0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
constant_init(m, val=1.0, bias=0.0)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._transform_inputs(inputs)
|
||||
b, c, h, w = x.shape
|
||||
x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c)
|
||||
|
||||
x = self.dec_proj(x)
|
||||
cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
|
||||
x = torch.cat((x, cls_emb), 1)
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
patches = self.patch_proj(x[:, :-self.num_classes])
|
||||
cls_seg_feat = self.classes_proj(x[:, -self.num_classes:])
|
||||
|
||||
patches = F.normalize(patches, dim=2, p=2)
|
||||
cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2)
|
||||
|
||||
masks = patches @ cls_seg_feat.transpose(1, 2)
|
||||
masks = self.mask_norm(masks)
|
||||
masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w)
|
||||
|
||||
return masks
|
||||
102
finetune/mmseg/models/decode_heads/sep_aspp_head.py
Normal file
102
finetune/mmseg/models/decode_heads/sep_aspp_head.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .aspp_head import ASPPHead, ASPPModule
|
||||
|
||||
|
||||
class DepthwiseSeparableASPPModule(ASPPModule):
|
||||
"""Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable
|
||||
conv."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for i, dilation in enumerate(self.dilations):
|
||||
if dilation > 1:
|
||||
self[i] = DepthwiseSeparableConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
dilation=dilation,
|
||||
padding=dilation,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DepthwiseSeparableASPPHead(ASPPHead):
|
||||
"""Encoder-Decoder with Atrous Separable Convolution for Semantic Image
|
||||
Segmentation.
|
||||
|
||||
This head is the implementation of `DeepLabV3+
|
||||
<https://arxiv.org/abs/1802.02611>`_.
|
||||
|
||||
Args:
|
||||
c1_in_channels (int): The input channels of c1 decoder. If is 0,
|
||||
the no decoder will be used.
|
||||
c1_channels (int): The intermediate channels of c1 decoder.
|
||||
"""
|
||||
|
||||
def __init__(self, c1_in_channels, c1_channels, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert c1_in_channels >= 0
|
||||
self.aspp_modules = DepthwiseSeparableASPPModule(
|
||||
dilations=self.dilations,
|
||||
in_channels=self.in_channels,
|
||||
channels=self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
if c1_in_channels > 0:
|
||||
self.c1_bottleneck = ConvModule(
|
||||
c1_in_channels,
|
||||
c1_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
else:
|
||||
self.c1_bottleneck = None
|
||||
self.sep_bottleneck = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
self.channels + c1_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
DepthwiseSeparableConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
aspp_outs = [
|
||||
resize(
|
||||
self.image_pool(x),
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
]
|
||||
aspp_outs.extend(self.aspp_modules(x))
|
||||
aspp_outs = torch.cat(aspp_outs, dim=1)
|
||||
output = self.bottleneck(aspp_outs)
|
||||
if self.c1_bottleneck is not None:
|
||||
c1_output = self.c1_bottleneck(inputs[0])
|
||||
output = resize(
|
||||
input=output,
|
||||
size=c1_output.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
output = torch.cat([output, c1_output], dim=1)
|
||||
output = self.sep_bottleneck(output)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
60
finetune/mmseg/models/decode_heads/sep_fcn_head.py
Normal file
60
finetune/mmseg/models/decode_heads/sep_fcn_head.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.cnn import DepthwiseSeparableConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DepthwiseSeparableFCNHead(FCNHead):
|
||||
"""Depthwise-Separable Fully Convolutional Network for Semantic
|
||||
Segmentation.
|
||||
|
||||
This head is implemented according to `Fast-SCNN: Fast Semantic
|
||||
Segmentation Network <https://arxiv.org/abs/1902.04502>`_.
|
||||
|
||||
Args:
|
||||
in_channels(int): Number of output channels of FFM.
|
||||
channels(int): Number of middle-stage channels in the decode head.
|
||||
concat_input(bool): Whether to concatenate original decode input into
|
||||
the result of several consecutive convolution layers.
|
||||
Default: True.
|
||||
num_classes(int): Used to determine the dimension of
|
||||
final prediction tensor.
|
||||
in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
loss_decode(dict): Config of loss type and some
|
||||
relevant additional options.
|
||||
dw_act_cfg (dict):Activation config of depthwise ConvModule. If it is
|
||||
'default', it will be the same as `act_cfg`. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, dw_act_cfg=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.convs[0] = DepthwiseSeparableConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.kernel_size // 2,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dw_act_cfg=dw_act_cfg)
|
||||
|
||||
for i in range(1, self.num_convs):
|
||||
self.convs[i] = DepthwiseSeparableConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.kernel_size // 2,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dw_act_cfg=dw_act_cfg)
|
||||
|
||||
if self.concat_input:
|
||||
self.conv_cat = DepthwiseSeparableConvModule(
|
||||
self.in_channels + self.channels,
|
||||
self.channels,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.kernel_size // 2,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dw_act_cfg=dw_act_cfg)
|
||||
62
finetune/mmseg/models/decode_heads/setr_mla_head.py
Normal file
62
finetune/mmseg/models/decode_heads/setr_mla_head.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SETRMLAHead(BaseDecodeHead):
|
||||
"""Multi level feature aggretation head of SETR.
|
||||
|
||||
MLA head of `SETR <https://arxiv.org/pdf/2012.15840.pdf>`_.
|
||||
|
||||
Args:
|
||||
mlahead_channels (int): Channels of conv-conv-4x of multi-level feature
|
||||
aggregation. Default: 128.
|
||||
up_scale (int): The scale factor of interpolate. Default:4.
|
||||
"""
|
||||
|
||||
def __init__(self, mla_channels=128, up_scale=4, **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
self.mla_channels = mla_channels
|
||||
|
||||
num_inputs = len(self.in_channels)
|
||||
|
||||
# Refer to self.cls_seg settings of BaseDecodeHead
|
||||
assert self.channels == num_inputs * mla_channels
|
||||
|
||||
self.up_convs = nn.ModuleList()
|
||||
for i in range(num_inputs):
|
||||
self.up_convs.append(
|
||||
nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.in_channels[i],
|
||||
out_channels=mla_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
in_channels=mla_channels,
|
||||
out_channels=mla_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
Upsample(
|
||||
scale_factor=up_scale,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)))
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs = self._transform_inputs(inputs)
|
||||
outs = []
|
||||
for x, up_conv in zip(inputs, self.up_convs):
|
||||
outs.append(up_conv(x))
|
||||
out = torch.cat(outs, dim=1)
|
||||
out = self.cls_seg(out)
|
||||
return out
|
||||
81
finetune/mmseg/models/decode_heads/setr_up_head.py
Normal file
81
finetune/mmseg/models/decode_heads/setr_up_head.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SETRUPHead(BaseDecodeHead):
|
||||
"""Naive upsampling head and Progressive upsampling head of SETR.
|
||||
|
||||
Naive or PUP head of `SETR <https://arxiv.org/pdf/2012.15840.pdf>`_.
|
||||
|
||||
Args:
|
||||
norm_layer (dict): Config dict for input normalization.
|
||||
Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True).
|
||||
num_convs (int): Number of decoder convolutions. Default: 1.
|
||||
up_scale (int): The scale factor of interpolate. Default:4.
|
||||
kernel_size (int): The kernel size of convolution when decoding
|
||||
feature information from backbone. Default: 3.
|
||||
init_cfg (dict | list[dict] | None): Initialization config dict.
|
||||
Default: dict(
|
||||
type='Constant', val=1.0, bias=0, layer='LayerNorm').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
norm_layer=dict(type='LN', eps=1e-6, requires_grad=True),
|
||||
num_convs=1,
|
||||
up_scale=4,
|
||||
kernel_size=3,
|
||||
init_cfg=[
|
||||
dict(type='Constant', val=1.0, bias=0, layer='LayerNorm'),
|
||||
dict(
|
||||
type='Normal',
|
||||
std=0.01,
|
||||
override=dict(name='conv_seg'))
|
||||
],
|
||||
**kwargs):
|
||||
|
||||
assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.'
|
||||
|
||||
super().__init__(init_cfg=init_cfg, **kwargs)
|
||||
|
||||
assert isinstance(self.in_channels, int)
|
||||
|
||||
_, self.norm = build_norm_layer(norm_layer, self.in_channels)
|
||||
|
||||
self.up_convs = nn.ModuleList()
|
||||
in_channels = self.in_channels
|
||||
out_channels = self.channels
|
||||
for _ in range(num_convs):
|
||||
self.up_convs.append(
|
||||
nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=int(kernel_size - 1) // 2,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
Upsample(
|
||||
scale_factor=up_scale,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)))
|
||||
in_channels = out_channels
|
||||
|
||||
def forward(self, x):
|
||||
x = self._transform_inputs(x)
|
||||
|
||||
n, c, h, w = x.shape
|
||||
x = x.reshape(n, c, h * w).transpose(2, 1).contiguous()
|
||||
x = self.norm(x)
|
||||
x = x.transpose(1, 2).reshape(n, c, h, w).contiguous()
|
||||
|
||||
for up_conv in self.up_convs:
|
||||
x = up_conv(x)
|
||||
out = self.cls_seg(x)
|
||||
return out
|
||||
97
finetune/mmseg/models/decode_heads/stdc_head.py
Normal file
97
finetune/mmseg/models/decode_heads/stdc_head.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.structures import PixelData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import SampleList
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class STDCHead(FCNHead):
|
||||
"""This head is the implementation of `Rethinking BiSeNet For Real-time
|
||||
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
|
||||
|
||||
Args:
|
||||
boundary_threshold (float): The threshold of calculating boundary.
|
||||
Default: 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self, boundary_threshold=0.1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.boundary_threshold = boundary_threshold
|
||||
# Using register buffer to make laplacian kernel on the same
|
||||
# device of `seg_label`.
|
||||
self.register_buffer(
|
||||
'laplacian_kernel',
|
||||
torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1],
|
||||
dtype=torch.float32,
|
||||
requires_grad=False).reshape((1, 1, 3, 3)))
|
||||
self.fusion_kernel = torch.nn.Parameter(
|
||||
torch.tensor([[6. / 10], [3. / 10], [1. / 10]],
|
||||
dtype=torch.float32).reshape(1, 3, 1, 1),
|
||||
requires_grad=False)
|
||||
|
||||
def loss_by_feat(self, seg_logits: Tensor,
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
"""Compute Detail Aggregation Loss."""
|
||||
# Note: The paper claims `fusion_kernel` is a trainable 1x1 conv
|
||||
# parameters. However, it is a constant in original repo and other
|
||||
# codebase because it would not be added into computation graph
|
||||
# after threshold operation.
|
||||
seg_label = self._stack_batch_gt(batch_data_samples).to(
|
||||
self.laplacian_kernel)
|
||||
boundary_targets = F.conv2d(
|
||||
seg_label, self.laplacian_kernel, padding=1)
|
||||
boundary_targets = boundary_targets.clamp(min=0)
|
||||
boundary_targets[boundary_targets > self.boundary_threshold] = 1
|
||||
boundary_targets[boundary_targets <= self.boundary_threshold] = 0
|
||||
|
||||
boundary_targets_x2 = F.conv2d(
|
||||
seg_label, self.laplacian_kernel, stride=2, padding=1)
|
||||
boundary_targets_x2 = boundary_targets_x2.clamp(min=0)
|
||||
|
||||
boundary_targets_x4 = F.conv2d(
|
||||
seg_label, self.laplacian_kernel, stride=4, padding=1)
|
||||
boundary_targets_x4 = boundary_targets_x4.clamp(min=0)
|
||||
|
||||
boundary_targets_x4_up = F.interpolate(
|
||||
boundary_targets_x4, boundary_targets.shape[2:], mode='nearest')
|
||||
boundary_targets_x2_up = F.interpolate(
|
||||
boundary_targets_x2, boundary_targets.shape[2:], mode='nearest')
|
||||
|
||||
boundary_targets_x2_up[
|
||||
boundary_targets_x2_up > self.boundary_threshold] = 1
|
||||
boundary_targets_x2_up[
|
||||
boundary_targets_x2_up <= self.boundary_threshold] = 0
|
||||
|
||||
boundary_targets_x4_up[
|
||||
boundary_targets_x4_up > self.boundary_threshold] = 1
|
||||
boundary_targets_x4_up[
|
||||
boundary_targets_x4_up <= self.boundary_threshold] = 0
|
||||
|
||||
boundary_targets_pyramids = torch.stack(
|
||||
(boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up),
|
||||
dim=1)
|
||||
|
||||
boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2)
|
||||
boudary_targets_pyramid = F.conv2d(boundary_targets_pyramids,
|
||||
self.fusion_kernel)
|
||||
|
||||
boudary_targets_pyramid[
|
||||
boudary_targets_pyramid > self.boundary_threshold] = 1
|
||||
boudary_targets_pyramid[
|
||||
boudary_targets_pyramid <= self.boundary_threshold] = 0
|
||||
|
||||
seg_labels = boudary_targets_pyramid.long()
|
||||
batch_sample_list = []
|
||||
for label in seg_labels:
|
||||
seg_data_sample = SegDataSample()
|
||||
seg_data_sample.gt_sem_seg = PixelData(data=label)
|
||||
batch_sample_list.append(seg_data_sample)
|
||||
|
||||
loss = super().loss_by_feat(seg_logits, batch_sample_list)
|
||||
return loss
|
||||
139
finetune/mmseg/models/decode_heads/uper_head.py
Normal file
139
finetune/mmseg/models/decode_heads/uper_head.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
from .psp_head import PPM
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class UPerHead(BaseDecodeHead):
|
||||
"""Unified Perceptual Parsing for Scene Understanding.
|
||||
|
||||
This head is the implementation of `UPerNet
|
||||
<https://arxiv.org/abs/1807.10221>`_.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module applied on the last feature. Default: (1, 2, 3, 6).
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
# PSP Module
|
||||
self.psp_modules = PPM(
|
||||
pool_scales,
|
||||
self.in_channels[-1],
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels[-1] + len(pool_scales) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
# FPN Module
|
||||
self.lateral_convs = nn.ModuleList()
|
||||
self.fpn_convs = nn.ModuleList()
|
||||
for in_channels in self.in_channels[:-1]: # skip the top layer
|
||||
l_conv = ConvModule(
|
||||
in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
inplace=False)
|
||||
fpn_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
inplace=False)
|
||||
self.lateral_convs.append(l_conv)
|
||||
self.fpn_convs.append(fpn_conv)
|
||||
|
||||
self.fpn_bottleneck = ConvModule(
|
||||
len(self.in_channels) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def psp_forward(self, inputs):
|
||||
"""Forward function of PSP module."""
|
||||
x = inputs[-1]
|
||||
psp_outs = [x]
|
||||
psp_outs.extend(self.psp_modules(x))
|
||||
psp_outs = torch.cat(psp_outs, dim=1)
|
||||
output = self.bottleneck(psp_outs)
|
||||
|
||||
return output
|
||||
|
||||
def _forward_feature(self, inputs):
|
||||
"""Forward function for feature maps before classifying each pixel with
|
||||
``self.cls_seg`` fc.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||
H, W) which is feature map for last layer of decoder head.
|
||||
"""
|
||||
inputs = self._transform_inputs(inputs)
|
||||
|
||||
# build laterals
|
||||
laterals = [
|
||||
lateral_conv(inputs[i])
|
||||
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||
]
|
||||
|
||||
laterals.append(self.psp_forward(inputs))
|
||||
|
||||
# build top-down path
|
||||
used_backbone_levels = len(laterals)
|
||||
for i in range(used_backbone_levels - 1, 0, -1):
|
||||
prev_shape = laterals[i - 1].shape[2:]
|
||||
laterals[i - 1] = laterals[i - 1] + resize(
|
||||
laterals[i],
|
||||
size=prev_shape,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
# build outputs
|
||||
fpn_outs = [
|
||||
self.fpn_convs[i](laterals[i])
|
||||
for i in range(used_backbone_levels - 1)
|
||||
]
|
||||
# append psp feature
|
||||
fpn_outs.append(laterals[-1])
|
||||
|
||||
for i in range(used_backbone_levels - 1, 0, -1):
|
||||
fpn_outs[i] = resize(
|
||||
fpn_outs[i],
|
||||
size=fpn_outs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
fpn_outs = torch.cat(fpn_outs, dim=1)
|
||||
feats = self.fpn_bottleneck(fpn_outs)
|
||||
return feats
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
output = self._forward_feature(inputs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
253
finetune/mmseg/models/decode_heads/vpd_depth_head.py
Normal file
253
finetune/mmseg/models/decode_heads/vpd_depth_head.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer
|
||||
from mmengine.model import BaseModule
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class VPDDepthDecoder(BaseModule):
|
||||
"""VPD Depth Decoder class.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
num_deconv_layers (int): Number of deconvolution layers.
|
||||
num_deconv_filters (List[int]): List of output channels for
|
||||
deconvolution layers.
|
||||
init_cfg (Optional[Union[Dict, List[Dict]]], optional): Configuration
|
||||
for weight initialization. Defaults to Normal for Conv2d and
|
||||
ConvTranspose2d layers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_deconv_layers: int,
|
||||
num_deconv_filters: List[int],
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
|
||||
type='Normal',
|
||||
std=0.001,
|
||||
layer=['Conv2d', 'ConvTranspose2d'])):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.deconv_layers = self._make_deconv_layer(
|
||||
num_deconv_layers,
|
||||
num_deconv_filters,
|
||||
)
|
||||
|
||||
conv_layers = []
|
||||
conv_layers.append(
|
||||
build_conv_layer(
|
||||
dict(type='Conv2d'),
|
||||
in_channels=num_deconv_filters[-1],
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1))
|
||||
conv_layers.append(build_norm_layer(dict(type='BN'), out_channels)[1])
|
||||
conv_layers.append(nn.ReLU(inplace=True))
|
||||
self.conv_layers = nn.Sequential(*conv_layers)
|
||||
|
||||
self.up_sample = nn.Upsample(
|
||||
scale_factor=2, mode='bilinear', align_corners=False)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass through the decoder network."""
|
||||
out = self.deconv_layers(x)
|
||||
out = self.conv_layers(out)
|
||||
|
||||
out = self.up_sample(out)
|
||||
out = self.up_sample(out)
|
||||
|
||||
return out
|
||||
|
||||
def _make_deconv_layer(self, num_layers, num_deconv_filters):
|
||||
"""Make deconv layers."""
|
||||
|
||||
layers = []
|
||||
in_channels = self.in_channels
|
||||
for i in range(num_layers):
|
||||
|
||||
num_channels = num_deconv_filters[i]
|
||||
layers.append(
|
||||
build_upsample_layer(
|
||||
dict(type='deconv'),
|
||||
in_channels=in_channels,
|
||||
out_channels=num_channels,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
output_padding=0,
|
||||
bias=False))
|
||||
layers.append(nn.BatchNorm2d(num_channels))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
in_channels = num_channels
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VPDDepthHead(BaseDecodeHead):
|
||||
"""Depth Prediction Head for VPD.
|
||||
|
||||
.. _`VPD`: https://arxiv.org/abs/2303.02153
|
||||
|
||||
Args:
|
||||
max_depth (float): Maximum depth value. Defaults to 10.0.
|
||||
in_channels (Sequence[int]): Number of input channels for each
|
||||
convolutional layer.
|
||||
embed_dim (int): Dimension of embedding. Defaults to 192.
|
||||
feature_dim (int): Dimension of aggregated feature. Defaults to 1536.
|
||||
num_deconv_layers (int): Number of deconvolution layers in the
|
||||
decoder. Defaults to 3.
|
||||
num_deconv_filters (Sequence[int]): Number of filters for each deconv
|
||||
layer. Defaults to (32, 32, 32).
|
||||
fmap_border (Union[int, Sequence[int]]): Feature map border for
|
||||
cropping. Defaults to 0.
|
||||
align_corners (bool): Flag for align_corners in interpolation.
|
||||
Defaults to False.
|
||||
loss_decode (dict): Configurations for the loss function. Defaults to
|
||||
dict(type='SiLogLoss').
|
||||
init_cfg (dict): Initialization configurations. Defaults to
|
||||
dict(type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']).
|
||||
"""
|
||||
|
||||
num_classes = 1
|
||||
out_channels = 1
|
||||
input_transform = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_depth: float = 10.0,
|
||||
in_channels: Sequence[int] = [320, 640, 1280, 1280],
|
||||
embed_dim: int = 192,
|
||||
feature_dim: int = 1536,
|
||||
num_deconv_layers: int = 3,
|
||||
num_deconv_filters: Sequence[int] = (32, 32, 32),
|
||||
fmap_border: Union[int, Sequence[int]] = 0,
|
||||
align_corners: bool = False,
|
||||
loss_decode: dict = dict(type='SiLogLoss'),
|
||||
init_cfg=dict(
|
||||
type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']),
|
||||
):
|
||||
|
||||
super(BaseDecodeHead, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
# initialize parameters
|
||||
self.in_channels = in_channels
|
||||
self.max_depth = max_depth
|
||||
self.align_corners = align_corners
|
||||
|
||||
# feature map border
|
||||
if isinstance(fmap_border, int):
|
||||
fmap_border = (fmap_border, fmap_border)
|
||||
self.fmap_border = fmap_border
|
||||
|
||||
# define network layers
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1),
|
||||
nn.GroupNorm(16, in_channels[0]),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1),
|
||||
)
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_channels[1], in_channels[1], 3, stride=2, padding=1)
|
||||
|
||||
self.conv_aggregation = nn.Sequential(
|
||||
nn.Conv2d(sum(in_channels), feature_dim, 1),
|
||||
nn.GroupNorm(16, feature_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.decoder = VPDDepthDecoder(
|
||||
in_channels=embed_dim * 8,
|
||||
out_channels=embed_dim,
|
||||
num_deconv_layers=num_deconv_layers,
|
||||
num_deconv_filters=num_deconv_filters)
|
||||
|
||||
self.depth_pred_layer = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
embed_dim, embed_dim, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(embed_dim, 1, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# build loss
|
||||
if isinstance(loss_decode, dict):
|
||||
self.loss_decode = MODELS.build(loss_decode)
|
||||
elif isinstance(loss_decode, (list, tuple)):
|
||||
self.loss_decode = nn.ModuleList()
|
||||
for loss in loss_decode:
|
||||
self.loss_decode.append(MODELS.build(loss))
|
||||
else:
|
||||
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
|
||||
but got {type(loss_decode)}')
|
||||
|
||||
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
|
||||
gt_depth_maps = [
|
||||
data_sample.gt_depth_map.data for data_sample in batch_data_samples
|
||||
]
|
||||
return torch.stack(gt_depth_maps, dim=0)
|
||||
|
||||
def forward(self, x):
|
||||
x = [
|
||||
x[0], x[1],
|
||||
torch.cat([x[2], F.interpolate(x[3], scale_factor=2)], dim=1)
|
||||
]
|
||||
x = torch.cat([self.conv1(x[0]), self.conv2(x[1]), x[2]], dim=1)
|
||||
x = self.conv_aggregation(x)
|
||||
|
||||
x = x[:, :, :x.size(2) - self.fmap_border[0], :x.size(3) -
|
||||
self.fmap_border[1]].contiguous()
|
||||
x = self.decoder(x)
|
||||
out = self.depth_pred_layer(x)
|
||||
|
||||
depth = torch.sigmoid(out) * self.max_depth
|
||||
|
||||
return depth
|
||||
|
||||
def loss_by_feat(self, pred_depth_map: Tensor,
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
"""Compute depth estimation loss.
|
||||
|
||||
Args:
|
||||
pred_depth_map (Tensor): The output from decode head forward
|
||||
function.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_dpeth_map`.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
|
||||
gt_depth_map = self._stack_batch_gt(batch_data_samples)
|
||||
loss = dict()
|
||||
pred_depth_map = resize(
|
||||
input=pred_depth_map,
|
||||
size=gt_depth_map.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.loss_decode]
|
||||
else:
|
||||
losses_decode = self.loss_decode
|
||||
for loss_decode in losses_decode:
|
||||
if loss_decode.loss_name not in loss:
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
pred_depth_map, gt_depth_map)
|
||||
else:
|
||||
loss[loss_decode.loss_name] += loss_decode(
|
||||
pred_depth_map, gt_depth_map)
|
||||
|
||||
return loss
|
||||
21
finetune/mmseg/models/losses/__init__.py
Normal file
21
finetune/mmseg/models/losses/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .accuracy import Accuracy, accuracy
|
||||
from .boundary_loss import BoundaryLoss
|
||||
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
|
||||
cross_entropy, mask_cross_entropy)
|
||||
from .dice_loss import DiceLoss
|
||||
from .focal_loss import FocalLoss
|
||||
from .huasdorff_distance_loss import HuasdorffDisstanceLoss
|
||||
from .lovasz_loss import LovaszLoss
|
||||
from .ohem_cross_entropy_loss import OhemCrossEntropy
|
||||
from .silog_loss import SiLogLoss
|
||||
from .tversky_loss import TverskyLoss
|
||||
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
|
||||
|
||||
__all__ = [
|
||||
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
|
||||
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
|
||||
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss',
|
||||
'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss',
|
||||
'HuasdorffDisstanceLoss', 'SiLogLoss'
|
||||
]
|
||||
92
finetune/mmseg/models/losses/accuracy.py
Normal file
92
finetune/mmseg/models/losses/accuracy.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def accuracy(pred, target, topk=1, thresh=None, ignore_index=None):
|
||||
"""Calculate accuracy according to the prediction and target.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
|
||||
target (torch.Tensor): The target of each prediction, shape (N, , ...)
|
||||
ignore_index (int | None): The label index to be ignored. Default: None
|
||||
topk (int | tuple[int], optional): If the predictions in ``topk``
|
||||
matches the target, the predictions will be regarded as
|
||||
correct ones. Defaults to 1.
|
||||
thresh (float, optional): If not None, predictions with scores under
|
||||
this threshold are considered incorrect. Default to None.
|
||||
|
||||
Returns:
|
||||
float | tuple[float]: If the input ``topk`` is a single integer,
|
||||
the function will return a single float as accuracy. If
|
||||
``topk`` is a tuple containing multiple integers, the
|
||||
function will return a tuple containing accuracies of
|
||||
each ``topk`` number.
|
||||
"""
|
||||
assert isinstance(topk, (int, tuple))
|
||||
if isinstance(topk, int):
|
||||
topk = (topk, )
|
||||
return_single = True
|
||||
else:
|
||||
return_single = False
|
||||
|
||||
maxk = max(topk)
|
||||
if pred.size(0) == 0:
|
||||
accu = [pred.new_tensor(0.) for i in range(len(topk))]
|
||||
return accu[0] if return_single else accu
|
||||
assert pred.ndim == target.ndim + 1
|
||||
assert pred.size(0) == target.size(0)
|
||||
assert maxk <= pred.size(1), \
|
||||
f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
|
||||
pred_value, pred_label = pred.topk(maxk, dim=1)
|
||||
# transpose to shape (maxk, N, ...)
|
||||
pred_label = pred_label.transpose(0, 1)
|
||||
correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
|
||||
if thresh is not None:
|
||||
# Only prediction values larger than thresh are counted as correct
|
||||
correct = correct & (pred_value > thresh).t()
|
||||
if ignore_index is not None:
|
||||
correct = correct[:, target != ignore_index]
|
||||
res = []
|
||||
eps = torch.finfo(torch.float32).eps
|
||||
for k in topk:
|
||||
# Avoid causing ZeroDivisionError when all pixels
|
||||
# of an image are ignored
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps
|
||||
if ignore_index is not None:
|
||||
total_num = target[target != ignore_index].numel() + eps
|
||||
else:
|
||||
total_num = target.numel() + eps
|
||||
res.append(correct_k.mul_(100.0 / total_num))
|
||||
return res[0] if return_single else res
|
||||
|
||||
|
||||
class Accuracy(nn.Module):
|
||||
"""Accuracy calculation module."""
|
||||
|
||||
def __init__(self, topk=(1, ), thresh=None, ignore_index=None):
|
||||
"""Module to calculate the accuracy.
|
||||
|
||||
Args:
|
||||
topk (tuple, optional): The criterion used to calculate the
|
||||
accuracy. Defaults to (1,).
|
||||
thresh (float, optional): If not None, predictions with scores
|
||||
under this threshold are considered incorrect. Default to None.
|
||||
"""
|
||||
super().__init__()
|
||||
self.topk = topk
|
||||
self.thresh = thresh
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def forward(self, pred, target):
|
||||
"""Forward function to calculate accuracy.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): Prediction of models.
|
||||
target (torch.Tensor): Target for each prediction.
|
||||
|
||||
Returns:
|
||||
tuple[float]: The accuracies under different topk criterions.
|
||||
"""
|
||||
return accuracy(pred, target, self.topk, self.thresh,
|
||||
self.ignore_index)
|
||||
62
finetune/mmseg/models/losses/boundary_loss.py
Normal file
62
finetune/mmseg/models/losses/boundary_loss.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BoundaryLoss(nn.Module):
|
||||
"""Boundary loss.
|
||||
|
||||
This function is modified from
|
||||
`PIDNet <https://github.com/XuJiacong/PIDNet/blob/main/utils/criterion.py#L122>`_. # noqa
|
||||
Licensed under the MIT License.
|
||||
|
||||
|
||||
Args:
|
||||
loss_weight (float): Weight of the loss. Defaults to 1.0.
|
||||
loss_name (str): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_boundary'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss_weight: float = 1.0,
|
||||
loss_name: str = 'loss_boundary'):
|
||||
super().__init__()
|
||||
self.loss_weight = loss_weight
|
||||
self.loss_name_ = loss_name
|
||||
|
||||
def forward(self, bd_pre: Tensor, bd_gt: Tensor) -> Tensor:
|
||||
"""Forward function.
|
||||
Args:
|
||||
bd_pre (Tensor): Predictions of the boundary head.
|
||||
bd_gt (Tensor): Ground truth of the boundary.
|
||||
|
||||
Returns:
|
||||
Tensor: Loss tensor.
|
||||
"""
|
||||
log_p = bd_pre.permute(0, 2, 3, 1).contiguous().view(1, -1)
|
||||
target_t = bd_gt.view(1, -1).float()
|
||||
|
||||
pos_index = (target_t == 1)
|
||||
neg_index = (target_t == 0)
|
||||
|
||||
weight = torch.zeros_like(log_p)
|
||||
pos_num = pos_index.sum()
|
||||
neg_num = neg_index.sum()
|
||||
sum_num = pos_num + neg_num
|
||||
weight[pos_index] = neg_num * 1.0 / sum_num
|
||||
weight[neg_index] = pos_num * 1.0 / sum_num
|
||||
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
log_p, target_t, weight, reduction='mean')
|
||||
|
||||
return self.loss_weight * loss
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
return self.loss_name_
|
||||
311
finetune/mmseg/models/losses/cross_entropy_loss.py
Normal file
311
finetune/mmseg/models/losses/cross_entropy_loss.py
Normal file
@@ -0,0 +1,311 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import get_class_weight, weight_reduce_loss
|
||||
|
||||
|
||||
def cross_entropy(pred,
|
||||
label,
|
||||
weight=None,
|
||||
class_weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
ignore_index=-100,
|
||||
avg_non_ignore=False):
|
||||
"""cross_entropy. The wrapper function for :func:`F.cross_entropy`
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, 1).
|
||||
label (torch.Tensor): The learning label of the prediction.
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||
Default: None.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
Default: None.
|
||||
reduction (str, optional): The method used to reduce the loss.
|
||||
Options are 'none', 'mean' and 'sum'. Default: 'mean'.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Default: None.
|
||||
ignore_index (int): Specifies a target value that is ignored and
|
||||
does not contribute to the input gradients. When
|
||||
``avg_non_ignore `` is ``True``, and the ``reduction`` is
|
||||
``''mean''``, the loss is averaged over non-ignored targets.
|
||||
Defaults: -100.
|
||||
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||
only averaged over non-ignored targets. Default: False.
|
||||
`New in version 0.23.0.`
|
||||
"""
|
||||
|
||||
# class_weight is a manual rescaling weight given to each class.
|
||||
# If given, has to be a Tensor of size C element-wise losses
|
||||
loss = F.cross_entropy(
|
||||
pred,
|
||||
label,
|
||||
weight=class_weight,
|
||||
reduction='none',
|
||||
ignore_index=ignore_index)
|
||||
|
||||
# apply weights and do the reduction
|
||||
# average loss over non-ignored elements
|
||||
# pytorch's official cross_entropy average loss over non-ignored elements
|
||||
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
|
||||
if (avg_factor is None) and reduction == 'mean':
|
||||
if class_weight is None:
|
||||
if avg_non_ignore:
|
||||
avg_factor = label.numel() - (label
|
||||
== ignore_index).sum().item()
|
||||
else:
|
||||
avg_factor = label.numel()
|
||||
|
||||
else:
|
||||
# the average factor should take the class weights into account
|
||||
label_weights = torch.stack([class_weight[cls] for cls in label
|
||||
]).to(device=class_weight.device)
|
||||
|
||||
if avg_non_ignore:
|
||||
label_weights[label == ignore_index] = 0
|
||||
avg_factor = label_weights.sum()
|
||||
|
||||
if weight is not None:
|
||||
weight = weight.float()
|
||||
loss = weight_reduce_loss(
|
||||
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
|
||||
"""Expand onehot labels to match the size of prediction."""
|
||||
bin_labels = labels.new_zeros(target_shape)
|
||||
valid_mask = (labels >= 0) & (labels != ignore_index)
|
||||
inds = torch.nonzero(valid_mask, as_tuple=True)
|
||||
|
||||
if inds[0].numel() > 0:
|
||||
if labels.dim() == 3:
|
||||
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
|
||||
else:
|
||||
bin_labels[inds[0], labels[valid_mask]] = 1
|
||||
|
||||
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
|
||||
|
||||
if label_weights is None:
|
||||
bin_label_weights = valid_mask
|
||||
else:
|
||||
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
|
||||
bin_label_weights = bin_label_weights * valid_mask
|
||||
|
||||
return bin_labels, bin_label_weights, valid_mask
|
||||
|
||||
|
||||
def binary_cross_entropy(pred,
|
||||
label,
|
||||
weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
class_weight=None,
|
||||
ignore_index=-100,
|
||||
avg_non_ignore=False,
|
||||
**kwargs):
|
||||
"""Calculate the binary CrossEntropy loss.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, 1).
|
||||
label (torch.Tensor): The learning label of the prediction.
|
||||
Note: In bce loss, label < 0 is invalid.
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||
reduction (str, optional): The method used to reduce the loss.
|
||||
Options are "none", "mean" and "sum".
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
ignore_index (int): The label index to be ignored. Default: -100.
|
||||
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||
only averaged over non-ignored targets. Default: False.
|
||||
`New in version 0.23.0.`
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
if pred.size(1) == 1:
|
||||
# For binary class segmentation, the shape of pred is
|
||||
# [N, 1, H, W] and that of label is [N, H, W].
|
||||
# As the ignore_index often set as 255, so the
|
||||
# binary class label check should mask out
|
||||
# ignore_index
|
||||
assert label[label != ignore_index].max() <= 1, \
|
||||
'For pred with shape [N, 1, H, W], its label must have at ' \
|
||||
'most 2 classes'
|
||||
pred = pred.squeeze(1)
|
||||
if pred.dim() != label.dim():
|
||||
assert (pred.dim() == 2 and label.dim() == 1) or (
|
||||
pred.dim() == 4 and label.dim() == 3), \
|
||||
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
|
||||
'H, W], label shape [N, H, W] are supported'
|
||||
# `weight` returned from `_expand_onehot_labels`
|
||||
# has been treated for valid (non-ignore) pixels
|
||||
label, weight, valid_mask = _expand_onehot_labels(
|
||||
label, weight, pred.shape, ignore_index)
|
||||
else:
|
||||
# should mask out the ignored elements
|
||||
valid_mask = ((label >= 0) & (label != ignore_index)).float()
|
||||
if weight is not None:
|
||||
weight = weight * valid_mask
|
||||
else:
|
||||
weight = valid_mask
|
||||
# average loss over non-ignored and valid elements
|
||||
if reduction == 'mean' and avg_factor is None and avg_non_ignore:
|
||||
avg_factor = valid_mask.sum().item()
|
||||
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
pred, label.float(), pos_weight=class_weight, reduction='none')
|
||||
# do the reduction for the weighted loss
|
||||
loss = weight_reduce_loss(
|
||||
loss, weight, reduction=reduction, avg_factor=avg_factor)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def mask_cross_entropy(pred,
|
||||
target,
|
||||
label,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
class_weight=None,
|
||||
ignore_index=None,
|
||||
**kwargs):
|
||||
"""Calculate the CrossEntropy loss for masks.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
||||
of classes.
|
||||
target (torch.Tensor): The learning label of the prediction.
|
||||
label (torch.Tensor): ``label`` indicates the class label of the mask'
|
||||
corresponding object. This will be used to select the mask in the
|
||||
of the class which the object belongs to when the mask prediction
|
||||
if not class-agnostic.
|
||||
reduction (str, optional): The method used to reduce the loss.
|
||||
Options are "none", "mean" and "sum".
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
ignore_index (None): Placeholder, to be consistent with other loss.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
assert ignore_index is None, 'BCE loss does not support ignore_index'
|
||||
# TODO: handle these two reserved arguments
|
||||
assert reduction == 'mean' and avg_factor is None
|
||||
num_rois = pred.size()[0]
|
||||
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
|
||||
pred_slice = pred[inds, label].squeeze(1)
|
||||
return F.binary_cross_entropy_with_logits(
|
||||
pred_slice, target, weight=class_weight, reduction='mean')[None]
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CrossEntropyLoss(nn.Module):
|
||||
"""CrossEntropyLoss.
|
||||
|
||||
Args:
|
||||
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
|
||||
of softmax. Defaults to False.
|
||||
use_mask (bool, optional): Whether to use mask cross entropy loss.
|
||||
Defaults to False.
|
||||
reduction (str, optional): . Defaults to 'mean'.
|
||||
Options are "none", "mean" and "sum".
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
||||
loss_name (str, optional): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_ce'.
|
||||
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||
only averaged over non-ignored targets. Default: False.
|
||||
`New in version 0.23.0.`
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
use_sigmoid=False,
|
||||
use_mask=False,
|
||||
reduction='mean',
|
||||
class_weight=None,
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_ce',
|
||||
avg_non_ignore=False):
|
||||
super().__init__()
|
||||
assert (use_sigmoid is False) or (use_mask is False)
|
||||
self.use_sigmoid = use_sigmoid
|
||||
self.use_mask = use_mask
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.class_weight = get_class_weight(class_weight)
|
||||
self.avg_non_ignore = avg_non_ignore
|
||||
if not self.avg_non_ignore and self.reduction == 'mean':
|
||||
warnings.warn(
|
||||
'Default ``avg_non_ignore`` is False, if you would like to '
|
||||
'ignore the certain label and average loss over non-ignore '
|
||||
'labels, which is the same with PyTorch official '
|
||||
'cross_entropy, set ``avg_non_ignore=True``.')
|
||||
|
||||
if self.use_sigmoid:
|
||||
self.cls_criterion = binary_cross_entropy
|
||||
elif self.use_mask:
|
||||
self.cls_criterion = mask_cross_entropy
|
||||
else:
|
||||
self.cls_criterion = cross_entropy
|
||||
self._loss_name = loss_name
|
||||
|
||||
def extra_repr(self):
|
||||
"""Extra repr."""
|
||||
s = f'avg_non_ignore={self.avg_non_ignore}'
|
||||
return s
|
||||
|
||||
def forward(self,
|
||||
cls_score,
|
||||
label,
|
||||
weight=None,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
ignore_index=-100,
|
||||
**kwargs):
|
||||
"""Forward function."""
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
if self.class_weight is not None:
|
||||
class_weight = cls_score.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
# Note: for BCE loss, label < 0 is invalid.
|
||||
loss_cls = self.loss_weight * self.cls_criterion(
|
||||
cls_score,
|
||||
label,
|
||||
weight,
|
||||
class_weight=class_weight,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor,
|
||||
avg_non_ignore=self.avg_non_ignore,
|
||||
ignore_index=ignore_index,
|
||||
**kwargs)
|
||||
return loss_cls
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
||||
202
finetune/mmseg/models/losses/dice_loss.py
Normal file
202
finetune/mmseg/models/losses/dice_loss.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import weight_reduce_loss
|
||||
|
||||
|
||||
def _expand_onehot_labels_dice(pred: torch.Tensor,
|
||||
target: torch.Tensor) -> torch.Tensor:
|
||||
"""Expand onehot labels to match the size of prediction.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction, has a shape (N, num_class, H, W).
|
||||
target (torch.Tensor): The learning label of the prediction,
|
||||
has a shape (N, H, W).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The target after one-hot encoding,
|
||||
has a shape (N, num_class, H, W).
|
||||
"""
|
||||
num_classes = pred.shape[1]
|
||||
one_hot_target = torch.clamp(target, min=0, max=num_classes)
|
||||
one_hot_target = torch.nn.functional.one_hot(one_hot_target,
|
||||
num_classes + 1)
|
||||
one_hot_target = one_hot_target[..., :num_classes].permute(0, 3, 1, 2)
|
||||
return one_hot_target
|
||||
|
||||
|
||||
def dice_loss(pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
weight: Union[torch.Tensor, None],
|
||||
eps: float = 1e-3,
|
||||
reduction: Union[str, None] = 'mean',
|
||||
naive_dice: Union[bool, None] = False,
|
||||
avg_factor: Union[int, None] = None,
|
||||
ignore_index: Union[int, None] = 255) -> float:
|
||||
"""Calculate dice loss, there are two forms of dice loss is supported:
|
||||
|
||||
- the one proposed in `V-Net: Fully Convolutional Neural
|
||||
Networks for Volumetric Medical Image Segmentation
|
||||
<https://arxiv.org/abs/1606.04797>`_.
|
||||
- the dice loss in which the power of the number in the
|
||||
denominator is the first power instead of the second
|
||||
power.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction, has a shape (n, *)
|
||||
target (torch.Tensor): The learning label of the prediction,
|
||||
shape (n, *), same shape of pred.
|
||||
weight (torch.Tensor, optional): The weight of loss for each
|
||||
prediction, has a shape (n,). Defaults to None.
|
||||
eps (float): Avoid dividing by zero. Default: 1e-3.
|
||||
reduction (str, optional): The method used to reduce the loss into
|
||||
a scalar. Defaults to 'mean'.
|
||||
Options are "none", "mean" and "sum".
|
||||
naive_dice (bool, optional): If false, use the dice
|
||||
loss defined in the V-Net paper, otherwise, use the
|
||||
naive dice loss in which the power of the number in the
|
||||
denominator is the first power instead of the second
|
||||
power.Defaults to False.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
ignore_index (int, optional): The label index to be ignored.
|
||||
Defaults to 255.
|
||||
"""
|
||||
if ignore_index is not None:
|
||||
num_classes = pred.shape[1]
|
||||
pred = pred[:, torch.arange(num_classes) != ignore_index, :, :]
|
||||
target = target[:, torch.arange(num_classes) != ignore_index, :, :]
|
||||
assert pred.shape[1] != 0 # if the ignored index is the only class
|
||||
input = pred.flatten(1)
|
||||
target = target.flatten(1).float()
|
||||
a = torch.sum(input * target, 1)
|
||||
if naive_dice:
|
||||
b = torch.sum(input, 1)
|
||||
c = torch.sum(target, 1)
|
||||
d = (2 * a + eps) / (b + c + eps)
|
||||
else:
|
||||
b = torch.sum(input * input, 1) + eps
|
||||
c = torch.sum(target * target, 1) + eps
|
||||
d = (2 * a) / (b + c)
|
||||
|
||||
loss = 1 - d
|
||||
if weight is not None:
|
||||
assert weight.ndim == loss.ndim
|
||||
assert len(weight) == len(pred)
|
||||
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
|
||||
return loss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DiceLoss(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
use_sigmoid=True,
|
||||
activate=True,
|
||||
reduction='mean',
|
||||
naive_dice=False,
|
||||
loss_weight=1.0,
|
||||
ignore_index=255,
|
||||
eps=1e-3,
|
||||
loss_name='loss_dice'):
|
||||
"""Compute dice loss.
|
||||
|
||||
Args:
|
||||
use_sigmoid (bool, optional): Whether to the prediction is
|
||||
used for sigmoid or softmax. Defaults to True.
|
||||
activate (bool): Whether to activate the predictions inside,
|
||||
this will disable the inside sigmoid operation.
|
||||
Defaults to True.
|
||||
reduction (str, optional): The method used
|
||||
to reduce the loss. Options are "none",
|
||||
"mean" and "sum". Defaults to 'mean'.
|
||||
naive_dice (bool, optional): If false, use the dice
|
||||
loss defined in the V-Net paper, otherwise, use the
|
||||
naive dice loss in which the power of the number in the
|
||||
denominator is the first power instead of the second
|
||||
power. Defaults to False.
|
||||
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
|
||||
ignore_index (int, optional): The label index to be ignored.
|
||||
Default: 255.
|
||||
eps (float): Avoid dividing by zero. Defaults to 1e-3.
|
||||
loss_name (str, optional): Name of the loss item. If you want this
|
||||
loss item to be included into the backward graph, `loss_` must
|
||||
be the prefix of the name. Defaults to 'loss_dice'.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.use_sigmoid = use_sigmoid
|
||||
self.reduction = reduction
|
||||
self.naive_dice = naive_dice
|
||||
self.loss_weight = loss_weight
|
||||
self.eps = eps
|
||||
self.activate = activate
|
||||
self.ignore_index = ignore_index
|
||||
self._loss_name = loss_name
|
||||
|
||||
def forward(self,
|
||||
pred,
|
||||
target,
|
||||
weight=None,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
ignore_index=255,
|
||||
**kwargs):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction, has a shape (n, *).
|
||||
target (torch.Tensor): The label of the prediction,
|
||||
shape (n, *), same shape of pred.
|
||||
weight (torch.Tensor, optional): The weight of loss for each
|
||||
prediction, has a shape (n,). Defaults to None.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
reduction_override (str, optional): The reduction method used to
|
||||
override the original reduction method of the loss.
|
||||
Options are "none", "mean" and "sum".
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
one_hot_target = target
|
||||
if (pred.shape != target.shape):
|
||||
one_hot_target = _expand_onehot_labels_dice(pred, target)
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
if self.activate:
|
||||
if self.use_sigmoid:
|
||||
pred = pred.sigmoid()
|
||||
elif pred.shape[1] != 1:
|
||||
# softmax does not work when there is only 1 class
|
||||
pred = pred.softmax(dim=1)
|
||||
loss = self.loss_weight * dice_loss(
|
||||
pred,
|
||||
one_hot_target,
|
||||
weight,
|
||||
eps=self.eps,
|
||||
reduction=reduction,
|
||||
naive_dice=self.naive_dice,
|
||||
avg_factor=avg_factor,
|
||||
ignore_index=self.ignore_index)
|
||||
|
||||
return loss
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
||||
337
finetune/mmseg/models/losses/focal_loss.py
Normal file
337
finetune/mmseg/models/losses/focal_loss.py
Normal file
@@ -0,0 +1,337 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from https://github.com/open-mmlab/mmdetection
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import weight_reduce_loss
|
||||
|
||||
|
||||
# This method is used when cuda is not available
|
||||
def py_sigmoid_focal_loss(pred,
|
||||
target,
|
||||
one_hot_target=None,
|
||||
weight=None,
|
||||
gamma=2.0,
|
||||
alpha=0.5,
|
||||
class_weight=None,
|
||||
valid_mask=None,
|
||||
reduction='mean',
|
||||
avg_factor=None):
|
||||
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, C), C is the
|
||||
number of classes
|
||||
target (torch.Tensor): The learning label of the prediction with
|
||||
shape (N, C)
|
||||
one_hot_target (None): Placeholder. It should be None.
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||
gamma (float, optional): The gamma for calculating the modulating
|
||||
factor. Defaults to 2.0.
|
||||
alpha (float | list[float], optional): A balanced form for Focal Loss.
|
||||
Defaults to 0.5.
|
||||
class_weight (list[float], optional): Weight of each class.
|
||||
Defaults to None.
|
||||
valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid
|
||||
samples and uses 0 to mark the ignored samples. Default: None.
|
||||
reduction (str, optional): The method used to reduce the loss into
|
||||
a scalar. Defaults to 'mean'.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
"""
|
||||
if isinstance(alpha, list):
|
||||
alpha = pred.new_tensor(alpha)
|
||||
pred_sigmoid = pred.sigmoid()
|
||||
target = target.type_as(pred)
|
||||
one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
|
||||
focal_weight = (alpha * target + (1 - alpha) *
|
||||
(1 - target)) * one_minus_pt.pow(gamma)
|
||||
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
pred, target, reduction='none') * focal_weight
|
||||
final_weight = torch.ones(1, pred.size(1)).type_as(loss)
|
||||
if weight is not None:
|
||||
if weight.shape != loss.shape and weight.size(0) == loss.size(0):
|
||||
# For most cases, weight is of shape (N, ),
|
||||
# which means it does not have the second axis num_class
|
||||
weight = weight.view(-1, 1)
|
||||
assert weight.dim() == loss.dim()
|
||||
final_weight = final_weight * weight
|
||||
if class_weight is not None:
|
||||
final_weight = final_weight * pred.new_tensor(class_weight)
|
||||
if valid_mask is not None:
|
||||
final_weight = final_weight * valid_mask
|
||||
loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor)
|
||||
return loss
|
||||
|
||||
|
||||
def sigmoid_focal_loss(pred,
|
||||
target,
|
||||
one_hot_target,
|
||||
weight=None,
|
||||
gamma=2.0,
|
||||
alpha=0.5,
|
||||
class_weight=None,
|
||||
valid_mask=None,
|
||||
reduction='mean',
|
||||
avg_factor=None):
|
||||
r"""A wrapper of cuda version `Focal Loss
|
||||
<https://arxiv.org/abs/1708.02002>`_.
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
||||
of classes.
|
||||
target (torch.Tensor): The learning label of the prediction. It's shape
|
||||
should be (N, )
|
||||
one_hot_target (torch.Tensor): The learning label with shape (N, C)
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||
gamma (float, optional): The gamma for calculating the modulating
|
||||
factor. Defaults to 2.0.
|
||||
alpha (float | list[float], optional): A balanced form for Focal Loss.
|
||||
Defaults to 0.5.
|
||||
class_weight (list[float], optional): Weight of each class.
|
||||
Defaults to None.
|
||||
valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid
|
||||
samples and uses 0 to mark the ignored samples. Default: None.
|
||||
reduction (str, optional): The method used to reduce the loss into
|
||||
a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
"""
|
||||
# Function.apply does not accept keyword arguments, so the decorator
|
||||
# "weighted_loss" is not applicable
|
||||
final_weight = torch.ones(1, pred.size(1)).type_as(pred)
|
||||
if isinstance(alpha, list):
|
||||
# _sigmoid_focal_loss doesn't accept alpha of list type. Therefore, if
|
||||
# a list is given, we set the input alpha as 0.5. This means setting
|
||||
# equal weight for foreground class and background class. By
|
||||
# multiplying the loss by 2, the effect of setting alpha as 0.5 is
|
||||
# undone. The alpha of type list is used to regulate the loss in the
|
||||
# post-processing process.
|
||||
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(),
|
||||
gamma, 0.5, None, 'none') * 2
|
||||
alpha = pred.new_tensor(alpha)
|
||||
final_weight = final_weight * (
|
||||
alpha * one_hot_target + (1 - alpha) * (1 - one_hot_target))
|
||||
else:
|
||||
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(),
|
||||
gamma, alpha, None, 'none')
|
||||
if weight is not None:
|
||||
if weight.shape != loss.shape and weight.size(0) == loss.size(0):
|
||||
# For most cases, weight is of shape (N, ),
|
||||
# which means it does not have the second axis num_class
|
||||
weight = weight.view(-1, 1)
|
||||
assert weight.dim() == loss.dim()
|
||||
final_weight = final_weight * weight
|
||||
if class_weight is not None:
|
||||
final_weight = final_weight * pred.new_tensor(class_weight)
|
||||
if valid_mask is not None:
|
||||
final_weight = final_weight * valid_mask
|
||||
loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor)
|
||||
return loss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FocalLoss(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
use_sigmoid=True,
|
||||
gamma=2.0,
|
||||
alpha=0.5,
|
||||
reduction='mean',
|
||||
class_weight=None,
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_focal'):
|
||||
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_
|
||||
Args:
|
||||
use_sigmoid (bool, optional): Whether to the prediction is
|
||||
used for sigmoid or softmax. Defaults to True.
|
||||
gamma (float, optional): The gamma for calculating the modulating
|
||||
factor. Defaults to 2.0.
|
||||
alpha (float | list[float], optional): A balanced form for Focal
|
||||
Loss. Defaults to 0.5. When a list is provided, the length
|
||||
of the list should be equal to the number of classes.
|
||||
Please be careful that this parameter is not the
|
||||
class-wise weight but the weight of a binary classification
|
||||
problem. This binary classification problem regards the
|
||||
pixels which belong to one class as the foreground
|
||||
and the other pixels as the background, each element in
|
||||
the list is the weight of the corresponding foreground class.
|
||||
The value of alpha or each element of alpha should be a float
|
||||
in the interval [0, 1]. If you want to specify the class-wise
|
||||
weight, please use `class_weight` parameter.
|
||||
reduction (str, optional): The method used to reduce the loss into
|
||||
a scalar. Defaults to 'mean'. Options are "none", "mean" and
|
||||
"sum".
|
||||
class_weight (list[float], optional): Weight of each class.
|
||||
Defaults to None.
|
||||
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
|
||||
loss_name (str, optional): Name of the loss item. If you want this
|
||||
loss item to be included into the backward graph, `loss_` must
|
||||
be the prefix of the name. Defaults to 'loss_focal'.
|
||||
"""
|
||||
super().__init__()
|
||||
assert use_sigmoid is True, \
|
||||
'AssertionError: Only sigmoid focal loss supported now.'
|
||||
assert reduction in ('none', 'mean', 'sum'), \
|
||||
"AssertionError: reduction should be 'none', 'mean' or " \
|
||||
"'sum'"
|
||||
assert isinstance(alpha, (float, list)), \
|
||||
'AssertionError: alpha should be of type float'
|
||||
assert isinstance(gamma, float), \
|
||||
'AssertionError: gamma should be of type float'
|
||||
assert isinstance(loss_weight, float), \
|
||||
'AssertionError: loss_weight should be of type float'
|
||||
assert isinstance(loss_name, str), \
|
||||
'AssertionError: loss_name should be of type str'
|
||||
assert isinstance(class_weight, list) or class_weight is None, \
|
||||
'AssertionError: class_weight must be None or of type list'
|
||||
self.use_sigmoid = use_sigmoid
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
self.reduction = reduction
|
||||
self.class_weight = class_weight
|
||||
self.loss_weight = loss_weight
|
||||
self._loss_name = loss_name
|
||||
|
||||
def forward(self,
|
||||
pred,
|
||||
target,
|
||||
weight=None,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
ignore_index=255,
|
||||
**kwargs):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape
|
||||
(N, C) where C = number of classes, or
|
||||
(N, C, d_1, d_2, ..., d_K) with K≥1 in the
|
||||
case of K-dimensional loss.
|
||||
target (torch.Tensor): The ground truth. If containing class
|
||||
indices, shape (N) where each value is 0≤targets[i]≤C−1,
|
||||
or (N, d_1, d_2, ..., d_K) with K≥1 in the case of
|
||||
K-dimensional loss. If containing class probabilities,
|
||||
same shape as the input.
|
||||
weight (torch.Tensor, optional): The weight of loss for each
|
||||
prediction. Defaults to None.
|
||||
avg_factor (int, optional): Average factor that is used to
|
||||
average the loss. Defaults to None.
|
||||
reduction_override (str, optional): The reduction method used
|
||||
to override the original reduction method of the loss.
|
||||
Options are "none", "mean" and "sum".
|
||||
ignore_index (int, optional): The label index to be ignored.
|
||||
Default: 255
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
assert isinstance(ignore_index, int), \
|
||||
'ignore_index must be of type int'
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum'), \
|
||||
"AssertionError: reduction should be 'none', 'mean' or " \
|
||||
"'sum'"
|
||||
assert pred.shape == target.shape or \
|
||||
(pred.size(0) == target.size(0) and
|
||||
pred.shape[2:] == target.shape[1:]), \
|
||||
"The shape of pred doesn't match the shape of target"
|
||||
|
||||
original_shape = pred.shape
|
||||
|
||||
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
|
||||
pred = pred.transpose(0, 1)
|
||||
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
|
||||
pred = pred.reshape(pred.size(0), -1)
|
||||
# [C, N] -> [N, C]
|
||||
pred = pred.transpose(0, 1).contiguous()
|
||||
|
||||
if original_shape == target.shape:
|
||||
# target with shape [B, C, d_1, d_2, ...]
|
||||
# transform it's shape into [N, C]
|
||||
# [B, C, d_1, d_2, ...] -> [C, B, d_1, d_2, ..., d_k]
|
||||
target = target.transpose(0, 1)
|
||||
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
|
||||
target = target.reshape(target.size(0), -1)
|
||||
# [C, N] -> [N, C]
|
||||
target = target.transpose(0, 1).contiguous()
|
||||
else:
|
||||
# target with shape [B, d_1, d_2, ...]
|
||||
# transform it's shape into [N, ]
|
||||
target = target.view(-1).contiguous()
|
||||
valid_mask = (target != ignore_index).view(-1, 1)
|
||||
# avoid raising error when using F.one_hot()
|
||||
target = torch.where(target == ignore_index, target.new_tensor(0),
|
||||
target)
|
||||
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
if self.use_sigmoid:
|
||||
num_classes = pred.size(1)
|
||||
if torch.cuda.is_available() and pred.is_cuda:
|
||||
if target.dim() == 1:
|
||||
one_hot_target = F.one_hot(
|
||||
target, num_classes=num_classes + 1)
|
||||
if num_classes == 1:
|
||||
one_hot_target = one_hot_target[:, 1]
|
||||
target = 1 - target
|
||||
else:
|
||||
one_hot_target = one_hot_target[:, :num_classes]
|
||||
else:
|
||||
one_hot_target = target
|
||||
target = target.argmax(dim=1)
|
||||
valid_mask = (target != ignore_index).view(-1, 1)
|
||||
calculate_loss_func = sigmoid_focal_loss
|
||||
else:
|
||||
one_hot_target = None
|
||||
if target.dim() == 1:
|
||||
target = F.one_hot(target, num_classes=num_classes + 1)
|
||||
if num_classes == 1:
|
||||
target = target[:, 1]
|
||||
else:
|
||||
target = target[:, num_classes]
|
||||
else:
|
||||
valid_mask = (target.argmax(dim=1) != ignore_index).view(
|
||||
-1, 1)
|
||||
calculate_loss_func = py_sigmoid_focal_loss
|
||||
|
||||
loss_cls = self.loss_weight * calculate_loss_func(
|
||||
pred,
|
||||
target,
|
||||
one_hot_target,
|
||||
weight,
|
||||
gamma=self.gamma,
|
||||
alpha=self.alpha,
|
||||
class_weight=self.class_weight,
|
||||
valid_mask=valid_mask,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor)
|
||||
|
||||
if reduction == 'none':
|
||||
# [N, C] -> [C, N]
|
||||
loss_cls = loss_cls.transpose(0, 1)
|
||||
# [C, N] -> [C, B, d1, d2, ...]
|
||||
# original_shape: [B, C, d1, d2, ...]
|
||||
loss_cls = loss_cls.reshape(original_shape[1],
|
||||
original_shape[0],
|
||||
*original_shape[2:])
|
||||
# [C, B, d1, d2, ...] -> [B, C, d1, d2, ...]
|
||||
loss_cls = loss_cls.transpose(0, 1).contiguous()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return loss_cls
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
||||
160
finetune/mmseg/models/losses/huasdorff_distance_loss.py
Normal file
160
finetune/mmseg/models/losses/huasdorff_distance_loss.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Modified from https://github.com/JunMa11/SegWithDistMap/blob/
|
||||
master/code/train_LA_HD.py (Apache-2.0 License)"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scipy.ndimage import distance_transform_edt as distance
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import get_class_weight, weighted_loss
|
||||
|
||||
|
||||
def compute_dtm(img_gt: Tensor, pred: Tensor) -> Tensor:
|
||||
"""
|
||||
compute the distance transform map of foreground in mask
|
||||
Args:
|
||||
img_gt: Ground truth of the image, (b, h, w)
|
||||
pred: Predictions of the segmentation head after softmax, (b, c, h, w)
|
||||
|
||||
Returns:
|
||||
output: the foreground Distance Map (SDM)
|
||||
dtm(x) = 0; x in segmentation boundary
|
||||
inf|x-y|; x in segmentation
|
||||
"""
|
||||
|
||||
fg_dtm = torch.zeros_like(pred)
|
||||
out_shape = pred.shape
|
||||
for b in range(out_shape[0]): # batch size
|
||||
for c in range(1, out_shape[1]): # default 0 channel is background
|
||||
posmask = img_gt[b].byte()
|
||||
if posmask.any():
|
||||
posdis = distance(posmask)
|
||||
fg_dtm[b][c] = torch.from_numpy(posdis)
|
||||
|
||||
return fg_dtm
|
||||
|
||||
|
||||
@weighted_loss
|
||||
def hd_loss(seg_soft: Tensor,
|
||||
gt: Tensor,
|
||||
seg_dtm: Tensor,
|
||||
gt_dtm: Tensor,
|
||||
class_weight=None,
|
||||
ignore_index=255) -> Tensor:
|
||||
"""
|
||||
compute huasdorff distance loss for segmentation
|
||||
Args:
|
||||
seg_soft: softmax results, shape=(b,c,x,y)
|
||||
gt: ground truth, shape=(b,x,y)
|
||||
seg_dtm: segmentation distance transform map, shape=(b,c,x,y)
|
||||
gt_dtm: ground truth distance transform map, shape=(b,c,x,y)
|
||||
|
||||
Returns:
|
||||
output: hd_loss
|
||||
"""
|
||||
assert seg_soft.shape[0] == gt.shape[0]
|
||||
total_loss = 0
|
||||
num_class = seg_soft.shape[1]
|
||||
if class_weight is not None:
|
||||
assert class_weight.ndim == num_class
|
||||
for i in range(1, num_class):
|
||||
if i != ignore_index:
|
||||
delta_s = (seg_soft[:, i, ...] - gt.float())**2
|
||||
s_dtm = seg_dtm[:, i, ...]**2
|
||||
g_dtm = gt_dtm[:, i, ...]**2
|
||||
dtm = s_dtm + g_dtm
|
||||
multiplied = torch.einsum('bxy, bxy->bxy', delta_s, dtm)
|
||||
hd_loss = multiplied.mean()
|
||||
if class_weight is not None:
|
||||
hd_loss *= class_weight[i]
|
||||
total_loss += hd_loss
|
||||
|
||||
return total_loss / num_class
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class HuasdorffDisstanceLoss(nn.Module):
|
||||
"""HuasdorffDisstanceLoss. This loss is proposed in `How Distance Transform
|
||||
Maps Boost Segmentation CNNs: An Empirical Study.
|
||||
|
||||
<http://proceedings.mlr.press/v121/ma20b.html>`_.
|
||||
Args:
|
||||
reduction (str, optional): The method used to reduce the loss into
|
||||
a scalar. Defaults to 'mean'.
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_weight (float): Weight of the loss. Defaults to 1.0.
|
||||
ignore_index (int | None): The label index to be ignored. Default: 255.
|
||||
loss_name (str): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_boundary'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reduction='mean',
|
||||
class_weight=None,
|
||||
loss_weight=1.0,
|
||||
ignore_index=255,
|
||||
loss_name='loss_huasdorff_disstance',
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.class_weight = get_class_weight(class_weight)
|
||||
self._loss_name = loss_name
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def forward(self,
|
||||
pred: Tensor,
|
||||
target: Tensor,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
**kwargs) -> Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
pred (Tensor): Predictions of the segmentation head. (B, C, H, W)
|
||||
target (Tensor): Ground truth of the image. (B, H, W)
|
||||
avg_factor (int, optional): Average factor that is used to
|
||||
average the loss. Defaults to None.
|
||||
reduction_override (str, optional): The reduction method used
|
||||
to override the original reduction method of the loss.
|
||||
Options are "none", "mean" and "sum".
|
||||
Returns:
|
||||
Tensor: Loss tensor.
|
||||
"""
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
if self.class_weight is not None:
|
||||
class_weight = pred.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
|
||||
pred_soft = F.softmax(pred, dim=1)
|
||||
valid_mask = (target != self.ignore_index).long()
|
||||
target = target * valid_mask
|
||||
|
||||
with torch.no_grad():
|
||||
gt_dtm = compute_dtm(target.cpu(), pred_soft)
|
||||
gt_dtm = gt_dtm.float()
|
||||
seg_dtm2 = compute_dtm(
|
||||
pred_soft.argmax(dim=1, keepdim=False).cpu(), pred_soft)
|
||||
seg_dtm2 = seg_dtm2.float()
|
||||
|
||||
loss_hd = self.loss_weight * hd_loss(
|
||||
pred_soft,
|
||||
target,
|
||||
seg_dtm=seg_dtm2,
|
||||
gt_dtm=gt_dtm,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor,
|
||||
class_weight=class_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
return loss_hd
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
return self._loss_name
|
||||
99
finetune/mmseg/models/losses/kldiv_loss.py
Normal file
99
finetune/mmseg/models/losses/kldiv_loss.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class KLDivLoss(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
temperature: float = 1.0,
|
||||
reduction: str = 'mean',
|
||||
loss_name: str = 'loss_kld'):
|
||||
"""Kullback-Leibler divergence Loss.
|
||||
|
||||
<https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Temperature param
|
||||
reduction (str, optional): The method to reduce the loss into a
|
||||
scalar. Default is "mean". Options are "none", "sum",
|
||||
and "mean"
|
||||
"""
|
||||
|
||||
assert isinstance(temperature, (float, int)), \
|
||||
'Expected temperature to be' \
|
||||
f'float or int, but got {temperature.__class__.__name__} instead'
|
||||
assert temperature != 0., 'Temperature must not be zero'
|
||||
|
||||
assert reduction in ['mean', 'none', 'sum'], \
|
||||
'Reduction must be one of the options ("mean", ' \
|
||||
f'"sum", "none"), but got {reduction}'
|
||||
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.reduction = reduction
|
||||
self._loss_name = loss_name
|
||||
|
||||
def forward(self, input: torch.Tensor, target: torch.Tensor):
|
||||
"""Forward function. Calculate KL divergence Loss.
|
||||
|
||||
Args:
|
||||
input (Tensor): Logit tensor,
|
||||
the data type is float32 or float64.
|
||||
The shape is (N, C) where N is batchsize and C is number of
|
||||
channels.
|
||||
If there more than 2 dimensions, shape is (N, C, D1, D2, ...
|
||||
Dk), k>= 1
|
||||
target (Tensor): Logit tensor,
|
||||
the data type is float32 or float64.
|
||||
input and target must be with the same shape.
|
||||
|
||||
Returns:
|
||||
(Tensor): Reduced loss.
|
||||
"""
|
||||
assert isinstance(input, torch.Tensor), 'Expected input to' \
|
||||
f'be Tensor, but got {input.__class__.__name__} instead'
|
||||
assert isinstance(target, torch.Tensor), 'Expected target to' \
|
||||
f'be Tensor, but got {target.__class__.__name__} instead'
|
||||
|
||||
assert input.shape == target.shape, 'Input and target ' \
|
||||
'must have same shape,' \
|
||||
f'but got shapes {input.shape} and {target.shape}'
|
||||
|
||||
input = F.softmax(input / self.temperature, dim=1)
|
||||
target = F.softmax(target / self.temperature, dim=1)
|
||||
|
||||
loss = F.kl_div(input, target, reduction='none', log_target=False)
|
||||
loss = loss * self.temperature**2
|
||||
|
||||
batch_size = input.shape[0]
|
||||
|
||||
if self.reduction == 'sum':
|
||||
# Change view to calculate instance-wise sum
|
||||
loss = loss.view(batch_size, -1)
|
||||
return torch.sum(loss, dim=1)
|
||||
|
||||
elif self.reduction == 'mean':
|
||||
# Change view to calculate instance-wise mean
|
||||
loss = loss.view(batch_size, -1)
|
||||
return torch.mean(loss, dim=1)
|
||||
|
||||
return loss
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
||||
323
finetune/mmseg/models/losses/lovasz_loss.py
Normal file
323
finetune/mmseg/models/losses/lovasz_loss.py
Normal file
@@ -0,0 +1,323 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor
|
||||
ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim
|
||||
Berman 2018 ESAT-PSI KU Leuven (MIT License)"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.utils import is_list_of
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import get_class_weight, weight_reduce_loss
|
||||
|
||||
|
||||
def lovasz_grad(gt_sorted):
|
||||
"""Computes gradient of the Lovasz extension w.r.t sorted errors.
|
||||
|
||||
See Alg. 1 in paper.
|
||||
"""
|
||||
p = len(gt_sorted)
|
||||
gts = gt_sorted.sum()
|
||||
intersection = gts - gt_sorted.float().cumsum(0)
|
||||
union = gts + (1 - gt_sorted).float().cumsum(0)
|
||||
jaccard = 1. - intersection / union
|
||||
if p > 1: # cover 1-pixel case
|
||||
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
|
||||
return jaccard
|
||||
|
||||
|
||||
def flatten_binary_logits(logits, labels, ignore_index=None):
|
||||
"""Flattens predictions in the batch (binary case) Remove labels equal to
|
||||
'ignore_index'."""
|
||||
logits = logits.view(-1)
|
||||
labels = labels.view(-1)
|
||||
if ignore_index is None:
|
||||
return logits, labels
|
||||
valid = (labels != ignore_index)
|
||||
vlogits = logits[valid]
|
||||
vlabels = labels[valid]
|
||||
return vlogits, vlabels
|
||||
|
||||
|
||||
def flatten_probs(probs, labels, ignore_index=None):
|
||||
"""Flattens predictions in the batch."""
|
||||
if probs.dim() == 3:
|
||||
# assumes output of a sigmoid layer
|
||||
B, H, W = probs.size()
|
||||
probs = probs.view(B, 1, H, W)
|
||||
B, C, H, W = probs.size()
|
||||
probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C
|
||||
labels = labels.view(-1)
|
||||
if ignore_index is None:
|
||||
return probs, labels
|
||||
valid = (labels != ignore_index)
|
||||
vprobs = probs[valid.nonzero().squeeze()]
|
||||
vlabels = labels[valid]
|
||||
return vprobs, vlabels
|
||||
|
||||
|
||||
def lovasz_hinge_flat(logits, labels):
|
||||
"""Binary Lovasz hinge loss.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): [P], logits at each prediction
|
||||
(between -infty and +infty).
|
||||
labels (torch.Tensor): [P], binary ground truth labels (0 or 1).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss.
|
||||
"""
|
||||
if len(labels) == 0:
|
||||
# only void pixels, the gradients should be 0
|
||||
return logits.sum() * 0.
|
||||
signs = 2. * labels.float() - 1.
|
||||
errors = (1. - logits * signs)
|
||||
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
|
||||
perm = perm.data
|
||||
gt_sorted = labels[perm]
|
||||
grad = lovasz_grad(gt_sorted)
|
||||
loss = torch.dot(F.relu(errors_sorted), grad)
|
||||
return loss
|
||||
|
||||
|
||||
def lovasz_hinge(logits,
|
||||
labels,
|
||||
classes='present',
|
||||
per_image=False,
|
||||
class_weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
ignore_index=255):
|
||||
"""Binary Lovasz hinge loss.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): [B, H, W], logits at each pixel
|
||||
(between -infty and +infty).
|
||||
labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1).
|
||||
classes (str | list[int], optional): Placeholder, to be consistent with
|
||||
other loss. Default: None.
|
||||
per_image (bool, optional): If per_image is True, compute the loss per
|
||||
image instead of per batch. Default: False.
|
||||
class_weight (list[float], optional): Placeholder, to be consistent
|
||||
with other loss. Default: None.
|
||||
reduction (str, optional): The method used to reduce the loss. Options
|
||||
are "none", "mean" and "sum". This parameter only works when
|
||||
per_image is True. Default: 'mean'.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. This parameter only works when per_image is True.
|
||||
Default: None.
|
||||
ignore_index (int | None): The label index to be ignored. Default: 255.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss.
|
||||
"""
|
||||
if per_image:
|
||||
loss = [
|
||||
lovasz_hinge_flat(*flatten_binary_logits(
|
||||
logit.unsqueeze(0), label.unsqueeze(0), ignore_index))
|
||||
for logit, label in zip(logits, labels)
|
||||
]
|
||||
loss = weight_reduce_loss(
|
||||
torch.stack(loss), None, reduction, avg_factor)
|
||||
else:
|
||||
loss = lovasz_hinge_flat(
|
||||
*flatten_binary_logits(logits, labels, ignore_index))
|
||||
return loss
|
||||
|
||||
|
||||
def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None):
|
||||
"""Multi-class Lovasz-Softmax loss.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): [P, C], class probabilities at each prediction
|
||||
(between 0 and 1).
|
||||
labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1).
|
||||
classes (str | list[int], optional): Classes chosen to calculate loss.
|
||||
'all' for all classes, 'present' for classes present in labels, or
|
||||
a list of classes to average. Default: 'present'.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss.
|
||||
"""
|
||||
if probs.numel() == 0:
|
||||
# only void pixels, the gradients should be 0
|
||||
return probs * 0.
|
||||
C = probs.size(1)
|
||||
losses = []
|
||||
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
|
||||
for c in class_to_sum:
|
||||
fg = (labels == c).float() # foreground for class c
|
||||
if (classes == 'present' and fg.sum() == 0):
|
||||
continue
|
||||
if C == 1:
|
||||
if len(classes) > 1:
|
||||
raise ValueError('Sigmoid output possible only with 1 class')
|
||||
class_pred = probs[:, 0]
|
||||
else:
|
||||
class_pred = probs[:, c]
|
||||
errors = (fg - class_pred).abs()
|
||||
errors_sorted, perm = torch.sort(errors, 0, descending=True)
|
||||
perm = perm.data
|
||||
fg_sorted = fg[perm]
|
||||
loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted))
|
||||
if class_weight is not None:
|
||||
loss *= class_weight[c]
|
||||
losses.append(loss)
|
||||
return torch.stack(losses).mean()
|
||||
|
||||
|
||||
def lovasz_softmax(probs,
|
||||
labels,
|
||||
classes='present',
|
||||
per_image=False,
|
||||
class_weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
ignore_index=255):
|
||||
"""Multi-class Lovasz-Softmax loss.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): [B, C, H, W], class probabilities at each
|
||||
prediction (between 0 and 1).
|
||||
labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and
|
||||
C - 1).
|
||||
classes (str | list[int], optional): Classes chosen to calculate loss.
|
||||
'all' for all classes, 'present' for classes present in labels, or
|
||||
a list of classes to average. Default: 'present'.
|
||||
per_image (bool, optional): If per_image is True, compute the loss per
|
||||
image instead of per batch. Default: False.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
Default: None.
|
||||
reduction (str, optional): The method used to reduce the loss. Options
|
||||
are "none", "mean" and "sum". This parameter only works when
|
||||
per_image is True. Default: 'mean'.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. This parameter only works when per_image is True.
|
||||
Default: None.
|
||||
ignore_index (int | None): The label index to be ignored. Default: 255.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss.
|
||||
"""
|
||||
|
||||
if per_image:
|
||||
loss = [
|
||||
lovasz_softmax_flat(
|
||||
*flatten_probs(
|
||||
prob.unsqueeze(0), label.unsqueeze(0), ignore_index),
|
||||
classes=classes,
|
||||
class_weight=class_weight)
|
||||
for prob, label in zip(probs, labels)
|
||||
]
|
||||
loss = weight_reduce_loss(
|
||||
torch.stack(loss), None, reduction, avg_factor)
|
||||
else:
|
||||
loss = lovasz_softmax_flat(
|
||||
*flatten_probs(probs, labels, ignore_index),
|
||||
classes=classes,
|
||||
class_weight=class_weight)
|
||||
return loss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class LovaszLoss(nn.Module):
|
||||
"""LovaszLoss.
|
||||
|
||||
This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate
|
||||
for the optimization of the intersection-over-union measure in neural
|
||||
networks <https://arxiv.org/abs/1705.08790>`_.
|
||||
|
||||
Args:
|
||||
loss_type (str, optional): Binary or multi-class loss.
|
||||
Default: 'multi_class'. Options are "binary" and "multi_class".
|
||||
classes (str | list[int], optional): Classes chosen to calculate loss.
|
||||
'all' for all classes, 'present' for classes present in labels, or
|
||||
a list of classes to average. Default: 'present'.
|
||||
per_image (bool, optional): If per_image is True, compute the loss per
|
||||
image instead of per batch. Default: False.
|
||||
reduction (str, optional): The method used to reduce the loss. Options
|
||||
are "none", "mean" and "sum". This parameter only works when
|
||||
per_image is True. Default: 'mean'.
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
||||
loss_name (str, optional): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_lovasz'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss_type='multi_class',
|
||||
classes='present',
|
||||
per_image=False,
|
||||
reduction='mean',
|
||||
class_weight=None,
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz'):
|
||||
super().__init__()
|
||||
assert loss_type in ('binary', 'multi_class'), "loss_type should be \
|
||||
'binary' or 'multi_class'."
|
||||
|
||||
if loss_type == 'binary':
|
||||
self.cls_criterion = lovasz_hinge
|
||||
else:
|
||||
self.cls_criterion = lovasz_softmax
|
||||
assert classes in ('all', 'present') or is_list_of(classes, int)
|
||||
if not per_image:
|
||||
assert reduction == 'none', "reduction should be 'none' when \
|
||||
per_image is False."
|
||||
|
||||
self.classes = classes
|
||||
self.per_image = per_image
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.class_weight = get_class_weight(class_weight)
|
||||
self._loss_name = loss_name
|
||||
|
||||
def forward(self,
|
||||
cls_score,
|
||||
label,
|
||||
weight=None,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
**kwargs):
|
||||
"""Forward function."""
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
if self.class_weight is not None:
|
||||
class_weight = cls_score.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
|
||||
# if multi-class loss, transform logits to probs
|
||||
if self.cls_criterion == lovasz_softmax:
|
||||
cls_score = F.softmax(cls_score, dim=1)
|
||||
|
||||
loss_cls = self.loss_weight * self.cls_criterion(
|
||||
cls_score,
|
||||
label,
|
||||
self.classes,
|
||||
self.per_image,
|
||||
class_weight=class_weight,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor,
|
||||
**kwargs)
|
||||
return loss_cls
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
||||
94
finetune/mmseg/models/losses/ohem_cross_entropy_loss.py
Normal file
94
finetune/mmseg/models/losses/ohem_cross_entropy_loss.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class OhemCrossEntropy(nn.Module):
|
||||
"""OhemCrossEntropy loss.
|
||||
|
||||
This func is modified from
|
||||
`PIDNet <https://github.com/XuJiacong/PIDNet/blob/main/utils/criterion.py#L43>`_. # noqa
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Args:
|
||||
ignore_label (int): Labels to ignore when computing the loss.
|
||||
Default: 255
|
||||
thresh (float, optional): The threshold for hard example selection.
|
||||
Below which, are prediction with low confidence. If not
|
||||
specified, the hard examples will be pixels of top ``min_kept``
|
||||
loss. Default: 0.7.
|
||||
min_kept (int, optional): The minimum number of predictions to keep.
|
||||
Default: 100000.
|
||||
loss_weight (float): Weight of the loss. Defaults to 1.0.
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_name (str): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_boundary'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ignore_label: int = 255,
|
||||
thres: float = 0.7,
|
||||
min_kept: int = 100000,
|
||||
loss_weight: float = 1.0,
|
||||
class_weight: Optional[Union[List[float], str]] = None,
|
||||
loss_name: str = 'loss_ohem'):
|
||||
super().__init__()
|
||||
self.thresh = thres
|
||||
self.min_kept = max(1, min_kept)
|
||||
self.ignore_label = ignore_label
|
||||
self.loss_weight = loss_weight
|
||||
self.loss_name_ = loss_name
|
||||
self.class_weight = class_weight
|
||||
|
||||
def forward(self, score: Tensor, target: Tensor) -> Tensor:
|
||||
"""Forward function.
|
||||
Args:
|
||||
score (Tensor): Predictions of the segmentation head.
|
||||
target (Tensor): Ground truth of the image.
|
||||
|
||||
Returns:
|
||||
Tensor: Loss tensor.
|
||||
"""
|
||||
# score: (N, C, H, W)
|
||||
pred = F.softmax(score, dim=1)
|
||||
if self.class_weight is not None:
|
||||
class_weight = score.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
|
||||
pixel_losses = F.cross_entropy(
|
||||
score,
|
||||
target,
|
||||
weight=class_weight,
|
||||
ignore_index=self.ignore_label,
|
||||
reduction='none').contiguous().view(-1) # (N*H*W)
|
||||
mask = target.contiguous().view(-1) != self.ignore_label # (N*H*W)
|
||||
|
||||
tmp_target = target.clone() # (N, H, W)
|
||||
tmp_target[tmp_target == self.ignore_label] = 0
|
||||
# pred: (N, C, H, W) -> (N*H*W, C)
|
||||
pred = pred.gather(1, tmp_target.unsqueeze(1))
|
||||
# pred: (N*H*W, C) -> (N*H*W), ind: (N*H*W)
|
||||
pred, ind = pred.contiguous().view(-1, )[mask].contiguous().sort()
|
||||
if pred.numel() > 0:
|
||||
min_value = pred[min(self.min_kept, pred.numel() - 1)]
|
||||
else:
|
||||
return score.new_tensor(0.0)
|
||||
threshold = max(min_value, self.thresh)
|
||||
|
||||
pixel_losses = pixel_losses[mask][ind]
|
||||
pixel_losses = pixel_losses[pred < threshold]
|
||||
return self.loss_weight * pixel_losses.mean()
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
return self.loss_name_
|
||||
122
finetune/mmseg/models/losses/silog_loss.py
Normal file
122
finetune/mmseg/models/losses/silog_loss.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import weight_reduce_loss
|
||||
|
||||
|
||||
def silog_loss(pred: Tensor,
|
||||
target: Tensor,
|
||||
weight: Optional[Tensor] = None,
|
||||
eps: float = 1e-4,
|
||||
reduction: Union[str, None] = 'mean',
|
||||
avg_factor: Optional[int] = None) -> Tensor:
|
||||
"""Computes the Scale-Invariant Logarithmic (SI-Log) loss between
|
||||
prediction and target.
|
||||
|
||||
Args:
|
||||
pred (Tensor): Predicted output.
|
||||
target (Tensor): Ground truth.
|
||||
weight (Optional[Tensor]): Optional weight to apply on the loss.
|
||||
eps (float): Epsilon value to avoid division and log(0).
|
||||
reduction (Union[str, None]): Specifies the reduction to apply to the
|
||||
output: 'mean', 'sum' or None.
|
||||
avg_factor (Optional[int]): Optional average factor for the loss.
|
||||
|
||||
Returns:
|
||||
Tensor: The calculated SI-Log loss.
|
||||
"""
|
||||
pred, target = pred.flatten(1), target.flatten(1)
|
||||
valid_mask = (target > eps).detach().float()
|
||||
|
||||
diff_log = torch.log(target.clamp(min=eps)) - torch.log(
|
||||
pred.clamp(min=eps))
|
||||
|
||||
valid_mask = (target > eps).detach() & (~torch.isnan(diff_log))
|
||||
diff_log[~valid_mask] = 0.0
|
||||
valid_mask = valid_mask.float()
|
||||
|
||||
diff_log_sq_mean = (diff_log.pow(2) * valid_mask).sum(
|
||||
dim=1) / valid_mask.sum(dim=1).clamp(min=eps)
|
||||
diff_log_mean = (diff_log * valid_mask).sum(dim=1) / valid_mask.sum(
|
||||
dim=1).clamp(min=eps)
|
||||
|
||||
loss = torch.sqrt(diff_log_sq_mean - 0.5 * diff_log_mean.pow(2))
|
||||
|
||||
if weight is not None:
|
||||
weight = weight.float()
|
||||
|
||||
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
|
||||
return loss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SiLogLoss(nn.Module):
|
||||
"""Compute SiLog loss.
|
||||
|
||||
Args:
|
||||
reduction (str, optional): The method used
|
||||
to reduce the loss. Options are "none",
|
||||
"mean" and "sum". Defaults to 'mean'.
|
||||
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
|
||||
eps (float): Avoid dividing by zero. Defaults to 1e-3.
|
||||
loss_name (str, optional): Name of the loss item. If you want this
|
||||
loss item to be included into the backward graph, `loss_` must
|
||||
be the prefix of the name. Defaults to 'loss_silog'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reduction='mean',
|
||||
loss_weight=1.0,
|
||||
eps=1e-6,
|
||||
loss_name='loss_silog'):
|
||||
super().__init__()
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.eps = eps
|
||||
self._loss_name = loss_name
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred,
|
||||
target,
|
||||
weight=None,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
):
|
||||
|
||||
assert pred.shape == target.shape, 'the shapes of pred ' \
|
||||
f'({pred.shape}) and target ({target.shape}) are mismatch'
|
||||
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
|
||||
loss = self.loss_weight * silog_loss(
|
||||
pred,
|
||||
target,
|
||||
weight,
|
||||
eps=self.eps,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor,
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
||||
137
finetune/mmseg/models/losses/tversky_loss.py
Normal file
137
finetune/mmseg/models/losses/tversky_loss.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Modified from
|
||||
https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/dice_loss.py#L333
|
||||
(Apache-2.0 License)"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import LOSSES
|
||||
from .utils import get_class_weight, weighted_loss
|
||||
|
||||
|
||||
@weighted_loss
|
||||
def tversky_loss(pred,
|
||||
target,
|
||||
valid_mask,
|
||||
alpha=0.3,
|
||||
beta=0.7,
|
||||
smooth=1,
|
||||
class_weight=None,
|
||||
ignore_index=255):
|
||||
assert pred.shape[0] == target.shape[0]
|
||||
total_loss = 0
|
||||
num_classes = pred.shape[1]
|
||||
for i in range(num_classes):
|
||||
if i != ignore_index:
|
||||
tversky_loss = binary_tversky_loss(
|
||||
pred[:, i],
|
||||
target[..., i],
|
||||
valid_mask=valid_mask,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
smooth=smooth)
|
||||
if class_weight is not None:
|
||||
tversky_loss *= class_weight[i]
|
||||
total_loss += tversky_loss
|
||||
return total_loss / num_classes
|
||||
|
||||
|
||||
@weighted_loss
|
||||
def binary_tversky_loss(pred,
|
||||
target,
|
||||
valid_mask,
|
||||
alpha=0.3,
|
||||
beta=0.7,
|
||||
smooth=1):
|
||||
assert pred.shape[0] == target.shape[0]
|
||||
pred = pred.reshape(pred.shape[0], -1)
|
||||
target = target.reshape(target.shape[0], -1)
|
||||
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
|
||||
|
||||
TP = torch.sum(torch.mul(pred, target) * valid_mask, dim=1)
|
||||
FP = torch.sum(torch.mul(pred, 1 - target) * valid_mask, dim=1)
|
||||
FN = torch.sum(torch.mul(1 - pred, target) * valid_mask, dim=1)
|
||||
tversky = (TP + smooth) / (TP + alpha * FP + beta * FN + smooth)
|
||||
|
||||
return 1 - tversky
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class TverskyLoss(nn.Module):
|
||||
"""TverskyLoss. This loss is proposed in `Tversky loss function for image
|
||||
segmentation using 3D fully convolutional deep networks.
|
||||
|
||||
<https://arxiv.org/abs/1706.05721>`_.
|
||||
Args:
|
||||
smooth (float): A float number to smooth loss, and avoid NaN error.
|
||||
Default: 1.
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_weight (float, optional): Weight of the loss. Default to 1.0.
|
||||
ignore_index (int | None): The label index to be ignored. Default: 255.
|
||||
alpha(float, in [0, 1]):
|
||||
The coefficient of false positives. Default: 0.3.
|
||||
beta (float, in [0, 1]):
|
||||
The coefficient of false negatives. Default: 0.7.
|
||||
Note: alpha + beta = 1.
|
||||
loss_name (str, optional): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_tversky'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
smooth=1,
|
||||
class_weight=None,
|
||||
loss_weight=1.0,
|
||||
ignore_index=255,
|
||||
alpha=0.3,
|
||||
beta=0.7,
|
||||
loss_name='loss_tversky'):
|
||||
super().__init__()
|
||||
self.smooth = smooth
|
||||
self.class_weight = get_class_weight(class_weight)
|
||||
self.loss_weight = loss_weight
|
||||
self.ignore_index = ignore_index
|
||||
assert (alpha + beta == 1.0), 'Sum of alpha and beta but be 1.0!'
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self._loss_name = loss_name
|
||||
|
||||
def forward(self, pred, target, **kwargs):
|
||||
if self.class_weight is not None:
|
||||
class_weight = pred.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
|
||||
pred = F.softmax(pred, dim=1)
|
||||
num_classes = pred.shape[1]
|
||||
one_hot_target = F.one_hot(
|
||||
torch.clamp(target.long(), 0, num_classes - 1),
|
||||
num_classes=num_classes)
|
||||
valid_mask = (target != self.ignore_index).long()
|
||||
|
||||
loss = self.loss_weight * tversky_loss(
|
||||
pred,
|
||||
one_hot_target,
|
||||
valid_mask=valid_mask,
|
||||
alpha=self.alpha,
|
||||
beta=self.beta,
|
||||
smooth=self.smooth,
|
||||
class_weight=class_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
return loss
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
||||
129
finetune/mmseg/models/losses/utils.py
Normal file
129
finetune/mmseg/models/losses/utils.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.fileio import load
|
||||
|
||||
|
||||
def get_class_weight(class_weight):
|
||||
"""Get class weight for loss function.
|
||||
|
||||
Args:
|
||||
class_weight (list[float] | str | None): If class_weight is a str,
|
||||
take it as a file name and read from it.
|
||||
"""
|
||||
if isinstance(class_weight, str):
|
||||
# take it as a file path
|
||||
if class_weight.endswith('.npy'):
|
||||
class_weight = np.load(class_weight)
|
||||
else:
|
||||
# pkl, json or yaml
|
||||
class_weight = load(class_weight)
|
||||
|
||||
return class_weight
|
||||
|
||||
|
||||
def reduce_loss(loss, reduction) -> torch.Tensor:
|
||||
"""Reduce loss as specified.
|
||||
|
||||
Args:
|
||||
loss (Tensor): Elementwise loss tensor.
|
||||
reduction (str): Options are "none", "mean" and "sum".
|
||||
|
||||
Return:
|
||||
Tensor: Reduced loss tensor.
|
||||
"""
|
||||
reduction_enum = F._Reduction.get_enum(reduction)
|
||||
# none: 0, elementwise_mean:1, sum: 2
|
||||
if reduction_enum == 0:
|
||||
return loss
|
||||
elif reduction_enum == 1:
|
||||
return loss.mean()
|
||||
elif reduction_enum == 2:
|
||||
return loss.sum()
|
||||
|
||||
|
||||
def weight_reduce_loss(loss,
|
||||
weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None) -> torch.Tensor:
|
||||
"""Apply element-wise weight and reduce loss.
|
||||
|
||||
Args:
|
||||
loss (Tensor): Element-wise loss.
|
||||
weight (Tensor): Element-wise weights.
|
||||
reduction (str): Same as built-in losses of PyTorch.
|
||||
avg_factor (float): Average factor when computing the mean of losses.
|
||||
|
||||
Returns:
|
||||
Tensor: Processed loss values.
|
||||
"""
|
||||
# if weight is specified, apply element-wise weight
|
||||
if weight is not None:
|
||||
assert weight.dim() == loss.dim()
|
||||
if weight.dim() > 1:
|
||||
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
|
||||
loss = loss * weight
|
||||
|
||||
# if avg_factor is not specified, just reduce the loss
|
||||
if avg_factor is None:
|
||||
loss = reduce_loss(loss, reduction)
|
||||
else:
|
||||
# if reduction is mean, then average the loss by avg_factor
|
||||
if reduction == 'mean':
|
||||
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
|
||||
# i.e., all labels of an image belong to ignore index.
|
||||
eps = torch.finfo(torch.float32).eps
|
||||
loss = loss.sum() / (avg_factor + eps)
|
||||
# if reduction is 'none', then do nothing, otherwise raise an error
|
||||
elif reduction != 'none':
|
||||
raise ValueError('avg_factor can not be used with reduction="sum"')
|
||||
return loss
|
||||
|
||||
|
||||
def weighted_loss(loss_func):
|
||||
"""Create a weighted version of a given loss function.
|
||||
|
||||
To use this decorator, the loss function must have the signature like
|
||||
`loss_func(pred, target, **kwargs)`. The function only needs to compute
|
||||
element-wise loss without any reduction. This decorator will add weight
|
||||
and reduction arguments to the function. The decorated function will have
|
||||
the signature like `loss_func(pred, target, weight=None, reduction='mean',
|
||||
avg_factor=None, **kwargs)`.
|
||||
|
||||
:Example:
|
||||
|
||||
>>> import torch
|
||||
>>> @weighted_loss
|
||||
>>> def l1_loss(pred, target):
|
||||
>>> return (pred - target).abs()
|
||||
|
||||
>>> pred = torch.Tensor([0, 2, 3])
|
||||
>>> target = torch.Tensor([1, 1, 1])
|
||||
>>> weight = torch.Tensor([1, 0, 1])
|
||||
|
||||
>>> l1_loss(pred, target)
|
||||
tensor(1.3333)
|
||||
>>> l1_loss(pred, target, weight)
|
||||
tensor(1.)
|
||||
>>> l1_loss(pred, target, reduction='none')
|
||||
tensor([1., 1., 2.])
|
||||
>>> l1_loss(pred, target, weight, avg_factor=2)
|
||||
tensor(1.5000)
|
||||
"""
|
||||
|
||||
@functools.wraps(loss_func)
|
||||
def wrapper(pred,
|
||||
target,
|
||||
weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
**kwargs):
|
||||
# get element-wise loss
|
||||
loss = loss_func(pred, target, **kwargs)
|
||||
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
|
||||
return loss
|
||||
|
||||
return wrapper
|
||||
14
finetune/mmseg/models/necks/__init__.py
Normal file
14
finetune/mmseg/models/necks/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .featurepyramid import Feature2Pyramid
|
||||
from .fpn import FPN
|
||||
from .ic_neck import ICNeck
|
||||
from .jpu import JPU
|
||||
from .mla_neck import MLANeck
|
||||
from .multilevel_neck import MultiLevelNeck
|
||||
from .fusion_transformer import FusionTransformer
|
||||
from .fusion_multilevel_neck import FusionMultiLevelNeck
|
||||
|
||||
__all__ = [
|
||||
'FPN', 'MultiLevelNeck', 'MLANeck', 'ICNeck', 'JPU', 'Feature2Pyramid',
|
||||
'FusionTransformer', 'FusionMultiLevelNeck'
|
||||
]
|
||||
67
finetune/mmseg/models/necks/featurepyramid.py
Normal file
67
finetune/mmseg/models/necks/featurepyramid.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class Feature2Pyramid(nn.Module):
|
||||
"""Feature2Pyramid.
|
||||
|
||||
A neck structure connect ViT backbone and decoder_heads.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Embedding dimension.
|
||||
rescales (list[float]): Different sampling multiples were
|
||||
used to obtain pyramid features. Default: [4, 2, 1, 0.5].
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='SyncBN', requires_grad=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dim,
|
||||
rescales=[4, 2, 1, 0.5],
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||||
super().__init__()
|
||||
self.rescales = rescales
|
||||
self.upsample_4x = None
|
||||
for k in self.rescales:
|
||||
if k == 4:
|
||||
self.upsample_4x = nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
embed_dim, embed_dim, kernel_size=2, stride=2),
|
||||
build_norm_layer(norm_cfg, embed_dim)[1],
|
||||
nn.GELU(),
|
||||
nn.ConvTranspose2d(
|
||||
embed_dim, embed_dim, kernel_size=2, stride=2),
|
||||
)
|
||||
elif k == 2:
|
||||
self.upsample_2x = nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
embed_dim, embed_dim, kernel_size=2, stride=2))
|
||||
elif k == 1:
|
||||
self.identity = nn.Identity()
|
||||
elif k == 0.5:
|
||||
self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
elif k == 0.25:
|
||||
self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4)
|
||||
else:
|
||||
raise KeyError(f'invalid {k} for feature2pyramid')
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == len(self.rescales)
|
||||
outputs = []
|
||||
if self.upsample_4x is not None:
|
||||
ops = [
|
||||
self.upsample_4x, self.upsample_2x, self.identity,
|
||||
self.downsample_2x
|
||||
]
|
||||
else:
|
||||
ops = [
|
||||
self.upsample_2x, self.identity, self.downsample_2x,
|
||||
self.downsample_4x
|
||||
]
|
||||
for i in range(len(inputs)):
|
||||
outputs.append(ops[i](inputs[i]))
|
||||
return tuple(outputs)
|
||||
212
finetune/mmseg/models/necks/fpn.py
Normal file
212
finetune/mmseg/models/necks/fpn.py
Normal file
@@ -0,0 +1,212 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FPN(BaseModule):
|
||||
"""Feature Pyramid Network.
|
||||
|
||||
This neck is the implementation of `Feature Pyramid Networks for Object
|
||||
Detection <https://arxiv.org/abs/1612.03144>`_.
|
||||
|
||||
Args:
|
||||
in_channels (list[int]): Number of input channels per scale.
|
||||
out_channels (int): Number of output channels (used at each scale).
|
||||
num_outs (int): Number of output scales.
|
||||
start_level (int): Index of the start input backbone level used to
|
||||
build the feature pyramid. Default: 0.
|
||||
end_level (int): Index of the end input backbone level (exclusive) to
|
||||
build the feature pyramid. Default: -1, which means the last level.
|
||||
add_extra_convs (bool | str): If bool, it decides whether to add conv
|
||||
layers on top of the original feature maps. Default to False.
|
||||
If True, its actual mode is specified by `extra_convs_on_inputs`.
|
||||
If str, it specifies the source feature map of the extra convs.
|
||||
Only the following options are allowed
|
||||
|
||||
- 'on_input': Last feat map of neck inputs (i.e. backbone feature).
|
||||
- 'on_lateral': Last feature map after lateral convs.
|
||||
- 'on_output': The last output feature map after fpn convs.
|
||||
extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
|
||||
on the original feature from the backbone. If True,
|
||||
it is equivalent to `add_extra_convs='on_input'`. If False, it is
|
||||
equivalent to set `add_extra_convs='on_output'`. Default to True.
|
||||
relu_before_extra_convs (bool): Whether to apply relu before the extra
|
||||
conv. Default: False.
|
||||
no_norm_on_lateral (bool): Whether to apply norm on lateral.
|
||||
Default: False.
|
||||
conv_cfg (dict): Config dict for convolution layer. Default: None.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
act_cfg (dict): Config dict for activation layer in ConvModule.
|
||||
Default: None.
|
||||
upsample_cfg (dict): Config dict for interpolate layer.
|
||||
Default: dict(mode='nearest').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> in_channels = [2, 3, 5, 7]
|
||||
>>> scales = [340, 170, 84, 43]
|
||||
>>> inputs = [torch.rand(1, c, s, s)
|
||||
... for c, s in zip(in_channels, scales)]
|
||||
>>> self = FPN(in_channels, 11, len(in_channels)).eval()
|
||||
>>> outputs = self.forward(inputs)
|
||||
>>> for i in range(len(outputs)):
|
||||
... print(f'outputs[{i}].shape = {outputs[i].shape}')
|
||||
outputs[0].shape = torch.Size([1, 11, 340, 340])
|
||||
outputs[1].shape = torch.Size([1, 11, 170, 170])
|
||||
outputs[2].shape = torch.Size([1, 11, 84, 84])
|
||||
outputs[3].shape = torch.Size([1, 11, 43, 43])
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_outs,
|
||||
start_level=0,
|
||||
end_level=-1,
|
||||
add_extra_convs=False,
|
||||
extra_convs_on_inputs=False,
|
||||
relu_before_extra_convs=False,
|
||||
no_norm_on_lateral=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=None,
|
||||
upsample_cfg=dict(mode='nearest'),
|
||||
init_cfg=dict(
|
||||
type='Xavier', layer='Conv2d', distribution='uniform')):
|
||||
super().__init__(init_cfg)
|
||||
assert isinstance(in_channels, list)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_ins = len(in_channels)
|
||||
self.num_outs = num_outs
|
||||
self.relu_before_extra_convs = relu_before_extra_convs
|
||||
self.no_norm_on_lateral = no_norm_on_lateral
|
||||
self.fp16_enabled = False
|
||||
self.upsample_cfg = upsample_cfg.copy()
|
||||
|
||||
if end_level == -1:
|
||||
self.backbone_end_level = self.num_ins
|
||||
assert num_outs >= self.num_ins - start_level
|
||||
else:
|
||||
# if end_level < inputs, no extra level is allowed
|
||||
self.backbone_end_level = end_level
|
||||
assert end_level <= len(in_channels)
|
||||
assert num_outs == end_level - start_level
|
||||
self.start_level = start_level
|
||||
self.end_level = end_level
|
||||
self.add_extra_convs = add_extra_convs
|
||||
assert isinstance(add_extra_convs, (str, bool))
|
||||
if isinstance(add_extra_convs, str):
|
||||
# Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
|
||||
assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
|
||||
elif add_extra_convs: # True
|
||||
if extra_convs_on_inputs:
|
||||
# For compatibility with previous release
|
||||
# TODO: deprecate `extra_convs_on_inputs`
|
||||
self.add_extra_convs = 'on_input'
|
||||
else:
|
||||
self.add_extra_convs = 'on_output'
|
||||
|
||||
self.lateral_convs = nn.ModuleList()
|
||||
self.fpn_convs = nn.ModuleList()
|
||||
|
||||
for i in range(self.start_level, self.backbone_end_level):
|
||||
l_conv = ConvModule(
|
||||
in_channels[i],
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
|
||||
act_cfg=act_cfg,
|
||||
inplace=False)
|
||||
fpn_conv = ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
inplace=False)
|
||||
|
||||
self.lateral_convs.append(l_conv)
|
||||
self.fpn_convs.append(fpn_conv)
|
||||
|
||||
# add extra conv layers (e.g., RetinaNet)
|
||||
extra_levels = num_outs - self.backbone_end_level + self.start_level
|
||||
if self.add_extra_convs and extra_levels >= 1:
|
||||
for i in range(extra_levels):
|
||||
if i == 0 and self.add_extra_convs == 'on_input':
|
||||
in_channels = self.in_channels[self.backbone_end_level - 1]
|
||||
else:
|
||||
in_channels = out_channels
|
||||
extra_fpn_conv = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
inplace=False)
|
||||
self.fpn_convs.append(extra_fpn_conv)
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
|
||||
# build laterals
|
||||
laterals = [
|
||||
lateral_conv(inputs[i + self.start_level])
|
||||
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||
]
|
||||
|
||||
# build top-down path
|
||||
used_backbone_levels = len(laterals)
|
||||
for i in range(used_backbone_levels - 1, 0, -1):
|
||||
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
|
||||
# it cannot co-exist with `size` in `F.interpolate`.
|
||||
if 'scale_factor' in self.upsample_cfg:
|
||||
laterals[i - 1] = laterals[i - 1] + resize(
|
||||
laterals[i], **self.upsample_cfg)
|
||||
else:
|
||||
prev_shape = laterals[i - 1].shape[2:]
|
||||
laterals[i - 1] = laterals[i - 1] + resize(
|
||||
laterals[i], size=prev_shape, **self.upsample_cfg)
|
||||
|
||||
# build outputs
|
||||
# part 1: from original levels
|
||||
outs = [
|
||||
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
|
||||
]
|
||||
# part 2: add extra levels
|
||||
if self.num_outs > len(outs):
|
||||
# use max pool to get more levels on top of outputs
|
||||
# (e.g., Faster R-CNN, Mask R-CNN)
|
||||
if not self.add_extra_convs:
|
||||
for i in range(self.num_outs - used_backbone_levels):
|
||||
outs.append(F.max_pool2d(outs[-1], 1, stride=2))
|
||||
# add conv layers on top of original feature maps (RetinaNet)
|
||||
else:
|
||||
if self.add_extra_convs == 'on_input':
|
||||
extra_source = inputs[self.backbone_end_level - 1]
|
||||
elif self.add_extra_convs == 'on_lateral':
|
||||
extra_source = laterals[-1]
|
||||
elif self.add_extra_convs == 'on_output':
|
||||
extra_source = outs[-1]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
outs.append(self.fpn_convs[used_backbone_levels](extra_source))
|
||||
for i in range(used_backbone_levels + 1, self.num_outs):
|
||||
if self.relu_before_extra_convs:
|
||||
outs.append(self.fpn_convs[i](F.relu(outs[-1])))
|
||||
else:
|
||||
outs.append(self.fpn_convs[i](outs[-1]))
|
||||
return tuple(outs)
|
||||
90
finetune/mmseg/models/necks/fusion_multilevel_neck.py
Normal file
90
finetune/mmseg/models/necks/fusion_multilevel_neck.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .multilevel_neck import MultiLevelNeck
|
||||
from .fusion_transformer import FusionTransformer
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FusionMultiLevelNeck(nn.Module):
|
||||
def __init__(self,
|
||||
ts_size=10,
|
||||
in_channels_ml=[768, 768, 768, 768],
|
||||
out_channels_ml=768,
|
||||
scales_ml=[0.5, 1, 2, 4],
|
||||
norm_cfg_ml=None,
|
||||
act_cfg_ml=None,
|
||||
input_dims=768,
|
||||
embed_dims=768,
|
||||
num_layers=4,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
with_cls_token=True,
|
||||
output_cls_token=True,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
init_cfg=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super(FusionMultiLevelNeck, self).__init__()
|
||||
self.in_channels = in_channels_ml
|
||||
self.ts_size = ts_size
|
||||
self.multilevel_neck = MultiLevelNeck(
|
||||
in_channels_ml,
|
||||
out_channels_ml,
|
||||
scales_ml,
|
||||
norm_cfg_ml,
|
||||
act_cfg_ml
|
||||
)
|
||||
# self.up_head = UPHead(1024, 2816, 4)
|
||||
|
||||
self.fusion_transformer = FusionTransformer(
|
||||
input_dims,
|
||||
embed_dims,
|
||||
num_layers,
|
||||
num_heads,
|
||||
mlp_ratio,
|
||||
qkv_bias,
|
||||
drop_rate,
|
||||
attn_drop_rate,
|
||||
drop_path_rate,
|
||||
with_cls_token,
|
||||
output_cls_token,
|
||||
norm_cfg,
|
||||
act_cfg,
|
||||
num_fcs,
|
||||
norm_eval,
|
||||
with_cp,
|
||||
init_cfg,
|
||||
)
|
||||
|
||||
def init_weights(self):
|
||||
self.fusion_transformer.init_weights()
|
||||
|
||||
def forward(self, inputs, require_feat: bool = False, require_two: bool = False):
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
|
||||
inputs = self.multilevel_neck(inputs)
|
||||
|
||||
ts = self.ts_size
|
||||
b_total, c, h, w = inputs[-1].shape
|
||||
b = int(b_total / ts)
|
||||
outs = []
|
||||
for idx in range(len(inputs)):
|
||||
|
||||
input_feat = inputs[idx]
|
||||
b_total, c, h, w = inputs[idx].shape
|
||||
input_feat = input_feat.reshape(b, ts, c, h, w).permute(0, 3, 4, 1, 2).reshape(b*h*w, ts, c) # b*ts, c, h, w转换为b*h*w, ts, c
|
||||
feat_fusion = self.fusion_transformer(input_feat, require_feat, require_two)
|
||||
c_fusion = feat_fusion.shape[-1]
|
||||
feat_fusion = feat_fusion.reshape(b, h, w, c_fusion).permute(0, 3, 1, 2) # b*h*w, c -> b, c, h, w
|
||||
outs.append(feat_fusion)
|
||||
|
||||
return tuple(outs)
|
||||
166
finetune/mmseg/models/necks/fusion_transformer.py
Normal file
166
finetune/mmseg/models/necks/fusion_transformer.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright (c) Ant Group. All rights reserved.
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from mmseg.models.backbones.vit import TransformerEncoderLayer
|
||||
|
||||
# from mmseg.utils import get_root_logger
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
# @MODELS.register_module()
|
||||
class FusionTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
input_dims=768,
|
||||
embed_dims=768,
|
||||
num_layers=4,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
with_cls_token=True,
|
||||
output_cls_token=True,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
init_cfg=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super(FusionTransformer, self).__init__()
|
||||
|
||||
self.porj_linear = nn.Linear(input_dims, embed_dims)
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
|
||||
self.init_cfg = init_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for i in range(num_layers):
|
||||
self.layers.append(
|
||||
TransformerEncoderLayer(embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=mlp_ratio *
|
||||
embed_dims,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
batch_first=True))
|
||||
|
||||
def init_weights(self):
|
||||
if isinstance(self.init_cfg, dict) and \
|
||||
self.init_cfg.get('type') in ['Pretrained', 'Pretrained_Part']:
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
|
||||
if self.init_cfg.get('type') == 'Pretrained':
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
elif self.init_cfg.get('type') == 'Pretrained_Part':
|
||||
state_dict = checkpoint.copy()
|
||||
para_prefix = 'image_encoder'
|
||||
prefix_len = len(para_prefix) + 1
|
||||
for k, v in checkpoint.items():
|
||||
state_dict.pop(k)
|
||||
if para_prefix in k:
|
||||
state_dict[k[prefix_len:]] = v
|
||||
|
||||
# if 'pos_embed' in state_dict.keys():
|
||||
# if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
||||
# print_log(msg=f'Resize the pos_embed shape from '
|
||||
# f'{state_dict["pos_embed"].shape} to '
|
||||
# f'{self.pos_embed.shape}')
|
||||
# h, w = self.img_size
|
||||
# pos_size = int(
|
||||
# math.sqrt(state_dict['pos_embed'].shape[1] - 1))
|
||||
# state_dict['pos_embed'] = self.resize_pos_embed(
|
||||
# state_dict['pos_embed'],
|
||||
# (h // self.patch_size, w // self.patch_size),
|
||||
# (pos_size, pos_size), self.interpolate_mode)
|
||||
|
||||
load_state_dict(self, state_dict, strict=False, logger=None)
|
||||
elif self.init_cfg is not None:
|
||||
super().init_weights()
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
# trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def forward(self, inputs, require_feat: bool = False, require_two: bool = False):
|
||||
inputs = self.porj_linear(inputs)
|
||||
B, N, C = inputs.shape
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, inputs), dim=1)
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
# add hidden and atten state
|
||||
block_outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if require_feat:
|
||||
block_outs.append(x)
|
||||
|
||||
if self.output_cls_token:
|
||||
if require_two:
|
||||
x = x[:, :2]
|
||||
else:
|
||||
x = x[:, 0]
|
||||
elif not self.output_cls_token and self.with_cls_token:
|
||||
x = x[:, 1:]
|
||||
|
||||
if require_feat:
|
||||
return x, block_outs
|
||||
else:
|
||||
return x
|
||||
|
||||
def train(self, mode=True):
|
||||
super(FusionTransformer, self).train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.LayerNorm):
|
||||
m.eval()
|
||||
|
||||
if __name__ == '__main__':
|
||||
fusion_transformer = FusionTransformer()
|
||||
print(fusion_transformer)
|
||||
148
finetune/mmseg/models/necks/ic_neck.py
Normal file
148
finetune/mmseg/models/necks/ic_neck.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class CascadeFeatureFusion(BaseModule):
|
||||
"""Cascade Feature Fusion Unit in ICNet.
|
||||
|
||||
Args:
|
||||
low_channels (int): The number of input channels for
|
||||
low resolution feature map.
|
||||
high_channels (int): The number of input channels for
|
||||
high resolution feature map.
|
||||
out_channels (int): The number of output channels.
|
||||
conv_cfg (dict): Dictionary to construct and config conv layer.
|
||||
Default: None.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Dictionary to construct and config act layer.
|
||||
Default: dict(type='ReLU').
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
x (Tensor): The output tensor of shape (N, out_channels, H, W).
|
||||
x_low (Tensor): The output tensor of shape (N, out_channels, H, W)
|
||||
for Cascade Label Guidance in auxiliary heads.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
low_channels,
|
||||
high_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.align_corners = align_corners
|
||||
self.conv_low = ConvModule(
|
||||
low_channels,
|
||||
out_channels,
|
||||
3,
|
||||
padding=2,
|
||||
dilation=2,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.conv_high = ConvModule(
|
||||
high_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x_low, x_high):
|
||||
x_low = resize(
|
||||
x_low,
|
||||
size=x_high.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
# Note: Different from original paper, `x_low` is underwent
|
||||
# `self.conv_low` rather than another 1x1 conv classifier
|
||||
# before being used for auxiliary head.
|
||||
x_low = self.conv_low(x_low)
|
||||
x_high = self.conv_high(x_high)
|
||||
x = x_low + x_high
|
||||
x = F.relu(x, inplace=True)
|
||||
return x, x_low
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ICNeck(BaseModule):
|
||||
"""ICNet for Real-Time Semantic Segmentation on High-Resolution Images.
|
||||
|
||||
This head is the implementation of `ICHead
|
||||
<https://arxiv.org/abs/1704.08545>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input image channels. Default: 3.
|
||||
out_channels (int): The numbers of output feature channels.
|
||||
Default: 128.
|
||||
conv_cfg (dict): Dictionary to construct and config conv layer.
|
||||
Default: None.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Dictionary to construct and config act layer.
|
||||
Default: dict(type='ReLU').
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=(64, 256, 256),
|
||||
out_channels=128,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(in_channels) == 3, 'Length of input channels \
|
||||
must be 3!'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.align_corners = align_corners
|
||||
self.cff_24 = CascadeFeatureFusion(
|
||||
self.in_channels[2],
|
||||
self.in_channels[1],
|
||||
self.out_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
|
||||
self.cff_12 = CascadeFeatureFusion(
|
||||
self.out_channels,
|
||||
self.in_channels[0],
|
||||
self.out_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == 3, 'Length of input feature \
|
||||
maps must be 3!'
|
||||
|
||||
x_sub1, x_sub2, x_sub4 = inputs
|
||||
x_cff_24, x_24 = self.cff_24(x_sub4, x_sub2)
|
||||
x_cff_12, x_12 = self.cff_12(x_cff_24, x_sub1)
|
||||
# Note: `x_cff_12` is used for decode_head,
|
||||
# `x_24` and `x_12` are used for auxiliary head.
|
||||
return x_24, x_12, x_cff_12
|
||||
131
finetune/mmseg/models/necks/jpu.py
Normal file
131
finetune/mmseg/models/necks/jpu.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class JPU(BaseModule):
|
||||
"""FastFCN: Rethinking Dilated Convolution in the Backbone
|
||||
for Semantic Segmentation.
|
||||
|
||||
This Joint Pyramid Upsampling (JPU) neck is the implementation of
|
||||
`FastFCN <https://arxiv.org/abs/1903.11816>`_.
|
||||
|
||||
Args:
|
||||
in_channels (Tuple[int], optional): The number of input channels
|
||||
for each convolution operations before upsampling.
|
||||
Default: (512, 1024, 2048).
|
||||
mid_channels (int): The number of output channels of JPU.
|
||||
Default: 512.
|
||||
start_level (int): Index of the start input backbone level used to
|
||||
build the feature pyramid. Default: 0.
|
||||
end_level (int): Index of the end input backbone level (exclusive) to
|
||||
build the feature pyramid. Default: -1, which means the last level.
|
||||
dilations (tuple[int]): Dilation rate of each Depthwise
|
||||
Separable ConvModule. Default: (1, 2, 4, 8).
|
||||
align_corners (bool, optional): The align_corners argument of
|
||||
resize operation. Default: False.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=(512, 1024, 2048),
|
||||
mid_channels=512,
|
||||
start_level=0,
|
||||
end_level=-1,
|
||||
dilations=(1, 2, 4, 8),
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert isinstance(in_channels, tuple)
|
||||
assert isinstance(dilations, tuple)
|
||||
self.in_channels = in_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.start_level = start_level
|
||||
self.num_ins = len(in_channels)
|
||||
if end_level == -1:
|
||||
self.backbone_end_level = self.num_ins
|
||||
else:
|
||||
self.backbone_end_level = end_level
|
||||
assert end_level <= len(in_channels)
|
||||
|
||||
self.dilations = dilations
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.dilation_layers = nn.ModuleList()
|
||||
for i in range(self.start_level, self.backbone_end_level):
|
||||
conv_layer = nn.Sequential(
|
||||
ConvModule(
|
||||
self.in_channels[i],
|
||||
self.mid_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.conv_layers.append(conv_layer)
|
||||
for i in range(len(dilations)):
|
||||
dilation_layer = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=(self.backbone_end_level - self.start_level) *
|
||||
self.mid_channels,
|
||||
out_channels=self.mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=dilations[i],
|
||||
dilation=dilations[i],
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=norm_cfg,
|
||||
pw_act_cfg=act_cfg))
|
||||
self.dilation_layers.append(dilation_layer)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
assert len(inputs) == len(self.in_channels), 'Length of inputs must \
|
||||
be the same with self.in_channels!'
|
||||
|
||||
feats = [
|
||||
self.conv_layers[i - self.start_level](inputs[i])
|
||||
for i in range(self.start_level, self.backbone_end_level)
|
||||
]
|
||||
|
||||
h, w = feats[0].shape[2:]
|
||||
for i in range(1, len(feats)):
|
||||
feats[i] = resize(
|
||||
feats[i],
|
||||
size=(h, w),
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
feat = torch.cat(feats, dim=1)
|
||||
concat_feat = torch.cat([
|
||||
self.dilation_layers[i](feat) for i in range(len(self.dilations))
|
||||
],
|
||||
dim=1)
|
||||
|
||||
outs = []
|
||||
|
||||
# Default: outs[2] is the output of JPU for decoder head, outs[1] is
|
||||
# the feature map from backbone for auxiliary head. Additionally,
|
||||
# outs[0] can also be used for auxiliary head.
|
||||
for i in range(self.start_level, self.backbone_end_level - 1):
|
||||
outs.append(inputs[i])
|
||||
outs.append(concat_feat)
|
||||
return tuple(outs)
|
||||
118
finetune/mmseg/models/necks/mla_neck.py
Normal file
118
finetune/mmseg/models/necks/mla_neck.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class MLAModule(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels=[1024, 1024, 1024, 1024],
|
||||
out_channels=256,
|
||||
norm_cfg=None,
|
||||
act_cfg=None):
|
||||
super().__init__()
|
||||
self.channel_proj = nn.ModuleList()
|
||||
for i in range(len(in_channels)):
|
||||
self.channel_proj.append(
|
||||
ConvModule(
|
||||
in_channels=in_channels[i],
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.feat_extract = nn.ModuleList()
|
||||
for i in range(len(in_channels)):
|
||||
self.feat_extract.append(
|
||||
ConvModule(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
# feat_list -> [p2, p3, p4, p5]
|
||||
feat_list = []
|
||||
for x, conv in zip(inputs, self.channel_proj):
|
||||
feat_list.append(conv(x))
|
||||
|
||||
# feat_list -> [p5, p4, p3, p2]
|
||||
# mid_list -> [m5, m4, m3, m2]
|
||||
feat_list = feat_list[::-1]
|
||||
mid_list = []
|
||||
for feat in feat_list:
|
||||
if len(mid_list) == 0:
|
||||
mid_list.append(feat)
|
||||
else:
|
||||
mid_list.append(mid_list[-1] + feat)
|
||||
|
||||
# mid_list -> [m5, m4, m3, m2]
|
||||
# out_list -> [o2, o3, o4, o5]
|
||||
out_list = []
|
||||
for mid, conv in zip(mid_list, self.feat_extract):
|
||||
out_list.append(conv(mid))
|
||||
|
||||
return tuple(out_list)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MLANeck(nn.Module):
|
||||
"""Multi-level Feature Aggregation.
|
||||
|
||||
This neck is `The Multi-level Feature Aggregation construction of
|
||||
SETR <https://arxiv.org/abs/2012.15840>`_.
|
||||
|
||||
|
||||
Args:
|
||||
in_channels (List[int]): Number of input channels per scale.
|
||||
out_channels (int): Number of output channels (used at each scale).
|
||||
norm_layer (dict): Config dict for input normalization.
|
||||
Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True).
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
act_cfg (dict): Config dict for activation layer in ConvModule.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
norm_layer=dict(type='LN', eps=1e-6, requires_grad=True),
|
||||
norm_cfg=None,
|
||||
act_cfg=None):
|
||||
super().__init__()
|
||||
assert isinstance(in_channels, list)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
# In order to build general vision transformer backbone, we have to
|
||||
# move MLA to neck.
|
||||
self.norm = nn.ModuleList([
|
||||
build_norm_layer(norm_layer, in_channels[i])[1]
|
||||
for i in range(len(in_channels))
|
||||
])
|
||||
|
||||
self.mla = MLAModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
|
||||
# Convert from nchw to nlc
|
||||
outs = []
|
||||
for i in range(len(inputs)):
|
||||
x = inputs[i]
|
||||
n, c, h, w = x.shape
|
||||
x = x.reshape(n, c, h * w).transpose(2, 1).contiguous()
|
||||
x = self.norm[i](x)
|
||||
x = x.transpose(1, 2).reshape(n, c, h, w).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
outs = self.mla(outs)
|
||||
return tuple(outs)
|
||||
79
finetune/mmseg/models/necks/multilevel_neck.py
Normal file
79
finetune/mmseg/models/necks/multilevel_neck.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model.weight_init import xavier_init
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MultiLevelNeck(nn.Module):
|
||||
"""MultiLevelNeck.
|
||||
|
||||
A neck structure connect vit backbone and decoder_heads.
|
||||
|
||||
Args:
|
||||
in_channels (List[int]): Number of input channels per scale.
|
||||
out_channels (int): Number of output channels (used at each scale).
|
||||
scales (List[float]): Scale factors for each input feature map.
|
||||
Default: [0.5, 1, 2, 4]
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
act_cfg (dict): Config dict for activation layer in ConvModule.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
scales=[0.5, 1, 2, 4],
|
||||
norm_cfg=None,
|
||||
act_cfg=None):
|
||||
super().__init__()
|
||||
assert isinstance(in_channels, list)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.scales = scales
|
||||
self.num_outs = len(scales)
|
||||
self.lateral_convs = nn.ModuleList()
|
||||
self.convs = nn.ModuleList()
|
||||
for in_channel in in_channels:
|
||||
self.lateral_convs.append(
|
||||
ConvModule(
|
||||
in_channel,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
for _ in range(self.num_outs):
|
||||
self.convs.append(
|
||||
ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
# default init_weights for conv(msra) and norm in ConvModule
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
xavier_init(m, distribution='uniform')
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
inputs = [
|
||||
lateral_conv(inputs[i])
|
||||
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||
]
|
||||
# for len(inputs) not equal to self.num_outs
|
||||
if len(inputs) == 1:
|
||||
inputs = [inputs[0] for _ in range(self.num_outs)]
|
||||
outs = []
|
||||
for i in range(self.num_outs):
|
||||
x_resize = resize(
|
||||
inputs[i], scale_factor=self.scales[i], mode='bilinear')
|
||||
outs.append(self.convs[i](x_resize))
|
||||
return tuple(outs)
|
||||
12
finetune/mmseg/models/segmentors/__init__.py
Normal file
12
finetune/mmseg/models/segmentors/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base import BaseSegmentor
|
||||
from .cascade_encoder_decoder import CascadeEncoderDecoder
|
||||
from .depth_estimator import DepthEstimator
|
||||
from .encoder_decoder import EncoderDecoder
|
||||
from .multimodal_encoder_decoder import MultimodalEncoderDecoder
|
||||
from .seg_tta import SegTTAModel
|
||||
|
||||
__all__ = [
|
||||
'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel',
|
||||
'MultimodalEncoderDecoder', 'DepthEstimator'
|
||||
]
|
||||
200
finetune/mmseg/models/segmentors/base.py
Normal file
200
finetune/mmseg/models/segmentors/base.py
Normal file
@@ -0,0 +1,200 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.structures import PixelData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import (ForwardResults, OptConfigType, OptMultiConfig,
|
||||
OptSampleList, SampleList)
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
||||
"""Base class for segmentors.
|
||||
|
||||
Args:
|
||||
data_preprocessor (dict, optional): Model preprocessing config
|
||||
for processing the input data. it usually includes
|
||||
``to_rgb``, ``pad_size_divisor``, ``pad_val``,
|
||||
``mean`` and ``std``. Default to None.
|
||||
init_cfg (dict, optional): the config to control the
|
||||
initialization. Default to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_preprocessor: OptConfigType = None,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
super().__init__(
|
||||
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
||||
|
||||
@property
|
||||
def with_neck(self) -> bool:
|
||||
"""bool: whether the segmentor has neck"""
|
||||
return hasattr(self, 'neck') and self.neck is not None
|
||||
|
||||
@property
|
||||
def with_auxiliary_head(self) -> bool:
|
||||
"""bool: whether the segmentor has auxiliary head"""
|
||||
return hasattr(self,
|
||||
'auxiliary_head') and self.auxiliary_head is not None
|
||||
|
||||
@property
|
||||
def with_decode_head(self) -> bool:
|
||||
"""bool: whether the segmentor has decode head"""
|
||||
return hasattr(self, 'decode_head') and self.decode_head is not None
|
||||
|
||||
@abstractmethod
|
||||
def extract_feat(self, inputs: Tensor) -> bool:
|
||||
"""Placeholder for extract features from images."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def encode_decode(self, inputs: Tensor, batch_data_samples: SampleList):
|
||||
"""Placeholder for encode images with backbone and decode into a
|
||||
semantic segmentation map of the same size as input."""
|
||||
pass
|
||||
|
||||
def forward(self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None,
|
||||
mode: str = 'tensor') -> ForwardResults:
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
|
||||
The method should accept three modes: "tensor", "predict" and "loss":
|
||||
|
||||
- "tensor": Forward the whole network and return tensor or tuple of
|
||||
tensor without any post-processing, same as a common nn.Module.
|
||||
- "predict": Forward and return the predictions, which are fully
|
||||
processed to a list of :obj:`SegDataSample`.
|
||||
- "loss": Forward and return a dict of losses according to the given
|
||||
inputs and data samples.
|
||||
|
||||
Note that this method doesn't handle neither back propagation nor
|
||||
optimizer updating, which are done in the :meth:`train_step`.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor with shape (N, C, ...) in
|
||||
general.
|
||||
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
|
||||
It usually includes information such as `metainfo` and
|
||||
`gt_sem_seg`. Default to None.
|
||||
mode (str): Return what kind of value. Defaults to 'tensor'.
|
||||
|
||||
Returns:
|
||||
The return type depends on ``mode``.
|
||||
|
||||
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
|
||||
- If ``mode="predict"``, return a list of :obj:`DetDataSample`.
|
||||
- If ``mode="loss"``, return a dict of tensor.
|
||||
"""
|
||||
if mode == 'loss':
|
||||
return self.loss(inputs, data_samples)
|
||||
elif mode == 'predict':
|
||||
return self.predict(inputs, data_samples)
|
||||
elif mode == 'tensor':
|
||||
return self._forward(inputs, data_samples)
|
||||
else:
|
||||
raise RuntimeError(f'Invalid mode "{mode}". '
|
||||
'Only supports loss, predict and tensor mode')
|
||||
|
||||
@abstractmethod
|
||||
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
|
||||
"""Calculate losses from a batch of inputs and data samples."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict(self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None) -> SampleList:
|
||||
"""Predict results from a batch of inputs and data samples with post-
|
||||
processing."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _forward(self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
|
||||
"""Network forward process.
|
||||
|
||||
Usually includes backbone, neck and head forward without any post-
|
||||
processing.
|
||||
"""
|
||||
pass
|
||||
|
||||
def postprocess_result(self,
|
||||
seg_logits: Tensor,
|
||||
data_samples: OptSampleList = None) -> SampleList:
|
||||
""" Convert results list to `SegDataSample`.
|
||||
Args:
|
||||
seg_logits (Tensor): The segmentation results, seg_logits from
|
||||
model of each input image.
|
||||
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
|
||||
It usually includes information such as `metainfo` and
|
||||
`gt_sem_seg`. Default to None.
|
||||
Returns:
|
||||
list[:obj:`SegDataSample`]: Segmentation results of the
|
||||
input images. Each SegDataSample usually contain:
|
||||
|
||||
- ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
|
||||
- ``seg_logits``(PixelData): Predicted logits of semantic
|
||||
segmentation before normalization.
|
||||
"""
|
||||
batch_size, C, H, W = seg_logits.shape
|
||||
|
||||
if data_samples is None:
|
||||
data_samples = [SegDataSample() for _ in range(batch_size)]
|
||||
only_prediction = True
|
||||
else:
|
||||
only_prediction = False
|
||||
|
||||
for i in range(batch_size):
|
||||
if not only_prediction:
|
||||
img_meta = data_samples[i].metainfo
|
||||
# remove padding area
|
||||
if 'img_padding_size' not in img_meta:
|
||||
padding_size = img_meta.get('padding_size', [0] * 4)
|
||||
else:
|
||||
padding_size = img_meta['img_padding_size']
|
||||
padding_left, padding_right, padding_top, padding_bottom =\
|
||||
padding_size
|
||||
# i_seg_logits shape is 1, C, H, W after remove padding
|
||||
i_seg_logits = seg_logits[i:i + 1, :,
|
||||
padding_top:H - padding_bottom,
|
||||
padding_left:W - padding_right]
|
||||
|
||||
flip = img_meta.get('flip', None)
|
||||
if flip:
|
||||
flip_direction = img_meta.get('flip_direction', None)
|
||||
assert flip_direction in ['horizontal', 'vertical']
|
||||
if flip_direction == 'horizontal':
|
||||
i_seg_logits = i_seg_logits.flip(dims=(3, ))
|
||||
else:
|
||||
i_seg_logits = i_seg_logits.flip(dims=(2, ))
|
||||
|
||||
# resize as original shape
|
||||
i_seg_logits = resize(
|
||||
i_seg_logits,
|
||||
size=img_meta['ori_shape'],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners,
|
||||
warning=False).squeeze(0)
|
||||
else:
|
||||
i_seg_logits = seg_logits[i]
|
||||
|
||||
if C > 1:
|
||||
i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True)
|
||||
else:
|
||||
i_seg_logits = i_seg_logits.sigmoid()
|
||||
i_seg_pred = (i_seg_logits >
|
||||
self.decode_head.threshold).to(i_seg_logits)
|
||||
data_samples[i].set_data({
|
||||
'seg_logits':
|
||||
PixelData(**{'data': i_seg_logits}),
|
||||
'pred_sem_seg':
|
||||
PixelData(**{'data': i_seg_pred})
|
||||
})
|
||||
|
||||
return data_samples
|
||||
138
finetune/mmseg/models/segmentors/cascade_encoder_decoder.py
Normal file
138
finetune/mmseg/models/segmentors/cascade_encoder_decoder.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig,
|
||||
OptSampleList, SampleList, add_prefix)
|
||||
from .encoder_decoder import EncoderDecoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CascadeEncoderDecoder(EncoderDecoder):
|
||||
"""Cascade Encoder Decoder segmentors.
|
||||
|
||||
CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of
|
||||
CascadeEncoderDecoder are cascaded. The output of previous decoder_head
|
||||
will be the input of next decoder_head.
|
||||
|
||||
Args:
|
||||
|
||||
num_stages (int): How many stages will be cascaded.
|
||||
backbone (ConfigType): The config for the backnone of segmentor.
|
||||
decode_head (ConfigType): The config for the decode head of segmentor.
|
||||
neck (OptConfigType): The config for the neck of segmentor.
|
||||
Defaults to None.
|
||||
auxiliary_head (OptConfigType): The config for the auxiliary head of
|
||||
segmentor. Defaults to None.
|
||||
train_cfg (OptConfigType): The config for training. Defaults to None.
|
||||
test_cfg (OptConfigType): The config for testing. Defaults to None.
|
||||
data_preprocessor (dict, optional): The pre-process config of
|
||||
:class:`BaseDataPreprocessor`.
|
||||
pretrained (str, optional): The path for pretrained model.
|
||||
Defaults to None.
|
||||
init_cfg (dict, optional): The weight initialized config for
|
||||
:class:`BaseModule`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_stages: int,
|
||||
backbone: ConfigType,
|
||||
decode_head: ConfigType,
|
||||
neck: OptConfigType = None,
|
||||
auxiliary_head: OptConfigType = None,
|
||||
train_cfg: OptConfigType = None,
|
||||
test_cfg: OptConfigType = None,
|
||||
data_preprocessor: OptConfigType = None,
|
||||
pretrained: Optional[str] = None,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
self.num_stages = num_stages
|
||||
super().__init__(
|
||||
backbone=backbone,
|
||||
decode_head=decode_head,
|
||||
neck=neck,
|
||||
auxiliary_head=auxiliary_head,
|
||||
train_cfg=train_cfg,
|
||||
test_cfg=test_cfg,
|
||||
data_preprocessor=data_preprocessor,
|
||||
pretrained=pretrained,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
def _init_decode_head(self, decode_head: ConfigType) -> None:
|
||||
"""Initialize ``decode_head``"""
|
||||
assert isinstance(decode_head, list)
|
||||
assert len(decode_head) == self.num_stages
|
||||
self.decode_head = nn.ModuleList()
|
||||
for i in range(self.num_stages):
|
||||
self.decode_head.append(MODELS.build(decode_head[i]))
|
||||
self.align_corners = self.decode_head[-1].align_corners
|
||||
self.num_classes = self.decode_head[-1].num_classes
|
||||
self.out_channels = self.decode_head[-1].out_channels
|
||||
|
||||
def encode_decode(self, inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Encode images with backbone and decode into a semantic segmentation
|
||||
map of the same size as input."""
|
||||
x = self.extract_feat(inputs)
|
||||
out = self.decode_head[0].forward(x)
|
||||
for i in range(1, self.num_stages - 1):
|
||||
out = self.decode_head[i].forward(x, out)
|
||||
seg_logits_list = self.decode_head[-1].predict(x, out, batch_img_metas,
|
||||
self.test_cfg)
|
||||
|
||||
return seg_logits_list
|
||||
|
||||
def _decode_head_forward_train(self, inputs: Tensor,
|
||||
data_samples: SampleList) -> dict:
|
||||
"""Run forward function and calculate loss for decode head in
|
||||
training."""
|
||||
losses = dict()
|
||||
|
||||
loss_decode = self.decode_head[0].loss(inputs, data_samples,
|
||||
self.train_cfg)
|
||||
|
||||
losses.update(add_prefix(loss_decode, 'decode_0'))
|
||||
# get batch_img_metas
|
||||
batch_size = len(data_samples)
|
||||
batch_img_metas = []
|
||||
for batch_index in range(batch_size):
|
||||
metainfo = data_samples[batch_index].metainfo
|
||||
batch_img_metas.append(metainfo)
|
||||
|
||||
for i in range(1, self.num_stages):
|
||||
# forward test again, maybe unnecessary for most methods.
|
||||
if i == 1:
|
||||
prev_outputs = self.decode_head[0].forward(inputs)
|
||||
else:
|
||||
prev_outputs = self.decode_head[i - 1].forward(
|
||||
inputs, prev_outputs)
|
||||
loss_decode = self.decode_head[i].loss(inputs, prev_outputs,
|
||||
data_samples,
|
||||
self.train_cfg)
|
||||
losses.update(add_prefix(loss_decode, f'decode_{i}'))
|
||||
|
||||
return losses
|
||||
|
||||
def _forward(self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None) -> Tensor:
|
||||
"""Network forward process.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
data_samples (List[:obj:`SegDataSample`]): The seg data samples.
|
||||
It usually includes information such as `metainfo` and
|
||||
`gt_semantic_seg`.
|
||||
|
||||
Returns:
|
||||
Tensor: Forward output of model without any post-processes.
|
||||
"""
|
||||
x = self.extract_feat(inputs)
|
||||
|
||||
out = self.decode_head[0].forward(x)
|
||||
for i in range(1, self.num_stages):
|
||||
# TODO support PointRend tensor mode
|
||||
out = self.decode_head[i].forward(x, out)
|
||||
|
||||
return out
|
||||
392
finetune/mmseg/models/segmentors/depth_estimator.py
Normal file
392
finetune/mmseg/models/segmentors/depth_estimator.py
Normal file
@@ -0,0 +1,392 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.structures import PixelData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig,
|
||||
OptSampleList, SampleList, add_prefix)
|
||||
from ..utils import resize
|
||||
from .encoder_decoder import EncoderDecoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DepthEstimator(EncoderDecoder):
|
||||
"""Encoder Decoder depth estimator.
|
||||
|
||||
EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
|
||||
Note that auxiliary_head is only used for deep supervision during training,
|
||||
which could be dumped during inference.
|
||||
|
||||
1. The ``loss`` method is used to calculate the loss of model,
|
||||
which includes two steps: (1) Extracts features to obtain the feature maps
|
||||
(2) Call the decode head loss function to forward decode head model and
|
||||
calculate losses.
|
||||
|
||||
.. code:: text
|
||||
|
||||
loss(): extract_feat() -> _decode_head_forward_train() -> _auxiliary_head_forward_train (optional)
|
||||
_decode_head_forward_train(): decode_head.loss()
|
||||
_auxiliary_head_forward_train(): auxiliary_head.loss (optional)
|
||||
|
||||
2. The ``predict`` method is used to predict depth estimation results,
|
||||
which includes two steps: (1) Run inference function to obtain the list of
|
||||
depth (2) Call post-processing function to obtain list of
|
||||
``SegDataSample`` including ``pred_depth_map``.
|
||||
|
||||
.. code:: text
|
||||
|
||||
predict(): inference() -> postprocess_result()
|
||||
inference(): whole_inference()/slide_inference()
|
||||
whole_inference()/slide_inference(): encoder_decoder()
|
||||
encoder_decoder(): extract_feat() -> decode_head.predict()
|
||||
|
||||
3. The ``_forward`` method is used to output the tensor by running the model,
|
||||
which includes two steps: (1) Extracts features to obtain the feature maps
|
||||
(2)Call the decode head forward function to forward decode head model.
|
||||
|
||||
.. code:: text
|
||||
|
||||
_forward(): extract_feat() -> _decode_head.forward()
|
||||
|
||||
Args:
|
||||
|
||||
backbone (ConfigType): The config for the backnone of depth estimator.
|
||||
decode_head (ConfigType): The config for the decode head of depth estimator.
|
||||
neck (OptConfigType): The config for the neck of depth estimator.
|
||||
Defaults to None.
|
||||
auxiliary_head (OptConfigType): The config for the auxiliary head of
|
||||
depth estimator. Defaults to None.
|
||||
train_cfg (OptConfigType): The config for training. Defaults to None.
|
||||
test_cfg (OptConfigType): The config for testing. Defaults to None.
|
||||
data_preprocessor (dict, optional): The pre-process config of
|
||||
:class:`BaseDataPreprocessor`.
|
||||
pretrained (str, optional): The path for pretrained model.
|
||||
Defaults to None.
|
||||
init_cfg (dict, optional): The weight initialized config for
|
||||
:class:`BaseModule`.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
backbone: ConfigType,
|
||||
decode_head: ConfigType,
|
||||
neck: OptConfigType = None,
|
||||
auxiliary_head: OptConfigType = None,
|
||||
train_cfg: OptConfigType = None,
|
||||
test_cfg: OptConfigType = None,
|
||||
data_preprocessor: OptConfigType = None,
|
||||
pretrained: Optional[str] = None,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
super().__init__(
|
||||
backbone=backbone,
|
||||
decode_head=decode_head,
|
||||
neck=neck,
|
||||
auxiliary_head=auxiliary_head,
|
||||
train_cfg=train_cfg,
|
||||
test_cfg=test_cfg,
|
||||
data_preprocessor=data_preprocessor,
|
||||
pretrained=pretrained,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
def extract_feat(self,
|
||||
inputs: Tensor,
|
||||
batch_img_metas: Optional[List[dict]] = None) -> Tensor:
|
||||
"""Extract features from images."""
|
||||
|
||||
if getattr(self.backbone, 'class_embed_select', False) and \
|
||||
isinstance(batch_img_metas, list) and \
|
||||
'category_id' in batch_img_metas[0]:
|
||||
cat_ids = [meta['category_id'] for meta in batch_img_metas]
|
||||
cat_ids = torch.tensor(cat_ids).to(inputs.device)
|
||||
inputs = (inputs, cat_ids)
|
||||
|
||||
x = self.backbone(inputs)
|
||||
if self.with_neck:
|
||||
x = self.neck(x)
|
||||
return x
|
||||
|
||||
def encode_decode(self, inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Encode images with backbone and decode into a depth map of the same
|
||||
size as input."""
|
||||
x = self.extract_feat(inputs, batch_img_metas)
|
||||
depth = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
|
||||
|
||||
return depth
|
||||
|
||||
def _decode_head_forward_train(self, inputs: List[Tensor],
|
||||
data_samples: SampleList) -> dict:
|
||||
"""Run forward function and calculate loss for decode head in
|
||||
training."""
|
||||
losses = dict()
|
||||
loss_decode = self.decode_head.loss(inputs, data_samples,
|
||||
self.train_cfg)
|
||||
|
||||
losses.update(add_prefix(loss_decode, 'decode'))
|
||||
return losses
|
||||
|
||||
def _auxiliary_head_forward_train(self, inputs: List[Tensor],
|
||||
data_samples: SampleList) -> dict:
|
||||
"""Run forward function and calculate loss for auxiliary head in
|
||||
training."""
|
||||
losses = dict()
|
||||
if isinstance(self.auxiliary_head, nn.ModuleList):
|
||||
for idx, aux_head in enumerate(self.auxiliary_head):
|
||||
loss_aux = aux_head.loss(inputs, data_samples, self.train_cfg)
|
||||
losses.update(add_prefix(loss_aux, f'aux_{idx}'))
|
||||
else:
|
||||
loss_aux = self.auxiliary_head.loss(inputs, data_samples,
|
||||
self.train_cfg)
|
||||
losses.update(add_prefix(loss_aux, 'aux'))
|
||||
|
||||
return losses
|
||||
|
||||
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
|
||||
"""Calculate losses from a batch of inputs and data samples.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): Input images.
|
||||
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
|
||||
It usually includes information such as `metainfo` and
|
||||
`gt_depth_map`.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
if data_samples is not None:
|
||||
batch_img_metas = [
|
||||
data_sample.metainfo for data_sample in data_samples
|
||||
]
|
||||
else:
|
||||
batch_img_metas = [
|
||||
dict(
|
||||
ori_shape=inputs.shape[2:],
|
||||
img_shape=inputs.shape[2:],
|
||||
pad_shape=inputs.shape[2:],
|
||||
padding_size=[0, 0, 0, 0])
|
||||
] * inputs.shape[0]
|
||||
|
||||
x = self.extract_feat(inputs, batch_img_metas)
|
||||
|
||||
losses = dict()
|
||||
|
||||
loss_decode = self._decode_head_forward_train(x, data_samples)
|
||||
losses.update(loss_decode)
|
||||
|
||||
if self.with_auxiliary_head:
|
||||
loss_aux = self._auxiliary_head_forward_train(x, data_samples)
|
||||
losses.update(loss_aux)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None) -> SampleList:
|
||||
"""Predict results from a batch of inputs and data samples with post-
|
||||
processing.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
data_samples (List[:obj:`SegDataSample`], optional): The seg data
|
||||
samples. It usually includes information such as `metainfo`
|
||||
and `gt_depth_map`.
|
||||
|
||||
Returns:
|
||||
list[:obj:`SegDataSample`]: Depth estimation results of the
|
||||
input images. Each SegDataSample usually contain:
|
||||
|
||||
- ``pred_depth_max``(PixelData): Prediction of depth estimation.
|
||||
"""
|
||||
if data_samples is not None:
|
||||
batch_img_metas = [
|
||||
data_sample.metainfo for data_sample in data_samples
|
||||
]
|
||||
else:
|
||||
batch_img_metas = [
|
||||
dict(
|
||||
ori_shape=inputs.shape[2:],
|
||||
img_shape=inputs.shape[2:],
|
||||
pad_shape=inputs.shape[2:],
|
||||
padding_size=[0, 0, 0, 0])
|
||||
] * inputs.shape[0]
|
||||
|
||||
depth = self.inference(inputs, batch_img_metas)
|
||||
|
||||
return self.postprocess_result(depth, data_samples)
|
||||
|
||||
def _forward(self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None) -> Tensor:
|
||||
"""Network forward process.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_depth_map`.
|
||||
|
||||
Returns:
|
||||
Tensor: Forward output of model without any post-processes.
|
||||
"""
|
||||
x = self.extract_feat(inputs)
|
||||
return self.decode_head.forward(x)
|
||||
|
||||
def slide_flip_inference(self, inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Inference by sliding-window with overlap and flip.
|
||||
|
||||
If h_crop > h_img or w_crop > w_img, the small patch will be used to
|
||||
decode without padding.
|
||||
|
||||
Args:
|
||||
inputs (tensor): the tensor should have a shape NxCxHxW,
|
||||
which contains all images in the batch.
|
||||
batch_img_metas (List[dict]): List of image metainfo where each may
|
||||
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
|
||||
Returns:
|
||||
Tensor: The depth estimation results.
|
||||
"""
|
||||
|
||||
h_stride, w_stride = self.test_cfg.stride
|
||||
h_crop, w_crop = self.test_cfg.crop_size
|
||||
batch_size, _, h_img, w_img = inputs.size()
|
||||
out_channels = self.out_channels
|
||||
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
||||
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
||||
preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
|
||||
count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
|
||||
for h_idx in range(h_grids):
|
||||
for w_idx in range(w_grids):
|
||||
y1 = h_idx * h_stride
|
||||
x1 = w_idx * w_stride
|
||||
y2 = min(y1 + h_crop, h_img)
|
||||
x2 = min(x1 + w_crop, w_img)
|
||||
y1 = max(y2 - h_crop, 0)
|
||||
x1 = max(x2 - w_crop, 0)
|
||||
crop_img = inputs[:, :, y1:y2, x1:x2]
|
||||
# change the image shape to patch shape
|
||||
batch_img_metas[0]['img_shape'] = crop_img.shape[2:]
|
||||
# the output of encode_decode is depth tensor map
|
||||
# with shape [N, C, H, W]
|
||||
crop_depth_map = self.encode_decode(crop_img, batch_img_metas)
|
||||
|
||||
# average out the original and flipped prediction
|
||||
crop_depth_map_flip = self.encode_decode(
|
||||
crop_img.flip(dims=(3, )), batch_img_metas)
|
||||
crop_depth_map_flip = crop_depth_map_flip.flip(dims=(3, ))
|
||||
crop_depth_map = (crop_depth_map + crop_depth_map_flip) / 2.0
|
||||
|
||||
preds += F.pad(crop_depth_map,
|
||||
(int(x1), int(preds.shape[3] - x2), int(y1),
|
||||
int(preds.shape[2] - y2)))
|
||||
|
||||
count_mat[:, :, y1:y2, x1:x2] += 1
|
||||
assert (count_mat == 0).sum() == 0
|
||||
depth = preds / count_mat
|
||||
|
||||
return depth
|
||||
|
||||
def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Inference with slide/whole style.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input image of shape (N, 3, H, W).
|
||||
batch_img_metas (List[dict]): List of image metainfo where each may
|
||||
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', 'pad_shape', and 'padding_size'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
|
||||
Returns:
|
||||
Tensor: The depth estimation results.
|
||||
"""
|
||||
assert self.test_cfg.get('mode', 'whole') in ['slide', 'whole',
|
||||
'slide_flip'], \
|
||||
f'Only "slide", "slide_flip" or "whole" test mode are ' \
|
||||
f'supported, but got {self.test_cfg["mode"]}.'
|
||||
ori_shape = batch_img_metas[0]['ori_shape']
|
||||
if not all(_['ori_shape'] == ori_shape for _ in batch_img_metas):
|
||||
print_log(
|
||||
'Image shapes are different in the batch.',
|
||||
logger='current',
|
||||
level=logging.WARN)
|
||||
if self.test_cfg.mode == 'slide':
|
||||
depth_map = self.slide_inference(inputs, batch_img_metas)
|
||||
if self.test_cfg.mode == 'slide_flip':
|
||||
depth_map = self.slide_flip_inference(inputs, batch_img_metas)
|
||||
else:
|
||||
depth_map = self.whole_inference(inputs, batch_img_metas)
|
||||
|
||||
return depth_map
|
||||
|
||||
def postprocess_result(self,
|
||||
depth: Tensor,
|
||||
data_samples: OptSampleList = None) -> SampleList:
|
||||
""" Convert results list to `SegDataSample`.
|
||||
Args:
|
||||
depth (Tensor): The depth estimation results.
|
||||
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
|
||||
It usually includes information such as `metainfo` and
|
||||
`gt_depth_map`. Default to None.
|
||||
Returns:
|
||||
list[:obj:`SegDataSample`]: Depth estomation results of the
|
||||
input images. Each SegDataSample usually contain:
|
||||
|
||||
- ``pred_depth_map``(PixelData): Prediction of depth estimation.
|
||||
"""
|
||||
batch_size, C, H, W = depth.shape
|
||||
|
||||
if data_samples is None:
|
||||
data_samples = [SegDataSample() for _ in range(batch_size)]
|
||||
only_prediction = True
|
||||
else:
|
||||
only_prediction = False
|
||||
|
||||
for i in range(batch_size):
|
||||
if not only_prediction:
|
||||
img_meta = data_samples[i].metainfo
|
||||
# remove padding area
|
||||
if 'img_padding_size' not in img_meta:
|
||||
padding_size = img_meta.get('padding_size', [0] * 4)
|
||||
else:
|
||||
padding_size = img_meta['img_padding_size']
|
||||
padding_left, padding_right, padding_top, padding_bottom =\
|
||||
padding_size
|
||||
# i_depth shape is 1, C, H, W after remove padding
|
||||
i_depth = depth[i:i + 1, :, padding_top:H - padding_bottom,
|
||||
padding_left:W - padding_right]
|
||||
|
||||
flip = img_meta.get('flip', None)
|
||||
if flip:
|
||||
flip_direction = img_meta.get('flip_direction', None)
|
||||
assert flip_direction in ['horizontal', 'vertical']
|
||||
if flip_direction == 'horizontal':
|
||||
i_depth = i_depth.flip(dims=(3, ))
|
||||
else:
|
||||
i_depth = i_depth.flip(dims=(2, ))
|
||||
|
||||
# resize as original shape
|
||||
i_depth = resize(
|
||||
i_depth,
|
||||
size=img_meta['ori_shape'],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners,
|
||||
warning=False).squeeze(0)
|
||||
else:
|
||||
i_depth = depth[i]
|
||||
|
||||
data_samples[i].set_data(
|
||||
{'pred_depth_map': PixelData(**{'data': i_depth})})
|
||||
|
||||
return data_samples
|
||||
364
finetune/mmseg/models/segmentors/encoder_decoder.py
Normal file
364
finetune/mmseg/models/segmentors/encoder_decoder.py
Normal file
@@ -0,0 +1,364 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.logging import print_log
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig,
|
||||
OptSampleList, SampleList, add_prefix)
|
||||
from .base import BaseSegmentor
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class EncoderDecoder(BaseSegmentor):
|
||||
"""Encoder Decoder segmentors.
|
||||
|
||||
EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
|
||||
Note that auxiliary_head is only used for deep supervision during training,
|
||||
which could be dumped during inference.
|
||||
|
||||
1. The ``loss`` method is used to calculate the loss of model,
|
||||
which includes two steps: (1) Extracts features to obtain the feature maps
|
||||
(2) Call the decode head loss function to forward decode head model and
|
||||
calculate losses.
|
||||
|
||||
.. code:: text
|
||||
|
||||
loss(): extract_feat() -> _decode_head_forward_train() -> _auxiliary_head_forward_train (optional)
|
||||
_decode_head_forward_train(): decode_head.loss()
|
||||
_auxiliary_head_forward_train(): auxiliary_head.loss (optional)
|
||||
|
||||
2. The ``predict`` method is used to predict segmentation results,
|
||||
which includes two steps: (1) Run inference function to obtain the list of
|
||||
seg_logits (2) Call post-processing function to obtain list of
|
||||
``SegDataSample`` including ``pred_sem_seg`` and ``seg_logits``.
|
||||
|
||||
.. code:: text
|
||||
|
||||
predict(): inference() -> postprocess_result()
|
||||
infercen(): whole_inference()/slide_inference()
|
||||
whole_inference()/slide_inference(): encoder_decoder()
|
||||
encoder_decoder(): extract_feat() -> decode_head.predict()
|
||||
|
||||
3. The ``_forward`` method is used to output the tensor by running the model,
|
||||
which includes two steps: (1) Extracts features to obtain the feature maps
|
||||
(2)Call the decode head forward function to forward decode head model.
|
||||
|
||||
.. code:: text
|
||||
|
||||
_forward(): extract_feat() -> _decode_head.forward()
|
||||
|
||||
Args:
|
||||
|
||||
backbone (ConfigType): The config for the backnone of segmentor.
|
||||
decode_head (ConfigType): The config for the decode head of segmentor.
|
||||
neck (OptConfigType): The config for the neck of segmentor.
|
||||
Defaults to None.
|
||||
auxiliary_head (OptConfigType): The config for the auxiliary head of
|
||||
segmentor. Defaults to None.
|
||||
train_cfg (OptConfigType): The config for training. Defaults to None.
|
||||
test_cfg (OptConfigType): The config for testing. Defaults to None.
|
||||
data_preprocessor (dict, optional): The pre-process config of
|
||||
:class:`BaseDataPreprocessor`.
|
||||
pretrained (str, optional): The path for pretrained model.
|
||||
Defaults to None.
|
||||
init_cfg (dict, optional): The weight initialized config for
|
||||
:class:`BaseModule`.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
backbone: ConfigType,
|
||||
decode_head: ConfigType,
|
||||
neck: OptConfigType = None,
|
||||
auxiliary_head: OptConfigType = None,
|
||||
train_cfg: OptConfigType = None,
|
||||
test_cfg: OptConfigType = None,
|
||||
data_preprocessor: OptConfigType = None,
|
||||
pretrained: Optional[str] = None,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
super().__init__(
|
||||
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
||||
if pretrained is not None:
|
||||
assert backbone.get('pretrained') is None, \
|
||||
'both backbone and segmentor set pretrained weight'
|
||||
backbone.pretrained = pretrained
|
||||
self.backbone = MODELS.build(backbone)
|
||||
if neck is not None:
|
||||
self.neck = MODELS.build(neck)
|
||||
self._init_decode_head(decode_head)
|
||||
self._init_auxiliary_head(auxiliary_head)
|
||||
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
|
||||
assert self.with_decode_head
|
||||
|
||||
def _init_decode_head(self, decode_head: ConfigType) -> None:
|
||||
"""Initialize ``decode_head``"""
|
||||
self.decode_head = MODELS.build(decode_head)
|
||||
self.align_corners = self.decode_head.align_corners
|
||||
self.num_classes = self.decode_head.num_classes
|
||||
self.out_channels = self.decode_head.out_channels
|
||||
|
||||
def _init_auxiliary_head(self, auxiliary_head: ConfigType) -> None:
|
||||
"""Initialize ``auxiliary_head``"""
|
||||
if auxiliary_head is not None:
|
||||
if isinstance(auxiliary_head, list):
|
||||
self.auxiliary_head = nn.ModuleList()
|
||||
for head_cfg in auxiliary_head:
|
||||
self.auxiliary_head.append(MODELS.build(head_cfg))
|
||||
else:
|
||||
self.auxiliary_head = MODELS.build(auxiliary_head)
|
||||
|
||||
def extract_feat(self, inputs: Tensor) -> List[Tensor]:
|
||||
"""Extract features from images."""
|
||||
x = self.backbone(inputs)
|
||||
if self.with_neck:
|
||||
x = self.neck(x)
|
||||
return x
|
||||
|
||||
def encode_decode(self, inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Encode images with backbone and decode into a semantic segmentation
|
||||
map of the same size as input."""
|
||||
x = self.extract_feat(inputs)
|
||||
seg_logits = self.decode_head.predict(x, batch_img_metas,
|
||||
self.test_cfg)
|
||||
|
||||
return seg_logits
|
||||
|
||||
def _decode_head_forward_train(self, inputs: List[Tensor],
|
||||
data_samples: SampleList) -> dict:
|
||||
"""Run forward function and calculate loss for decode head in
|
||||
training."""
|
||||
losses = dict()
|
||||
loss_decode = self.decode_head.loss(inputs, data_samples,
|
||||
self.train_cfg)
|
||||
|
||||
losses.update(add_prefix(loss_decode, 'decode'))
|
||||
return losses
|
||||
|
||||
def _auxiliary_head_forward_train(self, inputs: List[Tensor],
|
||||
data_samples: SampleList) -> dict:
|
||||
"""Run forward function and calculate loss for auxiliary head in
|
||||
training."""
|
||||
losses = dict()
|
||||
if isinstance(self.auxiliary_head, nn.ModuleList):
|
||||
for idx, aux_head in enumerate(self.auxiliary_head):
|
||||
loss_aux = aux_head.loss(inputs, data_samples, self.train_cfg)
|
||||
losses.update(add_prefix(loss_aux, f'aux_{idx}'))
|
||||
else:
|
||||
loss_aux = self.auxiliary_head.loss(inputs, data_samples,
|
||||
self.train_cfg)
|
||||
losses.update(add_prefix(loss_aux, 'aux'))
|
||||
|
||||
return losses
|
||||
|
||||
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
|
||||
"""Calculate losses from a batch of inputs and data samples.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): Input images.
|
||||
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
|
||||
It usually includes information such as `metainfo` and
|
||||
`gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
|
||||
x = self.extract_feat(inputs)
|
||||
|
||||
losses = dict()
|
||||
|
||||
loss_decode = self._decode_head_forward_train(x, data_samples)
|
||||
losses.update(loss_decode)
|
||||
|
||||
if self.with_auxiliary_head:
|
||||
loss_aux = self._auxiliary_head_forward_train(x, data_samples)
|
||||
losses.update(loss_aux)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None) -> SampleList:
|
||||
"""Predict results from a batch of inputs and data samples with post-
|
||||
processing.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
data_samples (List[:obj:`SegDataSample`], optional): The seg data
|
||||
samples. It usually includes information such as `metainfo`
|
||||
and `gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
list[:obj:`SegDataSample`]: Segmentation results of the
|
||||
input images. Each SegDataSample usually contain:
|
||||
|
||||
- ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
|
||||
- ``seg_logits``(PixelData): Predicted logits of semantic
|
||||
segmentation before normalization.
|
||||
"""
|
||||
if data_samples is not None:
|
||||
batch_img_metas = [
|
||||
data_sample.metainfo for data_sample in data_samples
|
||||
]
|
||||
else:
|
||||
batch_img_metas = [
|
||||
dict(
|
||||
ori_shape=inputs.shape[2:],
|
||||
img_shape=inputs.shape[2:],
|
||||
pad_shape=inputs.shape[2:],
|
||||
padding_size=[0, 0, 0, 0])
|
||||
] * inputs.shape[0]
|
||||
|
||||
seg_logits = self.inference(inputs, batch_img_metas)
|
||||
|
||||
return self.postprocess_result(seg_logits, data_samples)
|
||||
|
||||
def _forward(self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None) -> Tensor:
|
||||
"""Network forward process.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
Tensor: Forward output of model without any post-processes.
|
||||
"""
|
||||
x = self.extract_feat(inputs)
|
||||
return self.decode_head.forward(x)
|
||||
|
||||
def slide_inference(self, inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Inference by sliding-window with overlap.
|
||||
|
||||
If h_crop > h_img or w_crop > w_img, the small patch will be used to
|
||||
decode without padding.
|
||||
|
||||
Args:
|
||||
inputs (tensor): the tensor should have a shape NxCxHxW,
|
||||
which contains all images in the batch.
|
||||
batch_img_metas (List[dict]): List of image metainfo where each may
|
||||
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
|
||||
Returns:
|
||||
Tensor: The segmentation results, seg_logits from model of each
|
||||
input image.
|
||||
"""
|
||||
|
||||
h_stride, w_stride = self.test_cfg.stride
|
||||
h_crop, w_crop = self.test_cfg.crop_size
|
||||
batch_size, _, h_img, w_img = inputs.size()
|
||||
out_channels = self.out_channels
|
||||
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
||||
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
||||
preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
|
||||
count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
|
||||
for h_idx in range(h_grids):
|
||||
for w_idx in range(w_grids):
|
||||
y1 = h_idx * h_stride
|
||||
x1 = w_idx * w_stride
|
||||
y2 = min(y1 + h_crop, h_img)
|
||||
x2 = min(x1 + w_crop, w_img)
|
||||
y1 = max(y2 - h_crop, 0)
|
||||
x1 = max(x2 - w_crop, 0)
|
||||
crop_img = inputs[:, :, y1:y2, x1:x2]
|
||||
# change the image shape to patch shape
|
||||
batch_img_metas[0]['img_shape'] = crop_img.shape[2:]
|
||||
# the output of encode_decode is seg logits tensor map
|
||||
# with shape [N, C, H, W]
|
||||
crop_seg_logit = self.encode_decode(crop_img, batch_img_metas)
|
||||
preds += F.pad(crop_seg_logit,
|
||||
(int(x1), int(preds.shape[3] - x2), int(y1),
|
||||
int(preds.shape[2] - y2)))
|
||||
|
||||
count_mat[:, :, y1:y2, x1:x2] += 1
|
||||
assert (count_mat == 0).sum() == 0
|
||||
seg_logits = preds / count_mat
|
||||
|
||||
return seg_logits
|
||||
|
||||
def whole_inference(self, inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Inference with full image.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The tensor should have a shape NxCxHxW, which
|
||||
contains all images in the batch.
|
||||
batch_img_metas (List[dict]): List of image metainfo where each may
|
||||
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
|
||||
Returns:
|
||||
Tensor: The segmentation results, seg_logits from model of each
|
||||
input image.
|
||||
"""
|
||||
|
||||
seg_logits = self.encode_decode(inputs, batch_img_metas)
|
||||
|
||||
return seg_logits
|
||||
|
||||
def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Inference with slide/whole style.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input image of shape (N, 3, H, W).
|
||||
batch_img_metas (List[dict]): List of image metainfo where each may
|
||||
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', 'pad_shape', and 'padding_size'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
|
||||
Returns:
|
||||
Tensor: The segmentation results, seg_logits from model of each
|
||||
input image.
|
||||
"""
|
||||
assert self.test_cfg.get('mode', 'whole') in ['slide', 'whole'], \
|
||||
f'Only "slide" or "whole" test mode are supported, but got ' \
|
||||
f'{self.test_cfg["mode"]}.'
|
||||
ori_shape = batch_img_metas[0]['ori_shape']
|
||||
if not all(_['ori_shape'] == ori_shape for _ in batch_img_metas):
|
||||
print_log(
|
||||
'Image shapes are different in the batch.',
|
||||
logger='current',
|
||||
level=logging.WARN)
|
||||
if self.test_cfg.mode == 'slide':
|
||||
seg_logit = self.slide_inference(inputs, batch_img_metas)
|
||||
else:
|
||||
seg_logit = self.whole_inference(inputs, batch_img_metas)
|
||||
|
||||
return seg_logit
|
||||
|
||||
def aug_test(self, inputs, batch_img_metas, rescale=True):
|
||||
"""Test with augmentations.
|
||||
|
||||
Only rescale=True is supported.
|
||||
"""
|
||||
# aug_test rescale all imgs back to ori_shape for now
|
||||
assert rescale
|
||||
# to save memory, we get augmented seg logit inplace
|
||||
seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale)
|
||||
for i in range(1, len(inputs)):
|
||||
cur_seg_logit = self.inference(inputs[i], batch_img_metas[i],
|
||||
rescale)
|
||||
seg_logit += cur_seg_logit
|
||||
seg_logit /= len(inputs)
|
||||
seg_pred = seg_logit.argmax(dim=1)
|
||||
# unravel batch dim
|
||||
seg_pred = list(seg_pred)
|
||||
return seg_pred
|
||||
350
finetune/mmseg/models/segmentors/multimodal_encoder_decoder.py
Normal file
350
finetune/mmseg/models/segmentors/multimodal_encoder_decoder.py
Normal file
@@ -0,0 +1,350 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig,
|
||||
OptSampleList, SampleList, add_prefix)
|
||||
from .base import BaseSegmentor
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MultimodalEncoderDecoder(BaseSegmentor):
|
||||
"""Multimodal Encoder-Decoder segmentors.
|
||||
|
||||
Multimodal segmentation architecture is used for open-vocabulary
|
||||
semantic segmentation with combining the visual and language
|
||||
pretrain models. It consists of a image_encoder (backbone) to extract
|
||||
visual feature, a text encoder to extract text feature, and a decode
|
||||
head to generate semantic maps.
|
||||
Note that the deep supervision during training is implemented in decode head.
|
||||
|
||||
1. The ``loss`` method is used to calculate the loss of model,
|
||||
which includes two steps: (1) Extracts features to obtain the feature maps
|
||||
(2) Call the decode head loss function to forward decode head model and
|
||||
calculate losses.
|
||||
|
||||
.. code:: text
|
||||
|
||||
loss(): extract_feat() -> _decode_head_forward_train()
|
||||
_decode_head_forward_train(): decode_head.loss()
|
||||
|
||||
2. The ``predict`` method is used to predict segmentation results,
|
||||
which includes two steps: (1) Run inference function to obtain the list of
|
||||
seg_logits (2) Call post-processing function to obtain list of
|
||||
``SegDataSampel`` including ``pred_sem_seg`` and ``seg_logits``.
|
||||
|
||||
.. code:: text
|
||||
|
||||
predict(): inference() -> postprocess_result()
|
||||
inference(): whole_inference()/slide_inference()
|
||||
whole_inference()/slide_inference(): encoder_decoder()
|
||||
encoder_decoder(): extract_feat() -> decode_head.predict()
|
||||
|
||||
3. The ``_forward`` method is used to output the tensor by running the model,
|
||||
which includes two steps: (1) Extracts features to obtain the feature maps
|
||||
(2)Call the decode head forward function to forward decode head model.
|
||||
|
||||
.. code:: text
|
||||
|
||||
_forward(): extract_feat() -> _decode_head.forward()
|
||||
|
||||
Args:
|
||||
|
||||
image_encoder (ConfigType): The config for the visual encoder of segmentor.
|
||||
text_encoder ((ConfigType): The config for the text encoder of segmentor.
|
||||
decode_head (ConfigType): The config for the decode head of segmentor.
|
||||
train_cfg (OptConfigType): The config for training. Defaults to None.
|
||||
test_cfg (OptConfigType): The config for testing. Defaults to None.
|
||||
data_preprocessor (dict, optional): The pre-process config of
|
||||
:class:`BaseDataPreprocessor`.
|
||||
pretrained (str, optional): The path for pretrained model.
|
||||
Defaults to None.
|
||||
asymetric_input (bool): whether to use different size of input for image encoder
|
||||
and decode head. Defaults to False.
|
||||
encoder_resolution (float): resize scale of input images for image encoder.
|
||||
Defaults to None.
|
||||
init_cfg (dict, optional): The weight initialized config for
|
||||
:class:`BaseModule`.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
image_encoder: ConfigType,
|
||||
text_encoder: ConfigType,
|
||||
decode_head: ConfigType,
|
||||
train_cfg: OptConfigType = None,
|
||||
test_cfg: OptConfigType = None,
|
||||
data_preprocessor: OptConfigType = None,
|
||||
pretrained: Optional[str] = None,
|
||||
asymetric_input: bool = True,
|
||||
encoder_resolution: float = None,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
super().__init__(
|
||||
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
||||
if pretrained is not None:
|
||||
image_encoder.init_cfg = dict(
|
||||
type='Pretrained_Part', checkpoint=pretrained)
|
||||
text_encoder.init_cfg = dict(
|
||||
type='Pretrained_Part', checkpoint=pretrained)
|
||||
decode_head.init_cfg = dict(
|
||||
type='Pretrained_Part', checkpoint=pretrained)
|
||||
|
||||
if asymetric_input:
|
||||
assert encoder_resolution is not None, \
|
||||
'if asymetric_input set True, ' \
|
||||
'clip_resolution must be a certain value'
|
||||
self.asymetric_input = asymetric_input
|
||||
self.encoder_resolution = encoder_resolution
|
||||
self.image_encoder = MODELS.build(image_encoder)
|
||||
self.text_encoder = MODELS.build(text_encoder)
|
||||
self._init_decode_head(decode_head)
|
||||
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
|
||||
assert self.with_decode_head
|
||||
|
||||
def _init_decode_head(self, decode_head: ConfigType) -> None:
|
||||
"""Initialize ``decode_head``"""
|
||||
self.decode_head = MODELS.build(decode_head)
|
||||
self.align_corners = self.decode_head.align_corners
|
||||
self.num_classes = self.decode_head.num_classes
|
||||
self.out_channels = self.decode_head.out_channels
|
||||
|
||||
def extract_feat(self, inputs: Tensor) -> List[Tensor]:
|
||||
"""Extract visual features from images."""
|
||||
x = self.image_encoder(inputs)
|
||||
return x
|
||||
|
||||
def encode_decode(self, inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Encode the name of classes with text_encoder and encode images with
|
||||
image_encoder.
|
||||
|
||||
Then decode the class embedding and visual feature into a semantic
|
||||
segmentation map of the same size as input.
|
||||
"""
|
||||
classifier_embeds = self.text_encoder()
|
||||
clip_inputs = inputs
|
||||
if self.asymetric_input:
|
||||
clip_inputs = F.interpolate(
|
||||
inputs, scale_factor=self.encoder_resolution, mode='bilinear')
|
||||
x = self.image_encoder(clip_inputs)
|
||||
seg_logits = self.decode_head.predict([inputs, x, classifier_embeds],
|
||||
batch_img_metas, self.test_cfg)
|
||||
|
||||
return seg_logits
|
||||
|
||||
def _decode_head_forward_train(self, inputs: List[Tensor],
|
||||
data_samples: SampleList) -> dict:
|
||||
"""Run forward function and calculate loss for decode head in
|
||||
training."""
|
||||
losses = dict()
|
||||
loss_decode = self.decode_head.loss(inputs, data_samples,
|
||||
self.train_cfg)
|
||||
|
||||
losses.update(add_prefix(loss_decode, 'decode'))
|
||||
return losses
|
||||
|
||||
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
|
||||
"""Calculate losses from a batch of inputs and data samples.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): Input images.
|
||||
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
|
||||
It usually includes information such as `metainfo` and
|
||||
`gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
classifier_embeds = self.text_encoder()
|
||||
clip_inputs = inputs
|
||||
if self.asymetric_input:
|
||||
clip_inputs = F.interpolate(
|
||||
inputs, scale_factor=self.encoder_resolution, mode='bilinear')
|
||||
x = self.image_encoder(clip_inputs)
|
||||
|
||||
losses = dict()
|
||||
|
||||
loss_decode = self._decode_head_forward_train(
|
||||
[inputs, x, classifier_embeds], data_samples)
|
||||
losses.update(loss_decode)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None) -> SampleList:
|
||||
"""Predict results from a batch of inputs and data samples with post-
|
||||
processing.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
data_samples (List[:obj:`SegDataSample`], optional): The seg data
|
||||
samples. It usually includes information such as `metainfo`
|
||||
and `gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
list[:obj:`SegDataSample`]: Segmentation results of the
|
||||
input images. Each SegDataSample usually contain:
|
||||
|
||||
- ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
|
||||
- ``seg_logits``(PixelData): Predicted logits of semantic
|
||||
segmentation before normalization.
|
||||
"""
|
||||
if data_samples is not None:
|
||||
batch_img_metas = [
|
||||
data_sample.metainfo for data_sample in data_samples
|
||||
]
|
||||
else:
|
||||
batch_img_metas = [
|
||||
dict(
|
||||
ori_shape=inputs.shape[2:],
|
||||
img_shape=inputs.shape[2:],
|
||||
pad_shape=inputs.shape[2:],
|
||||
padding_size=[0, 0, 0, 0])
|
||||
] * inputs.shape[0]
|
||||
|
||||
seg_logits = self.inference(inputs, batch_img_metas)
|
||||
|
||||
return self.postprocess_result(seg_logits, data_samples)
|
||||
|
||||
def _forward(self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None) -> Tensor:
|
||||
"""Network forward process.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
Tensor: Forward output of model without any post-processes.
|
||||
"""
|
||||
x = self.extract_feat(inputs)
|
||||
return self.decode_head.forward(x)
|
||||
|
||||
def slide_inference(self, inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Inference by sliding-window with overlap.
|
||||
|
||||
If h_crop > h_img or w_crop > w_img, the small patch will be used to
|
||||
decode without padding.
|
||||
|
||||
Args:
|
||||
inputs (tensor): the tensor should have a shape NxCxHxW,
|
||||
which contains all images in the batch.
|
||||
batch_img_metas (List[dict]): List of image metainfo where each may
|
||||
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
|
||||
Returns:
|
||||
Tensor: The segmentation results, seg_logits from model of each
|
||||
input image.
|
||||
"""
|
||||
|
||||
h_stride, w_stride = self.test_cfg.stride
|
||||
h_crop, w_crop = self.test_cfg.crop_size
|
||||
batch_size, _, h_img, w_img = inputs.size()
|
||||
out_channels = self.out_channels
|
||||
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
||||
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
||||
preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
|
||||
count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
|
||||
for h_idx in range(h_grids):
|
||||
for w_idx in range(w_grids):
|
||||
y1 = h_idx * h_stride
|
||||
x1 = w_idx * w_stride
|
||||
y2 = min(y1 + h_crop, h_img)
|
||||
x2 = min(x1 + w_crop, w_img)
|
||||
y1 = max(y2 - h_crop, 0)
|
||||
x1 = max(x2 - w_crop, 0)
|
||||
crop_img = inputs[:, :, y1:y2, x1:x2]
|
||||
# change the image shape to patch shape
|
||||
batch_img_metas[0]['img_shape'] = crop_img.shape[2:]
|
||||
# the output of encode_decode is seg logits tensor map
|
||||
# with shape [N, C, H, W]
|
||||
crop_seg_logit = self.encode_decode(crop_img, batch_img_metas)
|
||||
preds += F.pad(crop_seg_logit,
|
||||
(int(x1), int(preds.shape[3] - x2), int(y1),
|
||||
int(preds.shape[2] - y2)))
|
||||
|
||||
count_mat[:, :, y1:y2, x1:x2] += 1
|
||||
assert (count_mat == 0).sum() == 0
|
||||
seg_logits = preds / count_mat
|
||||
|
||||
return seg_logits
|
||||
|
||||
def whole_inference(self, inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Inference with full image.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The tensor should have a shape NxCxHxW, which
|
||||
contains all images in the batch.
|
||||
batch_img_metas (List[dict]): List of image metainfo where each may
|
||||
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
|
||||
Returns:
|
||||
Tensor: The segmentation results, seg_logits from model of each
|
||||
input image.
|
||||
"""
|
||||
|
||||
seg_logits = self.encode_decode(inputs, batch_img_metas)
|
||||
|
||||
return seg_logits
|
||||
|
||||
def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Inference with slide/whole style.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input image of shape (N, 3, H, W).
|
||||
batch_img_metas (List[dict]): List of image metainfo where each may
|
||||
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', 'pad_shape', and 'padding_size'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
|
||||
Returns:
|
||||
Tensor: The segmentation results, seg_logits from model of each
|
||||
input image.
|
||||
"""
|
||||
|
||||
assert self.test_cfg.mode in ['slide', 'whole']
|
||||
ori_shape = batch_img_metas[0]['ori_shape']
|
||||
assert all(_['ori_shape'] == ori_shape for _ in batch_img_metas)
|
||||
if self.test_cfg.mode == 'slide':
|
||||
seg_logit = self.slide_inference(inputs, batch_img_metas)
|
||||
else:
|
||||
seg_logit = self.whole_inference(inputs, batch_img_metas)
|
||||
|
||||
return seg_logit
|
||||
|
||||
def aug_test(self, inputs, batch_img_metas, rescale=True):
|
||||
"""Test with augmentations.
|
||||
|
||||
Only rescale=True is supported.
|
||||
"""
|
||||
# aug_test rescale all imgs back to ori_shape for now
|
||||
assert rescale
|
||||
# to save memory, we get augmented seg logit inplace
|
||||
seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale)
|
||||
for i in range(1, len(inputs)):
|
||||
cur_seg_logit = self.inference(inputs[i], batch_img_metas[i],
|
||||
rescale)
|
||||
seg_logit += cur_seg_logit
|
||||
seg_logit /= len(inputs)
|
||||
seg_pred = seg_logit.argmax(dim=1)
|
||||
# unravel batch dim
|
||||
seg_pred = list(seg_pred)
|
||||
return seg_pred
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user