This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

7
lib/models/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
from .segmentors import SkySensePP
from .losses import (ModalityVAELoss, RecLoss)
from .metrics import (SemMetric)
__all__ = [
'SkySensePP', 'ModalityVAELoss', 'RecLoss', 'SemMetric'
]

View File

@@ -0,0 +1,14 @@
from .swin_v2 import SwinTransformerV2MSL
from .vit import VisionTransformerMSL
__all__ = [
'SwinTransformerV2MSL', 'VisionTransformerMSL'
]
type_mapping = {
'SwinTransformerV2MSL': SwinTransformerV2MSL,
'VisionTransformerMSL': VisionTransformerMSL
}
def build_backbone(type, **kwargs):
return type_mapping[type](**kwargs)

View File

@@ -0,0 +1,702 @@
from copy import deepcopy
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmcls.models.utils import (PatchMerging, ShiftWindowMSA, WindowMSAV2,
resize_pos_embed, to_2tuple)
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcv.runner import (CheckpointLoader,
load_state_dict)
from mmcv.cnn.bricks.transformer import MultiheadAttention
class SwinBlockV2(BaseModule):
"""Swin Transformer V2 block. Use post normalization.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
shift (bool): Shift the attention window or not. Defaults to False.
extra_norm (bool): Whether add extra norm at the end of main branch.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
drop_path (float): The drop path rate after attention and ffn.
Defaults to 0.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
attn_cfgs (dict): The extra config of Shift Window-MSA.
Defaults to empty dict.
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
norm_cfg (dict): The config of norm layers.
Defaults to ``dict(type='LN')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pretrained_window_size (int): Window size in pretrained.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size=8,
shift=False,
extra_norm=False,
ffn_ratio=4.,
drop_path=0.,
pad_small_map=False,
attn_cfgs=dict(),
ffn_cfgs=dict(),
norm_cfg=dict(type='LN'),
with_cp=False,
pretrained_window_size=0,
init_cfg=None):
super(SwinBlockV2, self).__init__(init_cfg)
self.with_cp = with_cp
self.extra_norm = extra_norm
_attn_cfgs = {
'embed_dims': embed_dims,
'num_heads': num_heads,
'shift_size': window_size // 2 if shift else 0,
'window_size': window_size,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'pad_small_map': pad_small_map,
**attn_cfgs
}
# use V2 attention implementation
_attn_cfgs.update(
window_msa=WindowMSAV2,
msa_cfg=dict(
pretrained_window_size=to_2tuple(pretrained_window_size)))
self.attn = ShiftWindowMSA(**_attn_cfgs)
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
_ffn_cfgs = {
'embed_dims': embed_dims,
'feedforward_channels': int(embed_dims * ffn_ratio),
'num_fcs': 2,
'ffn_drop': 0,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'act_cfg': dict(type='GELU'),
'add_identity': False,
**ffn_cfgs
}
self.ffn = FFN(**_ffn_cfgs)
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
# add extra norm for every n blocks in huge and giant model
if self.extra_norm:
self.norm3 = build_norm_layer(norm_cfg, embed_dims)[1]
def forward(self, x, hw_shape):
def _inner_forward(x):
# Use post normalization
identity = x
x = self.attn(x, hw_shape)
x = self.norm1(x)
x = x + identity
identity = x
x = self.ffn(x)
x = self.norm2(x)
x = x + identity
if self.extra_norm:
x = self.norm3(x)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class SwinBlockV2Sequence(BaseModule):
"""Module with successive Swin Transformer blocks and downsample layer.
Args:
embed_dims (int): Number of input channels.
depth (int): Number of successive swin transformer blocks.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
downsample (bool): Downsample the output of blocks by patch merging.
Defaults to False.
downsample_cfg (dict): The extra config of the patch merging layer.
Defaults to empty dict.
drop_paths (Sequence[float] | float): The drop path rate in each block.
Defaults to 0.
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
extra_norm_every_n_blocks (int): Add extra norm at the end of main
branch every n blocks. Defaults to 0, which means no needs for
extra norm layer.
pretrained_window_size (int): Window size in pretrained.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
depth,
num_heads,
window_size=8,
downsample=False,
downsample_cfg=dict(),
drop_paths=0.,
block_cfgs=dict(),
with_cp=False,
pad_small_map=False,
extra_norm_every_n_blocks=0,
pretrained_window_size=0,
init_cfg=None):
super().__init__(init_cfg)
if not isinstance(drop_paths, Sequence):
drop_paths = [drop_paths] * depth
if not isinstance(block_cfgs, Sequence):
block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)]
if downsample:
self.out_channels = 2 * embed_dims
_downsample_cfg = {
'in_channels': embed_dims,
'out_channels': self.out_channels,
'norm_cfg': dict(type='LN'),
**downsample_cfg
}
self.downsample = PatchMerging(**_downsample_cfg)
else:
self.out_channels = embed_dims
self.downsample = None
self.blocks = ModuleList()
for i in range(depth):
extra_norm = True if extra_norm_every_n_blocks and \
(i + 1) % extra_norm_every_n_blocks == 0 else False
_block_cfg = {
'embed_dims': self.out_channels,
'num_heads': num_heads,
'window_size': window_size,
'shift': False if i % 2 == 0 else True,
'extra_norm': extra_norm,
'drop_path': drop_paths[i],
'with_cp': with_cp,
'pad_small_map': pad_small_map,
'pretrained_window_size': pretrained_window_size,
**block_cfgs[i]
}
block = SwinBlockV2(**_block_cfg)
self.blocks.append(block)
def forward(self, x, in_shape):
if self.downsample:
x, out_shape = self.downsample(x, in_shape)
else:
out_shape = in_shape
for block in self.blocks:
x = block(x, out_shape)
return x, out_shape
class SwinTransformerV2(BaseBackbone):
"""Swin Transformer V2.
A PyTorch implement of : `Swin Transformer V2:
Scaling Up Capacity and Resolution
<https://arxiv.org/abs/2111.09883>`_
Inspiration from
https://github.com/microsoft/Swin-Transformer
Args:
arch (str | dict): Swin Transformer architecture. If use string, choose
from 'tiny', 'small', 'base' and 'large'. If use dict, it should
have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **num_heads** (List[int]): The number of heads in attention
modules of each stage.
- **extra_norm_every_n_blocks** (int): Add extra norm at the end
of main branch every n blocks.
Defaults to 'tiny'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 4.
in_channels (int): The num of input channels. Defaults to 3.
window_size (int | Sequence): The height and width of the window.
Defaults to 7.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
interpolate_mode (str): Select the interpolate mode for absolute
position embeding vector resize. Defaults to "bicubic".
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
stage. Defaults to an empty dict.
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to an empty dict.
pretrained_window_sizes (tuple(int)): Pretrained window sizes of
each layer.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmcls.models import SwinTransformerV2
>>> import torch
>>> extra_config = dict(
>>> arch='tiny',
>>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3,
>>> 'padding': 'same'}))
>>> self = SwinTransformerV2(**extra_config)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> output = self.forward(inputs)
>>> print(output.shape)
(1, 2592, 4)
"""
arch_zoo = {
**dict.fromkeys(['t', 'tiny'],
{'embed_dims': 96,
'depths': [2, 2, 6, 2],
'num_heads': [3, 6, 12, 24],
'extra_norm_every_n_blocks': 0}),
**dict.fromkeys(['s', 'small'],
{'embed_dims': 96,
'depths': [2, 2, 18, 2],
'num_heads': [3, 6, 12, 24],
'extra_norm_every_n_blocks': 0}),
**dict.fromkeys(['b', 'base'],
{'embed_dims': 128,
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
'extra_norm_every_n_blocks': 0}),
**dict.fromkeys(['l', 'large'],
{'embed_dims': 192,
'depths': [2, 2, 18, 2],
'num_heads': [6, 12, 24, 48],
'extra_norm_every_n_blocks': 0}),
# head count not certain for huge, and is employed for another
# parallel study about self-supervised learning.
**dict.fromkeys(['h', 'huge'],
{'embed_dims': 352,
'depths': [2, 2, 18, 2],
'num_heads': [8, 16, 32, 64],
'extra_norm_every_n_blocks': 6}),
**dict.fromkeys(['g', 'giant'],
{'embed_dims': 512,
'depths': [2, 2, 42, 4],
'num_heads': [16, 32, 64, 128],
'extra_norm_every_n_blocks': 6}),
} # yapf: disable
_version = 1
num_extra_tokens = 0
def __init__(self,
arch='tiny',
img_size=256,
patch_size=4,
in_channels=3,
vocabulary_size=128,
window_size=8,
drop_rate=0.,
drop_path_rate=0.1,
out_indices=(3, ),
use_abs_pos_embed=False,
interpolate_mode='bicubic',
with_cp=False,
frozen_stages=-1,
norm_eval=False,
pad_small_map=False,
norm_cfg=dict(type='LN'),
stage_cfgs=dict(downsample_cfg=dict(is_post_norm=True)),
patch_cfg=dict(),
pretrained_window_sizes=[0, 0, 0, 0],
init_cfg=None):
super(SwinTransformerV2, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims', 'depths', 'num_heads',
'extra_norm_every_n_blocks'
}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.vocabulary_size = vocabulary_size + 1 # 增加ignore类别
self.embed_dims = self.arch_settings['embed_dims']
self.depths = self.arch_settings['depths']
self.num_heads = self.arch_settings['num_heads']
self.extra_norm_every_n_blocks = self.arch_settings[
'extra_norm_every_n_blocks']
self.num_layers = len(self.depths)
self.out_indices = out_indices
self.use_abs_pos_embed = use_abs_pos_embed
self.interpolate_mode = interpolate_mode
self.frozen_stages = frozen_stages
if isinstance(window_size, int):
self.window_sizes = [window_size for _ in range(self.num_layers)]
elif isinstance(window_size, Sequence):
assert len(window_size) == self.num_layers, \
f'Length of window_sizes {len(window_size)} is not equal to '\
f'length of stages {self.num_layers}.'
self.window_sizes = window_size
else:
raise TypeError('window_size should be a Sequence or int.')
_patch_cfg = dict(
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
norm_cfg=dict(type='LN'),
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
self.patch_size = patch_size
if self.use_abs_pos_embed:
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, num_patches, self.embed_dims))
self._register_load_state_dict_pre_hook(
self._prepare_abs_pos_embed)
self._register_load_state_dict_pre_hook(self._delete_reinit_params)
self.drop_after_pos = nn.Dropout(p=drop_rate)
self.norm_eval = norm_eval
# stochastic depth
total_depth = sum(self.depths)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
self.stages = ModuleList()
embed_dims = [self.embed_dims]
for i, (depth, num_heads) in enumerate(zip(self.depths,
self.num_heads)):
if isinstance(stage_cfgs, Sequence):
stage_cfg = stage_cfgs[i]
else:
stage_cfg = deepcopy(stage_cfgs)
downsample = True if i > 0 else False
_stage_cfg = {
'embed_dims': embed_dims[-1],
'depth': depth,
'num_heads': num_heads,
'window_size': self.window_sizes[i],
'downsample': downsample,
'drop_paths': dpr[:depth],
'with_cp': with_cp,
'pad_small_map': pad_small_map,
'extra_norm_every_n_blocks': self.extra_norm_every_n_blocks,
'pretrained_window_size': pretrained_window_sizes[i],
**stage_cfg
}
stage = SwinBlockV2Sequence(**_stage_cfg)
self.stages.append(stage)
dpr = dpr[depth:]
embed_dims.append(stage.out_channels)
for i in out_indices:
if norm_cfg is not None:
norm_layer = build_norm_layer(norm_cfg, embed_dims[i + 1])[1]
else:
norm_layer = nn.Identity()
self.add_module(f'norm{i}', norm_layer)
def init_weights(self):
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
from mmcls.utils import get_root_logger
logger = get_root_logger()
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
# print(self.state_dict().keys())
# print('---')
# print(state_dict.keys())
# import pdb; pdb.set_trace()
load_state_dict(self, state_dict, strict=False, logger=logger)
return
else:
super(SwinTransformerV2, self).init_weights()
if self.use_abs_pos_embed:
trunc_normal_(self.absolute_pos_embed, std=0.02)
def forward(self, x):
x, hw_shape = self.patch_embed(x)
if self.use_abs_pos_embed:
x = x + resize_pos_embed(
self.absolute_pos_embed, self.patch_resolution, hw_shape,
self.interpolate_mode, self.num_extra_tokens)
x = self.drop_after_pos(x)
outs = []
for i, stage in enumerate(self.stages):
x, hw_shape = stage(x, hw_shape)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(x)
out = out.view(-1, *hw_shape,
stage.out_channels).permute(0, 3, 1,
2).contiguous()
outs.append(out)
return outs
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
for i in range(0, self.frozen_stages + 1):
m = self.stages[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
for i in self.out_indices:
if i <= self.frozen_stages:
for param in getattr(self, f'norm{i}').parameters():
param.requires_grad = False
def train(self, mode=True):
super(SwinTransformerV2, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'absolute_pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.absolute_pos_embed.shape != ckpt_pos_embed_shape:
from mmcls.utils import get_root_logger
logger = get_root_logger()
logger.info(
'Resize the absolute_pos_embed shape from '
f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.init_out_size
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
def _delete_reinit_params(self, state_dict, prefix, *args, **kwargs):
# delete relative_position_index since we always re-init it
relative_position_index_keys = [
k for k in state_dict.keys() if 'relative_position_index' in k
]
for k in relative_position_index_keys:
del state_dict[k]
# delete relative_coords_table since we always re-init it
relative_position_index_keys = [
k for k in state_dict.keys() if 'relative_coords_table' in k
]
for k in relative_position_index_keys:
del state_dict[k]
class Proj_MHSA(nn.Module):
def __init__(
self,
embed_dims,
proj_dims,
num_heads=16,
batch_first=True,
bias = True
):
super().__init__()
self.proj_in = nn.Linear(in_features=embed_dims, out_features=proj_dims)
self.attn = MultiheadAttention(
embed_dims=proj_dims,
num_heads=num_heads,
batch_first=batch_first,
bias=bias
)
self.proj_out = nn.Linear(in_features=proj_dims, out_features=embed_dims)
def forward(self, x):
x = self.proj_in(x)
x = self.attn(x, x, x)
x = self.proj_out(x)
return x
class SwinTransformerV2MSL(SwinTransformerV2):
def __init__(self, **kwargs):
if 'use_attn' in kwargs:
self.use_attn = kwargs.pop('use_attn')
else:
self.use_attn = False
if 'merge_stage' in kwargs:
self.merge_stage = kwargs.pop('merge_stage')
else:
self.merge_stage = 0
if 'with_cls_pos' in kwargs:
self.with_cls_pos = kwargs.pop('with_cls_pos')
else:
self.with_cls_pos = False
super().__init__(**kwargs)
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
#self.vocabulary_token = nn.Parameter(torch.zeros(1, 1, 1, self.vocabulary_size, self.embed_dims))
self.vocabulary_token = nn.Parameter(torch.zeros(self.vocabulary_size, self.embed_dims))
self.vocabulary_weight = nn.Parameter(torch.zeros(1, self.patch_size * self.patch_size))
trunc_normal_(self.mask_token, mean=0., std=.02)
trunc_normal_(self.vocabulary_token, mean=0., std=.02)
if self.use_attn:
self.attn1 = Proj_MHSA(embed_dims=352, proj_dims=256, num_heads=16, batch_first=True, bias = True)
self.attn2 = Proj_MHSA(embed_dims=704, proj_dims=512, num_heads=16, batch_first=True, bias = True)
self.attn3 = Proj_MHSA( embed_dims=1408, proj_dims=1024, num_heads=16, batch_first=True, bias = True)
self.attention_blocks = [self.attn1, self.attn2, self.attn3]
self.norm_attn = build_norm_layer(dict(type='LN'), 1408)[1]
def create_ann_token(self, anno_img):
B, H, W = anno_img.shape
ann_token = torch.index_select(self.vocabulary_token, 0, anno_img.reshape(-1)).reshape(B, H, W, -1)
assert H % self.patch_size == 0 and W % self.patch_size == 0
nph, npw = H // self.patch_size, W // self.patch_size
weight = F.softmax(self.vocabulary_weight, dim=1) * self.patch_size * self.patch_size
weight = weight.reshape(1, 1, self.patch_size, 1, self.patch_size).repeat(1, nph, 1, npw, 1).reshape(1, H, W, 1)
ann_token = ann_token * weight
ann_token = F.avg_pool2d(torch.einsum('BHWC->BCHW', ann_token), self.patch_size, self.patch_size)
ann_token = torch.einsum('BCHW->BHWC', ann_token).reshape(B, nph * npw, self.embed_dims) # shape B, L, C
return ann_token
def forward(self, hr_img, anno_img, mask=None):
x, hw_shape = self.patch_embed(hr_img)
y = self.create_ann_token(anno_img)
assert x.shape == y.shape
B, L, C = y.shape
if mask is not None:
mask_tokens = self.mask_token.expand(B, L, -1)
w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
y = y * (1. - w) + mask_tokens * w
if self.merge_stage == 0:
x = (x + y) * 0.5
else:
x = x.reshape(B, *hw_shape, C)
y = y.reshape(B, *hw_shape, C)
x = torch.cat((x, y), dim=2)
hw_shape = (hw_shape[0], hw_shape[1] * 2)
x = x.reshape(B, -1, C)
if self.use_abs_pos_embed:
x = x + resize_pos_embed(
self.absolute_pos_embed, self.patch_resolution, hw_shape,
self.interpolate_mode, self.num_extra_tokens)
if self.with_cls_pos:
hw_shape_half = [hw_shape[0], hw_shape[1] // 2]
x = x.reshape(B, *hw_shape, C)
x1 = x[:, :, :x.shape[2]//2, :].reshape(B, -1, C)
x2 = x[:, :, x.shape[2]//2:, :].reshape(B, -1, C)
x1 = x1 + resize_pos_embed(
self.absolute_pos_embed, self.patch_resolution, hw_shape_half,
self.interpolate_mode, self.num_extra_tokens)
x2 = x2 + resize_pos_embed(
self.absolute_pos_embed, self.patch_resolution, hw_shape_half,
self.interpolate_mode, self.num_extra_tokens)
x1 = x1.reshape(B, *hw_shape_half, C)
x2 = x2.reshape(B, *hw_shape_half, C)
x = torch.cat((x1, x2), dim=2).reshape(B, -1, C)
x = self.drop_after_pos(x)
outs = []
merge_idx = self.merge_stage - 1
for i, stage in enumerate(self.stages):
x, hw_shape = stage(x, hw_shape)
if i == merge_idx:
x = x.reshape(x.shape[0], *hw_shape, x.shape[-1]) # b,l,c -> b, h, w, c
x = (x[:, :, :x.shape[2]//2] + x[:, :, x.shape[2]//2:]) * 0.5
x = x.reshape(x.shape[0], -1, x.shape[-1])
hw_shape = (hw_shape[0], hw_shape[1] // 2)
if self.use_attn:
if i <= len(self.attention_blocks) - 1:
x = x + self.attention_blocks[i](x)
if i == len(self.attention_blocks) - 1:
x = self.norm_attn(x)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(x)
out = out.view(-1, *hw_shape, stage.out_channels).permute(0, 3, 1, 2).contiguous()
outs.append(out)
return outs

611
lib/models/backbones/vit.py Normal file
View File

@@ -0,0 +1,611 @@
# Copyright (c) Ant Group. All rights reserved.
import math
import warnings
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
trunc_normal_)
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
load_state_dict)
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.ops import resize
from mmseg.utils import get_root_logger
from mmseg.models.utils.embed import PatchEmbed
import torch.nn.functional as F
import numpy as np
class TransformerEncoderLayer(BaseModule):
"""Implements one encoder layer in Vision Transformer.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default: 0.0.
attn_drop_rate (float): The drop out rate for attention layer.
Default: 0.0.
drop_path_rate (float): stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): enable bias for qkv if True. Default: True
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default: True.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
batch_first=True,
attn_cfg=dict(),
ffn_cfg=dict(),
with_cp=False):
super(TransformerEncoderLayer, self).__init__()
self.norm1_name, norm1 = build_norm_layer(norm_cfg,
embed_dims,
postfix=1)
self.add_module(self.norm1_name, norm1)
attn_cfg.update(
dict(embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
batch_first=batch_first,
bias=qkv_bias))
self.build_attn(attn_cfg)
self.norm2_name, norm2 = build_norm_layer(norm_cfg,
embed_dims,
postfix=2)
self.add_module(self.norm2_name, norm2)
ffn_cfg.update(
dict(embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate)
if drop_path_rate > 0 else None,
act_cfg=act_cfg))
self.build_ffn(ffn_cfg)
self.with_cp = with_cp
def build_attn(self, attn_cfg):
self.attn = MultiheadAttention(**attn_cfg)
def build_ffn(self, ffn_cfg):
self.ffn = FFN(**ffn_cfg)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def forward(self, x):
def _inner_forward(x):
x = self.attn(self.norm1(x), identity=x)
x = self.ffn(self.norm2(x), identity=x)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class VisionTransformer(BaseModule):
"""Vision Transformer.
This backbone is the implementation of `An Image is Worth 16x16 Words:
Transformers for Image Recognition at
Scale <https://arxiv.org/abs/2010.11929>`_.
Args:
img_size (int | tuple): Input image size. Default: 224.
patch_size (int): The patch size. Default: 16.
in_channels (int): Number of input channels. Default: 3.
embed_dims (int): embedding dimension. Default: 768.
num_layers (int): depth of transformer. Default: 12.
num_heads (int): number of attention heads. Default: 12.
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
Default: 4.
out_indices (list | tuple | int): Output from which stages.
Default: -1.
qkv_bias (bool): enable bias for qkv if True. Default: True.
drop_rate (float): Probability of an element to be zeroed.
Default 0.0
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
drop_path_rate (float): stochastic depth rate. Default 0.0
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Default: True.
output_cls_token (bool): Whether output the cls_token. If set True,
`with_cls_token` must be True. Default: False.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
Default: False.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Default: False.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Default: bicubic.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
img_size=224,
patch_size=16,
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_indices=-1,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
with_cls_token=True,
output_cls_token=False,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
patch_norm=False,
final_norm=False,
interpolate_mode='bicubic',
num_fcs=2,
norm_eval=False,
with_cp=False,
use_ccd=False,
ccd_num=0,
pretrained=None,
init_cfg=None):
super(VisionTransformer, self).__init__(init_cfg=init_cfg)
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, tuple):
if len(img_size) == 1:
img_size = to_2tuple(img_size[0])
assert len(img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.img_size = img_size
self.patch_size = patch_size
self.interpolate_mode = interpolate_mode
self.norm_eval = norm_eval
self.with_cp = with_cp
self.pretrained = pretrained
self.embed_dims = embed_dims
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
padding='corner',
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None,
)
self.use_ccd = use_ccd
self.ccd_num = ccd_num
if self.use_ccd:
self.ccd_embed = nn.Parameter(
torch.rand(1, self.ccd_num, embed_dims))
num_patches = (img_size[0] // patch_size) * \
(img_size[1] // patch_size)
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
# self.pos_embed = nn.Parameter(
# torch.zeros(1, num_patches, embed_dims))
# 原来是
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dims))
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
if out_indices == -1:
out_indices = num_layers - 1
self.out_indices = [out_indices]
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
self.out_indices = out_indices
else:
raise TypeError('out_indices must be type of int, list or tuple')
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
] # stochastic depth decay rule
self.layers = ModuleList()
for i in range(num_layers):
self.layers.append(
TransformerEncoderLayer(embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=mlp_ratio *
embed_dims,
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
num_fcs=num_fcs,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
batch_first=True))
self.final_norm = final_norm
if final_norm:
self.norm1_name, norm1 = build_norm_layer(norm_cfg,
embed_dims,
postfix=1)
self.add_module(self.norm1_name, norm1)
@property
def norm1(self):
return getattr(self, self.norm1_name)
def init_weights(self):
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
logger = get_root_logger()
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
if 'pos_embed' in state_dict.keys():
if self.pos_embed.shape != state_dict['pos_embed'].shape:
logger.info(msg=f'Resize the pos_embed shape from '
f'{state_dict["pos_embed"].shape} to '
f'{self.pos_embed.shape}')
h, w = self.img_size
pos_size = int(
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
state_dict['pos_embed'] = self.resize_pos_embed(
state_dict['pos_embed'],
(h // self.patch_size, w // self.patch_size),
(pos_size, pos_size), self.interpolate_mode)
load_state_dict(self, state_dict, strict=False, logger=logger)
elif self.init_cfg is not None:
super(VisionTransformer, self).init_weights()
else:
# We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
if self.use_ccd:
trunc_normal_(self.ccd_embed, std=0.02)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
if 'ffn' in n:
nn.init.normal_(m.bias, mean=0., std=1e-6)
else:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
kaiming_init(m, mode='fan_in', bias=0.)
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m, val=1.0, bias=0.)
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
"""Positioning embeding method.
Resize the pos_embed, if the input image size doesn't match
the training size.
Args:
patched_img (torch.Tensor): The patched image, it should be
shape of [B, L1, C].
hw_shape (tuple): The downsampled image resolution.
pos_embed (torch.Tensor): The pos_embed weighs, it should be
shape of [B, L2, c].
Return:
torch.Tensor: The pos encoded image feature.
"""
assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
'the shapes of patched_img and pos_embed must be [B, L, C]'
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
if x_len != pos_len:
if pos_len == (self.img_size[0] // self.patch_size) * (
self.img_size[1] // self.patch_size) + 1:
pos_h = self.img_size[0] // self.patch_size
pos_w = self.img_size[1] // self.patch_size
else:
raise ValueError(
'Unexpected shape of pos_embed, got {}.'.format(
pos_embed.shape))
pos_embed = self.resize_pos_embed(pos_embed, hw_shape,
(pos_h, pos_w),
self.interpolate_mode)
return self.drop_after_pos(patched_img + pos_embed)
@staticmethod
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
"""Resize pos_embed weights.
Resize pos_embed using bicubic interpolate method.
Args:
pos_embed (torch.Tensor): Position embedding weights.
input_shpae (tuple): Tuple for (downsampled input image height,
downsampled input image width).
pos_shape (tuple): The resolution of downsampled origin training
image.
mode (str): Algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'``. Default: ``'nearest'``
Return:
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
"""
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
pos_h, pos_w = pos_shape
# keep dim for easy deployment
cls_token_weight = pos_embed[:, 0:1]
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
pos_embed_weight = pos_embed_weight.reshape(
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
pos_embed_weight = resize(pos_embed_weight,
size=input_shpae,
align_corners=False,
mode=mode)
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
return pos_embed
def forward(self, inputs, ccd_index=None):
B = inputs.shape[0]
x, hw_shape = self.patch_embed(inputs)
if self.use_ccd:
_ccd_idx = np.concatenate(ccd_index, axis=0)
_ccd_embed = self.ccd_embed[:, _ccd_idx, :].permute(1, 0, 2)
x = x + _ccd_embed
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = self._pos_embeding(x, hw_shape, self.pos_embed)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1:
if self.final_norm:
x = self.norm1(x)
if i in self.out_indices:
if self.with_cls_token:
# Remove class token and reshape token for decoder heads
out = x[:, 1:]
else:
out = x
B, _, C = out.shape
out = out.reshape(B, hw_shape[0], hw_shape[1],
C).permute(0, 3, 1, 2).contiguous()
if self.output_cls_token:
out = [out, x[:, 0]]
outs.append(out)
return tuple(outs)
def train(self, mode=True):
super(VisionTransformer, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, nn.LayerNorm):
m.eval()
class VisionTransformerMSL(VisionTransformer):
def __init__(self, **kwargs):
if 'use_attn' in kwargs:
self.use_attn = kwargs.pop('use_attn')
else:
self.use_attn = False
if 'merge_stage' in kwargs:
self.merge_stage = kwargs.pop('merge_stage')
else:
self.merge_stage = 0
if 'with_cls_pos' in kwargs:
self.with_cls_pos = kwargs.pop('with_cls_pos')
else:
self.with_cls_pos = False
self.vocabulary_size = kwargs.pop('vocabulary_size') + 1 # 增加ignore类别
super().__init__(**kwargs)
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
img_size = kwargs.pop('img_size')
patch_size = kwargs.pop('patch_size')
num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dims))
self.vocabulary_token = nn.Parameter(torch.zeros(self.vocabulary_size, self.embed_dims))
self.vocabulary_weight = nn.Parameter(torch.zeros(1, self.patch_size * self.patch_size))
trunc_normal_(self.mask_token, mean=0., std=.02)
trunc_normal_(self.vocabulary_token, mean=0., std=.02)
if self.use_attn:
self.attn1 = MultiheadAttention(embed_dims=1024, num_heads=16, batch_first=True, bias = True)
self.attn2 = MultiheadAttention(embed_dims=1024, num_heads=16, batch_first=True, bias = True)
self.attn3 = MultiheadAttention(embed_dims=1024, num_heads=16, batch_first=True, bias = True)
self.attention_blocks = [self.attn1, self.attn2, self.attn3]
self.norm_attn = build_norm_layer(dict(type='LN'), 1024)[1]
def create_ann_token(self, anno_img):
B, H, W = anno_img.shape
ann_token = torch.index_select(self.vocabulary_token, 0, anno_img.reshape(-1)).reshape(B, H, W, -1)
assert H % self.patch_size == 0 and W % self.patch_size == 0
nph, npw = H // self.patch_size, W // self.patch_size
weight = F.softmax(self.vocabulary_weight, dim=1) * self.patch_size * self.patch_size
weight = weight.reshape(1, 1, self.patch_size, 1, self.patch_size).repeat(1, nph, 1, npw, 1).reshape(1, H, W, 1)
ann_token = ann_token * weight
ann_token = F.avg_pool2d(torch.einsum('BHWC->BCHW', ann_token), self.patch_size, self.patch_size)
ann_token = torch.einsum('BCHW->BHWC', ann_token).reshape(B, nph * npw, self.embed_dims) # shape B, L, C
return ann_token
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
"""Positioning embeding method.
Resize the pos_embed, if the input image size doesn't match
the training size.
Args:
patched_img (torch.Tensor): The patched image, it should be
shape of [B, L1, C].
hw_shape (tuple): The downsampled image resolution.
pos_embed (torch.Tensor): The pos_embed weighs, it should be
shape of [B, L2, c].
Return:
torch.Tensor: The pos encoded image feature.
"""
assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
'the shapes of patched_img and pos_embed must be [B, L, C]'
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
if x_len != pos_len:
if pos_len == (self.img_size[0] // self.patch_size) * (
self.img_size[1] // self.patch_size):
pos_h = self.img_size[0] // self.patch_size
pos_w = self.img_size[1] // self.patch_size
else:
raise ValueError(
'Unexpected shape of pos_embed, got {}.'.format(
pos_embed.shape))
pos_embed = self.resize_pos_embed(pos_embed, hw_shape,
(pos_h, pos_w),
self.interpolate_mode)
return self.drop_after_pos(patched_img + pos_embed)
@staticmethod
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
"""Resize pos_embed weights.
Resize pos_embed using bicubic interpolate method.
Args:
pos_embed (torch.Tensor): Position embedding weights.
input_shpae (tuple): Tuple for (downsampled input image height,
downsampled input image width).
pos_shape (tuple): The resolution of downsampled origin training
image.
mode (str): Algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'``. Default: ``'nearest'``
Return:
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
"""
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
pos_h, pos_w = pos_shape
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
pos_embed_weight = pos_embed_weight.reshape(
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
pos_embed_weight = resize(pos_embed_weight,
size=input_shpae,
align_corners=False,
mode=mode)
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
pos_embed = pos_embed_weight # torch.cat((cls_token_weight, pos_embed_weight), dim=1)
return pos_embed
def forward(self, x, y, mask=None):
x, hw_shape = self.patch_embed(x)
y = self.create_ann_token(y)
assert x.shape == y.shape
B, L, C = y.shape
if mask is not None:
mask_tokens = self.mask_token.expand(B, L, -1)
w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
y = y * (1. - w) + mask_tokens * w
if self.merge_stage == 0:
x = (x + y) * 0.5
else:
x = x.reshape(B, *hw_shape, C)
y = y.reshape(B, *hw_shape, C)
x = torch.cat((x, y), dim=2)
hw_shape = (hw_shape[0], hw_shape[1] * 2)
x = x.reshape(B, -1, C)
x = self._pos_embeding(x, hw_shape, self.pos_embed)
merge_idx = self.merge_stage - 1
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == merge_idx:
x = x.reshape(x.shape[0], *hw_shape, x.shape[-1]) # b,l,c -> b, h, w, c
x = (x[:, :, :x.shape[2]//2] + x[:, :, x.shape[2]//2:]) * 0.5
x = x.reshape(x.shape[0], -1, x.shape[-1])
hw_shape = (hw_shape[0], hw_shape[1] // 2)
if self.use_attn:
if i <= len(self.attention_blocks) - 1:
x = x + self.attention_blocks[i](x)
if i == len(self.attention_blocks) - 1:
x = self.norm_attn(x) # 会不会有冲突
if (not self.use_attn) and (i == len(self.layers) - 1):
if self.final_norm:
x = self.norm1(x) # 会不会有冲突
if i in self.out_indices:
if self.with_cls_token:
# Remove class token and reshape token for decoder heads
out = x[:, 1:]
else:
out = x
B, _, C = out.shape
out = out.reshape(B, hw_shape[0], hw_shape[1],
C).permute(0, 3, 1, 2).contiguous()
if self.output_cls_token:
out = [out, x[:, 0]]
outs.append(out)
return tuple(outs)

View File

@@ -0,0 +1,15 @@
from .uper_head import UPerHead
from .up_head import UPHead
__all__ = [
'UPerHead', 'UPHead'
]
type_mapping = {
'UPerHead': UPerHead,
'UPHead': UPHead
}
def build_head(type, **kwargs):
return type_mapping[type](**kwargs)

View File

@@ -0,0 +1,201 @@
import warnings
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
from mmcv.runner import BaseModule, auto_fp16
from mmseg.core import build_pixel_sampler
from mmseg.ops import resize
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
"""Base class for BaseDecodeHead.
Args:
in_channels (int|Sequence[int]): Input channels.
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
out_channels (int): Output channels of conv_seg.
threshold (float): Threshold for binary segmentation in the case of
`out_channels==1`. Default: None.
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
conv_cfg (dict|None): Config of conv layers. Default: None.
norm_cfg (dict|None): Config of norm layers. Default: None.
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU')
in_index (int|Sequence[int]): Input feature index. Default: -1
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
Default: None.
loss_decode (dict | Sequence[dict]): Config of decode loss.
The `loss_name` is property of corresponding loss function which
could be shown in training log. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
e.g. dict(type='CrossEntropyLoss'),
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='DiceLoss', loss_name='loss_dice')]
Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255.
sampler (dict|None): The config of segmentation map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
in_channels,
channels,
*,
num_classes,
out_channels=None,
threshold=None,
dropout_ratio=0.1,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
in_index=-1,
input_transform=None,
sampler=None,
align_corners=False,
init_cfg=dict(type='Normal',
std=0.01,
override=dict(name='conv_seg'))):
super(BaseDecodeHead, self).__init__(init_cfg)
self._init_inputs(in_channels, in_index, input_transform)
self.channels = channels
self.dropout_ratio = dropout_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.in_index = in_index
self.align_corners = align_corners
if out_channels is None:
if num_classes == 2:
warnings.warn('For binary segmentation, we suggest using'
'`out_channels = 1` to define the output'
'channels of segmentor, and use `threshold`'
'to convert seg_logist into a prediction'
'applying a threshold')
out_channels = num_classes
if out_channels != num_classes and out_channels != 1:
raise ValueError(
'out_channels should be equal to num_classes,'
'except binary segmentation set out_channels == 1 and'
f'num_classes == 2, but got out_channels={out_channels}'
f'and num_classes={num_classes}')
if out_channels == 1 and threshold is None:
threshold = 0.3
warnings.warn('threshold is not defined for binary, and defaults'
'to 0.3')
self.num_classes = num_classes
self.out_channels = out_channels
self.threshold = threshold
if sampler is not None:
self.sampler = build_pixel_sampler(sampler, context=self)
else:
self.sampler = None
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
if dropout_ratio > 0:
self.dropout = nn.Dropout2d(dropout_ratio)
else:
self.dropout = None
self.fp16_enabled = False
def extra_repr(self):
"""Extra repr."""
s = f'input_transform={self.input_transform}, ' \
f'align_corners={self.align_corners}'
return s
def _init_inputs(self, in_channels, in_index, input_transform):
"""Check and initialize input transforms.
The in_channels, in_index and input_transform must match.
Specifically, when input_transform is None, only single feature map
will be selected. So in_channels and in_index must be of type int.
When input_transform
Args:
in_channels (int|Sequence[int]): Input channels.
in_index (int|Sequence[int]): Input feature index.
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
"""
if input_transform is not None:
assert input_transform in ['resize_concat', 'multiple_select']
self.input_transform = input_transform
self.in_index = in_index
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(in_index, (list, tuple))
assert len(in_channels) == len(in_index)
if input_transform == 'resize_concat':
self.in_channels = sum(in_channels)
else:
self.in_channels = in_channels
else:
assert isinstance(in_channels, int)
assert isinstance(in_index, int)
self.in_channels = in_channels
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
@auto_fp16()
@abstractmethod
def forward(self, inputs):
"""Placeholder of forward function."""
pass
def cls_seg(self, feat):
"""Classify each pixel."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.conv_seg(feat)
return output

View File

@@ -0,0 +1,60 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
class PPM(nn.ModuleList):
"""Pooling Pyramid Module used in PSPNet.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
align_corners (bool): align_corners argument of F.interpolate.
"""
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
act_cfg, align_corners, **kwargs):
super(PPM, self).__init__()
self.pool_scales = pool_scales
self.align_corners = align_corners
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
for pool_scale in pool_scales:
self.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
**kwargs)))
def forward(self, x):
"""Forward function."""
ppm_outs = []
for ppm in self:
ppm_out = ppm(x)
ppm_out = ppm_out.to(torch.float32)
upsampled_ppm_out = resize(
ppm_out,
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
upsampled_ppm_out = upsampled_ppm_out.to(torch.bfloat16)
ppm_outs.append(upsampled_ppm_out)
return ppm_outs

View File

@@ -0,0 +1,52 @@
import torch.nn as nn
from collections import OrderedDict
from mmcv.cnn.utils.weight_init import (kaiming_init, trunc_normal_)
from mmcv.runner import (CheckpointLoader, load_state_dict)
from mmseg.utils import get_root_logger
class UPHead(nn.Module):
def __init__(self, in_dim, out_dim, up_scale, init_cfg=None):
super().__init__()
self.decoder = nn.Sequential(
nn.Conv2d(in_channels=in_dim,
out_channels=up_scale**2 * out_dim,
kernel_size=1),
nn.PixelShuffle(up_scale),
)
self.init_cfg = init_cfg
self.apply(self._init_weights)
def _init_weights(self, m):
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
logger = get_root_logger()
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
_state_dict = checkpoint['state_dict']
else:
_state_dict = checkpoint
state_dict = OrderedDict()
for k, v in _state_dict.items():
if k.startswith('backbone.'):
state_dict[k[9:]] = v
else:
state_dict[k] = v
print(f'loading weight: {self.init_cfg["checkpoint"]}')
load_state_dict(self, state_dict, strict=False, logger=logger)
else:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
kaiming_init(m, mode='fan_in', bias=0.)
def forward(self, x):
x = self.decoder(x)
return x

View File

@@ -0,0 +1,130 @@
# coding: utf-8
# Copyright (c) Ant Group. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from .psp_head import PPM
class UPerHead(BaseDecodeHead):
"""Unified Perceptual Parsing for Scene Understanding.
This head is the implementation of `UPerNet
<https://arxiv.org/abs/1807.10221>`_.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module applied on the last feature. Default: (1, 2, 3, 6).
"""
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super(UPerHead, self).__init__(
input_transform='multiple_select', **kwargs)
# PSP Module
self.psp_modules = PPM(
pool_scales,
self.in_channels[-1],
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.bottleneck = ConvModule(
self.in_channels[-1] + len(pool_scales) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
# FPN Module
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
for in_channels in self.in_channels[:-1]: # skip the top layer
l_conv = ConvModule(
in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False)
fpn_conv = ConvModule(
self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
self.fpn_bottleneck = ConvModule(
len(self.in_channels) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def psp_forward(self, inputs):
"""Forward function of PSP module."""
x = inputs[-1]
psp_outs = [x]
# breakpoint()
psp_outs.extend(self.psp_modules(x))
psp_outs = torch.cat(psp_outs, dim=1)
output = self.bottleneck(psp_outs)
return output
def forward(self, inputs):
"""Forward function."""
# breakpoint()
inputs = self._transform_inputs(inputs)
# build laterals
laterals = [
lateral_conv(inputs[i])
for i, lateral_conv in enumerate(self.lateral_convs)
]
laterals.append(self.psp_forward(inputs))
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
prev_shape = laterals[i - 1].shape[2:]
laterals[i] = laterals[i].type(torch.float32)
laterals[i - 1] = laterals[i - 1] + resize(
laterals[i],
size=prev_shape,
mode='bilinear',
align_corners=self.align_corners)
# build outputs
fpn_outs = [
self.fpn_convs[i](laterals[i])
for i in range(used_backbone_levels - 1)
]
# append psp feature
fpn_outs.append(laterals[-1])
for i in range(used_backbone_levels - 1, 0, -1):
fpn_outs[i] = fpn_outs[i].type(torch.float32)
fpn_outs[i] = resize(
fpn_outs[i],
size=fpn_outs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners)
fpn_outs = torch.cat(fpn_outs, dim=1)
output = self.fpn_bottleneck(fpn_outs)
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,4 @@
from .modality_vae_loss import ModalityVAELoss
from .recon_anno_loss import RecLoss
__all__ = [ "ModalityVAELoss", "RecLoss" ]

View File

@@ -0,0 +1,46 @@
# Copyright (c) Ant Group and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F
from antmmf.common.registry import registry
@registry.register_loss("ModalityVAELoss")
class ModalityVAELoss(nn.Module):
def __init__(self, **params):
super().__init__()
self.weight = params.pop("weight")
def compute_rec_loss(self, x_in, x_out, modal_flag):
loss_per_pixel = F.mse_loss(x_in, x_out, reduction='none')
loss_b = torch.mean(loss_per_pixel, dim=[1, 2, 3])
return torch.sum(loss_b * modal_flag)/ (modal_flag.sum() + 1e-6)
def forward(self, sample_list, output, *args, **kwargs):
vae_out = output["vae_out"]
feat_hr = vae_out['input_hr']
feat_s2 = vae_out['input_s2']
feat_s1 = vae_out['input_s1']
g_hr = vae_out['g_hr']
g_s2 = vae_out['g_s2']
g_s1 = vae_out['g_s1']
# process modality flags
modality_info = vae_out['modality_info']
B_M, L_M = modality_info.shape
modality_hr = modality_info[:,0]
modality_s2 = modality_info[:,1]
modality_s1 = modality_info[:,2]
######## rec losses ########
loss_xent = self.compute_rec_loss(g_hr, feat_hr, modality_hr) \
+ self.compute_rec_loss(g_s2, feat_s2, modality_s2) \
+ self.compute_rec_loss(g_s1, feat_s1, modality_s1)
loss_quant = vae_out["loss_quant"]
total_loss = loss_xent / 3 + loss_quant
return total_loss * self.weight

View File

@@ -0,0 +1,89 @@
# Copyright (c) Ant Group and its affiliates.
import torch
import torch.nn as nn
from antmmf.common.registry import registry
import torch.nn.functional as F
@registry.register_loss("RecLoss")
class RecLoss(nn.Module):
def __init__(self, **params):
super().__init__()
self.weight = params.pop("weight")
self.patch_size = params.pop("patch_size")
self.eps = torch.finfo(torch.bfloat16).eps
self.pred_key = params.pop("pred_key")
self.vocabulary_size = params.pop("vocabulary_size") + 1
self.mask_key = params.pop("mask_key")
self.target_key = params.pop("target_key")
self.feature_merged = params.pop("feature_merged")
self.cnt_train = 0
self.cnt_val = 0
self.use_bg = params.pop("use_bg")
if "use_all_patch" in params:
self.use_all_patch = params.pop("use_all_patch")
else:
self.use_all_patch = False
if "balance" in params:
self.balance = params.pop("balance")
else:
self.balance = False
if "sim_regularization" in params:
self.sim_regularization = params.pop("sim_regularization")
else:
self.sim_regularization = False
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.patch_size
w = int((x.shape[1]*0.5)**.5)
h = w * 2
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p))
x = torch.einsum('nhwpq->nhpwq', x)
imgs = x.reshape(shape=(x.shape[0], h * p, w * p))
return imgs
def forward(self, sample_list, output, *args, **kwargs):
pred = output[self.pred_key] # B, C, H, W
target = output[self.target_key] # B, H, W
mask = output[self.mask_key]
b_mask, h_mask, w_mask = mask.shape
mask = mask.reshape((b_mask, h_mask*w_mask))
mask = mask[:, :, None].repeat(1, 1, self.patch_size**2)
mask = self.unpatchify(mask)
if not self.use_bg:
valid = sample_list['valid']
mask = mask * valid
loss = F.cross_entropy(pred, target, reduction="none")
if self.balance:
if self.use_all_patch:
loss_pos = loss[target > 0].sum() / ((target > 0).sum() + 1e-6)
loss_neg = loss[target == 0].sum() / ((target == 0).sum() + 1e-6)
loss = (loss_pos + loss_neg) * 0.5
else:
loss_pos = loss[(target > 0) & (mask == 1)].sum() / (((target > 0) & (mask == 1)).sum() + 1e-6)
loss_neg = loss[(target == 0) & (mask == 1)].sum() / (((target == 0) & (mask == 1)).sum() + 1e-6)
loss = (loss_pos + loss_neg) * 0.5
else:
if self.use_all_patch:
loss = loss.mean()
else:
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
if self.sim_regularization:
vocabulary_token = output['vocabulary_token']
voca_normed = F.normalize(vocabulary_token, 2, 1)
similarity_matrix = 1 + torch.einsum('nd,md->nm', voca_normed, voca_normed)
num = voca_normed.shape[0]
index = torch.triu(voca_normed.new_ones(num, num), diagonal=1).type(torch.bool)
loss_reg = similarity_matrix[index].mean()
return loss * self.weight + loss_reg * 0.05
return loss * self.weight

View File

@@ -0,0 +1,4 @@
from .sem_metrics import SemMetric
__all__ = ["SemMetric"]

View File

@@ -0,0 +1,93 @@
# coding: utf-8
# Copyright (c) Ant Group. All rights reserved.
import torch
from torch.distributed import all_reduce, ReduceOp
from antmmf.common.registry import registry
from antmmf.modules.metrics.base_metric import BaseMetric
@registry.register_metric("sem_metric")
class SemMetric(BaseMetric):
"""Segmentation metrics used in evaluation phase.
Args:
name (str): Name of the metric.
eval_type(str): 3 types are supported: 'mIoU', 'mDice', 'mFscore'
result_field(str): key of predicted results in output dict
target_field(str): key of ground truth in output dict
ignore_index(int): class value will be ignored in evaluation
num_cls(int): total number of categories in evaluation
"""
def __init__(self,
name="dummy_metric", **kwargs
):
super().__init__(name)
self.reset()
def calculate(self, sample_list, model_output, *args, **kwargs):
"""Calculate Intersection and Union for a batch.
Args:
sample_list (Sample_List): data which contains ground truth segmentation maps
model_output (dict): data which contains prediction segmentation maps
Returns:
torch.Tensor: The intersection of prediction and ground truth histogram
on all classes.
torch.Tensor: The union of prediction and ground truth histogram on all
classes.
torch.Tensor: The prediction histogram on all classes.
torch.Tensor: The ground truth histogram on all classes.
"""
return torch.tensor(0).float()
def reset(self):
""" initialized all attributes value before evaluation
"""
self.total_mask_mae = 0
self.total_num = torch.tensor(0)
def collect(self, sample_list, model_output, *args, **kwargs):
"""
Args:
sample_list(Sample_List): data which contains ground truth segmentation maps
model_output (Dict): Dict returned by model, that contains two modalities
Returns:
torch.FloatTensor: Accuracy
"""
batch_mask_mae = \
self.calculate(sample_list, model_output, *args, **kwargs)
self.total_mask_mae += batch_mask_mae
self.total_num += 1
def format(self, *args):
""" Format evaluated metrics for profile.
Returns:
dict: dict of all evaluated metrics.
"""
output_metric = dict()
# if self.eval_type == 'mae':
mae = args[0]
output_metric['mae'] = mae.item()
return output_metric
def summarize(self, *args, **kwargs):
"""This method is used to calculate the overall metric.
Returns:
dict: dict of all evaluated metrics.
"""
# if self.eval_type == 'mae':
mae = self.total_mask_mae / (self.total_num)
return self.format(mae)
def all_reduce(self):
total_number = torch.stack([
self.total_mask_mae, self.total_num
]).cuda()
all_reduce(total_number, op=ReduceOp.SUM)
self.total_mask_mae = total_number[0].cpu()
self.total_num = total_number[1].cpu()

View File

@@ -0,0 +1,13 @@
from .transformer_encoder import TransformerEncoder
from .modality_completion import ModalityCompletion
__all__ = ['TransformerEncoder', 'ModalityCompletion']
type_mapping = {
'TransformerEncoder': TransformerEncoder,
'ModalityCompletion': ModalityCompletion
}
def build_neck(type, **kwargs):
return type_mapping[type](**kwargs)

View File

@@ -0,0 +1,212 @@
# Copyright (c) AntGroup. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
class BFloat16UpsampleNearest2d(nn.Module):
def __init__(self, scale_factor, mode='bilinear'):
super().__init__()
self.scale_factor = scale_factor
self.mode = mode
def forward(self, x):
x_float = x.float()
upsampled_x = F.interpolate(x_float, scale_factor=self.scale_factor, mode=self.mode)
return upsampled_x.to(x.dtype)
class ConvVQVAEv2(nn.Module):
def __init__(self, input_shape, conv_dim, z_dim, num_tokens=8192, temp=0.9):
super().__init__()
self.z_dim = z_dim
self.conv_dim = conv_dim # 256
self.input_shape = input_shape # 256
self.temp = temp
# code book
self.codebook = nn.Embedding(num_tokens, z_dim)
# encoder
self.relu = nn.LeakyReLU()
self.pool = nn.AvgPool2d(2)
self.conv1 = nn.Conv2d(input_shape[0], conv_dim, 5, stride=1, padding=2)
self.enc_block1 = nn.Sequential(
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
nn.LeakyReLU(),
)
self.gamma_1 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
self.enc_block2 = nn.Sequential(
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
nn.LeakyReLU(),
)
self.gamma_2 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
self.logit_conv = nn.Conv2d(conv_dim, num_tokens, 1)
# decoder
self.unpool = BFloat16UpsampleNearest2d(scale_factor=2)
self.conv2 = nn.Conv2d(z_dim, conv_dim, 3, stride=1, padding=1)
self.dec_block1 = nn.Sequential(
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
nn.LeakyReLU(),
)
self.gamma_3 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
self.dec_block2 = nn.Sequential(
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
nn.LeakyReLU(),
)
self.gamma_4 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
self.rec_conv = nn.Conv2d(conv_dim, input_shape[0], 3, stride=1, padding=1)
def forward_encoder(self, x):
x = self.relu(self.conv1(x))
x = x + self.gamma_1 * self.enc_block1(x)
x = self.pool(x)
x = x + self.gamma_2 * self.enc_block2(x)
x = self.pool(x)
logits = self.logit_conv(x)
return logits
def forward_decoder(self, logits):
soft_one_hot = F.softmax(logits * (self.temp*10), dim=1)
sampled = torch.einsum('bnhw,nd->bdhw', soft_one_hot, self.codebook.weight)
x = self.relu(self.conv2(sampled))
x = self.unpool(x)
x = x + self.gamma_3 * self.dec_block1(x)
x = self.unpool(x)
x = x + self.gamma_4 * self.dec_block2(x)
rec_feats = self.rec_conv(x)
return rec_feats, soft_one_hot
def forward(self, x):
print(x.shape)
logits = self.forward_encoder(x)
images_p, soft_one_hot = self.forward_decoder(logits)
return [logits, images_p]
class ModalityCompletion(nn.Module):
def __init__(self,
input_shape_hr=(2816, 16, 16),
input_shape_s2=(2816, 16, 16),
input_shape_s1=(2816, 16, 16),
conv_dim=256,
z_dim=256,
n_codebook=8192,
init_cfg=None
):
super(ModalityCompletion, self).__init__()
self.vae_hr = ConvVQVAEv2(input_shape=input_shape_hr, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook)
self.vae_s2 = ConvVQVAEv2(input_shape=input_shape_s2, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook)
self.vae_s1 = ConvVQVAEv2(input_shape=input_shape_s1, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook)
self.kl_div_loss = torch.nn.KLDivLoss(reduction="none", log_target=True)
self.init_cfg=init_cfg
def init_weights(self):
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
from mmcls.utils import get_root_logger
from mmcv.runner import CheckpointLoader, load_state_dict
logger = get_root_logger()
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
load_state_dict(self, state_dict, strict=False, logger=logger)
else:
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def kl_loss(self, logits_hr, logits_s2, logits_s1, modality_info):
prob_hr = F.log_softmax(logits_hr, dim=1)
prob_s2 = F.log_softmax(logits_s2, dim=1)
prob_s1 = F.log_softmax(logits_s1, dim=1)
flag_hr = modality_info[:,0][:, None, None, None]
flag_s2 = modality_info[:,1][:, None, None, None]
flag_s1 = modality_info[:,2][:, None, None, None]
loss_hr_s2 = self.kl_div_loss(prob_hr, prob_s2) + self.kl_div_loss(prob_s2, prob_hr)
loss_hr_s2 = (loss_hr_s2 * flag_hr * flag_s2).sum((1, 2, 3)).mean()
loss_hr_s1 = self.kl_div_loss(prob_hr, prob_s1) + self.kl_div_loss(prob_s1, prob_hr)
loss_hr_s1 = (loss_hr_s1 * flag_hr * flag_s1).sum((1, 2, 3)).mean()
loss_s2_s1 = self.kl_div_loss(prob_s2, prob_s1) + self.kl_div_loss(prob_s1, prob_s2)
loss_s2_s1 = (loss_s2_s1 * flag_s2 * flag_s1).sum((1, 2, 3)).mean()
loss = (loss_hr_s2 + loss_hr_s1 + loss_s2_s1) / 6.0
return loss
def forward(self, feat_hr, feat_s2, feat_s1, modality_info):
# encodersadd noise
# each modality
# 2816, 16, 16 => conv 256, 4, 4 => flatten 4096(256*4*4) => linear mu 256, log_var 256
B, C, H, W = feat_hr.shape
B_M, L_M = modality_info.shape
assert B == B_M, f'feat_hr batch: {B}, modality_info batch: {B_M}'
# quant, emb_loss, info
# hr input flow
logits_hr = self.vae_hr.forward_encoder(feat_hr)
logits_s2 = self.vae_s2.forward_encoder(feat_s2)
logits_s1 = self.vae_s1.forward_encoder(feat_s1)
modality_hr = modality_info[:,0]
modality_s2 = modality_info[:,1]
modality_s1 = modality_info[:,2]
flag_hr = modality_hr[:, None, None, None] # B => B, C, H, W
flag_s2 = modality_s2[:, None, None, None]
flag_s1 = modality_s1[:, None, None, None]
mean_logits_hr_s2 = logits_hr * flag_hr + logits_s2 * flag_s2
mean_logits_hr_s1 = logits_hr * flag_hr + logits_s1 * flag_s1
mean_logits_s1_s2 = logits_s1 * flag_s1 + logits_s2 * flag_s2
logits_hr_rec = logits_hr * flag_hr + mean_logits_s1_s2 * (~flag_hr)
logits_s2_rec = logits_s2 * flag_s2 + mean_logits_hr_s1 * (~flag_s2)
logits_s1_rec = logits_s1 * flag_s1 + mean_logits_hr_s2 * (~flag_s1)
g_hr, soft_one_hot_hr = self.vae_hr.forward_decoder(logits_hr_rec)
g_s2, soft_one_s2 = self.vae_s2.forward_decoder(logits_s2_rec)
g_s1, soft_one_s1 = self.vae_s1.forward_decoder(logits_s1_rec)
hr_out = feat_hr * flag_hr + g_hr * (~flag_hr)
s2_out = feat_s2 * flag_s2 + g_s2 * (~flag_s2)
s1_out = feat_s1 * flag_s1 + g_s1 * (~flag_s1)
output = {}
output['hr_out'] = hr_out
output['s2_out'] = s2_out
output['s1_out'] = s1_out
output['modality_info'] = modality_info
output['input_hr'] = feat_hr
output['input_s2'] = feat_s2
output['input_s1'] = feat_s1
output['logits_hr'] = logits_hr
output['logits_s2'] = logits_s2
output['logits_s1'] = logits_s1
output['soft_one_hot_hr'] = soft_one_hot_hr
output['soft_one_hot_s2'] = soft_one_s2
output['soft_one_hot_s1'] = soft_one_s1
output['g_hr'] = g_hr
output['g_s2'] = g_s2
output['g_s1'] = g_s1
output['loss_quant'] = self.kl_loss(logits_hr, logits_s2, logits_s1, modality_info)
return output

View File

@@ -0,0 +1,144 @@
# Copyright (c) Ant Group. All rights reserved.
from collections import OrderedDict
import torch
import torch.nn as nn
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
trunc_normal_)
from mmcv.runner import (CheckpointLoader, load_state_dict)
from torch.nn.modules.batchnorm import _BatchNorm
from mmseg.models.backbones.vit import TransformerEncoderLayer
from mmseg.utils import get_root_logger
class TransformerEncoder(nn.Module):
def __init__(self,
input_dims=768,
embed_dims=768,
num_layers=4,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
with_cls_token=True,
output_cls_token=False,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
num_fcs=2,
norm_eval=False,
with_cp=False,
init_cfg=None,
*args,
**kwargs):
super(TransformerEncoder, self).__init__()
self.porj_linear = nn.Linear(input_dims, embed_dims)
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
self.init_cfg = init_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self.drop_after_pos = nn.Dropout(p=drop_rate)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
] # stochastic depth decay rule
self.layers = nn.ModuleList()
for i in range(num_layers):
self.layers.append(
TransformerEncoderLayer(embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=mlp_ratio *
embed_dims,
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
num_fcs=num_fcs,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
batch_first=True))
def init_weights(self):
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
logger = get_root_logger()
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
_state_dict = checkpoint['state_dict']
else:
_state_dict = checkpoint
state_dict = OrderedDict()
for k, v in _state_dict.items():
if k.startswith('backbone.'):
state_dict[k[9:]] = v
else:
state_dict[k] = v
load_state_dict(self, state_dict, strict=False, logger=logger)
else:
# We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
trunc_normal_(self.cls_token, std=.02)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
if 'ffn' in n:
nn.init.normal_(m.bias, mean=0., std=1e-6)
else:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
kaiming_init(m, mode='fan_in', bias=0.)
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m, val=1.0, bias=0.)
def forward(self, inputs, require_feat: bool = False, require_two: bool = False):
inputs = self.porj_linear(inputs)
B, N, C = inputs.shape
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, inputs), dim=1)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
# add hidden and atten state
block_outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if require_feat:
block_outs.append(x)
if self.output_cls_token:
if require_two:
x = x[:, :2]
else:
x = x[:, 0]
elif not self.output_cls_token and self.with_cls_token:
x = x # [:, :]
if require_feat:
return x, block_outs
else:
return x
def train(self, mode=True):
super(TransformerEncoder, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, nn.LayerNorm):
m.eval()

View File

@@ -0,0 +1,4 @@
# Copyright (c) Ant Financial Service Group and its affiliates.
from .skysense_pp_pipeline import SkySensePP
__all__ = ['SkySensePP']

View File

@@ -0,0 +1,458 @@
# coding: utf-8
# Copyright (c) Ant Group. All rights reserved.
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import math
import random
from antmmf.common.registry import registry
from antmmf.models.base_model import BaseModel
from lib.models.backbones import build_backbone
from lib.models.necks import build_neck
from lib.models.heads import build_head
from lib.utils.utils import LayerDecayValueAssigner
@registry.register_model("SkySensePP")
class SkySensePP(BaseModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.sources = config.sources
assert len(self.sources) > 0, 'at least one data source is required'
if 's2' in self.sources:
self.use_ctpe = config.use_ctpe
self.use_modal_vae = config.use_modal_vae
self.use_cls_token_uper_head = config.use_cls_token_uper_head
self.target_mean=[0.485, 0.456, 0.406]
self.target_std=[0.229, 0.224, 0.225]
self.vocabulary_size = config.vocabulary_size
self.vocabulary = list(range(1, config.vocabulary_size + 1)) # 0 for ignore
def build(self):
if 'hr' in self.sources:
self.backbone_hr = self._build_backbone('hr')
if 's2' in self.sources:
self.backbone_s2 = self._build_backbone('s2')
if self.use_ctpe:
self.ctpe = nn.Parameter(
torch.zeros(1, self.config.calendar_time,
self.config.necks.input_dims))
if 'head_s2' in self.config.keys():
self.head_s2 = self._build_head('head_s2')
self.fusion = self._build_neck('necks')
if 's1' in self.sources:
self.backbone_s1 = self._build_backbone('s1')
if 'head_s1' in self.config.keys():
self.head_s1 = self._build_head('head_s1')
self.head_rec_hr = self._build_head('rec_head_hr')
self.with_aux_head = False
if self.use_modal_vae:
self.modality_vae = self._build_neck('modality_vae')
if 'auxiliary_head' in self.config.keys():
self.with_aux_head = True
self.aux_head = self._build_head('auxiliary_head')
if 'init_cfg' in self.config.keys(
) and self.config.init_cfg is not None and self.config.init_cfg.checkpoint is not None and self.config.init_cfg.key is not None:
self.load_pretrained(self.config.init_cfg.checkpoint,
self.config.init_cfg.key)
def _build_backbone(self, key):
config_dict = self.config[f'backbone_{key}'].to_dict()
backbone_type = config_dict.pop('type')
backbone = build_backbone(backbone_type, **config_dict)
backbone.init_weights()
return backbone
def _build_neck(self, key):
config_dict = self.config[key].to_dict()
neck_type = config_dict.pop('type')
neck = build_neck(neck_type, **config_dict)
neck.init_weights()
return neck
def _build_head(self, key):
head_config = self.config[key].to_dict()
head_type = head_config.pop('type')
head = build_head(head_type, **head_config)
return head
def get_optimizer_parameters(self, config):
optimizer_grouped_parameters = [
{
"params": [],
"lr": config.optimizer_attributes.params.lr,
"weight_decay": config.optimizer_attributes.params.weight_decay,
},
{
"params": [],
"lr": config.optimizer_attributes.params.lr,
"weight_decay": 0.0,
},
]
layer_decay_value_assigner_hr = LayerDecayValueAssigner(
config.lr_parameters.layer_decay, None,
config.optimizer_attributes.params.lr, 'swin',
config.model_attributes.SkySensePP.backbone_hr.arch
)
layer_decay_value_assigner_s2 = LayerDecayValueAssigner(
config.lr_parameters.layer_decay, 24,
config.optimizer_attributes.params.lr, 'vit',
)
layer_decay_value_assigner_s1 = LayerDecayValueAssigner(
config.lr_parameters.layer_decay, 24,
config.optimizer_attributes.params.lr, 'vit',
)
layer_decay_value_assigner_fusion = LayerDecayValueAssigner(
config.lr_parameters.layer_decay, 24,
config.optimizer_attributes.params.lr, 'vit',
)
num_frozen_params = 0
if 'hr' in self.sources:
print('hr'.center(60, '-'))
num_frozen_params += layer_decay_value_assigner_hr.fix_param(
self.backbone_hr,
config.lr_parameters.frozen_blocks,
)
optimizer_grouped_parameters.extend(
layer_decay_value_assigner_hr.get_parameter_groups(
self.backbone_hr, config.optimizer_attributes.params.weight_decay
)
)
if 's2' in self.sources:
print('s2'.center(60, '-'))
num_frozen_params += layer_decay_value_assigner_s2.fix_param(
self.backbone_s2,
config.lr_parameters.frozen_blocks,
)
optimizer_grouped_parameters.extend(
layer_decay_value_assigner_s2.get_parameter_groups(
self.backbone_s2, config.optimizer_attributes.params.weight_decay
)
)
no_decay = [".bn.", "bias"]
optimizer_grouped_parameters[0]["params"] += [
p for n, p in self.head_s2.named_parameters()
if not any(nd in n for nd in no_decay)
]
optimizer_grouped_parameters[1]["params"] += [
p for n, p in self.head_s2.named_parameters()
if any(nd in n for nd in no_decay)
]
if self.use_ctpe:
optimizer_grouped_parameters[1]["params"] += [self.ctpe]
if 's1' in self.sources:
print('s1'.center(60, '-'))
num_frozen_params += layer_decay_value_assigner_s1.fix_param(
self.backbone_s1,
config.lr_parameters.frozen_blocks,
)
optimizer_grouped_parameters.extend(
layer_decay_value_assigner_s1.get_parameter_groups(
self.backbone_s1, config.optimizer_attributes.params.weight_decay
)
)
no_decay = [".bn.", "bias"]
optimizer_grouped_parameters[0]["params"] += [
p for n, p in self.head_s1.named_parameters()
if not any(nd in n for nd in no_decay)
]
optimizer_grouped_parameters[1]["params"] += [
p for n, p in self.head_s1.named_parameters()
if any(nd in n for nd in no_decay)
]
if len(self.sources) > 1:
print('fusion'.center(60, '-'))
num_frozen_params += layer_decay_value_assigner_fusion.fix_param_deeper(
self.fusion,
config.lr_parameters.frozen_fusion_blocks_start, # 冻结后面所有的stage
)
optimizer_grouped_parameters.extend(
layer_decay_value_assigner_fusion.get_parameter_groups(
self.fusion, config.optimizer_attributes.params.weight_decay
)
)
if self.use_modal_vae:
no_decay = [".bn.", "bias"]
optimizer_grouped_parameters[0]["params"] += [
p for n, p in self.modality_vae.named_parameters()
if not any(nd in n for nd in no_decay)
]
optimizer_grouped_parameters[1]["params"] += [
p for n, p in self.modality_vae.named_parameters()
if any(nd in n for nd in no_decay)
]
no_decay = [".bn.", "bias"]
optimizer_grouped_parameters[0]["params"] += [
p for n, p in self.head_rec_hr.named_parameters()
if not any(nd in n for nd in no_decay)
]
optimizer_grouped_parameters[1]["params"] += [
p for n, p in self.head_rec_hr.named_parameters()
if any(nd in n for nd in no_decay)
]
if self.with_aux_head:
no_decay = [".bn.", "bias"]
optimizer_grouped_parameters[0]["params"] += [
p for n, p in self.aux_head.named_parameters()
if not any(nd in n for nd in no_decay)
]
optimizer_grouped_parameters[1]["params"] += [
p for n, p in self.aux_head.named_parameters()
if any(nd in n for nd in no_decay)
]
num_params = [len(x['params']) for x in optimizer_grouped_parameters]
print(len(list(self.parameters())), sum(num_params), num_frozen_params)
assert len(list(self.parameters())) == sum(num_params) + num_frozen_params
return optimizer_grouped_parameters
def get_custom_scheduler(self, trainer):
optimizer = trainer.optimizer
num_training_steps = trainer.config.training_parameters.max_iterations
num_warmup_steps = trainer.config.training_parameters.num_warmup_steps
if "train" in trainer.run_type:
if num_training_steps == math.inf:
epoches = trainer.config.training_parameters.max_epochs
assert epoches != math.inf
num_training_steps = trainer.config.training_parameters.max_epochs * trainer.epoch_iterations
def linear_with_wram_up(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(
1, num_warmup_steps))
return max(
0.0,
float(num_training_steps - current_step) /
float(max(1, num_training_steps - num_warmup_steps)),
)
def cos_with_wram_up(current_step):
num_cycles = 0.5
if current_step < num_warmup_steps:
return float(current_step) / float(max(
1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps))
return max(
0.0,
0.5 *
(1.0 +
math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
)
lr_lambda = cos_with_wram_up if trainer.config.training_parameters.cos_lr else linear_with_wram_up
else:
def lr_lambda(current_step):
return 0.0 # noqa
return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, -1)
def convert_target(self, target):
mean = target.new_tensor(self.target_mean).reshape(1, 3, 1, 1)
std = target.new_tensor(self.target_std).reshape(1, 3, 1, 1)
target = ((target * std + mean)*255).to(torch.long)
target[:, 0] = target[:, 0] * 256 * 256
target[:, 1] = target[:, 1] * 256
target = target.sum(1).type(torch.long)
unique_target = target.unique()
target_index = torch.searchsorted(unique_target, target)
no_bg = False
if unique_target[0].item() > 0:
target_index += 1
no_bg = True
target_index_unique = target_index.unique().tolist()
random.shuffle(self.vocabulary)
value = target.new_tensor([0] + self.vocabulary)
mapped_target = target_index.clone()
idx_2_color = {}
for v in target_index_unique:
mapped_target[target_index == v] = value[v]
idx_2_color[value[v].item()] = unique_target[v - 1 if no_bg else v].item()
return mapped_target, idx_2_color
def forward(self, sample_list):
output = dict()
modality_flag_hr = sample_list["modality_flag_hr"]
modality_flag_s2 = sample_list["modality_flag_s2"]
modality_flag_s1 = sample_list["modality_flag_s1"]
modalities = [modality_flag_hr, modality_flag_s2, modality_flag_s1]
modalities = torch.tensor(modalities).permute(1,0).contiguous() # L, B => B, L
anno_img = sample_list["targets"]
anno_img, idx_2_color = self.convert_target(anno_img)
output["mapped_targets"] = anno_img
output["idx_2_color"] = idx_2_color
anno_mask = sample_list["anno_mask"]
anno_s2 = anno_img[:, 15::32, 15::32]
anno_s1 = anno_s2
output["anno_hr"] = anno_img
output["anno_s2"] = anno_s2
### 1. backbone
if 'hr' in self.sources:
hr_img = sample_list["hr_img"]
B_MASK, H_MASK, W_MASK = anno_mask.shape
block_size = 32
anno_mask_hr = anno_mask.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, block_size, block_size)
anno_mask_hr = anno_mask_hr.permute(0, 1, 3, 2, 4).reshape(B_MASK, H_MASK*block_size, W_MASK*block_size).contiguous()
B, C_G, H_G, W_G = hr_img.shape
hr_features = self.backbone_hr(hr_img, anno_img, anno_mask_hr)
output['mask_hr'] = anno_mask_hr
output['target_hr'] = anno_img
if 's2' in self.sources:
s2_img = sample_list["s2_img"]
B, C_S2, S_S2, H_S2, W_S2 = s2_img.shape
s2_img = s2_img.permute(0, 2, 1, 3,
4).reshape(B * S_S2, C_S2, H_S2, W_S2).contiguous() # ts time to batch
anno_mask_s2 = anno_mask
s2_features = self.backbone_s2(s2_img, anno_s2, anno_mask_s2)
if 'head_s2' in self.config.keys():
s2_features = self.head_s2(s2_features[-1])
s2_features = [s2_features]
if 's1' in self.sources:
s1_img = sample_list["s1_img"]
B, C_S1, S_S1, H_S1, W_S1 = s1_img.shape
s1_img = s1_img.permute(0, 2, 1, 3,
4).reshape(B * S_S1, C_S1, H_S1, W_S1).contiguous()
anno_mask_s1 = anno_mask
s1_features = self.backbone_s1(s1_img, anno_s1, anno_mask_s1)
if 'head_s1' in self.config.keys():
s1_features = self.head_s1(s1_features[-1])
s1_features = [s1_features]
### 2. prepare features for fusion
hr_features_stage3 = hr_features[-1]
s2_features_stage3 = s2_features[-1]
s1_features_stage3 = s1_features[-1]
modalities = modalities.to(hr_features_stage3.device)
if self.use_modal_vae:
vae_out = self.modality_vae(hr_features_stage3, s2_features_stage3, s1_features_stage3, modalities)
hr_features_stage3 = vae_out['hr_out']
s2_features_stage3 = vae_out['s2_out']
s1_features_stage3 = vae_out['s1_out']
output['vae_out'] = vae_out
features_stage3 = []
if 'hr' in self.sources:
B, C3_G, H3_G, W3_G = hr_features_stage3.shape
hr_features_stage3 = hr_features_stage3.permute(
0, 2, 3, 1).reshape(B * H3_G * W3_G, C3_G).unsqueeze(1).contiguous() # B * H3_G * W3_G, 1, C3_G
features_stage3 = hr_features_stage3
if 's2' in self.sources:
# s2_features_stage3 = s2_features[-1]
_, C3_S2, H3_S2, W3_S2 = s2_features_stage3.shape
s2_features_stage3 = s2_features_stage3.reshape(
B, S_S2, C3_S2, H3_S2,
W3_S2).permute(0, 3, 4, 1, 2).reshape(B, H3_S2 * W3_S2, S_S2,
C3_S2).contiguous()
if self.use_ctpe:
ct_index = sample_list["s2_ct"]
ctpe = self.ctpe[:, ct_index, :].contiguous().permute(1, 0, 2, 3).contiguous()
ctpe = ctpe.expand(-1, 256, -1, -1)
ct_index_2 = sample_list["s2_ct2"]
ctpe2 = self.ctpe[:, ct_index_2, :].contiguous().permute(1, 0, 2, 3).contiguous()
ctpe2 = ctpe2.expand(-1, 256, -1, -1)
ctpe_comb = torch.cat([ctpe, ctpe2], 1)
# import pdb;pdb.set_trace()
s2_features_stage3 = (s2_features_stage3 + ctpe_comb).reshape(
B * H3_S2 * W3_S2, S_S2, C3_S2).contiguous()
else:
s2_features_stage3 = s2_features_stage3.reshape(
B * H3_S2 * W3_S2, S_S2, C3_S2).contiguous()
if len(features_stage3) > 0:
assert H3_G == H3_S2 and W3_G == W3_S2 and C3_G == C3_S2
features_stage3 = torch.cat((features_stage3, s2_features_stage3), dim=1)
else:
features_stage3 = s2_features_stage3
if 's1' in self.sources:
# s1_features_stage3 = s1_features[-1]
_, C3_S1, H3_S1, W3_S1 = s1_features_stage3.shape
s1_features_stage3 = s1_features_stage3.reshape(
B, S_S1, C3_S1, H3_S1,
W3_S1).permute(0, 3, 4, 1, 2).reshape(B, H3_S1 * W3_S1, S_S1,
C3_S1).contiguous()
s1_features_stage3 = s1_features_stage3.reshape(
B * H3_S1 * W3_S1, S_S1, C3_S1).contiguous()
if len(features_stage3) > 0:
assert H3_S1 == H3_S2 and W3_S1 == W3_S2 and C3_S1 == C3_S2
features_stage3 = torch.cat((features_stage3, s1_features_stage3),
dim=1)
else:
features_stage3 = s1_features_stage3
### 3. fusion
if self.config.necks.output_cls_token:
if self.config.necks.get('require_feat', False):
cls_token, block_outs = self.fusion(features_stage3 , True)
else:
cls_token = self.fusion(features_stage3)
_, C3_G = cls_token.shape
cls_token = cls_token.reshape(B, H3_G, W3_G,
C3_G).contiguous().permute(0, 3, 1, 2).contiguous() # b, c, h, w
else:
assert self.config.necks.with_cls_token is False
if self.config.necks.get('require_feat', False):
features_stage3, block_outs = self.fusion(features_stage3, True)
else:
features_stage3 = self.fusion(features_stage3)
features_stage3 = features_stage3.reshape(
B, H3_S2, W3_S2, S_S2,
C3_S2).permute(0, 3, 4, 1, 2).reshape(B * S_S2, C3_S2, H3_S2,
W3_S2).contiguous()
### 4. decoder for rec
hr_rec_inputs = hr_features
feat_stage1 = hr_rec_inputs[0]
if feat_stage1.shape[-1] == feat_stage1.shape[-2]:
feat_stage1_left, feat_stage1_right = torch.split(feat_stage1, feat_stage1.shape[-1] // 2, dim=-1)
feat_stage1 = torch.cat((feat_stage1_left, feat_stage1_right), dim=1)
hr_rec_inputs = list(hr_features)
hr_rec_inputs[0] = feat_stage1
rec_feats = [*hr_rec_inputs, cls_token]
logits_hr = self.head_rec_hr(rec_feats)
if self.config.get('upsacle_results', True):
logits_hr = logits_hr.to(torch.float32)
logits_hr = F.interpolate(logits_hr, scale_factor=4, mode='bilinear', align_corners=True)
output["logits_hr"] = logits_hr
return output
def load_pretrained(self, ckpt_path, key):
pretrained_dict = torch.load(ckpt_path, map_location={'cuda:0': 'cpu'})
pretrained_dict = pretrained_dict[key]
for k, v in pretrained_dict.items():
if k == 'backbone_s2.patch_embed.projection.weight':
pretrained_in_channels = v.shape[1]
if self.config.backbone_s2.in_channels == 4:
new_weight = v[:, [0, 1, 2, 6]]
new_weight = new_weight * (
pretrained_in_channels /
self.config.backbone_s2.in_channels)
pretrained_dict[k] = new_weight
missing_keys, unexpected_keys = self.load_state_dict(pretrained_dict,
strict=False)
print('missing_keys:', missing_keys)
print('unexpected_keys:', unexpected_keys)