init
This commit is contained in:
7
lib/models/__init__.py
Normal file
7
lib/models/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .segmentors import SkySensePP
|
||||
from .losses import (ModalityVAELoss, RecLoss)
|
||||
from .metrics import (SemMetric)
|
||||
|
||||
__all__ = [
|
||||
'SkySensePP', 'ModalityVAELoss', 'RecLoss', 'SemMetric'
|
||||
]
|
||||
14
lib/models/backbones/__init__.py
Normal file
14
lib/models/backbones/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from .swin_v2 import SwinTransformerV2MSL
|
||||
from .vit import VisionTransformerMSL
|
||||
|
||||
__all__ = [
|
||||
'SwinTransformerV2MSL', 'VisionTransformerMSL'
|
||||
]
|
||||
|
||||
type_mapping = {
|
||||
'SwinTransformerV2MSL': SwinTransformerV2MSL,
|
||||
'VisionTransformerMSL': VisionTransformerMSL
|
||||
}
|
||||
|
||||
def build_backbone(type, **kwargs):
|
||||
return type_mapping[type](**kwargs)
|
||||
702
lib/models/backbones/swin_v2.py
Normal file
702
lib/models/backbones/swin_v2.py
Normal file
@@ -0,0 +1,702 @@
|
||||
from copy import deepcopy
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.utils import (PatchMerging, ShiftWindowMSA, WindowMSAV2,
|
||||
resize_pos_embed, to_2tuple)
|
||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||
from mmcv.runner import (CheckpointLoader,
|
||||
load_state_dict)
|
||||
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
||||
|
||||
class SwinBlockV2(BaseModule):
|
||||
"""Swin Transformer V2 block. Use post normalization.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): The height and width of the window. Defaults to 7.
|
||||
shift (bool): Shift the attention window or not. Defaults to False.
|
||||
extra_norm (bool): Whether add extra norm at the end of main branch.
|
||||
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
||||
layer channels. Defaults to 4.
|
||||
drop_path (float): The drop path rate after attention and ffn.
|
||||
Defaults to 0.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
attn_cfgs (dict): The extra config of Shift Window-MSA.
|
||||
Defaults to empty dict.
|
||||
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
|
||||
norm_cfg (dict): The config of norm layers.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
pretrained_window_size (int): Window size in pretrained.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
window_size=8,
|
||||
shift=False,
|
||||
extra_norm=False,
|
||||
ffn_ratio=4.,
|
||||
drop_path=0.,
|
||||
pad_small_map=False,
|
||||
attn_cfgs=dict(),
|
||||
ffn_cfgs=dict(),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
pretrained_window_size=0,
|
||||
init_cfg=None):
|
||||
|
||||
super(SwinBlockV2, self).__init__(init_cfg)
|
||||
self.with_cp = with_cp
|
||||
self.extra_norm = extra_norm
|
||||
|
||||
_attn_cfgs = {
|
||||
'embed_dims': embed_dims,
|
||||
'num_heads': num_heads,
|
||||
'shift_size': window_size // 2 if shift else 0,
|
||||
'window_size': window_size,
|
||||
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
|
||||
'pad_small_map': pad_small_map,
|
||||
**attn_cfgs
|
||||
}
|
||||
# use V2 attention implementation
|
||||
_attn_cfgs.update(
|
||||
window_msa=WindowMSAV2,
|
||||
msa_cfg=dict(
|
||||
pretrained_window_size=to_2tuple(pretrained_window_size)))
|
||||
self.attn = ShiftWindowMSA(**_attn_cfgs)
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
_ffn_cfgs = {
|
||||
'embed_dims': embed_dims,
|
||||
'feedforward_channels': int(embed_dims * ffn_ratio),
|
||||
'num_fcs': 2,
|
||||
'ffn_drop': 0,
|
||||
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
|
||||
'act_cfg': dict(type='GELU'),
|
||||
'add_identity': False,
|
||||
**ffn_cfgs
|
||||
}
|
||||
self.ffn = FFN(**_ffn_cfgs)
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
# add extra norm for every n blocks in huge and giant model
|
||||
if self.extra_norm:
|
||||
self.norm3 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
|
||||
def _inner_forward(x):
|
||||
# Use post normalization
|
||||
identity = x
|
||||
x = self.attn(x, hw_shape)
|
||||
x = self.norm1(x)
|
||||
x = x + identity
|
||||
|
||||
identity = x
|
||||
x = self.ffn(x)
|
||||
x = self.norm2(x)
|
||||
x = x + identity
|
||||
|
||||
if self.extra_norm:
|
||||
x = self.norm3(x)
|
||||
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SwinBlockV2Sequence(BaseModule):
|
||||
"""Module with successive Swin Transformer blocks and downsample layer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
depth (int): Number of successive swin transformer blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): The height and width of the window. Defaults to 7.
|
||||
downsample (bool): Downsample the output of blocks by patch merging.
|
||||
Defaults to False.
|
||||
downsample_cfg (dict): The extra config of the patch merging layer.
|
||||
Defaults to empty dict.
|
||||
drop_paths (Sequence[float] | float): The drop path rate in each block.
|
||||
Defaults to 0.
|
||||
block_cfgs (Sequence[dict] | dict): The extra config of each block.
|
||||
Defaults to empty dicts.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
extra_norm_every_n_blocks (int): Add extra norm at the end of main
|
||||
branch every n blocks. Defaults to 0, which means no needs for
|
||||
extra norm layer.
|
||||
pretrained_window_size (int): Window size in pretrained.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size=8,
|
||||
downsample=False,
|
||||
downsample_cfg=dict(),
|
||||
drop_paths=0.,
|
||||
block_cfgs=dict(),
|
||||
with_cp=False,
|
||||
pad_small_map=False,
|
||||
extra_norm_every_n_blocks=0,
|
||||
pretrained_window_size=0,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
if not isinstance(drop_paths, Sequence):
|
||||
drop_paths = [drop_paths] * depth
|
||||
|
||||
if not isinstance(block_cfgs, Sequence):
|
||||
block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)]
|
||||
|
||||
if downsample:
|
||||
self.out_channels = 2 * embed_dims
|
||||
_downsample_cfg = {
|
||||
'in_channels': embed_dims,
|
||||
'out_channels': self.out_channels,
|
||||
'norm_cfg': dict(type='LN'),
|
||||
**downsample_cfg
|
||||
}
|
||||
self.downsample = PatchMerging(**_downsample_cfg)
|
||||
else:
|
||||
self.out_channels = embed_dims
|
||||
self.downsample = None
|
||||
|
||||
self.blocks = ModuleList()
|
||||
for i in range(depth):
|
||||
extra_norm = True if extra_norm_every_n_blocks and \
|
||||
(i + 1) % extra_norm_every_n_blocks == 0 else False
|
||||
_block_cfg = {
|
||||
'embed_dims': self.out_channels,
|
||||
'num_heads': num_heads,
|
||||
'window_size': window_size,
|
||||
'shift': False if i % 2 == 0 else True,
|
||||
'extra_norm': extra_norm,
|
||||
'drop_path': drop_paths[i],
|
||||
'with_cp': with_cp,
|
||||
'pad_small_map': pad_small_map,
|
||||
'pretrained_window_size': pretrained_window_size,
|
||||
**block_cfgs[i]
|
||||
}
|
||||
block = SwinBlockV2(**_block_cfg)
|
||||
self.blocks.append(block)
|
||||
|
||||
def forward(self, x, in_shape):
|
||||
if self.downsample:
|
||||
x, out_shape = self.downsample(x, in_shape)
|
||||
else:
|
||||
out_shape = in_shape
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, out_shape)
|
||||
|
||||
return x, out_shape
|
||||
|
||||
|
||||
class SwinTransformerV2(BaseBackbone):
|
||||
"""Swin Transformer V2.
|
||||
|
||||
A PyTorch implement of : `Swin Transformer V2:
|
||||
Scaling Up Capacity and Resolution
|
||||
<https://arxiv.org/abs/2111.09883>`_
|
||||
|
||||
Inspiration from
|
||||
https://github.com/microsoft/Swin-Transformer
|
||||
|
||||
Args:
|
||||
arch (str | dict): Swin Transformer architecture. If use string, choose
|
||||
from 'tiny', 'small', 'base' and 'large'. If use dict, it should
|
||||
have below keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **depths** (List[int]): The number of blocks in each stage.
|
||||
- **num_heads** (List[int]): The number of heads in attention
|
||||
modules of each stage.
|
||||
- **extra_norm_every_n_blocks** (int): Add extra norm at the end
|
||||
of main branch every n blocks.
|
||||
|
||||
Defaults to 'tiny'.
|
||||
img_size (int | tuple): The expected input image shape. Because we
|
||||
support dynamic input shape, just set the argument to the most
|
||||
common input image shape. Defaults to 224.
|
||||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 4.
|
||||
in_channels (int): The num of input channels. Defaults to 3.
|
||||
window_size (int | Sequence): The height and width of the window.
|
||||
Defaults to 7.
|
||||
drop_rate (float): Dropout rate after embedding. Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
|
||||
use_abs_pos_embed (bool): If True, add absolute position embedding to
|
||||
the patch embedding. Defaults to False.
|
||||
interpolate_mode (str): Select the interpolate mode for absolute
|
||||
position embeding vector resize. Defaults to "bicubic".
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Defaults to False.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
norm_cfg (dict): Config dict for normalization layer for all output
|
||||
features. Defaults to ``dict(type='LN')``
|
||||
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
|
||||
stage. Defaults to an empty dict.
|
||||
patch_cfg (dict): Extra config dict for patch embedding.
|
||||
Defaults to an empty dict.
|
||||
pretrained_window_sizes (tuple(int)): Pretrained window sizes of
|
||||
each layer.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> from mmcls.models import SwinTransformerV2
|
||||
>>> import torch
|
||||
>>> extra_config = dict(
|
||||
>>> arch='tiny',
|
||||
>>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3,
|
||||
>>> 'padding': 'same'}))
|
||||
>>> self = SwinTransformerV2(**extra_config)
|
||||
>>> inputs = torch.rand(1, 3, 224, 224)
|
||||
>>> output = self.forward(inputs)
|
||||
>>> print(output.shape)
|
||||
(1, 2592, 4)
|
||||
"""
|
||||
arch_zoo = {
|
||||
**dict.fromkeys(['t', 'tiny'],
|
||||
{'embed_dims': 96,
|
||||
'depths': [2, 2, 6, 2],
|
||||
'num_heads': [3, 6, 12, 24],
|
||||
'extra_norm_every_n_blocks': 0}),
|
||||
**dict.fromkeys(['s', 'small'],
|
||||
{'embed_dims': 96,
|
||||
'depths': [2, 2, 18, 2],
|
||||
'num_heads': [3, 6, 12, 24],
|
||||
'extra_norm_every_n_blocks': 0}),
|
||||
**dict.fromkeys(['b', 'base'],
|
||||
{'embed_dims': 128,
|
||||
'depths': [2, 2, 18, 2],
|
||||
'num_heads': [4, 8, 16, 32],
|
||||
'extra_norm_every_n_blocks': 0}),
|
||||
**dict.fromkeys(['l', 'large'],
|
||||
{'embed_dims': 192,
|
||||
'depths': [2, 2, 18, 2],
|
||||
'num_heads': [6, 12, 24, 48],
|
||||
'extra_norm_every_n_blocks': 0}),
|
||||
# head count not certain for huge, and is employed for another
|
||||
# parallel study about self-supervised learning.
|
||||
**dict.fromkeys(['h', 'huge'],
|
||||
{'embed_dims': 352,
|
||||
'depths': [2, 2, 18, 2],
|
||||
'num_heads': [8, 16, 32, 64],
|
||||
'extra_norm_every_n_blocks': 6}),
|
||||
**dict.fromkeys(['g', 'giant'],
|
||||
{'embed_dims': 512,
|
||||
'depths': [2, 2, 42, 4],
|
||||
'num_heads': [16, 32, 64, 128],
|
||||
'extra_norm_every_n_blocks': 6}),
|
||||
} # yapf: disable
|
||||
|
||||
_version = 1
|
||||
num_extra_tokens = 0
|
||||
|
||||
def __init__(self,
|
||||
arch='tiny',
|
||||
img_size=256,
|
||||
patch_size=4,
|
||||
in_channels=3,
|
||||
vocabulary_size=128,
|
||||
window_size=8,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.1,
|
||||
out_indices=(3, ),
|
||||
use_abs_pos_embed=False,
|
||||
interpolate_mode='bicubic',
|
||||
with_cp=False,
|
||||
frozen_stages=-1,
|
||||
norm_eval=False,
|
||||
pad_small_map=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
stage_cfgs=dict(downsample_cfg=dict(is_post_norm=True)),
|
||||
patch_cfg=dict(),
|
||||
pretrained_window_sizes=[0, 0, 0, 0],
|
||||
init_cfg=None):
|
||||
super(SwinTransformerV2, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(arch, str):
|
||||
arch = arch.lower()
|
||||
assert arch in set(self.arch_zoo), \
|
||||
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
|
||||
self.arch_settings = self.arch_zoo[arch]
|
||||
else:
|
||||
essential_keys = {
|
||||
'embed_dims', 'depths', 'num_heads',
|
||||
'extra_norm_every_n_blocks'
|
||||
}
|
||||
assert isinstance(arch, dict) and set(arch) == essential_keys, \
|
||||
f'Custom arch needs a dict with keys {essential_keys}'
|
||||
self.arch_settings = arch
|
||||
|
||||
self.vocabulary_size = vocabulary_size + 1 # 增加ignore类别
|
||||
self.embed_dims = self.arch_settings['embed_dims']
|
||||
self.depths = self.arch_settings['depths']
|
||||
self.num_heads = self.arch_settings['num_heads']
|
||||
self.extra_norm_every_n_blocks = self.arch_settings[
|
||||
'extra_norm_every_n_blocks']
|
||||
self.num_layers = len(self.depths)
|
||||
self.out_indices = out_indices
|
||||
self.use_abs_pos_embed = use_abs_pos_embed
|
||||
self.interpolate_mode = interpolate_mode
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
if isinstance(window_size, int):
|
||||
self.window_sizes = [window_size for _ in range(self.num_layers)]
|
||||
elif isinstance(window_size, Sequence):
|
||||
assert len(window_size) == self.num_layers, \
|
||||
f'Length of window_sizes {len(window_size)} is not equal to '\
|
||||
f'length of stages {self.num_layers}.'
|
||||
self.window_sizes = window_size
|
||||
else:
|
||||
raise TypeError('window_size should be a Sequence or int.')
|
||||
|
||||
_patch_cfg = dict(
|
||||
in_channels=in_channels,
|
||||
input_size=img_size,
|
||||
embed_dims=self.embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
norm_cfg=dict(type='LN'),
|
||||
)
|
||||
_patch_cfg.update(patch_cfg)
|
||||
self.patch_embed = PatchEmbed(**_patch_cfg)
|
||||
self.patch_resolution = self.patch_embed.init_out_size
|
||||
self.patch_size = patch_size
|
||||
|
||||
if self.use_abs_pos_embed:
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
self.absolute_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches, self.embed_dims))
|
||||
self._register_load_state_dict_pre_hook(
|
||||
self._prepare_abs_pos_embed)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self._delete_reinit_params)
|
||||
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
self.norm_eval = norm_eval
|
||||
|
||||
# stochastic depth
|
||||
total_depth = sum(self.depths)
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
self.stages = ModuleList()
|
||||
embed_dims = [self.embed_dims]
|
||||
for i, (depth, num_heads) in enumerate(zip(self.depths,
|
||||
self.num_heads)):
|
||||
if isinstance(stage_cfgs, Sequence):
|
||||
stage_cfg = stage_cfgs[i]
|
||||
else:
|
||||
stage_cfg = deepcopy(stage_cfgs)
|
||||
downsample = True if i > 0 else False
|
||||
_stage_cfg = {
|
||||
'embed_dims': embed_dims[-1],
|
||||
'depth': depth,
|
||||
'num_heads': num_heads,
|
||||
'window_size': self.window_sizes[i],
|
||||
'downsample': downsample,
|
||||
'drop_paths': dpr[:depth],
|
||||
'with_cp': with_cp,
|
||||
'pad_small_map': pad_small_map,
|
||||
'extra_norm_every_n_blocks': self.extra_norm_every_n_blocks,
|
||||
'pretrained_window_size': pretrained_window_sizes[i],
|
||||
**stage_cfg
|
||||
}
|
||||
|
||||
stage = SwinBlockV2Sequence(**_stage_cfg)
|
||||
self.stages.append(stage)
|
||||
|
||||
dpr = dpr[depth:]
|
||||
embed_dims.append(stage.out_channels)
|
||||
|
||||
for i in out_indices:
|
||||
if norm_cfg is not None:
|
||||
norm_layer = build_norm_layer(norm_cfg, embed_dims[i + 1])[1]
|
||||
else:
|
||||
norm_layer = nn.Identity()
|
||||
|
||||
self.add_module(f'norm{i}', norm_layer)
|
||||
|
||||
def init_weights(self):
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
# Suppress default init if use pretrained model.
|
||||
from mmcls.utils import get_root_logger
|
||||
logger = get_root_logger()
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
# print(self.state_dict().keys())
|
||||
# print('---')
|
||||
# print(state_dict.keys())
|
||||
# import pdb; pdb.set_trace()
|
||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||
return
|
||||
else:
|
||||
super(SwinTransformerV2, self).init_weights()
|
||||
if self.use_abs_pos_embed:
|
||||
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
x, hw_shape = self.patch_embed(x)
|
||||
|
||||
if self.use_abs_pos_embed:
|
||||
x = x + resize_pos_embed(
|
||||
self.absolute_pos_embed, self.patch_resolution, hw_shape,
|
||||
self.interpolate_mode, self.num_extra_tokens)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
outs = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
x, hw_shape = stage(x, hw_shape)
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
out = norm_layer(x)
|
||||
out = out.view(-1, *hw_shape,
|
||||
stage.out_channels).permute(0, 3, 1,
|
||||
2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return outs
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
for i in range(0, self.frozen_stages + 1):
|
||||
m = self.stages[i]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
for i in self.out_indices:
|
||||
if i <= self.frozen_stages:
|
||||
for param in getattr(self, f'norm{i}').parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super(SwinTransformerV2, self).train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs):
|
||||
name = prefix + 'absolute_pos_embed'
|
||||
if name not in state_dict.keys():
|
||||
return
|
||||
|
||||
ckpt_pos_embed_shape = state_dict[name].shape
|
||||
if self.absolute_pos_embed.shape != ckpt_pos_embed_shape:
|
||||
from mmcls.utils import get_root_logger
|
||||
logger = get_root_logger()
|
||||
logger.info(
|
||||
'Resize the absolute_pos_embed shape from '
|
||||
f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.')
|
||||
|
||||
ckpt_pos_embed_shape = to_2tuple(
|
||||
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
|
||||
pos_embed_shape = self.patch_embed.init_out_size
|
||||
|
||||
state_dict[name] = resize_pos_embed(state_dict[name],
|
||||
ckpt_pos_embed_shape,
|
||||
pos_embed_shape,
|
||||
self.interpolate_mode,
|
||||
self.num_extra_tokens)
|
||||
|
||||
def _delete_reinit_params(self, state_dict, prefix, *args, **kwargs):
|
||||
# delete relative_position_index since we always re-init it
|
||||
relative_position_index_keys = [
|
||||
k for k in state_dict.keys() if 'relative_position_index' in k
|
||||
]
|
||||
for k in relative_position_index_keys:
|
||||
del state_dict[k]
|
||||
|
||||
# delete relative_coords_table since we always re-init it
|
||||
relative_position_index_keys = [
|
||||
k for k in state_dict.keys() if 'relative_coords_table' in k
|
||||
]
|
||||
for k in relative_position_index_keys:
|
||||
del state_dict[k]
|
||||
|
||||
class Proj_MHSA(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims,
|
||||
proj_dims,
|
||||
num_heads=16,
|
||||
batch_first=True,
|
||||
bias = True
|
||||
):
|
||||
super().__init__()
|
||||
self.proj_in = nn.Linear(in_features=embed_dims, out_features=proj_dims)
|
||||
self.attn = MultiheadAttention(
|
||||
embed_dims=proj_dims,
|
||||
num_heads=num_heads,
|
||||
batch_first=batch_first,
|
||||
bias=bias
|
||||
)
|
||||
self.proj_out = nn.Linear(in_features=proj_dims, out_features=embed_dims)
|
||||
def forward(self, x):
|
||||
x = self.proj_in(x)
|
||||
x = self.attn(x, x, x)
|
||||
x = self.proj_out(x)
|
||||
return x
|
||||
|
||||
class SwinTransformerV2MSL(SwinTransformerV2):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if 'use_attn' in kwargs:
|
||||
self.use_attn = kwargs.pop('use_attn')
|
||||
else:
|
||||
self.use_attn = False
|
||||
if 'merge_stage' in kwargs:
|
||||
self.merge_stage = kwargs.pop('merge_stage')
|
||||
else:
|
||||
self.merge_stage = 0
|
||||
if 'with_cls_pos' in kwargs:
|
||||
self.with_cls_pos = kwargs.pop('with_cls_pos')
|
||||
else:
|
||||
self.with_cls_pos = False
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
#self.vocabulary_token = nn.Parameter(torch.zeros(1, 1, 1, self.vocabulary_size, self.embed_dims))
|
||||
self.vocabulary_token = nn.Parameter(torch.zeros(self.vocabulary_size, self.embed_dims))
|
||||
self.vocabulary_weight = nn.Parameter(torch.zeros(1, self.patch_size * self.patch_size))
|
||||
trunc_normal_(self.mask_token, mean=0., std=.02)
|
||||
trunc_normal_(self.vocabulary_token, mean=0., std=.02)
|
||||
|
||||
if self.use_attn:
|
||||
self.attn1 = Proj_MHSA(embed_dims=352, proj_dims=256, num_heads=16, batch_first=True, bias = True)
|
||||
self.attn2 = Proj_MHSA(embed_dims=704, proj_dims=512, num_heads=16, batch_first=True, bias = True)
|
||||
self.attn3 = Proj_MHSA( embed_dims=1408, proj_dims=1024, num_heads=16, batch_first=True, bias = True)
|
||||
self.attention_blocks = [self.attn1, self.attn2, self.attn3]
|
||||
self.norm_attn = build_norm_layer(dict(type='LN'), 1408)[1]
|
||||
|
||||
def create_ann_token(self, anno_img):
|
||||
B, H, W = anno_img.shape
|
||||
ann_token = torch.index_select(self.vocabulary_token, 0, anno_img.reshape(-1)).reshape(B, H, W, -1)
|
||||
assert H % self.patch_size == 0 and W % self.patch_size == 0
|
||||
nph, npw = H // self.patch_size, W // self.patch_size
|
||||
weight = F.softmax(self.vocabulary_weight, dim=1) * self.patch_size * self.patch_size
|
||||
weight = weight.reshape(1, 1, self.patch_size, 1, self.patch_size).repeat(1, nph, 1, npw, 1).reshape(1, H, W, 1)
|
||||
ann_token = ann_token * weight
|
||||
ann_token = F.avg_pool2d(torch.einsum('BHWC->BCHW', ann_token), self.patch_size, self.patch_size)
|
||||
ann_token = torch.einsum('BCHW->BHWC', ann_token).reshape(B, nph * npw, self.embed_dims) # shape B, L, C
|
||||
return ann_token
|
||||
|
||||
def forward(self, hr_img, anno_img, mask=None):
|
||||
x, hw_shape = self.patch_embed(hr_img)
|
||||
y = self.create_ann_token(anno_img)
|
||||
assert x.shape == y.shape
|
||||
B, L, C = y.shape
|
||||
if mask is not None:
|
||||
mask_tokens = self.mask_token.expand(B, L, -1)
|
||||
w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
|
||||
y = y * (1. - w) + mask_tokens * w
|
||||
|
||||
if self.merge_stage == 0:
|
||||
x = (x + y) * 0.5
|
||||
else:
|
||||
x = x.reshape(B, *hw_shape, C)
|
||||
y = y.reshape(B, *hw_shape, C)
|
||||
x = torch.cat((x, y), dim=2)
|
||||
hw_shape = (hw_shape[0], hw_shape[1] * 2)
|
||||
x = x.reshape(B, -1, C)
|
||||
if self.use_abs_pos_embed:
|
||||
x = x + resize_pos_embed(
|
||||
self.absolute_pos_embed, self.patch_resolution, hw_shape,
|
||||
self.interpolate_mode, self.num_extra_tokens)
|
||||
if self.with_cls_pos:
|
||||
hw_shape_half = [hw_shape[0], hw_shape[1] // 2]
|
||||
x = x.reshape(B, *hw_shape, C)
|
||||
x1 = x[:, :, :x.shape[2]//2, :].reshape(B, -1, C)
|
||||
x2 = x[:, :, x.shape[2]//2:, :].reshape(B, -1, C)
|
||||
x1 = x1 + resize_pos_embed(
|
||||
self.absolute_pos_embed, self.patch_resolution, hw_shape_half,
|
||||
self.interpolate_mode, self.num_extra_tokens)
|
||||
x2 = x2 + resize_pos_embed(
|
||||
self.absolute_pos_embed, self.patch_resolution, hw_shape_half,
|
||||
self.interpolate_mode, self.num_extra_tokens)
|
||||
x1 = x1.reshape(B, *hw_shape_half, C)
|
||||
x2 = x2.reshape(B, *hw_shape_half, C)
|
||||
x = torch.cat((x1, x2), dim=2).reshape(B, -1, C)
|
||||
x = self.drop_after_pos(x)
|
||||
outs = []
|
||||
merge_idx = self.merge_stage - 1
|
||||
for i, stage in enumerate(self.stages):
|
||||
x, hw_shape = stage(x, hw_shape)
|
||||
if i == merge_idx:
|
||||
x = x.reshape(x.shape[0], *hw_shape, x.shape[-1]) # b,l,c -> b, h, w, c
|
||||
x = (x[:, :, :x.shape[2]//2] + x[:, :, x.shape[2]//2:]) * 0.5
|
||||
x = x.reshape(x.shape[0], -1, x.shape[-1])
|
||||
hw_shape = (hw_shape[0], hw_shape[1] // 2)
|
||||
if self.use_attn:
|
||||
if i <= len(self.attention_blocks) - 1:
|
||||
x = x + self.attention_blocks[i](x)
|
||||
if i == len(self.attention_blocks) - 1:
|
||||
x = self.norm_attn(x)
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
out = norm_layer(x)
|
||||
out = out.view(-1, *hw_shape, stage.out_channels).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
return outs
|
||||
611
lib/models/backbones/vit.py
Normal file
611
lib/models/backbones/vit.py
Normal file
@@ -0,0 +1,611 @@
|
||||
# Copyright (c) Ant Group. All rights reserved.
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
|
||||
load_state_dict)
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.utils import get_root_logger
|
||||
from mmseg.models.utils.embed import PatchEmbed
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
class TransformerEncoderLayer(BaseModule):
|
||||
"""Implements one encoder layer in Vision Transformer.
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Default: 0.0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default: 0.0.
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default: True
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default: True.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
qkv_bias=True,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
batch_first=True,
|
||||
attn_cfg=dict(),
|
||||
ffn_cfg=dict(),
|
||||
with_cp=False):
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(norm_cfg,
|
||||
embed_dims,
|
||||
postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
attn_cfg.update(
|
||||
dict(embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
batch_first=batch_first,
|
||||
bias=qkv_bias))
|
||||
|
||||
self.build_attn(attn_cfg)
|
||||
|
||||
self.norm2_name, norm2 = build_norm_layer(norm_cfg,
|
||||
embed_dims,
|
||||
postfix=2)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
|
||||
ffn_cfg.update(
|
||||
dict(embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
if drop_path_rate > 0 else None,
|
||||
act_cfg=act_cfg))
|
||||
self.build_ffn(ffn_cfg)
|
||||
self.with_cp = with_cp
|
||||
|
||||
def build_attn(self, attn_cfg):
|
||||
self.attn = MultiheadAttention(**attn_cfg)
|
||||
|
||||
def build_ffn(self, ffn_cfg):
|
||||
self.ffn = FFN(**ffn_cfg)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
x = self.attn(self.norm1(x), identity=x)
|
||||
x = self.ffn(self.norm2(x), identity=x)
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(BaseModule):
|
||||
"""Vision Transformer.
|
||||
This backbone is the implementation of `An Image is Worth 16x16 Words:
|
||||
Transformers for Image Recognition at
|
||||
Scale <https://arxiv.org/abs/2010.11929>`_.
|
||||
Args:
|
||||
img_size (int | tuple): Input image size. Default: 224.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): embedding dimension. Default: 768.
|
||||
num_layers (int): depth of transformer. Default: 12.
|
||||
num_heads (int): number of attention heads. Default: 12.
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
out_indices (list | tuple | int): Output from which stages.
|
||||
Default: -1.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default: True.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.0
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Default: True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
`with_cls_token` must be True. Default: False.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
|
||||
Default: False.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Default: False.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Default: bicubic.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dims=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
out_indices=-1,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
with_cls_token=True,
|
||||
output_cls_token=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
patch_norm=False,
|
||||
final_norm=False,
|
||||
interpolate_mode='bicubic',
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
use_ccd=False,
|
||||
ccd_num=0,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super(VisionTransformer, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(img_size, int):
|
||||
img_size = to_2tuple(img_size)
|
||||
elif isinstance(img_size, tuple):
|
||||
if len(img_size) == 1:
|
||||
img_size = to_2tuple(img_size[0])
|
||||
assert len(img_size) == 2, \
|
||||
f'The size of image should have length 1 or 2, ' \
|
||||
f'but got {len(img_size)}'
|
||||
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.interpolate_mode = interpolate_mode
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.pretrained = pretrained
|
||||
self.embed_dims = embed_dims
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
padding='corner',
|
||||
norm_cfg=norm_cfg if patch_norm else None,
|
||||
init_cfg=None,
|
||||
)
|
||||
|
||||
self.use_ccd = use_ccd
|
||||
self.ccd_num = ccd_num
|
||||
if self.use_ccd:
|
||||
self.ccd_embed = nn.Parameter(
|
||||
torch.rand(1, self.ccd_num, embed_dims))
|
||||
|
||||
num_patches = (img_size[0] // patch_size) * \
|
||||
(img_size[1] // patch_size)
|
||||
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
# self.pos_embed = nn.Parameter(
|
||||
# torch.zeros(1, num_patches, embed_dims))
|
||||
# 原来是
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dims))
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
if out_indices == -1:
|
||||
out_indices = num_layers - 1
|
||||
self.out_indices = [out_indices]
|
||||
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
|
||||
self.out_indices = out_indices
|
||||
else:
|
||||
raise TypeError('out_indices must be type of int, list or tuple')
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
self.layers = ModuleList()
|
||||
for i in range(num_layers):
|
||||
self.layers.append(
|
||||
TransformerEncoderLayer(embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=mlp_ratio *
|
||||
embed_dims,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
batch_first=True))
|
||||
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm1_name, norm1 = build_norm_layer(norm_cfg,
|
||||
embed_dims,
|
||||
postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def init_weights(self):
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
logger = get_root_logger()
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
if 'pos_embed' in state_dict.keys():
|
||||
if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
||||
logger.info(msg=f'Resize the pos_embed shape from '
|
||||
f'{state_dict["pos_embed"].shape} to '
|
||||
f'{self.pos_embed.shape}')
|
||||
h, w = self.img_size
|
||||
pos_size = int(
|
||||
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
|
||||
state_dict['pos_embed'] = self.resize_pos_embed(
|
||||
state_dict['pos_embed'],
|
||||
(h // self.patch_size, w // self.patch_size),
|
||||
(pos_size, pos_size), self.interpolate_mode)
|
||||
|
||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||
elif self.init_cfg is not None:
|
||||
super(VisionTransformer, self).init_weights()
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
if self.use_ccd:
|
||||
trunc_normal_(self.ccd_embed, std=0.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
|
||||
"""Positioning embeding method.
|
||||
Resize the pos_embed, if the input image size doesn't match
|
||||
the training size.
|
||||
Args:
|
||||
patched_img (torch.Tensor): The patched image, it should be
|
||||
shape of [B, L1, C].
|
||||
hw_shape (tuple): The downsampled image resolution.
|
||||
pos_embed (torch.Tensor): The pos_embed weighs, it should be
|
||||
shape of [B, L2, c].
|
||||
Return:
|
||||
torch.Tensor: The pos encoded image feature.
|
||||
"""
|
||||
assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
|
||||
'the shapes of patched_img and pos_embed must be [B, L, C]'
|
||||
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
|
||||
if x_len != pos_len:
|
||||
if pos_len == (self.img_size[0] // self.patch_size) * (
|
||||
self.img_size[1] // self.patch_size) + 1:
|
||||
pos_h = self.img_size[0] // self.patch_size
|
||||
pos_w = self.img_size[1] // self.patch_size
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unexpected shape of pos_embed, got {}.'.format(
|
||||
pos_embed.shape))
|
||||
pos_embed = self.resize_pos_embed(pos_embed, hw_shape,
|
||||
(pos_h, pos_w),
|
||||
self.interpolate_mode)
|
||||
return self.drop_after_pos(patched_img + pos_embed)
|
||||
|
||||
@staticmethod
|
||||
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
|
||||
"""Resize pos_embed weights.
|
||||
Resize pos_embed using bicubic interpolate method.
|
||||
Args:
|
||||
pos_embed (torch.Tensor): Position embedding weights.
|
||||
input_shpae (tuple): Tuple for (downsampled input image height,
|
||||
downsampled input image width).
|
||||
pos_shape (tuple): The resolution of downsampled origin training
|
||||
image.
|
||||
mode (str): Algorithm used for upsampling:
|
||||
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||
``'trilinear'``. Default: ``'nearest'``
|
||||
Return:
|
||||
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
|
||||
"""
|
||||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
||||
pos_h, pos_w = pos_shape
|
||||
# keep dim for easy deployment
|
||||
cls_token_weight = pos_embed[:, 0:1]
|
||||
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
||||
pos_embed_weight = pos_embed_weight.reshape(
|
||||
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
||||
pos_embed_weight = resize(pos_embed_weight,
|
||||
size=input_shpae,
|
||||
align_corners=False,
|
||||
mode=mode)
|
||||
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
||||
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
||||
return pos_embed
|
||||
|
||||
def forward(self, inputs, ccd_index=None):
|
||||
B = inputs.shape[0]
|
||||
|
||||
x, hw_shape = self.patch_embed(inputs)
|
||||
|
||||
if self.use_ccd:
|
||||
_ccd_idx = np.concatenate(ccd_index, axis=0)
|
||||
_ccd_embed = self.ccd_embed[:, _ccd_idx, :].permute(1, 0, 2)
|
||||
x = x + _ccd_embed
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = self._pos_embeding(x, hw_shape, self.pos_embed)
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i == len(self.layers) - 1:
|
||||
if self.final_norm:
|
||||
x = self.norm1(x)
|
||||
if i in self.out_indices:
|
||||
if self.with_cls_token:
|
||||
# Remove class token and reshape token for decoder heads
|
||||
out = x[:, 1:]
|
||||
else:
|
||||
out = x
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
if self.output_cls_token:
|
||||
out = [out, x[:, 0]]
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
super(VisionTransformer, self).train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.LayerNorm):
|
||||
m.eval()
|
||||
|
||||
|
||||
class VisionTransformerMSL(VisionTransformer):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if 'use_attn' in kwargs:
|
||||
self.use_attn = kwargs.pop('use_attn')
|
||||
else:
|
||||
self.use_attn = False
|
||||
if 'merge_stage' in kwargs:
|
||||
self.merge_stage = kwargs.pop('merge_stage')
|
||||
else:
|
||||
self.merge_stage = 0
|
||||
if 'with_cls_pos' in kwargs:
|
||||
self.with_cls_pos = kwargs.pop('with_cls_pos')
|
||||
else:
|
||||
self.with_cls_pos = False
|
||||
|
||||
self.vocabulary_size = kwargs.pop('vocabulary_size') + 1 # 增加ignore类别
|
||||
super().__init__(**kwargs)
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
img_size = kwargs.pop('img_size')
|
||||
patch_size = kwargs.pop('patch_size')
|
||||
num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dims))
|
||||
|
||||
self.vocabulary_token = nn.Parameter(torch.zeros(self.vocabulary_size, self.embed_dims))
|
||||
self.vocabulary_weight = nn.Parameter(torch.zeros(1, self.patch_size * self.patch_size))
|
||||
trunc_normal_(self.mask_token, mean=0., std=.02)
|
||||
trunc_normal_(self.vocabulary_token, mean=0., std=.02)
|
||||
|
||||
if self.use_attn:
|
||||
self.attn1 = MultiheadAttention(embed_dims=1024, num_heads=16, batch_first=True, bias = True)
|
||||
self.attn2 = MultiheadAttention(embed_dims=1024, num_heads=16, batch_first=True, bias = True)
|
||||
self.attn3 = MultiheadAttention(embed_dims=1024, num_heads=16, batch_first=True, bias = True)
|
||||
self.attention_blocks = [self.attn1, self.attn2, self.attn3]
|
||||
self.norm_attn = build_norm_layer(dict(type='LN'), 1024)[1]
|
||||
|
||||
def create_ann_token(self, anno_img):
|
||||
B, H, W = anno_img.shape
|
||||
ann_token = torch.index_select(self.vocabulary_token, 0, anno_img.reshape(-1)).reshape(B, H, W, -1)
|
||||
assert H % self.patch_size == 0 and W % self.patch_size == 0
|
||||
nph, npw = H // self.patch_size, W // self.patch_size
|
||||
weight = F.softmax(self.vocabulary_weight, dim=1) * self.patch_size * self.patch_size
|
||||
weight = weight.reshape(1, 1, self.patch_size, 1, self.patch_size).repeat(1, nph, 1, npw, 1).reshape(1, H, W, 1)
|
||||
ann_token = ann_token * weight
|
||||
ann_token = F.avg_pool2d(torch.einsum('BHWC->BCHW', ann_token), self.patch_size, self.patch_size)
|
||||
ann_token = torch.einsum('BCHW->BHWC', ann_token).reshape(B, nph * npw, self.embed_dims) # shape B, L, C
|
||||
return ann_token
|
||||
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
|
||||
"""Positioning embeding method.
|
||||
Resize the pos_embed, if the input image size doesn't match
|
||||
the training size.
|
||||
Args:
|
||||
patched_img (torch.Tensor): The patched image, it should be
|
||||
shape of [B, L1, C].
|
||||
hw_shape (tuple): The downsampled image resolution.
|
||||
pos_embed (torch.Tensor): The pos_embed weighs, it should be
|
||||
shape of [B, L2, c].
|
||||
Return:
|
||||
torch.Tensor: The pos encoded image feature.
|
||||
"""
|
||||
assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
|
||||
'the shapes of patched_img and pos_embed must be [B, L, C]'
|
||||
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
|
||||
if x_len != pos_len:
|
||||
if pos_len == (self.img_size[0] // self.patch_size) * (
|
||||
self.img_size[1] // self.patch_size):
|
||||
pos_h = self.img_size[0] // self.patch_size
|
||||
pos_w = self.img_size[1] // self.patch_size
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unexpected shape of pos_embed, got {}.'.format(
|
||||
pos_embed.shape))
|
||||
pos_embed = self.resize_pos_embed(pos_embed, hw_shape,
|
||||
(pos_h, pos_w),
|
||||
self.interpolate_mode)
|
||||
return self.drop_after_pos(patched_img + pos_embed)
|
||||
|
||||
@staticmethod
|
||||
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
|
||||
"""Resize pos_embed weights.
|
||||
Resize pos_embed using bicubic interpolate method.
|
||||
Args:
|
||||
pos_embed (torch.Tensor): Position embedding weights.
|
||||
input_shpae (tuple): Tuple for (downsampled input image height,
|
||||
downsampled input image width).
|
||||
pos_shape (tuple): The resolution of downsampled origin training
|
||||
image.
|
||||
mode (str): Algorithm used for upsampling:
|
||||
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||
``'trilinear'``. Default: ``'nearest'``
|
||||
Return:
|
||||
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
|
||||
"""
|
||||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
||||
pos_h, pos_w = pos_shape
|
||||
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
||||
pos_embed_weight = pos_embed_weight.reshape(
|
||||
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
||||
pos_embed_weight = resize(pos_embed_weight,
|
||||
size=input_shpae,
|
||||
align_corners=False,
|
||||
mode=mode)
|
||||
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
||||
pos_embed = pos_embed_weight # torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
||||
return pos_embed
|
||||
|
||||
def forward(self, x, y, mask=None):
|
||||
x, hw_shape = self.patch_embed(x)
|
||||
y = self.create_ann_token(y)
|
||||
assert x.shape == y.shape
|
||||
B, L, C = y.shape
|
||||
if mask is not None:
|
||||
mask_tokens = self.mask_token.expand(B, L, -1)
|
||||
w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
|
||||
y = y * (1. - w) + mask_tokens * w
|
||||
if self.merge_stage == 0:
|
||||
x = (x + y) * 0.5
|
||||
else:
|
||||
x = x.reshape(B, *hw_shape, C)
|
||||
y = y.reshape(B, *hw_shape, C)
|
||||
x = torch.cat((x, y), dim=2)
|
||||
hw_shape = (hw_shape[0], hw_shape[1] * 2)
|
||||
x = x.reshape(B, -1, C)
|
||||
x = self._pos_embeding(x, hw_shape, self.pos_embed)
|
||||
merge_idx = self.merge_stage - 1
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i == merge_idx:
|
||||
x = x.reshape(x.shape[0], *hw_shape, x.shape[-1]) # b,l,c -> b, h, w, c
|
||||
x = (x[:, :, :x.shape[2]//2] + x[:, :, x.shape[2]//2:]) * 0.5
|
||||
x = x.reshape(x.shape[0], -1, x.shape[-1])
|
||||
hw_shape = (hw_shape[0], hw_shape[1] // 2)
|
||||
if self.use_attn:
|
||||
if i <= len(self.attention_blocks) - 1:
|
||||
x = x + self.attention_blocks[i](x)
|
||||
if i == len(self.attention_blocks) - 1:
|
||||
x = self.norm_attn(x) # 会不会有冲突
|
||||
if (not self.use_attn) and (i == len(self.layers) - 1):
|
||||
if self.final_norm:
|
||||
x = self.norm1(x) # 会不会有冲突
|
||||
if i in self.out_indices:
|
||||
if self.with_cls_token:
|
||||
# Remove class token and reshape token for decoder heads
|
||||
out = x[:, 1:]
|
||||
else:
|
||||
out = x
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
if self.output_cls_token:
|
||||
out = [out, x[:, 0]]
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
15
lib/models/heads/__init__.py
Normal file
15
lib/models/heads/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .uper_head import UPerHead
|
||||
from .up_head import UPHead
|
||||
|
||||
__all__ = [
|
||||
'UPerHead', 'UPHead'
|
||||
]
|
||||
|
||||
type_mapping = {
|
||||
'UPerHead': UPerHead,
|
||||
'UPHead': UPHead
|
||||
}
|
||||
|
||||
|
||||
def build_head(type, **kwargs):
|
||||
return type_mapping[type](**kwargs)
|
||||
201
lib/models/heads/decode_head.py
Normal file
201
lib/models/heads/decode_head.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule, auto_fp16
|
||||
|
||||
from mmseg.core import build_pixel_sampler
|
||||
from mmseg.ops import resize
|
||||
|
||||
|
||||
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
"""Base class for BaseDecodeHead.
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
num_classes (int): Number of classes.
|
||||
out_channels (int): Output channels of conv_seg.
|
||||
threshold (float): Threshold for binary segmentation in the case of
|
||||
`out_channels==1`. Default: None.
|
||||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
||||
conv_cfg (dict|None): Config of conv layers. Default: None.
|
||||
norm_cfg (dict|None): Config of norm layers. Default: None.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU')
|
||||
in_index (int|Sequence[int]): Input feature index. Default: -1
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
Default: None.
|
||||
loss_decode (dict | Sequence[dict]): Config of decode loss.
|
||||
The `loss_name` is property of corresponding loss function which
|
||||
could be shown in training log. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_ce'.
|
||||
e.g. dict(type='CrossEntropyLoss'),
|
||||
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
|
||||
dict(type='DiceLoss', loss_name='loss_dice')]
|
||||
Default: dict(type='CrossEntropyLoss').
|
||||
ignore_index (int | None): The label index to be ignored. When using
|
||||
masked BCE loss, ignore_index should be set to None. Default: 255.
|
||||
sampler (dict|None): The config of segmentation map sampler.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channels,
|
||||
*,
|
||||
num_classes,
|
||||
out_channels=None,
|
||||
threshold=None,
|
||||
dropout_ratio=0.1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
in_index=-1,
|
||||
input_transform=None,
|
||||
sampler=None,
|
||||
align_corners=False,
|
||||
init_cfg=dict(type='Normal',
|
||||
std=0.01,
|
||||
override=dict(name='conv_seg'))):
|
||||
super(BaseDecodeHead, self).__init__(init_cfg)
|
||||
self._init_inputs(in_channels, in_index, input_transform)
|
||||
self.channels = channels
|
||||
self.dropout_ratio = dropout_ratio
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.in_index = in_index
|
||||
|
||||
self.align_corners = align_corners
|
||||
|
||||
if out_channels is None:
|
||||
if num_classes == 2:
|
||||
warnings.warn('For binary segmentation, we suggest using'
|
||||
'`out_channels = 1` to define the output'
|
||||
'channels of segmentor, and use `threshold`'
|
||||
'to convert seg_logist into a prediction'
|
||||
'applying a threshold')
|
||||
out_channels = num_classes
|
||||
|
||||
if out_channels != num_classes and out_channels != 1:
|
||||
raise ValueError(
|
||||
'out_channels should be equal to num_classes,'
|
||||
'except binary segmentation set out_channels == 1 and'
|
||||
f'num_classes == 2, but got out_channels={out_channels}'
|
||||
f'and num_classes={num_classes}')
|
||||
|
||||
if out_channels == 1 and threshold is None:
|
||||
threshold = 0.3
|
||||
warnings.warn('threshold is not defined for binary, and defaults'
|
||||
'to 0.3')
|
||||
self.num_classes = num_classes
|
||||
self.out_channels = out_channels
|
||||
self.threshold = threshold
|
||||
|
||||
if sampler is not None:
|
||||
self.sampler = build_pixel_sampler(sampler, context=self)
|
||||
else:
|
||||
self.sampler = None
|
||||
|
||||
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
|
||||
if dropout_ratio > 0:
|
||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
||||
else:
|
||||
self.dropout = None
|
||||
self.fp16_enabled = False
|
||||
|
||||
def extra_repr(self):
|
||||
"""Extra repr."""
|
||||
s = f'input_transform={self.input_transform}, ' \
|
||||
f'align_corners={self.align_corners}'
|
||||
return s
|
||||
|
||||
def _init_inputs(self, in_channels, in_index, input_transform):
|
||||
"""Check and initialize input transforms.
|
||||
|
||||
The in_channels, in_index and input_transform must match.
|
||||
Specifically, when input_transform is None, only single feature map
|
||||
will be selected. So in_channels and in_index must be of type int.
|
||||
When input_transform
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
in_index (int|Sequence[int]): Input feature index.
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
"""
|
||||
|
||||
if input_transform is not None:
|
||||
assert input_transform in ['resize_concat', 'multiple_select']
|
||||
self.input_transform = input_transform
|
||||
self.in_index = in_index
|
||||
if input_transform is not None:
|
||||
assert isinstance(in_channels, (list, tuple))
|
||||
assert isinstance(in_index, (list, tuple))
|
||||
assert len(in_channels) == len(in_index)
|
||||
if input_transform == 'resize_concat':
|
||||
self.in_channels = sum(in_channels)
|
||||
else:
|
||||
self.in_channels = in_channels
|
||||
else:
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(in_index, int)
|
||||
self.in_channels = in_channels
|
||||
|
||||
def _transform_inputs(self, inputs):
|
||||
"""Transform inputs for decoder.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
Tensor: The transformed inputs
|
||||
"""
|
||||
|
||||
if self.input_transform == 'resize_concat':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
upsampled_inputs = [
|
||||
resize(input=x,
|
||||
size=inputs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners) for x in inputs
|
||||
]
|
||||
inputs = torch.cat(upsampled_inputs, dim=1)
|
||||
elif self.input_transform == 'multiple_select':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
else:
|
||||
inputs = inputs[self.in_index]
|
||||
|
||||
return inputs
|
||||
|
||||
@auto_fp16()
|
||||
@abstractmethod
|
||||
def forward(self, inputs):
|
||||
"""Placeholder of forward function."""
|
||||
pass
|
||||
|
||||
def cls_seg(self, feat):
|
||||
"""Classify each pixel."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.conv_seg(feat)
|
||||
return output
|
||||
60
lib/models/heads/psp_head.py
Normal file
60
lib/models/heads/psp_head.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class PPM(nn.ModuleList):
|
||||
"""Pooling Pyramid Module used in PSPNet.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
|
||||
act_cfg, align_corners, **kwargs):
|
||||
super(PPM, self).__init__()
|
||||
self.pool_scales = pool_scales
|
||||
self.align_corners = align_corners
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
for pool_scale in pool_scales:
|
||||
self.append(
|
||||
nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(pool_scale),
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
**kwargs)))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
ppm_outs = []
|
||||
for ppm in self:
|
||||
ppm_out = ppm(x)
|
||||
ppm_out = ppm_out.to(torch.float32)
|
||||
upsampled_ppm_out = resize(
|
||||
ppm_out,
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
upsampled_ppm_out = upsampled_ppm_out.to(torch.bfloat16)
|
||||
ppm_outs.append(upsampled_ppm_out)
|
||||
return ppm_outs
|
||||
52
lib/models/heads/up_head.py
Normal file
52
lib/models/heads/up_head.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
from mmcv.cnn.utils.weight_init import (kaiming_init, trunc_normal_)
|
||||
from mmcv.runner import (CheckpointLoader, load_state_dict)
|
||||
from mmseg.utils import get_root_logger
|
||||
|
||||
|
||||
class UPHead(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, up_scale, init_cfg=None):
|
||||
super().__init__()
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Conv2d(in_channels=in_dim,
|
||||
out_channels=up_scale**2 * out_dim,
|
||||
kernel_size=1),
|
||||
nn.PixelShuffle(up_scale),
|
||||
)
|
||||
self.init_cfg = init_cfg
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
logger = get_root_logger()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
_state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
_state_dict = checkpoint
|
||||
|
||||
state_dict = OrderedDict()
|
||||
for k, v in _state_dict.items():
|
||||
if k.startswith('backbone.'):
|
||||
state_dict[k[9:]] = v
|
||||
else:
|
||||
state_dict[k] = v
|
||||
print(f'loading weight: {self.init_cfg["checkpoint"]}')
|
||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||
else:
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
130
lib/models/heads/uper_head.py
Normal file
130
lib/models/heads/uper_head.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# coding: utf-8
|
||||
# Copyright (c) Ant Group. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from .psp_head import PPM
|
||||
|
||||
|
||||
class UPerHead(BaseDecodeHead):
|
||||
"""Unified Perceptual Parsing for Scene Understanding.
|
||||
|
||||
This head is the implementation of `UPerNet
|
||||
<https://arxiv.org/abs/1807.10221>`_.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module applied on the last feature. Default: (1, 2, 3, 6).
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
||||
super(UPerHead, self).__init__(
|
||||
input_transform='multiple_select', **kwargs)
|
||||
# PSP Module
|
||||
self.psp_modules = PPM(
|
||||
pool_scales,
|
||||
self.in_channels[-1],
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels[-1] + len(pool_scales) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
# FPN Module
|
||||
self.lateral_convs = nn.ModuleList()
|
||||
self.fpn_convs = nn.ModuleList()
|
||||
for in_channels in self.in_channels[:-1]: # skip the top layer
|
||||
l_conv = ConvModule(
|
||||
in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
inplace=False)
|
||||
fpn_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
inplace=False)
|
||||
self.lateral_convs.append(l_conv)
|
||||
self.fpn_convs.append(fpn_conv)
|
||||
|
||||
self.fpn_bottleneck = ConvModule(
|
||||
len(self.in_channels) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def psp_forward(self, inputs):
|
||||
"""Forward function of PSP module."""
|
||||
x = inputs[-1]
|
||||
psp_outs = [x]
|
||||
# breakpoint()
|
||||
psp_outs.extend(self.psp_modules(x))
|
||||
psp_outs = torch.cat(psp_outs, dim=1)
|
||||
output = self.bottleneck(psp_outs)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
# breakpoint()
|
||||
inputs = self._transform_inputs(inputs)
|
||||
|
||||
# build laterals
|
||||
laterals = [
|
||||
lateral_conv(inputs[i])
|
||||
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||
]
|
||||
|
||||
laterals.append(self.psp_forward(inputs))
|
||||
|
||||
# build top-down path
|
||||
used_backbone_levels = len(laterals)
|
||||
for i in range(used_backbone_levels - 1, 0, -1):
|
||||
prev_shape = laterals[i - 1].shape[2:]
|
||||
laterals[i] = laterals[i].type(torch.float32)
|
||||
laterals[i - 1] = laterals[i - 1] + resize(
|
||||
laterals[i],
|
||||
size=prev_shape,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
# build outputs
|
||||
fpn_outs = [
|
||||
self.fpn_convs[i](laterals[i])
|
||||
for i in range(used_backbone_levels - 1)
|
||||
]
|
||||
# append psp feature
|
||||
fpn_outs.append(laterals[-1])
|
||||
|
||||
for i in range(used_backbone_levels - 1, 0, -1):
|
||||
fpn_outs[i] = fpn_outs[i].type(torch.float32)
|
||||
fpn_outs[i] = resize(
|
||||
fpn_outs[i],
|
||||
size=fpn_outs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
fpn_outs = torch.cat(fpn_outs, dim=1)
|
||||
output = self.fpn_bottleneck(fpn_outs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
4
lib/models/losses/__init__.py
Normal file
4
lib/models/losses/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .modality_vae_loss import ModalityVAELoss
|
||||
from .recon_anno_loss import RecLoss
|
||||
|
||||
__all__ = [ "ModalityVAELoss", "RecLoss" ]
|
||||
46
lib/models/losses/modality_vae_loss.py
Normal file
46
lib/models/losses/modality_vae_loss.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) Ant Group and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from antmmf.common.registry import registry
|
||||
|
||||
|
||||
@registry.register_loss("ModalityVAELoss")
|
||||
class ModalityVAELoss(nn.Module):
|
||||
def __init__(self, **params):
|
||||
super().__init__()
|
||||
self.weight = params.pop("weight")
|
||||
|
||||
def compute_rec_loss(self, x_in, x_out, modal_flag):
|
||||
loss_per_pixel = F.mse_loss(x_in, x_out, reduction='none')
|
||||
loss_b = torch.mean(loss_per_pixel, dim=[1, 2, 3])
|
||||
return torch.sum(loss_b * modal_flag)/ (modal_flag.sum() + 1e-6)
|
||||
|
||||
def forward(self, sample_list, output, *args, **kwargs):
|
||||
vae_out = output["vae_out"]
|
||||
feat_hr = vae_out['input_hr']
|
||||
feat_s2 = vae_out['input_s2']
|
||||
feat_s1 = vae_out['input_s1']
|
||||
|
||||
g_hr = vae_out['g_hr']
|
||||
g_s2 = vae_out['g_s2']
|
||||
g_s1 = vae_out['g_s1']
|
||||
|
||||
# process modality flags
|
||||
modality_info = vae_out['modality_info']
|
||||
B_M, L_M = modality_info.shape
|
||||
|
||||
modality_hr = modality_info[:,0]
|
||||
modality_s2 = modality_info[:,1]
|
||||
modality_s1 = modality_info[:,2]
|
||||
|
||||
######## rec losses ########
|
||||
loss_xent = self.compute_rec_loss(g_hr, feat_hr, modality_hr) \
|
||||
+ self.compute_rec_loss(g_s2, feat_s2, modality_s2) \
|
||||
+ self.compute_rec_loss(g_s1, feat_s1, modality_s1)
|
||||
|
||||
|
||||
loss_quant = vae_out["loss_quant"]
|
||||
total_loss = loss_xent / 3 + loss_quant
|
||||
return total_loss * self.weight
|
||||
89
lib/models/losses/recon_anno_loss.py
Normal file
89
lib/models/losses/recon_anno_loss.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# Copyright (c) Ant Group and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from antmmf.common.registry import registry
|
||||
import torch.nn.functional as F
|
||||
|
||||
@registry.register_loss("RecLoss")
|
||||
class RecLoss(nn.Module):
|
||||
def __init__(self, **params):
|
||||
super().__init__()
|
||||
self.weight = params.pop("weight")
|
||||
self.patch_size = params.pop("patch_size")
|
||||
self.eps = torch.finfo(torch.bfloat16).eps
|
||||
self.pred_key = params.pop("pred_key")
|
||||
self.vocabulary_size = params.pop("vocabulary_size") + 1
|
||||
self.mask_key = params.pop("mask_key")
|
||||
self.target_key = params.pop("target_key")
|
||||
self.feature_merged = params.pop("feature_merged")
|
||||
self.cnt_train = 0
|
||||
self.cnt_val = 0
|
||||
self.use_bg = params.pop("use_bg")
|
||||
if "use_all_patch" in params:
|
||||
self.use_all_patch = params.pop("use_all_patch")
|
||||
else:
|
||||
self.use_all_patch = False
|
||||
if "balance" in params:
|
||||
self.balance = params.pop("balance")
|
||||
else:
|
||||
self.balance = False
|
||||
if "sim_regularization" in params:
|
||||
self.sim_regularization = params.pop("sim_regularization")
|
||||
else:
|
||||
self.sim_regularization = False
|
||||
|
||||
def unpatchify(self, x):
|
||||
"""
|
||||
x: (N, L, patch_size**2 *3)
|
||||
imgs: (N, 3, H, W)
|
||||
"""
|
||||
p = self.patch_size
|
||||
w = int((x.shape[1]*0.5)**.5)
|
||||
h = w * 2
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p))
|
||||
x = torch.einsum('nhwpq->nhpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], h * p, w * p))
|
||||
return imgs
|
||||
|
||||
|
||||
def forward(self, sample_list, output, *args, **kwargs):
|
||||
pred = output[self.pred_key] # B, C, H, W
|
||||
target = output[self.target_key] # B, H, W
|
||||
mask = output[self.mask_key]
|
||||
b_mask, h_mask, w_mask = mask.shape
|
||||
mask = mask.reshape((b_mask, h_mask*w_mask))
|
||||
mask = mask[:, :, None].repeat(1, 1, self.patch_size**2)
|
||||
mask = self.unpatchify(mask)
|
||||
|
||||
if not self.use_bg:
|
||||
valid = sample_list['valid']
|
||||
mask = mask * valid
|
||||
|
||||
loss = F.cross_entropy(pred, target, reduction="none")
|
||||
|
||||
if self.balance:
|
||||
if self.use_all_patch:
|
||||
loss_pos = loss[target > 0].sum() / ((target > 0).sum() + 1e-6)
|
||||
loss_neg = loss[target == 0].sum() / ((target == 0).sum() + 1e-6)
|
||||
loss = (loss_pos + loss_neg) * 0.5
|
||||
else:
|
||||
loss_pos = loss[(target > 0) & (mask == 1)].sum() / (((target > 0) & (mask == 1)).sum() + 1e-6)
|
||||
loss_neg = loss[(target == 0) & (mask == 1)].sum() / (((target == 0) & (mask == 1)).sum() + 1e-6)
|
||||
loss = (loss_pos + loss_neg) * 0.5
|
||||
else:
|
||||
if self.use_all_patch:
|
||||
loss = loss.mean()
|
||||
else:
|
||||
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
||||
if self.sim_regularization:
|
||||
vocabulary_token = output['vocabulary_token']
|
||||
voca_normed = F.normalize(vocabulary_token, 2, 1)
|
||||
similarity_matrix = 1 + torch.einsum('nd,md->nm', voca_normed, voca_normed)
|
||||
num = voca_normed.shape[0]
|
||||
index = torch.triu(voca_normed.new_ones(num, num), diagonal=1).type(torch.bool)
|
||||
loss_reg = similarity_matrix[index].mean()
|
||||
return loss * self.weight + loss_reg * 0.05
|
||||
return loss * self.weight
|
||||
|
||||
4
lib/models/metrics/__init__.py
Normal file
4
lib/models/metrics/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .sem_metrics import SemMetric
|
||||
|
||||
__all__ = ["SemMetric"]
|
||||
|
||||
93
lib/models/metrics/sem_metrics.py
Normal file
93
lib/models/metrics/sem_metrics.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# coding: utf-8
|
||||
# Copyright (c) Ant Group. All rights reserved.
|
||||
import torch
|
||||
from torch.distributed import all_reduce, ReduceOp
|
||||
from antmmf.common.registry import registry
|
||||
from antmmf.modules.metrics.base_metric import BaseMetric
|
||||
|
||||
@registry.register_metric("sem_metric")
|
||||
class SemMetric(BaseMetric):
|
||||
"""Segmentation metrics used in evaluation phase.
|
||||
|
||||
Args:
|
||||
name (str): Name of the metric.
|
||||
eval_type(str): 3 types are supported: 'mIoU', 'mDice', 'mFscore'
|
||||
result_field(str): key of predicted results in output dict
|
||||
target_field(str): key of ground truth in output dict
|
||||
ignore_index(int): class value will be ignored in evaluation
|
||||
num_cls(int): total number of categories in evaluation
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name="dummy_metric", **kwargs
|
||||
):
|
||||
super().__init__(name)
|
||||
self.reset()
|
||||
|
||||
def calculate(self, sample_list, model_output, *args, **kwargs):
|
||||
"""Calculate Intersection and Union for a batch.
|
||||
|
||||
Args:
|
||||
sample_list (Sample_List): data which contains ground truth segmentation maps
|
||||
model_output (dict): data which contains prediction segmentation maps
|
||||
Returns:
|
||||
torch.Tensor: The intersection of prediction and ground truth histogram
|
||||
on all classes.
|
||||
torch.Tensor: The union of prediction and ground truth histogram on all
|
||||
classes.
|
||||
torch.Tensor: The prediction histogram on all classes.
|
||||
torch.Tensor: The ground truth histogram on all classes.
|
||||
"""
|
||||
|
||||
return torch.tensor(0).float()
|
||||
|
||||
def reset(self):
|
||||
""" initialized all attributes value before evaluation
|
||||
|
||||
"""
|
||||
self.total_mask_mae = 0
|
||||
self.total_num = torch.tensor(0)
|
||||
|
||||
def collect(self, sample_list, model_output, *args, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
sample_list(Sample_List): data which contains ground truth segmentation maps
|
||||
model_output (Dict): Dict returned by model, that contains two modalities
|
||||
Returns:
|
||||
torch.FloatTensor: Accuracy
|
||||
"""
|
||||
batch_mask_mae = \
|
||||
self.calculate(sample_list, model_output, *args, **kwargs)
|
||||
self.total_mask_mae += batch_mask_mae
|
||||
self.total_num += 1
|
||||
|
||||
def format(self, *args):
|
||||
""" Format evaluated metrics for profile.
|
||||
|
||||
Returns:
|
||||
dict: dict of all evaluated metrics.
|
||||
"""
|
||||
output_metric = dict()
|
||||
# if self.eval_type == 'mae':
|
||||
mae = args[0]
|
||||
output_metric['mae'] = mae.item()
|
||||
return output_metric
|
||||
|
||||
def summarize(self, *args, **kwargs):
|
||||
"""This method is used to calculate the overall metric.
|
||||
|
||||
Returns:
|
||||
dict: dict of all evaluated metrics.
|
||||
|
||||
"""
|
||||
# if self.eval_type == 'mae':
|
||||
mae = self.total_mask_mae / (self.total_num)
|
||||
return self.format(mae)
|
||||
|
||||
def all_reduce(self):
|
||||
total_number = torch.stack([
|
||||
self.total_mask_mae, self.total_num
|
||||
]).cuda()
|
||||
all_reduce(total_number, op=ReduceOp.SUM)
|
||||
self.total_mask_mae = total_number[0].cpu()
|
||||
self.total_num = total_number[1].cpu()
|
||||
13
lib/models/necks/__init__.py
Normal file
13
lib/models/necks/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .transformer_encoder import TransformerEncoder
|
||||
from .modality_completion import ModalityCompletion
|
||||
|
||||
__all__ = ['TransformerEncoder', 'ModalityCompletion']
|
||||
|
||||
type_mapping = {
|
||||
'TransformerEncoder': TransformerEncoder,
|
||||
'ModalityCompletion': ModalityCompletion
|
||||
}
|
||||
|
||||
|
||||
def build_neck(type, **kwargs):
|
||||
return type_mapping[type](**kwargs)
|
||||
212
lib/models/necks/modality_completion.py
Normal file
212
lib/models/necks/modality_completion.py
Normal file
@@ -0,0 +1,212 @@
|
||||
# Copyright (c) AntGroup. All rights reserved.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class BFloat16UpsampleNearest2d(nn.Module):
|
||||
def __init__(self, scale_factor, mode='bilinear'):
|
||||
super().__init__()
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, x):
|
||||
x_float = x.float()
|
||||
upsampled_x = F.interpolate(x_float, scale_factor=self.scale_factor, mode=self.mode)
|
||||
return upsampled_x.to(x.dtype)
|
||||
|
||||
class ConvVQVAEv2(nn.Module):
|
||||
def __init__(self, input_shape, conv_dim, z_dim, num_tokens=8192, temp=0.9):
|
||||
super().__init__()
|
||||
self.z_dim = z_dim
|
||||
self.conv_dim = conv_dim # 256
|
||||
self.input_shape = input_shape # 256
|
||||
self.temp = temp
|
||||
# code book
|
||||
self.codebook = nn.Embedding(num_tokens, z_dim)
|
||||
# encoder
|
||||
self.relu = nn.LeakyReLU()
|
||||
self.pool = nn.AvgPool2d(2)
|
||||
self.conv1 = nn.Conv2d(input_shape[0], conv_dim, 5, stride=1, padding=2)
|
||||
self.enc_block1 = nn.Sequential(
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
self.gamma_1 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
|
||||
self.enc_block2 = nn.Sequential(
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
self.gamma_2 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
|
||||
self.logit_conv = nn.Conv2d(conv_dim, num_tokens, 1)
|
||||
# decoder
|
||||
self.unpool = BFloat16UpsampleNearest2d(scale_factor=2)
|
||||
self.conv2 = nn.Conv2d(z_dim, conv_dim, 3, stride=1, padding=1)
|
||||
self.dec_block1 = nn.Sequential(
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
self.gamma_3 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
|
||||
self.dec_block2 = nn.Sequential(
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
self.gamma_4 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
|
||||
self.rec_conv = nn.Conv2d(conv_dim, input_shape[0], 3, stride=1, padding=1)
|
||||
|
||||
def forward_encoder(self, x):
|
||||
x = self.relu(self.conv1(x))
|
||||
x = x + self.gamma_1 * self.enc_block1(x)
|
||||
x = self.pool(x)
|
||||
x = x + self.gamma_2 * self.enc_block2(x)
|
||||
x = self.pool(x)
|
||||
logits = self.logit_conv(x)
|
||||
return logits
|
||||
|
||||
def forward_decoder(self, logits):
|
||||
soft_one_hot = F.softmax(logits * (self.temp*10), dim=1)
|
||||
sampled = torch.einsum('bnhw,nd->bdhw', soft_one_hot, self.codebook.weight)
|
||||
x = self.relu(self.conv2(sampled))
|
||||
x = self.unpool(x)
|
||||
x = x + self.gamma_3 * self.dec_block1(x)
|
||||
x = self.unpool(x)
|
||||
x = x + self.gamma_4 * self.dec_block2(x)
|
||||
rec_feats = self.rec_conv(x)
|
||||
return rec_feats, soft_one_hot
|
||||
|
||||
def forward(self, x):
|
||||
print(x.shape)
|
||||
logits = self.forward_encoder(x)
|
||||
images_p, soft_one_hot = self.forward_decoder(logits)
|
||||
return [logits, images_p]
|
||||
|
||||
class ModalityCompletion(nn.Module):
|
||||
def __init__(self,
|
||||
input_shape_hr=(2816, 16, 16),
|
||||
input_shape_s2=(2816, 16, 16),
|
||||
input_shape_s1=(2816, 16, 16),
|
||||
conv_dim=256,
|
||||
z_dim=256,
|
||||
n_codebook=8192,
|
||||
init_cfg=None
|
||||
):
|
||||
super(ModalityCompletion, self).__init__()
|
||||
self.vae_hr = ConvVQVAEv2(input_shape=input_shape_hr, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook)
|
||||
self.vae_s2 = ConvVQVAEv2(input_shape=input_shape_s2, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook)
|
||||
self.vae_s1 = ConvVQVAEv2(input_shape=input_shape_s1, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook)
|
||||
self.kl_div_loss = torch.nn.KLDivLoss(reduction="none", log_target=True)
|
||||
self.init_cfg=init_cfg
|
||||
|
||||
def init_weights(self):
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
# Suppress default init if use pretrained model.
|
||||
from mmcls.utils import get_root_logger
|
||||
from mmcv.runner import CheckpointLoader, load_state_dict
|
||||
logger = get_root_logger()
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||
else:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def kl_loss(self, logits_hr, logits_s2, logits_s1, modality_info):
|
||||
prob_hr = F.log_softmax(logits_hr, dim=1)
|
||||
prob_s2 = F.log_softmax(logits_s2, dim=1)
|
||||
prob_s1 = F.log_softmax(logits_s1, dim=1)
|
||||
flag_hr = modality_info[:,0][:, None, None, None]
|
||||
flag_s2 = modality_info[:,1][:, None, None, None]
|
||||
flag_s1 = modality_info[:,2][:, None, None, None]
|
||||
loss_hr_s2 = self.kl_div_loss(prob_hr, prob_s2) + self.kl_div_loss(prob_s2, prob_hr)
|
||||
loss_hr_s2 = (loss_hr_s2 * flag_hr * flag_s2).sum((1, 2, 3)).mean()
|
||||
loss_hr_s1 = self.kl_div_loss(prob_hr, prob_s1) + self.kl_div_loss(prob_s1, prob_hr)
|
||||
loss_hr_s1 = (loss_hr_s1 * flag_hr * flag_s1).sum((1, 2, 3)).mean()
|
||||
loss_s2_s1 = self.kl_div_loss(prob_s2, prob_s1) + self.kl_div_loss(prob_s1, prob_s2)
|
||||
loss_s2_s1 = (loss_s2_s1 * flag_s2 * flag_s1).sum((1, 2, 3)).mean()
|
||||
loss = (loss_hr_s2 + loss_hr_s1 + loss_s2_s1) / 6.0
|
||||
|
||||
return loss
|
||||
|
||||
def forward(self, feat_hr, feat_s2, feat_s1, modality_info):
|
||||
# encoders,add noise
|
||||
# each modality
|
||||
# 2816, 16, 16 => conv 256, 4, 4 => flatten 4096(256*4*4) => linear mu 256, log_var 256
|
||||
B, C, H, W = feat_hr.shape
|
||||
B_M, L_M = modality_info.shape
|
||||
assert B == B_M, f'feat_hr batch: {B}, modality_info batch: {B_M}'
|
||||
|
||||
# quant, emb_loss, info
|
||||
# hr input flow
|
||||
logits_hr = self.vae_hr.forward_encoder(feat_hr)
|
||||
logits_s2 = self.vae_s2.forward_encoder(feat_s2)
|
||||
logits_s1 = self.vae_s1.forward_encoder(feat_s1)
|
||||
modality_hr = modality_info[:,0]
|
||||
modality_s2 = modality_info[:,1]
|
||||
modality_s1 = modality_info[:,2]
|
||||
flag_hr = modality_hr[:, None, None, None] # B => B, C, H, W
|
||||
flag_s2 = modality_s2[:, None, None, None]
|
||||
flag_s1 = modality_s1[:, None, None, None]
|
||||
|
||||
mean_logits_hr_s2 = logits_hr * flag_hr + logits_s2 * flag_s2
|
||||
mean_logits_hr_s1 = logits_hr * flag_hr + logits_s1 * flag_s1
|
||||
mean_logits_s1_s2 = logits_s1 * flag_s1 + logits_s2 * flag_s2
|
||||
|
||||
logits_hr_rec = logits_hr * flag_hr + mean_logits_s1_s2 * (~flag_hr)
|
||||
logits_s2_rec = logits_s2 * flag_s2 + mean_logits_hr_s1 * (~flag_s2)
|
||||
logits_s1_rec = logits_s1 * flag_s1 + mean_logits_hr_s2 * (~flag_s1)
|
||||
g_hr, soft_one_hot_hr = self.vae_hr.forward_decoder(logits_hr_rec)
|
||||
g_s2, soft_one_s2 = self.vae_s2.forward_decoder(logits_s2_rec)
|
||||
g_s1, soft_one_s1 = self.vae_s1.forward_decoder(logits_s1_rec)
|
||||
|
||||
hr_out = feat_hr * flag_hr + g_hr * (~flag_hr)
|
||||
s2_out = feat_s2 * flag_s2 + g_s2 * (~flag_s2)
|
||||
s1_out = feat_s1 * flag_s1 + g_s1 * (~flag_s1)
|
||||
|
||||
output = {}
|
||||
|
||||
output['hr_out'] = hr_out
|
||||
output['s2_out'] = s2_out
|
||||
output['s1_out'] = s1_out
|
||||
|
||||
output['modality_info'] = modality_info
|
||||
|
||||
output['input_hr'] = feat_hr
|
||||
output['input_s2'] = feat_s2
|
||||
output['input_s1'] = feat_s1
|
||||
|
||||
output['logits_hr'] = logits_hr
|
||||
output['logits_s2'] = logits_s2
|
||||
output['logits_s1'] = logits_s1
|
||||
|
||||
output['soft_one_hot_hr'] = soft_one_hot_hr
|
||||
output['soft_one_hot_s2'] = soft_one_s2
|
||||
output['soft_one_hot_s1'] = soft_one_s1
|
||||
|
||||
output['g_hr'] = g_hr
|
||||
output['g_s2'] = g_s2
|
||||
output['g_s1'] = g_s1
|
||||
output['loss_quant'] = self.kl_loss(logits_hr, logits_s2, logits_s1, modality_info)
|
||||
|
||||
return output
|
||||
|
||||
144
lib/models/necks/transformer_encoder.py
Normal file
144
lib/models/necks/transformer_encoder.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright (c) Ant Group. All rights reserved.
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmcv.runner import (CheckpointLoader, load_state_dict)
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from mmseg.models.backbones.vit import TransformerEncoderLayer
|
||||
|
||||
from mmseg.utils import get_root_logger
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_dims=768,
|
||||
embed_dims=768,
|
||||
num_layers=4,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
with_cls_token=True,
|
||||
output_cls_token=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
init_cfg=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super(TransformerEncoder, self).__init__()
|
||||
|
||||
self.porj_linear = nn.Linear(input_dims, embed_dims)
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
|
||||
self.init_cfg = init_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for i in range(num_layers):
|
||||
self.layers.append(
|
||||
TransformerEncoderLayer(embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=mlp_ratio *
|
||||
embed_dims,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
batch_first=True))
|
||||
|
||||
def init_weights(self):
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
logger = get_root_logger()
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
_state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
_state_dict = checkpoint
|
||||
|
||||
state_dict = OrderedDict()
|
||||
for k, v in _state_dict.items():
|
||||
if k.startswith('backbone.'):
|
||||
state_dict[k[9:]] = v
|
||||
else:
|
||||
state_dict[k] = v
|
||||
|
||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def forward(self, inputs, require_feat: bool = False, require_two: bool = False):
|
||||
inputs = self.porj_linear(inputs)
|
||||
B, N, C = inputs.shape
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, inputs), dim=1)
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
# add hidden and atten state
|
||||
block_outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if require_feat:
|
||||
block_outs.append(x)
|
||||
|
||||
if self.output_cls_token:
|
||||
if require_two:
|
||||
x = x[:, :2]
|
||||
else:
|
||||
x = x[:, 0]
|
||||
elif not self.output_cls_token and self.with_cls_token:
|
||||
x = x # [:, :]
|
||||
|
||||
if require_feat:
|
||||
return x, block_outs
|
||||
else:
|
||||
return x
|
||||
|
||||
def train(self, mode=True):
|
||||
super(TransformerEncoder, self).train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.LayerNorm):
|
||||
m.eval()
|
||||
4
lib/models/segmentors/__init__.py
Normal file
4
lib/models/segmentors/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Ant Financial Service Group and its affiliates.
|
||||
from .skysense_pp_pipeline import SkySensePP
|
||||
|
||||
__all__ = ['SkySensePP']
|
||||
458
lib/models/segmentors/skysense_pp_pipeline.py
Normal file
458
lib/models/segmentors/skysense_pp_pipeline.py
Normal file
@@ -0,0 +1,458 @@
|
||||
# coding: utf-8
|
||||
# Copyright (c) Ant Group. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
import math
|
||||
import random
|
||||
from antmmf.common.registry import registry
|
||||
from antmmf.models.base_model import BaseModel
|
||||
from lib.models.backbones import build_backbone
|
||||
from lib.models.necks import build_neck
|
||||
from lib.models.heads import build_head
|
||||
from lib.utils.utils import LayerDecayValueAssigner
|
||||
|
||||
|
||||
@registry.register_model("SkySensePP")
|
||||
class SkySensePP(BaseModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.sources = config.sources
|
||||
assert len(self.sources) > 0, 'at least one data source is required'
|
||||
if 's2' in self.sources:
|
||||
self.use_ctpe = config.use_ctpe
|
||||
self.use_modal_vae = config.use_modal_vae
|
||||
self.use_cls_token_uper_head = config.use_cls_token_uper_head
|
||||
self.target_mean=[0.485, 0.456, 0.406]
|
||||
self.target_std=[0.229, 0.224, 0.225]
|
||||
self.vocabulary_size = config.vocabulary_size
|
||||
self.vocabulary = list(range(1, config.vocabulary_size + 1)) # 0 for ignore
|
||||
|
||||
|
||||
def build(self):
|
||||
if 'hr' in self.sources:
|
||||
self.backbone_hr = self._build_backbone('hr')
|
||||
if 's2' in self.sources:
|
||||
self.backbone_s2 = self._build_backbone('s2')
|
||||
if self.use_ctpe:
|
||||
self.ctpe = nn.Parameter(
|
||||
torch.zeros(1, self.config.calendar_time,
|
||||
self.config.necks.input_dims))
|
||||
if 'head_s2' in self.config.keys():
|
||||
self.head_s2 = self._build_head('head_s2')
|
||||
self.fusion = self._build_neck('necks')
|
||||
if 's1' in self.sources:
|
||||
self.backbone_s1 = self._build_backbone('s1')
|
||||
if 'head_s1' in self.config.keys():
|
||||
self.head_s1 = self._build_head('head_s1')
|
||||
self.head_rec_hr = self._build_head('rec_head_hr')
|
||||
|
||||
self.with_aux_head = False
|
||||
|
||||
if self.use_modal_vae:
|
||||
self.modality_vae = self._build_neck('modality_vae')
|
||||
if 'auxiliary_head' in self.config.keys():
|
||||
self.with_aux_head = True
|
||||
self.aux_head = self._build_head('auxiliary_head')
|
||||
if 'init_cfg' in self.config.keys(
|
||||
) and self.config.init_cfg is not None and self.config.init_cfg.checkpoint is not None and self.config.init_cfg.key is not None:
|
||||
self.load_pretrained(self.config.init_cfg.checkpoint,
|
||||
self.config.init_cfg.key)
|
||||
|
||||
def _build_backbone(self, key):
|
||||
config_dict = self.config[f'backbone_{key}'].to_dict()
|
||||
backbone_type = config_dict.pop('type')
|
||||
backbone = build_backbone(backbone_type, **config_dict)
|
||||
backbone.init_weights()
|
||||
return backbone
|
||||
|
||||
def _build_neck(self, key):
|
||||
config_dict = self.config[key].to_dict()
|
||||
neck_type = config_dict.pop('type')
|
||||
neck = build_neck(neck_type, **config_dict)
|
||||
neck.init_weights()
|
||||
return neck
|
||||
|
||||
def _build_head(self, key):
|
||||
head_config = self.config[key].to_dict()
|
||||
head_type = head_config.pop('type')
|
||||
head = build_head(head_type, **head_config)
|
||||
return head
|
||||
|
||||
def get_optimizer_parameters(self, config):
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [],
|
||||
"lr": config.optimizer_attributes.params.lr,
|
||||
"weight_decay": config.optimizer_attributes.params.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [],
|
||||
"lr": config.optimizer_attributes.params.lr,
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
layer_decay_value_assigner_hr = LayerDecayValueAssigner(
|
||||
config.lr_parameters.layer_decay, None,
|
||||
config.optimizer_attributes.params.lr, 'swin',
|
||||
config.model_attributes.SkySensePP.backbone_hr.arch
|
||||
)
|
||||
layer_decay_value_assigner_s2 = LayerDecayValueAssigner(
|
||||
config.lr_parameters.layer_decay, 24,
|
||||
config.optimizer_attributes.params.lr, 'vit',
|
||||
)
|
||||
layer_decay_value_assigner_s1 = LayerDecayValueAssigner(
|
||||
config.lr_parameters.layer_decay, 24,
|
||||
config.optimizer_attributes.params.lr, 'vit',
|
||||
)
|
||||
layer_decay_value_assigner_fusion = LayerDecayValueAssigner(
|
||||
config.lr_parameters.layer_decay, 24,
|
||||
config.optimizer_attributes.params.lr, 'vit',
|
||||
)
|
||||
num_frozen_params = 0
|
||||
if 'hr' in self.sources:
|
||||
print('hr'.center(60, '-'))
|
||||
num_frozen_params += layer_decay_value_assigner_hr.fix_param(
|
||||
self.backbone_hr,
|
||||
config.lr_parameters.frozen_blocks,
|
||||
)
|
||||
optimizer_grouped_parameters.extend(
|
||||
layer_decay_value_assigner_hr.get_parameter_groups(
|
||||
self.backbone_hr, config.optimizer_attributes.params.weight_decay
|
||||
)
|
||||
)
|
||||
if 's2' in self.sources:
|
||||
print('s2'.center(60, '-'))
|
||||
num_frozen_params += layer_decay_value_assigner_s2.fix_param(
|
||||
self.backbone_s2,
|
||||
config.lr_parameters.frozen_blocks,
|
||||
)
|
||||
optimizer_grouped_parameters.extend(
|
||||
layer_decay_value_assigner_s2.get_parameter_groups(
|
||||
self.backbone_s2, config.optimizer_attributes.params.weight_decay
|
||||
)
|
||||
)
|
||||
no_decay = [".bn.", "bias"]
|
||||
optimizer_grouped_parameters[0]["params"] += [
|
||||
p for n, p in self.head_s2.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
]
|
||||
optimizer_grouped_parameters[1]["params"] += [
|
||||
p for n, p in self.head_s2.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
]
|
||||
if self.use_ctpe:
|
||||
optimizer_grouped_parameters[1]["params"] += [self.ctpe]
|
||||
|
||||
if 's1' in self.sources:
|
||||
print('s1'.center(60, '-'))
|
||||
num_frozen_params += layer_decay_value_assigner_s1.fix_param(
|
||||
self.backbone_s1,
|
||||
config.lr_parameters.frozen_blocks,
|
||||
)
|
||||
optimizer_grouped_parameters.extend(
|
||||
layer_decay_value_assigner_s1.get_parameter_groups(
|
||||
self.backbone_s1, config.optimizer_attributes.params.weight_decay
|
||||
)
|
||||
)
|
||||
no_decay = [".bn.", "bias"]
|
||||
optimizer_grouped_parameters[0]["params"] += [
|
||||
p for n, p in self.head_s1.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
]
|
||||
optimizer_grouped_parameters[1]["params"] += [
|
||||
p for n, p in self.head_s1.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
]
|
||||
|
||||
if len(self.sources) > 1:
|
||||
print('fusion'.center(60, '-'))
|
||||
num_frozen_params += layer_decay_value_assigner_fusion.fix_param_deeper(
|
||||
self.fusion,
|
||||
config.lr_parameters.frozen_fusion_blocks_start, # 冻结后面所有的stage
|
||||
)
|
||||
optimizer_grouped_parameters.extend(
|
||||
layer_decay_value_assigner_fusion.get_parameter_groups(
|
||||
self.fusion, config.optimizer_attributes.params.weight_decay
|
||||
)
|
||||
)
|
||||
|
||||
if self.use_modal_vae:
|
||||
no_decay = [".bn.", "bias"]
|
||||
optimizer_grouped_parameters[0]["params"] += [
|
||||
p for n, p in self.modality_vae.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
]
|
||||
optimizer_grouped_parameters[1]["params"] += [
|
||||
p for n, p in self.modality_vae.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
]
|
||||
|
||||
no_decay = [".bn.", "bias"]
|
||||
optimizer_grouped_parameters[0]["params"] += [
|
||||
p for n, p in self.head_rec_hr.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
]
|
||||
optimizer_grouped_parameters[1]["params"] += [
|
||||
p for n, p in self.head_rec_hr.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
]
|
||||
|
||||
if self.with_aux_head:
|
||||
no_decay = [".bn.", "bias"]
|
||||
optimizer_grouped_parameters[0]["params"] += [
|
||||
p for n, p in self.aux_head.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
]
|
||||
optimizer_grouped_parameters[1]["params"] += [
|
||||
p for n, p in self.aux_head.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
]
|
||||
num_params = [len(x['params']) for x in optimizer_grouped_parameters]
|
||||
print(len(list(self.parameters())), sum(num_params), num_frozen_params)
|
||||
assert len(list(self.parameters())) == sum(num_params) + num_frozen_params
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
def get_custom_scheduler(self, trainer):
|
||||
optimizer = trainer.optimizer
|
||||
num_training_steps = trainer.config.training_parameters.max_iterations
|
||||
num_warmup_steps = trainer.config.training_parameters.num_warmup_steps
|
||||
|
||||
if "train" in trainer.run_type:
|
||||
if num_training_steps == math.inf:
|
||||
epoches = trainer.config.training_parameters.max_epochs
|
||||
assert epoches != math.inf
|
||||
num_training_steps = trainer.config.training_parameters.max_epochs * trainer.epoch_iterations
|
||||
|
||||
def linear_with_wram_up(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(
|
||||
1, num_warmup_steps))
|
||||
return max(
|
||||
0.0,
|
||||
float(num_training_steps - current_step) /
|
||||
float(max(1, num_training_steps - num_warmup_steps)),
|
||||
)
|
||||
|
||||
def cos_with_wram_up(current_step):
|
||||
num_cycles = 0.5
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(
|
||||
1, num_warmup_steps))
|
||||
progress = float(current_step - num_warmup_steps) / float(
|
||||
max(1, num_training_steps - num_warmup_steps))
|
||||
return max(
|
||||
0.0,
|
||||
0.5 *
|
||||
(1.0 +
|
||||
math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
|
||||
)
|
||||
|
||||
lr_lambda = cos_with_wram_up if trainer.config.training_parameters.cos_lr else linear_with_wram_up
|
||||
|
||||
else:
|
||||
def lr_lambda(current_step):
|
||||
return 0.0 # noqa
|
||||
|
||||
return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
def convert_target(self, target):
|
||||
mean = target.new_tensor(self.target_mean).reshape(1, 3, 1, 1)
|
||||
std = target.new_tensor(self.target_std).reshape(1, 3, 1, 1)
|
||||
target = ((target * std + mean)*255).to(torch.long)
|
||||
target[:, 0] = target[:, 0] * 256 * 256
|
||||
target[:, 1] = target[:, 1] * 256
|
||||
target = target.sum(1).type(torch.long)
|
||||
unique_target = target.unique()
|
||||
target_index = torch.searchsorted(unique_target, target)
|
||||
no_bg = False
|
||||
if unique_target[0].item() > 0:
|
||||
target_index += 1
|
||||
no_bg = True
|
||||
target_index_unique = target_index.unique().tolist()
|
||||
random.shuffle(self.vocabulary)
|
||||
value = target.new_tensor([0] + self.vocabulary)
|
||||
mapped_target = target_index.clone()
|
||||
idx_2_color = {}
|
||||
for v in target_index_unique:
|
||||
mapped_target[target_index == v] = value[v]
|
||||
idx_2_color[value[v].item()] = unique_target[v - 1 if no_bg else v].item()
|
||||
return mapped_target, idx_2_color
|
||||
|
||||
def forward(self, sample_list):
|
||||
output = dict()
|
||||
modality_flag_hr = sample_list["modality_flag_hr"]
|
||||
modality_flag_s2 = sample_list["modality_flag_s2"]
|
||||
modality_flag_s1 = sample_list["modality_flag_s1"]
|
||||
modalities = [modality_flag_hr, modality_flag_s2, modality_flag_s1]
|
||||
modalities = torch.tensor(modalities).permute(1,0).contiguous() # L, B => B, L
|
||||
|
||||
anno_img = sample_list["targets"]
|
||||
anno_img, idx_2_color = self.convert_target(anno_img)
|
||||
output["mapped_targets"] = anno_img
|
||||
output["idx_2_color"] = idx_2_color
|
||||
anno_mask = sample_list["anno_mask"]
|
||||
anno_s2 = anno_img[:, 15::32, 15::32]
|
||||
anno_s1 = anno_s2
|
||||
|
||||
output["anno_hr"] = anno_img
|
||||
output["anno_s2"] = anno_s2
|
||||
|
||||
### 1. backbone
|
||||
if 'hr' in self.sources:
|
||||
hr_img = sample_list["hr_img"]
|
||||
B_MASK, H_MASK, W_MASK = anno_mask.shape
|
||||
block_size = 32
|
||||
anno_mask_hr = anno_mask.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, block_size, block_size)
|
||||
anno_mask_hr = anno_mask_hr.permute(0, 1, 3, 2, 4).reshape(B_MASK, H_MASK*block_size, W_MASK*block_size).contiguous()
|
||||
B, C_G, H_G, W_G = hr_img.shape
|
||||
hr_features = self.backbone_hr(hr_img, anno_img, anno_mask_hr)
|
||||
output['mask_hr'] = anno_mask_hr
|
||||
output['target_hr'] = anno_img
|
||||
|
||||
if 's2' in self.sources:
|
||||
s2_img = sample_list["s2_img"]
|
||||
B, C_S2, S_S2, H_S2, W_S2 = s2_img.shape
|
||||
s2_img = s2_img.permute(0, 2, 1, 3,
|
||||
4).reshape(B * S_S2, C_S2, H_S2, W_S2).contiguous() # ts time to batch
|
||||
anno_mask_s2 = anno_mask
|
||||
s2_features = self.backbone_s2(s2_img, anno_s2, anno_mask_s2)
|
||||
if 'head_s2' in self.config.keys():
|
||||
s2_features = self.head_s2(s2_features[-1])
|
||||
s2_features = [s2_features]
|
||||
|
||||
if 's1' in self.sources:
|
||||
s1_img = sample_list["s1_img"]
|
||||
B, C_S1, S_S1, H_S1, W_S1 = s1_img.shape
|
||||
s1_img = s1_img.permute(0, 2, 1, 3,
|
||||
4).reshape(B * S_S1, C_S1, H_S1, W_S1).contiguous()
|
||||
|
||||
anno_mask_s1 = anno_mask
|
||||
s1_features = self.backbone_s1(s1_img, anno_s1, anno_mask_s1)
|
||||
if 'head_s1' in self.config.keys():
|
||||
s1_features = self.head_s1(s1_features[-1])
|
||||
s1_features = [s1_features]
|
||||
|
||||
### 2. prepare features for fusion
|
||||
hr_features_stage3 = hr_features[-1]
|
||||
s2_features_stage3 = s2_features[-1]
|
||||
s1_features_stage3 = s1_features[-1]
|
||||
modalities = modalities.to(hr_features_stage3.device)
|
||||
if self.use_modal_vae:
|
||||
vae_out = self.modality_vae(hr_features_stage3, s2_features_stage3, s1_features_stage3, modalities)
|
||||
hr_features_stage3 = vae_out['hr_out']
|
||||
s2_features_stage3 = vae_out['s2_out']
|
||||
s1_features_stage3 = vae_out['s1_out']
|
||||
output['vae_out'] = vae_out
|
||||
|
||||
features_stage3 = []
|
||||
if 'hr' in self.sources:
|
||||
B, C3_G, H3_G, W3_G = hr_features_stage3.shape
|
||||
hr_features_stage3 = hr_features_stage3.permute(
|
||||
0, 2, 3, 1).reshape(B * H3_G * W3_G, C3_G).unsqueeze(1).contiguous() # B * H3_G * W3_G, 1, C3_G
|
||||
features_stage3 = hr_features_stage3
|
||||
|
||||
if 's2' in self.sources:
|
||||
# s2_features_stage3 = s2_features[-1]
|
||||
_, C3_S2, H3_S2, W3_S2 = s2_features_stage3.shape
|
||||
s2_features_stage3 = s2_features_stage3.reshape(
|
||||
B, S_S2, C3_S2, H3_S2,
|
||||
W3_S2).permute(0, 3, 4, 1, 2).reshape(B, H3_S2 * W3_S2, S_S2,
|
||||
C3_S2).contiguous()
|
||||
if self.use_ctpe:
|
||||
ct_index = sample_list["s2_ct"]
|
||||
ctpe = self.ctpe[:, ct_index, :].contiguous().permute(1, 0, 2, 3).contiguous()
|
||||
ctpe = ctpe.expand(-1, 256, -1, -1)
|
||||
|
||||
ct_index_2 = sample_list["s2_ct2"]
|
||||
ctpe2 = self.ctpe[:, ct_index_2, :].contiguous().permute(1, 0, 2, 3).contiguous()
|
||||
ctpe2 = ctpe2.expand(-1, 256, -1, -1)
|
||||
|
||||
ctpe_comb = torch.cat([ctpe, ctpe2], 1)
|
||||
# import pdb;pdb.set_trace()
|
||||
s2_features_stage3 = (s2_features_stage3 + ctpe_comb).reshape(
|
||||
B * H3_S2 * W3_S2, S_S2, C3_S2).contiguous()
|
||||
else:
|
||||
s2_features_stage3 = s2_features_stage3.reshape(
|
||||
B * H3_S2 * W3_S2, S_S2, C3_S2).contiguous()
|
||||
|
||||
if len(features_stage3) > 0:
|
||||
assert H3_G == H3_S2 and W3_G == W3_S2 and C3_G == C3_S2
|
||||
features_stage3 = torch.cat((features_stage3, s2_features_stage3), dim=1)
|
||||
else:
|
||||
features_stage3 = s2_features_stage3
|
||||
|
||||
if 's1' in self.sources:
|
||||
# s1_features_stage3 = s1_features[-1]
|
||||
_, C3_S1, H3_S1, W3_S1 = s1_features_stage3.shape
|
||||
s1_features_stage3 = s1_features_stage3.reshape(
|
||||
B, S_S1, C3_S1, H3_S1,
|
||||
W3_S1).permute(0, 3, 4, 1, 2).reshape(B, H3_S1 * W3_S1, S_S1,
|
||||
C3_S1).contiguous()
|
||||
s1_features_stage3 = s1_features_stage3.reshape(
|
||||
B * H3_S1 * W3_S1, S_S1, C3_S1).contiguous()
|
||||
|
||||
if len(features_stage3) > 0:
|
||||
assert H3_S1 == H3_S2 and W3_S1 == W3_S2 and C3_S1 == C3_S2
|
||||
features_stage3 = torch.cat((features_stage3, s1_features_stage3),
|
||||
dim=1)
|
||||
else:
|
||||
features_stage3 = s1_features_stage3
|
||||
|
||||
### 3. fusion
|
||||
if self.config.necks.output_cls_token:
|
||||
if self.config.necks.get('require_feat', False):
|
||||
cls_token, block_outs = self.fusion(features_stage3 , True)
|
||||
else:
|
||||
cls_token = self.fusion(features_stage3)
|
||||
_, C3_G = cls_token.shape
|
||||
cls_token = cls_token.reshape(B, H3_G, W3_G,
|
||||
C3_G).contiguous().permute(0, 3, 1, 2).contiguous() # b, c, h, w
|
||||
else:
|
||||
assert self.config.necks.with_cls_token is False
|
||||
if self.config.necks.get('require_feat', False):
|
||||
features_stage3, block_outs = self.fusion(features_stage3, True)
|
||||
else:
|
||||
features_stage3 = self.fusion(features_stage3)
|
||||
features_stage3 = features_stage3.reshape(
|
||||
B, H3_S2, W3_S2, S_S2,
|
||||
C3_S2).permute(0, 3, 4, 1, 2).reshape(B * S_S2, C3_S2, H3_S2,
|
||||
W3_S2).contiguous()
|
||||
### 4. decoder for rec
|
||||
hr_rec_inputs = hr_features
|
||||
feat_stage1 = hr_rec_inputs[0]
|
||||
|
||||
if feat_stage1.shape[-1] == feat_stage1.shape[-2]:
|
||||
feat_stage1_left, feat_stage1_right = torch.split(feat_stage1, feat_stage1.shape[-1] // 2, dim=-1)
|
||||
feat_stage1 = torch.cat((feat_stage1_left, feat_stage1_right), dim=1)
|
||||
hr_rec_inputs = list(hr_features)
|
||||
hr_rec_inputs[0] = feat_stage1
|
||||
|
||||
rec_feats = [*hr_rec_inputs, cls_token]
|
||||
logits_hr = self.head_rec_hr(rec_feats)
|
||||
if self.config.get('upsacle_results', True):
|
||||
logits_hr = logits_hr.to(torch.float32)
|
||||
logits_hr = F.interpolate(logits_hr, scale_factor=4, mode='bilinear', align_corners=True)
|
||||
output["logits_hr"] = logits_hr
|
||||
return output
|
||||
|
||||
def load_pretrained(self, ckpt_path, key):
|
||||
pretrained_dict = torch.load(ckpt_path, map_location={'cuda:0': 'cpu'})
|
||||
pretrained_dict = pretrained_dict[key]
|
||||
for k, v in pretrained_dict.items():
|
||||
if k == 'backbone_s2.patch_embed.projection.weight':
|
||||
pretrained_in_channels = v.shape[1]
|
||||
if self.config.backbone_s2.in_channels == 4:
|
||||
new_weight = v[:, [0, 1, 2, 6]]
|
||||
new_weight = new_weight * (
|
||||
pretrained_in_channels /
|
||||
self.config.backbone_s2.in_channels)
|
||||
pretrained_dict[k] = new_weight
|
||||
missing_keys, unexpected_keys = self.load_state_dict(pretrained_dict,
|
||||
strict=False)
|
||||
print('missing_keys:', missing_keys)
|
||||
print('unexpected_keys:', unexpected_keys)
|
||||
|
||||
Reference in New Issue
Block a user