737 lines
30 KiB
Python
737 lines
30 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from functools import partial
|
|
from typing import Dict, List, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import ConvModule, build_norm_layer
|
|
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
|
from mmcv.ops import point_sample
|
|
from mmengine.dist import all_reduce
|
|
from mmengine.model.weight_init import (caffe2_xavier_init, normal_init,
|
|
trunc_normal_)
|
|
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
|
|
from mmengine.structures import InstanceData
|
|
from torch import Tensor
|
|
from torch.nn import functional as F
|
|
|
|
from mmseg.models.backbones.vit import TransformerEncoderLayer
|
|
from mmseg.registry import MODELS
|
|
from mmseg.utils import (ConfigType, MatchMasks, SampleList,
|
|
seg_data_to_instance_data)
|
|
from ..utils import (MLP, LayerNorm2d, PatchEmbed, cross_attn_layer,
|
|
get_uncertain_point_coords_with_randomness, resize)
|
|
from .decode_head import BaseDecodeHead
|
|
|
|
|
|
class MLPMaskDecoder(nn.Module):
|
|
"""Module for decoding query and visual features with MLP layers to
|
|
generate the attention biases and the mask proposals."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
in_channels: int,
|
|
total_heads: int = 1,
|
|
total_layers: int = 1,
|
|
embed_channels: int = 256,
|
|
mlp_channels: int = 256,
|
|
mlp_num_layers: int = 3,
|
|
rescale_attn_bias: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.total_heads = total_heads
|
|
self.total_layers = total_layers
|
|
|
|
dense_affine_func = partial(nn.Conv2d, kernel_size=1)
|
|
# Query Branch
|
|
self.query_mlp = MLP(in_channels, mlp_channels, embed_channels,
|
|
mlp_num_layers)
|
|
# Pixel Branch
|
|
self.pix_mlp = MLP(
|
|
in_channels,
|
|
mlp_channels,
|
|
embed_channels,
|
|
mlp_num_layers,
|
|
affine_func=dense_affine_func,
|
|
)
|
|
# Attention Bias Branch
|
|
self.attn_mlp = MLP(
|
|
in_channels,
|
|
mlp_channels,
|
|
embed_channels * self.total_heads * self.total_layers,
|
|
mlp_num_layers,
|
|
affine_func=dense_affine_func,
|
|
)
|
|
if rescale_attn_bias:
|
|
self.bias_scaling = nn.Linear(1, 1)
|
|
else:
|
|
self.bias_scaling = nn.Identity()
|
|
|
|
def forward(self, query: torch.Tensor,
|
|
x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
|
"""Forward function.
|
|
Args:
|
|
query (Tensor): Query Tokens [B,N,C].
|
|
x (Tensor): Visual features [B,C,H,W]
|
|
|
|
Return:
|
|
mask_preds (Tensor): Mask proposals.
|
|
attn_bias (List[Tensor]): List of attention bias.
|
|
"""
|
|
query = self.query_mlp(query)
|
|
pix = self.pix_mlp(x)
|
|
b, c, h, w = pix.shape
|
|
# preidict mask
|
|
mask_preds = torch.einsum('bqc,bchw->bqhw', query, pix)
|
|
# generate attn bias
|
|
attn = self.attn_mlp(x)
|
|
attn = attn.reshape(b, self.total_layers, self.total_heads, c, h, w)
|
|
attn_bias = torch.einsum('bqc,blnchw->blnqhw', query, attn)
|
|
attn_bias = self.bias_scaling(attn_bias[..., None]).squeeze(-1)
|
|
attn_bias = attn_bias.chunk(self.total_layers, dim=1)
|
|
attn_bias = [attn.squeeze(1) for attn in attn_bias]
|
|
return mask_preds, attn_bias
|
|
|
|
|
|
class SideAdapterNetwork(nn.Module):
|
|
"""Side Adapter Network for predicting mask proposals and attention bias.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels. Default: 3.
|
|
clip_channels (int): Number of channels of visual features.
|
|
Default: 768.
|
|
embed_dims (int): embedding dimension. Default: 240.
|
|
patch_size (int): The patch size. Default: 16.
|
|
patch_bias (bool): Whether use bias in patch embedding.
|
|
Default: True.
|
|
num_queries (int): Number of queries for mask proposals.
|
|
Default: 100.
|
|
fusion_index (List[int]): The layer number of the encode
|
|
transformer to fuse with the CLIP feature.
|
|
Default: [0, 1, 2, 3].
|
|
cfg_encoder (ConfigType): Configs for the encode layers.
|
|
cfg_decoder (ConfigType): Configs for the decode layers.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='LN').
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int = 3,
|
|
clip_channels: int = 768,
|
|
embed_dims: int = 240,
|
|
patch_size: int = 16,
|
|
patch_bias: bool = True,
|
|
num_queries: int = 100,
|
|
fusion_index: list = [0, 1, 2, 3],
|
|
cfg_encoder: ConfigType = ...,
|
|
cfg_decoder: ConfigType = ...,
|
|
norm_cfg: dict = dict(type='LN'),
|
|
):
|
|
super().__init__()
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
in_channels=in_channels,
|
|
embed_dims=embed_dims,
|
|
conv_type='Conv2d',
|
|
kernel_size=patch_size,
|
|
stride=patch_size,
|
|
padding=0,
|
|
input_size=(640, 640),
|
|
bias=patch_bias,
|
|
norm_cfg=None,
|
|
init_cfg=None,
|
|
)
|
|
ori_h, ori_w = self.patch_embed.init_out_size
|
|
num_patches = ori_h * ori_w
|
|
self.pos_embed = nn.Parameter(
|
|
torch.randn(1, num_patches, embed_dims) * .02)
|
|
self.query_pos_embed = nn.Parameter(
|
|
torch.zeros(1, num_queries, embed_dims))
|
|
self.query_embed = nn.Parameter(
|
|
torch.zeros(1, num_queries, embed_dims))
|
|
encode_layers = []
|
|
for i in range(cfg_encoder.num_encode_layer):
|
|
encode_layers.append(
|
|
TransformerEncoderLayer(
|
|
embed_dims=embed_dims,
|
|
num_heads=cfg_encoder.num_heads,
|
|
feedforward_channels=cfg_encoder.mlp_ratio * embed_dims,
|
|
norm_cfg=norm_cfg))
|
|
self.encode_layers = nn.ModuleList(encode_layers)
|
|
conv_clips = []
|
|
for i in range(len(fusion_index)):
|
|
conv_clips.append(
|
|
nn.Sequential(
|
|
LayerNorm2d(clip_channels),
|
|
ConvModule(
|
|
clip_channels,
|
|
embed_dims,
|
|
kernel_size=1,
|
|
norm_cfg=None,
|
|
act_cfg=None)))
|
|
self.conv_clips = nn.ModuleList(conv_clips)
|
|
self.fusion_index = fusion_index
|
|
self.mask_decoder = MLPMaskDecoder(
|
|
in_channels=embed_dims,
|
|
total_heads=cfg_decoder.num_heads,
|
|
total_layers=cfg_decoder.num_layers,
|
|
embed_channels=cfg_decoder.embed_channels,
|
|
mlp_channels=cfg_decoder.mlp_channels,
|
|
mlp_num_layers=cfg_decoder.num_mlp,
|
|
rescale_attn_bias=cfg_decoder.rescale)
|
|
|
|
def init_weights(self):
|
|
trunc_normal_(self.pos_embed, std=0.02)
|
|
nn.init.normal_(self.query_embed, std=0.02)
|
|
nn.init.normal_(self.query_pos_embed, std=0.02)
|
|
for i in range(len(self.conv_clips)):
|
|
caffe2_xavier_init(self.conv_clips[i][1].conv)
|
|
|
|
def fuse_clip(self, fused_index: int, x: torch.Tensor,
|
|
clip_feature: torch.Tensor, hwshape: Tuple[int,
|
|
int], L: int):
|
|
"""Fuse CLIP feature and visual tokens."""
|
|
fused_clip = (resize(
|
|
self.conv_clips[fused_index](clip_feature.contiguous()),
|
|
size=hwshape,
|
|
mode='bilinear',
|
|
align_corners=False)).permute(0, 2, 3, 1).reshape(x[:, -L:,
|
|
...].shape)
|
|
x = torch.cat([x[:, :-L, ...], x[:, -L:, ...] + fused_clip], dim=1)
|
|
return x
|
|
|
|
def encode_feature(self, image: torch.Tensor,
|
|
clip_features: List[torch.Tensor],
|
|
deep_supervision_idxs: List[int]) -> List[List]:
|
|
"""Encode images by a lightweight vision transformer."""
|
|
assert len(self.fusion_index) == len(clip_features)
|
|
x, hwshape = self.patch_embed(image)
|
|
ori_h, ori_w = self.patch_embed.init_out_size
|
|
pos_embed = self.pos_embed
|
|
if self.pos_embed.shape[1] != x.shape[1]:
|
|
# resize the position embedding
|
|
pos_embed = (
|
|
resize(
|
|
self.pos_embed.reshape(1, ori_h, ori_w,
|
|
-1).permute(0, 3, 1, 2),
|
|
size=hwshape,
|
|
mode='bicubic',
|
|
align_corners=False,
|
|
).flatten(2).permute(0, 2, 1))
|
|
pos_embed = torch.cat([
|
|
self.query_pos_embed.expand(pos_embed.shape[0], -1, -1), pos_embed
|
|
],
|
|
dim=1)
|
|
x = torch.cat([self.query_embed.expand(x.shape[0], -1, -1), x], dim=1)
|
|
x = x + pos_embed
|
|
L = hwshape[0] * hwshape[1]
|
|
fused_index = 0
|
|
if self.fusion_index[fused_index] == 0:
|
|
x = self.fuse_clip(fused_index, x, clip_features[0][0], hwshape, L)
|
|
fused_index += 1
|
|
outs = []
|
|
for index, block in enumerate(self.encode_layers, start=1):
|
|
x = block(x)
|
|
if index < len(self.fusion_index
|
|
) and index == self.fusion_index[fused_index]:
|
|
x = self.fuse_clip(fused_index, x,
|
|
clip_features[fused_index][0], hwshape, L)
|
|
fused_index += 1
|
|
x_query = x[:, :-L, ...]
|
|
x_feat = x[:, -L:, ...].permute(0, 2, 1)\
|
|
.reshape(x.shape[0], x.shape[-1], hwshape[0], hwshape[1])
|
|
|
|
if index in deep_supervision_idxs or index == len(
|
|
self.encode_layers):
|
|
outs.append({'query': x_query, 'x': x_feat})
|
|
|
|
if index < len(self.encode_layers):
|
|
x = x + pos_embed
|
|
return outs
|
|
|
|
def decode_feature(self, features):
|
|
mask_embeds = []
|
|
attn_biases = []
|
|
for feature in features:
|
|
mask_embed, attn_bias = self.mask_decoder(**feature)
|
|
mask_embeds.append(mask_embed)
|
|
attn_biases.append(attn_bias)
|
|
return mask_embeds, attn_biases
|
|
|
|
def forward(
|
|
self, image: torch.Tensor, clip_features: List[torch.Tensor],
|
|
deep_supervision_idxs: List[int]
|
|
) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
|
|
"""Forward function."""
|
|
features = self.encode_feature(image, clip_features,
|
|
deep_supervision_idxs)
|
|
mask_embeds, attn_biases = self.decode_feature(features)
|
|
return mask_embeds, attn_biases
|
|
|
|
|
|
class RecWithAttnbias(nn.Module):
|
|
"""Mask recognition module by applying the attention biases to rest deeper
|
|
CLIP layers.
|
|
|
|
Args:
|
|
sos_token_format (str): The format of sos token. It should be
|
|
chosen from ["cls_token", "learnable_token", "pos_embedding"].
|
|
Default: 'cls_token'.
|
|
sos_token_num (int): Number of sos token. It should be equal to
|
|
the number of quries. Default: 100.
|
|
num_layers (int): Number of rest CLIP layers for mask recognition.
|
|
Default: 3.
|
|
cross_attn (bool): Whether use cross attention to update sos token.
|
|
Default: False.
|
|
embed_dims (int): The feature dimension of CLIP layers.
|
|
Default: 768.
|
|
num_heads (int): Parallel attention heads of CLIP layers.
|
|
Default: 768.
|
|
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
|
|
Default: 4.
|
|
qkv_bias (bool): Whether to use bias in multihead-attention.
|
|
Default: True.
|
|
out_dims (int): Number of channels of the output mask proposals.
|
|
It should be equal to the out_dims of text_encoder.
|
|
Default: 512.
|
|
final_norm (True): Whether use norm layer for sos token.
|
|
act_cfg (dict): The activation config for FFNs.
|
|
Default: dict(type='GELU').
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='LN').
|
|
frozen_exclude (List): List of parameters that are not to be frozen.
|
|
"""
|
|
|
|
def __init__(self,
|
|
sos_token_format: str = 'cls_token',
|
|
sos_token_num: int = 100,
|
|
num_layers: int = 3,
|
|
cross_attn: bool = False,
|
|
embed_dims: int = 768,
|
|
num_heads: int = 12,
|
|
mlp_ratio: int = 4,
|
|
num_fcs: int = 2,
|
|
qkv_bias: bool = True,
|
|
out_dims: int = 512,
|
|
final_norm: bool = True,
|
|
act_cfg: dict = dict(type='GELU'),
|
|
norm_cfg: dict = dict(type='LN'),
|
|
frozen_exclude: List = []):
|
|
super().__init__()
|
|
|
|
assert sos_token_format in [
|
|
'cls_token', 'learnable_token', 'pos_embedding'
|
|
]
|
|
self.sos_token_format = sos_token_format
|
|
self.sos_token_num = sos_token_num
|
|
self.frozen_exclude = frozen_exclude
|
|
self.cross_attn = cross_attn
|
|
self.num_layers = num_layers
|
|
self.num_heads = num_heads
|
|
if sos_token_format in ['learnable_token', 'pos_embedding']:
|
|
self.sos_token = nn.Parameter(
|
|
torch.randn(sos_token_num, 1, self.proj.shape[0]))
|
|
self.frozen.append('sos_token')
|
|
|
|
layers = []
|
|
for i in range(num_layers):
|
|
layers.append(
|
|
BaseTransformerLayer(
|
|
attn_cfgs=dict(
|
|
type='MultiheadAttention',
|
|
embed_dims=embed_dims,
|
|
num_heads=num_heads,
|
|
batch_first=False,
|
|
bias=qkv_bias),
|
|
ffn_cfgs=dict(
|
|
type='FFN',
|
|
embed_dims=embed_dims,
|
|
feedforward_channels=mlp_ratio * embed_dims,
|
|
act_cfg=act_cfg),
|
|
operation_order=('norm', 'self_attn', 'norm', 'ffn')))
|
|
self.layers = nn.ModuleList(layers)
|
|
|
|
self.ln_post = build_norm_layer(norm_cfg, embed_dims)[1]
|
|
self.proj = nn.Linear(embed_dims, out_dims, bias=False)
|
|
|
|
self.final_norm = final_norm
|
|
self._freeze()
|
|
|
|
def init_weights(self, rec_state_dict):
|
|
if hasattr(self, 'sos_token'):
|
|
normal_init(self.sos_token, std=0.02)
|
|
if rec_state_dict is not None:
|
|
load_state_dict(self, rec_state_dict, strict=False, logger=None)
|
|
else:
|
|
super().init_weights()
|
|
|
|
def _freeze(self):
|
|
if 'all' in self.frozen_exclude:
|
|
return
|
|
for name, param in self.named_parameters():
|
|
if not any([exclude in name for exclude in self.frozen_exclude]):
|
|
param.requires_grad = False
|
|
|
|
def _build_attn_biases(self, attn_biases, target_shape):
|
|
formatted_attn_biases = []
|
|
for attn_bias in attn_biases:
|
|
# convert it to proper format: N*num_head,L,L
|
|
# attn_bias: [N, num_head/1, num_sos,H,W]
|
|
n, num_head, num_sos, h, w = attn_bias.shape
|
|
# reshape and downsample
|
|
attn_bias = F.adaptive_max_pool2d(
|
|
attn_bias.reshape(n, num_head * num_sos, h, w),
|
|
output_size=target_shape)
|
|
attn_bias = attn_bias.reshape(n, num_head, num_sos, *target_shape)
|
|
|
|
true_num_head = self.num_heads
|
|
assert (num_head == 1 or num_head
|
|
== true_num_head), f'num_head={num_head} is not supported.'
|
|
if num_head == 1:
|
|
attn_bias = attn_bias.repeat(1, true_num_head, 1, 1, 1)
|
|
attn_bias = attn_bias.reshape(n * true_num_head, num_sos, -1)
|
|
L = attn_bias.shape[-1]
|
|
if self.cross_attn:
|
|
# [n*num_head, num_sos, L]
|
|
formatted_attn_biases.append(attn_bias)
|
|
else:
|
|
# [n*num_head, num_sos+1+L, num_sos+1+L]
|
|
new_attn_bias = attn_bias.new_zeros(num_sos + 1 + L,
|
|
num_sos + 1 + L)
|
|
new_attn_bias[:, :num_sos] = -100
|
|
new_attn_bias[torch.arange(num_sos), torch.arange(num_sos)] = 0
|
|
new_attn_bias[:num_sos, num_sos] = -100
|
|
new_attn_bias = (
|
|
new_attn_bias[None, ...].expand(n * true_num_head, -1,
|
|
-1).clone())
|
|
new_attn_bias[..., :num_sos, -L:] = attn_bias
|
|
formatted_attn_biases.append(new_attn_bias)
|
|
|
|
if len(formatted_attn_biases) == 1:
|
|
formatted_attn_biases = [
|
|
formatted_attn_biases[0] for _ in range(self.num_layers)
|
|
]
|
|
return formatted_attn_biases
|
|
|
|
def forward(self, bias: List[Tensor], feature: List[Tensor]):
|
|
"""Forward function to recognize the category of masks
|
|
Args:
|
|
bias (List[Tensor]): Attention bias for transformer layers
|
|
feature (List[Tensor]): Output of the image encoder,
|
|
including cls_token and img_feature.
|
|
"""
|
|
cls_token = feature[1].unsqueeze(0)
|
|
img_feature = feature[0]
|
|
b, c, h, w = img_feature.shape
|
|
# construct clip shadow features
|
|
x = torch.cat(
|
|
[cls_token,
|
|
img_feature.reshape(b, c, -1).permute(2, 0, 1)])
|
|
|
|
# construct sos token
|
|
if self.sos_token_format == 'cls_token':
|
|
sos_token = cls_token.repeat(self.sos_token_num, 1, 1)
|
|
elif self.sos_token_format == 'learnable_token':
|
|
sos_token = self.sos_token.expand(-1, b, -1)
|
|
elif self.sos_token_format == 'pos_embedding':
|
|
sos_token = self.sos_token.expand(-1, b, -1) + cls_token
|
|
|
|
# construct attn bias
|
|
attn_biases = self._build_attn_biases(bias, target_shape=(h, w))
|
|
|
|
if self.cross_attn:
|
|
for i, block in enumerate(self.layers):
|
|
if self.cross_attn:
|
|
sos_token = cross_attn_layer(
|
|
block,
|
|
sos_token,
|
|
x[1:, ],
|
|
attn_biases[i],
|
|
)
|
|
if i < len(self.layers) - 1:
|
|
x = block(x)
|
|
else:
|
|
x = torch.cat([sos_token, x], dim=0)
|
|
for i, block in enumerate(self.layers):
|
|
x = block(x, attn_masks=[attn_biases[i]])
|
|
sos_token = x[:self.sos_token_num]
|
|
|
|
sos_token = sos_token.permute(1, 0, 2) # LND -> NLD
|
|
sos_token = self.ln_post(sos_token)
|
|
sos_token = self.proj(sos_token)
|
|
if self.final_norm:
|
|
sos_token = F.normalize(sos_token, dim=-1)
|
|
return sos_token
|
|
|
|
|
|
@MODELS.register_module()
|
|
class SideAdapterCLIPHead(BaseDecodeHead):
|
|
"""Side Adapter Network (SAN) for open-vocabulary semantic segmentation
|
|
with pre-trained vision-language model.
|
|
|
|
This decode head is the implementation of `Side Adapter Network
|
|
for Open-Vocabulary Semantic Segmentation`
|
|
<https://arxiv.org/abs/2302.12242>.
|
|
Modified from https://github.com/MendelXu/SAN/blob/main/san/model/side_adapter/side_adapter.py # noqa:E501
|
|
Copyright (c) 2023 MendelXu.
|
|
Licensed under the MIT License
|
|
|
|
Args:
|
|
num_classes (int): the number of classes.
|
|
san_cfg (ConfigType): Configs for SideAdapterNetwork module
|
|
maskgen_cfg (ConfigType): Configs for RecWithAttnbias module
|
|
"""
|
|
|
|
def __init__(self, num_classes: int, san_cfg: ConfigType,
|
|
maskgen_cfg: ConfigType, deep_supervision_idxs: List[int],
|
|
train_cfg: ConfigType, **kwargs):
|
|
super().__init__(
|
|
in_channels=san_cfg.in_channels,
|
|
channels=san_cfg.embed_dims,
|
|
num_classes=num_classes,
|
|
**kwargs)
|
|
assert san_cfg.num_queries == maskgen_cfg.sos_token_num, \
|
|
'num_queries in san_cfg should be equal to sos_token_num ' \
|
|
'in maskgen_cfg'
|
|
del self.conv_seg
|
|
self.side_adapter_network = SideAdapterNetwork(**san_cfg)
|
|
self.rec_with_attnbias = RecWithAttnbias(**maskgen_cfg)
|
|
self.deep_supervision_idxs = deep_supervision_idxs
|
|
self.train_cfg = train_cfg
|
|
if train_cfg:
|
|
self.match_masks = MatchMasks(
|
|
num_points=train_cfg.num_points,
|
|
num_queries=san_cfg.num_queries,
|
|
num_classes=num_classes,
|
|
assigner=train_cfg.assigner)
|
|
|
|
def init_weights(self):
|
|
|
|
rec_state_dict = None
|
|
if isinstance(self.init_cfg, dict) and \
|
|
self.init_cfg.get('type') == 'Pretrained_Part':
|
|
checkpoint = CheckpointLoader.load_checkpoint(
|
|
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
|
|
|
rec_state_dict = checkpoint.copy()
|
|
para_prefix = 'decode_head.rec_with_attnbias'
|
|
prefix_len = len(para_prefix) + 1
|
|
for k, v in checkpoint.items():
|
|
rec_state_dict.pop(k)
|
|
if para_prefix in k:
|
|
rec_state_dict[k[prefix_len:]] = v
|
|
|
|
self.side_adapter_network.init_weights()
|
|
self.rec_with_attnbias.init_weights(rec_state_dict)
|
|
|
|
def forward(self, inputs: Tuple[Tensor],
|
|
deep_supervision_idxs) -> Tuple[List]:
|
|
"""Forward function.
|
|
|
|
Args:
|
|
inputs (Tuple[Tensor]): A triplet including images,
|
|
list of multi-level visual features from image encoder and
|
|
class embeddings from text_encoder.
|
|
|
|
Returns:
|
|
mask_props (List[Tensor]): Mask proposals predicted by SAN.
|
|
mask_logits (List[Tensor]): Class logits of mask proposals.
|
|
"""
|
|
imgs, clip_feature, class_embeds = inputs
|
|
# predict mask proposals and attention bias
|
|
mask_props, attn_biases = self.side_adapter_network(
|
|
imgs, clip_feature, deep_supervision_idxs)
|
|
|
|
# mask recognition with attention bias
|
|
mask_embeds = [
|
|
self.rec_with_attnbias(att_bias, clip_feature[-1])
|
|
for att_bias in attn_biases
|
|
]
|
|
# Obtain class prediction of masks by comparing the similarity
|
|
# between the image token and the text embedding of class names.
|
|
mask_logits = [
|
|
torch.einsum('bqc,nc->bqn', mask_embed, class_embeds)
|
|
for mask_embed in mask_embeds
|
|
]
|
|
return mask_props, mask_logits
|
|
|
|
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
|
test_cfg: ConfigType) -> Tensor:
|
|
"""Forward function for prediction.
|
|
|
|
Args:
|
|
inputs (Tuple[Tensor]): Images, visual features from image encoder
|
|
and class embedding from text encoder.
|
|
batch_img_metas (dict): List Image info where each dict may also
|
|
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
|
'ori_shape', and 'pad_shape'.
|
|
For details on the values of these keys see
|
|
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
|
test_cfg (dict): The testing config.
|
|
|
|
Returns:
|
|
Tensor: Outputs segmentation logits map.
|
|
"""
|
|
mask_props, mask_logits = self.forward(inputs, [])
|
|
|
|
return self.predict_by_feat([mask_props[-1], mask_logits[-1]],
|
|
batch_img_metas)
|
|
|
|
def predict_by_feat(self, seg_logits: List[Tensor],
|
|
batch_img_metas: List[dict]) -> Tensor:
|
|
"""1. Transform a batch of mask proposals to the input shape.
|
|
2. Generate segmentation map with mask proposals and class logits.
|
|
"""
|
|
mask_pred = seg_logits[0]
|
|
cls_score = seg_logits[1]
|
|
if isinstance(batch_img_metas[0]['img_shape'], torch.Size):
|
|
# slide inference
|
|
size = batch_img_metas[0]['img_shape']
|
|
elif 'pad_shape' in batch_img_metas[0]:
|
|
size = batch_img_metas[0]['pad_shape'][:2]
|
|
else:
|
|
size = batch_img_metas[0]['img_shape']
|
|
# upsample mask
|
|
mask_pred = F.interpolate(
|
|
mask_pred, size=size, mode='bilinear', align_corners=False)
|
|
|
|
mask_cls = F.softmax(cls_score, dim=-1)[..., :-1]
|
|
mask_pred = mask_pred.sigmoid()
|
|
seg_logits = torch.einsum('bqc,bqhw->bchw', mask_cls, mask_pred)
|
|
return seg_logits
|
|
|
|
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
|
|
train_cfg: ConfigType) -> dict:
|
|
"""Perform forward propagation and loss calculation of the decoder head
|
|
on the features of the upstream network.
|
|
|
|
Args:
|
|
x (tuple[Tensor]): Multi-level features from the upstream
|
|
network, each is a 4D-tensor.
|
|
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
|
Samples. It usually includes information such as
|
|
`gt_sem_seg`.
|
|
train_cfg (ConfigType): Training config.
|
|
|
|
Returns:
|
|
dict[str, Tensor]: a dictionary of loss components.
|
|
"""
|
|
# batch SegDataSample to InstanceDataSample
|
|
batch_gt_instances = seg_data_to_instance_data(self.ignore_index,
|
|
batch_data_samples)
|
|
|
|
# forward
|
|
all_mask_props, all_mask_logits = self.forward(
|
|
x, self.deep_supervision_idxs)
|
|
|
|
# loss
|
|
losses = self.loss_by_feat(all_mask_logits, all_mask_props,
|
|
batch_gt_instances)
|
|
|
|
return losses
|
|
|
|
def loss_by_feat(
|
|
self, all_cls_scores: Tensor, all_mask_preds: Tensor,
|
|
batch_gt_instances: List[InstanceData]) -> Dict[str, Tensor]:
|
|
"""Loss function.
|
|
|
|
Args:
|
|
all_cls_scores (Tensor): Classification scores for all decoder
|
|
layers with shape (num_decoder, batch_size, num_queries,
|
|
cls_out_channels). Note `cls_out_channels` should includes
|
|
background.
|
|
all_mask_preds (Tensor): Mask scores for all decoder layers with
|
|
shape (num_decoder, batch_size, num_queries, h, w).
|
|
batch_gt_instances (list[obj:`InstanceData`]): each contains
|
|
``labels`` and ``masks``.
|
|
|
|
Returns:
|
|
dict[str, Tensor]: A dictionary of loss components.
|
|
"""
|
|
num_dec_layers = len(all_cls_scores)
|
|
batch_gt_instances_list = [
|
|
batch_gt_instances for _ in range(num_dec_layers)
|
|
]
|
|
|
|
losses = []
|
|
for i in range(num_dec_layers):
|
|
cls_scores = all_cls_scores[i]
|
|
mask_preds = all_mask_preds[i]
|
|
# matching N mask predictions to K category labels
|
|
(labels, mask_targets, mask_weights,
|
|
avg_factor) = self.match_masks.get_targets(
|
|
cls_scores, mask_preds, batch_gt_instances_list[i])
|
|
cls_scores = cls_scores.flatten(0, 1)
|
|
labels = labels.flatten(0, 1)
|
|
num_total_masks = cls_scores.new_tensor([avg_factor],
|
|
dtype=torch.float)
|
|
all_reduce(num_total_masks, op='mean')
|
|
num_total_masks = max(num_total_masks, 1)
|
|
|
|
# extract positive ones
|
|
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
|
|
mask_preds = mask_preds[mask_weights > 0]
|
|
|
|
if mask_targets.shape[0] != 0:
|
|
with torch.no_grad():
|
|
points_coords = get_uncertain_point_coords_with_randomness(
|
|
mask_preds.unsqueeze(1), None,
|
|
self.train_cfg.num_points,
|
|
self.train_cfg.oversample_ratio,
|
|
self.train_cfg.importance_sample_ratio)
|
|
# shape (num_total_gts, h, w)
|
|
# -> (num_total_gts, num_points)
|
|
mask_point_targets = point_sample(
|
|
mask_targets.unsqueeze(1).float(),
|
|
points_coords).squeeze(1)
|
|
# shape (num_queries, h, w) -> (num_queries, num_points)
|
|
mask_point_preds = point_sample(
|
|
mask_preds.unsqueeze(1), points_coords).squeeze(1)
|
|
|
|
if not isinstance(self.loss_decode, nn.ModuleList):
|
|
losses_decode = [self.loss_decode]
|
|
else:
|
|
losses_decode = self.loss_decode
|
|
loss = dict()
|
|
for loss_decode in losses_decode:
|
|
if 'loss_cls' in loss_decode.loss_name:
|
|
if loss_decode.loss_name == 'loss_cls_ce':
|
|
loss[loss_decode.loss_name] = loss_decode(
|
|
cls_scores, labels)
|
|
else:
|
|
assert False, "Only support 'CrossEntropyLoss' in" \
|
|
' classification loss'
|
|
|
|
elif 'loss_mask' in loss_decode.loss_name:
|
|
if mask_targets.shape[0] == 0:
|
|
loss[loss_decode.loss_name] = mask_preds.sum()
|
|
elif loss_decode.loss_name == 'loss_mask_ce':
|
|
loss[loss_decode.loss_name] = loss_decode(
|
|
mask_point_preds,
|
|
mask_point_targets,
|
|
avg_factor=num_total_masks *
|
|
self.train_cfg.num_points)
|
|
elif loss_decode.loss_name == 'loss_mask_dice':
|
|
loss[loss_decode.loss_name] = loss_decode(
|
|
mask_point_preds,
|
|
mask_point_targets,
|
|
avg_factor=num_total_masks)
|
|
else:
|
|
assert False, "Only support 'CrossEntropyLoss' and" \
|
|
" 'DiceLoss' in mask loss"
|
|
else:
|
|
assert False, "Only support for 'loss_cls' and 'loss_mask'"
|
|
|
|
losses.append(loss)
|
|
|
|
loss_dict = dict()
|
|
# loss from the last decoder layer
|
|
loss_dict.update(losses[-1])
|
|
# loss from other decoder layers
|
|
for i, loss in enumerate(losses[:-1]):
|
|
for k, v in loss.items():
|
|
loss_dict[f'd{self.deep_supervision_idxs[i]}.{k}'] = v
|
|
return loss_dict
|