init
This commit is contained in:
458
lib/models/segmentors/skysense_pp_pipeline.py
Normal file
458
lib/models/segmentors/skysense_pp_pipeline.py
Normal file
@@ -0,0 +1,458 @@
|
||||
# coding: utf-8
|
||||
# Copyright (c) Ant Group. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
import math
|
||||
import random
|
||||
from antmmf.common.registry import registry
|
||||
from antmmf.models.base_model import BaseModel
|
||||
from lib.models.backbones import build_backbone
|
||||
from lib.models.necks import build_neck
|
||||
from lib.models.heads import build_head
|
||||
from lib.utils.utils import LayerDecayValueAssigner
|
||||
|
||||
|
||||
@registry.register_model("SkySensePP")
|
||||
class SkySensePP(BaseModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.sources = config.sources
|
||||
assert len(self.sources) > 0, 'at least one data source is required'
|
||||
if 's2' in self.sources:
|
||||
self.use_ctpe = config.use_ctpe
|
||||
self.use_modal_vae = config.use_modal_vae
|
||||
self.use_cls_token_uper_head = config.use_cls_token_uper_head
|
||||
self.target_mean=[0.485, 0.456, 0.406]
|
||||
self.target_std=[0.229, 0.224, 0.225]
|
||||
self.vocabulary_size = config.vocabulary_size
|
||||
self.vocabulary = list(range(1, config.vocabulary_size + 1)) # 0 for ignore
|
||||
|
||||
|
||||
def build(self):
|
||||
if 'hr' in self.sources:
|
||||
self.backbone_hr = self._build_backbone('hr')
|
||||
if 's2' in self.sources:
|
||||
self.backbone_s2 = self._build_backbone('s2')
|
||||
if self.use_ctpe:
|
||||
self.ctpe = nn.Parameter(
|
||||
torch.zeros(1, self.config.calendar_time,
|
||||
self.config.necks.input_dims))
|
||||
if 'head_s2' in self.config.keys():
|
||||
self.head_s2 = self._build_head('head_s2')
|
||||
self.fusion = self._build_neck('necks')
|
||||
if 's1' in self.sources:
|
||||
self.backbone_s1 = self._build_backbone('s1')
|
||||
if 'head_s1' in self.config.keys():
|
||||
self.head_s1 = self._build_head('head_s1')
|
||||
self.head_rec_hr = self._build_head('rec_head_hr')
|
||||
|
||||
self.with_aux_head = False
|
||||
|
||||
if self.use_modal_vae:
|
||||
self.modality_vae = self._build_neck('modality_vae')
|
||||
if 'auxiliary_head' in self.config.keys():
|
||||
self.with_aux_head = True
|
||||
self.aux_head = self._build_head('auxiliary_head')
|
||||
if 'init_cfg' in self.config.keys(
|
||||
) and self.config.init_cfg is not None and self.config.init_cfg.checkpoint is not None and self.config.init_cfg.key is not None:
|
||||
self.load_pretrained(self.config.init_cfg.checkpoint,
|
||||
self.config.init_cfg.key)
|
||||
|
||||
def _build_backbone(self, key):
|
||||
config_dict = self.config[f'backbone_{key}'].to_dict()
|
||||
backbone_type = config_dict.pop('type')
|
||||
backbone = build_backbone(backbone_type, **config_dict)
|
||||
backbone.init_weights()
|
||||
return backbone
|
||||
|
||||
def _build_neck(self, key):
|
||||
config_dict = self.config[key].to_dict()
|
||||
neck_type = config_dict.pop('type')
|
||||
neck = build_neck(neck_type, **config_dict)
|
||||
neck.init_weights()
|
||||
return neck
|
||||
|
||||
def _build_head(self, key):
|
||||
head_config = self.config[key].to_dict()
|
||||
head_type = head_config.pop('type')
|
||||
head = build_head(head_type, **head_config)
|
||||
return head
|
||||
|
||||
def get_optimizer_parameters(self, config):
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [],
|
||||
"lr": config.optimizer_attributes.params.lr,
|
||||
"weight_decay": config.optimizer_attributes.params.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [],
|
||||
"lr": config.optimizer_attributes.params.lr,
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
layer_decay_value_assigner_hr = LayerDecayValueAssigner(
|
||||
config.lr_parameters.layer_decay, None,
|
||||
config.optimizer_attributes.params.lr, 'swin',
|
||||
config.model_attributes.SkySensePP.backbone_hr.arch
|
||||
)
|
||||
layer_decay_value_assigner_s2 = LayerDecayValueAssigner(
|
||||
config.lr_parameters.layer_decay, 24,
|
||||
config.optimizer_attributes.params.lr, 'vit',
|
||||
)
|
||||
layer_decay_value_assigner_s1 = LayerDecayValueAssigner(
|
||||
config.lr_parameters.layer_decay, 24,
|
||||
config.optimizer_attributes.params.lr, 'vit',
|
||||
)
|
||||
layer_decay_value_assigner_fusion = LayerDecayValueAssigner(
|
||||
config.lr_parameters.layer_decay, 24,
|
||||
config.optimizer_attributes.params.lr, 'vit',
|
||||
)
|
||||
num_frozen_params = 0
|
||||
if 'hr' in self.sources:
|
||||
print('hr'.center(60, '-'))
|
||||
num_frozen_params += layer_decay_value_assigner_hr.fix_param(
|
||||
self.backbone_hr,
|
||||
config.lr_parameters.frozen_blocks,
|
||||
)
|
||||
optimizer_grouped_parameters.extend(
|
||||
layer_decay_value_assigner_hr.get_parameter_groups(
|
||||
self.backbone_hr, config.optimizer_attributes.params.weight_decay
|
||||
)
|
||||
)
|
||||
if 's2' in self.sources:
|
||||
print('s2'.center(60, '-'))
|
||||
num_frozen_params += layer_decay_value_assigner_s2.fix_param(
|
||||
self.backbone_s2,
|
||||
config.lr_parameters.frozen_blocks,
|
||||
)
|
||||
optimizer_grouped_parameters.extend(
|
||||
layer_decay_value_assigner_s2.get_parameter_groups(
|
||||
self.backbone_s2, config.optimizer_attributes.params.weight_decay
|
||||
)
|
||||
)
|
||||
no_decay = [".bn.", "bias"]
|
||||
optimizer_grouped_parameters[0]["params"] += [
|
||||
p for n, p in self.head_s2.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
]
|
||||
optimizer_grouped_parameters[1]["params"] += [
|
||||
p for n, p in self.head_s2.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
]
|
||||
if self.use_ctpe:
|
||||
optimizer_grouped_parameters[1]["params"] += [self.ctpe]
|
||||
|
||||
if 's1' in self.sources:
|
||||
print('s1'.center(60, '-'))
|
||||
num_frozen_params += layer_decay_value_assigner_s1.fix_param(
|
||||
self.backbone_s1,
|
||||
config.lr_parameters.frozen_blocks,
|
||||
)
|
||||
optimizer_grouped_parameters.extend(
|
||||
layer_decay_value_assigner_s1.get_parameter_groups(
|
||||
self.backbone_s1, config.optimizer_attributes.params.weight_decay
|
||||
)
|
||||
)
|
||||
no_decay = [".bn.", "bias"]
|
||||
optimizer_grouped_parameters[0]["params"] += [
|
||||
p for n, p in self.head_s1.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
]
|
||||
optimizer_grouped_parameters[1]["params"] += [
|
||||
p for n, p in self.head_s1.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
]
|
||||
|
||||
if len(self.sources) > 1:
|
||||
print('fusion'.center(60, '-'))
|
||||
num_frozen_params += layer_decay_value_assigner_fusion.fix_param_deeper(
|
||||
self.fusion,
|
||||
config.lr_parameters.frozen_fusion_blocks_start, # 冻结后面所有的stage
|
||||
)
|
||||
optimizer_grouped_parameters.extend(
|
||||
layer_decay_value_assigner_fusion.get_parameter_groups(
|
||||
self.fusion, config.optimizer_attributes.params.weight_decay
|
||||
)
|
||||
)
|
||||
|
||||
if self.use_modal_vae:
|
||||
no_decay = [".bn.", "bias"]
|
||||
optimizer_grouped_parameters[0]["params"] += [
|
||||
p for n, p in self.modality_vae.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
]
|
||||
optimizer_grouped_parameters[1]["params"] += [
|
||||
p for n, p in self.modality_vae.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
]
|
||||
|
||||
no_decay = [".bn.", "bias"]
|
||||
optimizer_grouped_parameters[0]["params"] += [
|
||||
p for n, p in self.head_rec_hr.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
]
|
||||
optimizer_grouped_parameters[1]["params"] += [
|
||||
p for n, p in self.head_rec_hr.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
]
|
||||
|
||||
if self.with_aux_head:
|
||||
no_decay = [".bn.", "bias"]
|
||||
optimizer_grouped_parameters[0]["params"] += [
|
||||
p for n, p in self.aux_head.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
]
|
||||
optimizer_grouped_parameters[1]["params"] += [
|
||||
p for n, p in self.aux_head.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
]
|
||||
num_params = [len(x['params']) for x in optimizer_grouped_parameters]
|
||||
print(len(list(self.parameters())), sum(num_params), num_frozen_params)
|
||||
assert len(list(self.parameters())) == sum(num_params) + num_frozen_params
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
def get_custom_scheduler(self, trainer):
|
||||
optimizer = trainer.optimizer
|
||||
num_training_steps = trainer.config.training_parameters.max_iterations
|
||||
num_warmup_steps = trainer.config.training_parameters.num_warmup_steps
|
||||
|
||||
if "train" in trainer.run_type:
|
||||
if num_training_steps == math.inf:
|
||||
epoches = trainer.config.training_parameters.max_epochs
|
||||
assert epoches != math.inf
|
||||
num_training_steps = trainer.config.training_parameters.max_epochs * trainer.epoch_iterations
|
||||
|
||||
def linear_with_wram_up(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(
|
||||
1, num_warmup_steps))
|
||||
return max(
|
||||
0.0,
|
||||
float(num_training_steps - current_step) /
|
||||
float(max(1, num_training_steps - num_warmup_steps)),
|
||||
)
|
||||
|
||||
def cos_with_wram_up(current_step):
|
||||
num_cycles = 0.5
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(
|
||||
1, num_warmup_steps))
|
||||
progress = float(current_step - num_warmup_steps) / float(
|
||||
max(1, num_training_steps - num_warmup_steps))
|
||||
return max(
|
||||
0.0,
|
||||
0.5 *
|
||||
(1.0 +
|
||||
math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
|
||||
)
|
||||
|
||||
lr_lambda = cos_with_wram_up if trainer.config.training_parameters.cos_lr else linear_with_wram_up
|
||||
|
||||
else:
|
||||
def lr_lambda(current_step):
|
||||
return 0.0 # noqa
|
||||
|
||||
return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
def convert_target(self, target):
|
||||
mean = target.new_tensor(self.target_mean).reshape(1, 3, 1, 1)
|
||||
std = target.new_tensor(self.target_std).reshape(1, 3, 1, 1)
|
||||
target = ((target * std + mean)*255).to(torch.long)
|
||||
target[:, 0] = target[:, 0] * 256 * 256
|
||||
target[:, 1] = target[:, 1] * 256
|
||||
target = target.sum(1).type(torch.long)
|
||||
unique_target = target.unique()
|
||||
target_index = torch.searchsorted(unique_target, target)
|
||||
no_bg = False
|
||||
if unique_target[0].item() > 0:
|
||||
target_index += 1
|
||||
no_bg = True
|
||||
target_index_unique = target_index.unique().tolist()
|
||||
random.shuffle(self.vocabulary)
|
||||
value = target.new_tensor([0] + self.vocabulary)
|
||||
mapped_target = target_index.clone()
|
||||
idx_2_color = {}
|
||||
for v in target_index_unique:
|
||||
mapped_target[target_index == v] = value[v]
|
||||
idx_2_color[value[v].item()] = unique_target[v - 1 if no_bg else v].item()
|
||||
return mapped_target, idx_2_color
|
||||
|
||||
def forward(self, sample_list):
|
||||
output = dict()
|
||||
modality_flag_hr = sample_list["modality_flag_hr"]
|
||||
modality_flag_s2 = sample_list["modality_flag_s2"]
|
||||
modality_flag_s1 = sample_list["modality_flag_s1"]
|
||||
modalities = [modality_flag_hr, modality_flag_s2, modality_flag_s1]
|
||||
modalities = torch.tensor(modalities).permute(1,0).contiguous() # L, B => B, L
|
||||
|
||||
anno_img = sample_list["targets"]
|
||||
anno_img, idx_2_color = self.convert_target(anno_img)
|
||||
output["mapped_targets"] = anno_img
|
||||
output["idx_2_color"] = idx_2_color
|
||||
anno_mask = sample_list["anno_mask"]
|
||||
anno_s2 = anno_img[:, 15::32, 15::32]
|
||||
anno_s1 = anno_s2
|
||||
|
||||
output["anno_hr"] = anno_img
|
||||
output["anno_s2"] = anno_s2
|
||||
|
||||
### 1. backbone
|
||||
if 'hr' in self.sources:
|
||||
hr_img = sample_list["hr_img"]
|
||||
B_MASK, H_MASK, W_MASK = anno_mask.shape
|
||||
block_size = 32
|
||||
anno_mask_hr = anno_mask.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, block_size, block_size)
|
||||
anno_mask_hr = anno_mask_hr.permute(0, 1, 3, 2, 4).reshape(B_MASK, H_MASK*block_size, W_MASK*block_size).contiguous()
|
||||
B, C_G, H_G, W_G = hr_img.shape
|
||||
hr_features = self.backbone_hr(hr_img, anno_img, anno_mask_hr)
|
||||
output['mask_hr'] = anno_mask_hr
|
||||
output['target_hr'] = anno_img
|
||||
|
||||
if 's2' in self.sources:
|
||||
s2_img = sample_list["s2_img"]
|
||||
B, C_S2, S_S2, H_S2, W_S2 = s2_img.shape
|
||||
s2_img = s2_img.permute(0, 2, 1, 3,
|
||||
4).reshape(B * S_S2, C_S2, H_S2, W_S2).contiguous() # ts time to batch
|
||||
anno_mask_s2 = anno_mask
|
||||
s2_features = self.backbone_s2(s2_img, anno_s2, anno_mask_s2)
|
||||
if 'head_s2' in self.config.keys():
|
||||
s2_features = self.head_s2(s2_features[-1])
|
||||
s2_features = [s2_features]
|
||||
|
||||
if 's1' in self.sources:
|
||||
s1_img = sample_list["s1_img"]
|
||||
B, C_S1, S_S1, H_S1, W_S1 = s1_img.shape
|
||||
s1_img = s1_img.permute(0, 2, 1, 3,
|
||||
4).reshape(B * S_S1, C_S1, H_S1, W_S1).contiguous()
|
||||
|
||||
anno_mask_s1 = anno_mask
|
||||
s1_features = self.backbone_s1(s1_img, anno_s1, anno_mask_s1)
|
||||
if 'head_s1' in self.config.keys():
|
||||
s1_features = self.head_s1(s1_features[-1])
|
||||
s1_features = [s1_features]
|
||||
|
||||
### 2. prepare features for fusion
|
||||
hr_features_stage3 = hr_features[-1]
|
||||
s2_features_stage3 = s2_features[-1]
|
||||
s1_features_stage3 = s1_features[-1]
|
||||
modalities = modalities.to(hr_features_stage3.device)
|
||||
if self.use_modal_vae:
|
||||
vae_out = self.modality_vae(hr_features_stage3, s2_features_stage3, s1_features_stage3, modalities)
|
||||
hr_features_stage3 = vae_out['hr_out']
|
||||
s2_features_stage3 = vae_out['s2_out']
|
||||
s1_features_stage3 = vae_out['s1_out']
|
||||
output['vae_out'] = vae_out
|
||||
|
||||
features_stage3 = []
|
||||
if 'hr' in self.sources:
|
||||
B, C3_G, H3_G, W3_G = hr_features_stage3.shape
|
||||
hr_features_stage3 = hr_features_stage3.permute(
|
||||
0, 2, 3, 1).reshape(B * H3_G * W3_G, C3_G).unsqueeze(1).contiguous() # B * H3_G * W3_G, 1, C3_G
|
||||
features_stage3 = hr_features_stage3
|
||||
|
||||
if 's2' in self.sources:
|
||||
# s2_features_stage3 = s2_features[-1]
|
||||
_, C3_S2, H3_S2, W3_S2 = s2_features_stage3.shape
|
||||
s2_features_stage3 = s2_features_stage3.reshape(
|
||||
B, S_S2, C3_S2, H3_S2,
|
||||
W3_S2).permute(0, 3, 4, 1, 2).reshape(B, H3_S2 * W3_S2, S_S2,
|
||||
C3_S2).contiguous()
|
||||
if self.use_ctpe:
|
||||
ct_index = sample_list["s2_ct"]
|
||||
ctpe = self.ctpe[:, ct_index, :].contiguous().permute(1, 0, 2, 3).contiguous()
|
||||
ctpe = ctpe.expand(-1, 256, -1, -1)
|
||||
|
||||
ct_index_2 = sample_list["s2_ct2"]
|
||||
ctpe2 = self.ctpe[:, ct_index_2, :].contiguous().permute(1, 0, 2, 3).contiguous()
|
||||
ctpe2 = ctpe2.expand(-1, 256, -1, -1)
|
||||
|
||||
ctpe_comb = torch.cat([ctpe, ctpe2], 1)
|
||||
# import pdb;pdb.set_trace()
|
||||
s2_features_stage3 = (s2_features_stage3 + ctpe_comb).reshape(
|
||||
B * H3_S2 * W3_S2, S_S2, C3_S2).contiguous()
|
||||
else:
|
||||
s2_features_stage3 = s2_features_stage3.reshape(
|
||||
B * H3_S2 * W3_S2, S_S2, C3_S2).contiguous()
|
||||
|
||||
if len(features_stage3) > 0:
|
||||
assert H3_G == H3_S2 and W3_G == W3_S2 and C3_G == C3_S2
|
||||
features_stage3 = torch.cat((features_stage3, s2_features_stage3), dim=1)
|
||||
else:
|
||||
features_stage3 = s2_features_stage3
|
||||
|
||||
if 's1' in self.sources:
|
||||
# s1_features_stage3 = s1_features[-1]
|
||||
_, C3_S1, H3_S1, W3_S1 = s1_features_stage3.shape
|
||||
s1_features_stage3 = s1_features_stage3.reshape(
|
||||
B, S_S1, C3_S1, H3_S1,
|
||||
W3_S1).permute(0, 3, 4, 1, 2).reshape(B, H3_S1 * W3_S1, S_S1,
|
||||
C3_S1).contiguous()
|
||||
s1_features_stage3 = s1_features_stage3.reshape(
|
||||
B * H3_S1 * W3_S1, S_S1, C3_S1).contiguous()
|
||||
|
||||
if len(features_stage3) > 0:
|
||||
assert H3_S1 == H3_S2 and W3_S1 == W3_S2 and C3_S1 == C3_S2
|
||||
features_stage3 = torch.cat((features_stage3, s1_features_stage3),
|
||||
dim=1)
|
||||
else:
|
||||
features_stage3 = s1_features_stage3
|
||||
|
||||
### 3. fusion
|
||||
if self.config.necks.output_cls_token:
|
||||
if self.config.necks.get('require_feat', False):
|
||||
cls_token, block_outs = self.fusion(features_stage3 , True)
|
||||
else:
|
||||
cls_token = self.fusion(features_stage3)
|
||||
_, C3_G = cls_token.shape
|
||||
cls_token = cls_token.reshape(B, H3_G, W3_G,
|
||||
C3_G).contiguous().permute(0, 3, 1, 2).contiguous() # b, c, h, w
|
||||
else:
|
||||
assert self.config.necks.with_cls_token is False
|
||||
if self.config.necks.get('require_feat', False):
|
||||
features_stage3, block_outs = self.fusion(features_stage3, True)
|
||||
else:
|
||||
features_stage3 = self.fusion(features_stage3)
|
||||
features_stage3 = features_stage3.reshape(
|
||||
B, H3_S2, W3_S2, S_S2,
|
||||
C3_S2).permute(0, 3, 4, 1, 2).reshape(B * S_S2, C3_S2, H3_S2,
|
||||
W3_S2).contiguous()
|
||||
### 4. decoder for rec
|
||||
hr_rec_inputs = hr_features
|
||||
feat_stage1 = hr_rec_inputs[0]
|
||||
|
||||
if feat_stage1.shape[-1] == feat_stage1.shape[-2]:
|
||||
feat_stage1_left, feat_stage1_right = torch.split(feat_stage1, feat_stage1.shape[-1] // 2, dim=-1)
|
||||
feat_stage1 = torch.cat((feat_stage1_left, feat_stage1_right), dim=1)
|
||||
hr_rec_inputs = list(hr_features)
|
||||
hr_rec_inputs[0] = feat_stage1
|
||||
|
||||
rec_feats = [*hr_rec_inputs, cls_token]
|
||||
logits_hr = self.head_rec_hr(rec_feats)
|
||||
if self.config.get('upsacle_results', True):
|
||||
logits_hr = logits_hr.to(torch.float32)
|
||||
logits_hr = F.interpolate(logits_hr, scale_factor=4, mode='bilinear', align_corners=True)
|
||||
output["logits_hr"] = logits_hr
|
||||
return output
|
||||
|
||||
def load_pretrained(self, ckpt_path, key):
|
||||
pretrained_dict = torch.load(ckpt_path, map_location={'cuda:0': 'cpu'})
|
||||
pretrained_dict = pretrained_dict[key]
|
||||
for k, v in pretrained_dict.items():
|
||||
if k == 'backbone_s2.patch_embed.projection.weight':
|
||||
pretrained_in_channels = v.shape[1]
|
||||
if self.config.backbone_s2.in_channels == 4:
|
||||
new_weight = v[:, [0, 1, 2, 6]]
|
||||
new_weight = new_weight * (
|
||||
pretrained_in_channels /
|
||||
self.config.backbone_s2.in_channels)
|
||||
pretrained_dict[k] = new_weight
|
||||
missing_keys, unexpected_keys = self.load_state_dict(pretrained_dict,
|
||||
strict=False)
|
||||
print('missing_keys:', missing_keys)
|
||||
print('unexpected_keys:', unexpected_keys)
|
||||
|
||||
Reference in New Issue
Block a user