This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

135
lib/utils/checkpoint.py Normal file
View File

@@ -0,0 +1,135 @@
# Copyright (c) Ant Financial Service Group. and its affiliates.
import os
import warnings
import torch
from antmmf.common import constants
from antmmf.common.registry import registry
from antmmf.common.checkpoint import Checkpoint
from antmmf.utils.distributed_utils import is_main_process
class SegCheckpoint(Checkpoint):
def __init__(self, trainer, load_only=False):
super().__init__(trainer, load_only=False)
def load_model_weights(self, file, force=False):
self.trainer.writer.write("Loading checkpoint")
ckpt = self._torch_load(file)
if registry.get(constants.STATE) is constants.STATE_ONLINE_SERVING:
data_parallel = False
else:
data_parallel = registry.get("data_parallel") or registry.get(
"distributed")
if "model" in ckpt:
ckpt_model = ckpt["model"]
else:
ckpt_model = ckpt
ckpt = {"model": ckpt}
new_dict = {}
# TODO: Move to separate function
for attr in ckpt_model:
if "fa_history" in attr:
new_dict[attr.replace("fa_history",
"fa_context")] = ckpt_model[attr]
elif data_parallel is False and attr.startswith("module."):
new_k = attr.replace("module.", "", 1)
if '.Wqkv.' in new_k:
new_k = new_k.replace('.Wqkv.', '.in_proj_')
new_dict[new_k] = ckpt_model[attr]
elif data_parallel is not False and not attr.startswith("module."):
new_dict["module." + attr] = ckpt_model[attr]
elif data_parallel is False and not attr.startswith("module."):
print('data_parallel is False and not attr!!!')
new_k = attr
if '.Wqkv.' in new_k:
new_k = new_k.replace('.Wqkv.', '.in_proj_')
new_dict[new_k] = ckpt_model[attr]
else:
new_dict[attr] = ckpt_model[attr]
print(new_dict.keys())
self._load_state_dict(new_dict)
self._load_model_weights_with_mapping(new_dict, force=force)
print(f'load weight: {file} done!')
return ckpt
def _load(self, file, force=False, resume_state=False):
ckpt = self.load_model_weights(file, force=force)
# skip loading training state
if resume_state is False:
return
if "optimizer" in ckpt:
try:
self.trainer.optimizer.load_state_dict(ckpt["optimizer"])
# fix the bug of checkpoint in the pytorch with version higher than 1.11
if "capturable" in self.trainer.optimizer.param_groups[0]:
self.trainer.optimizer.param_groups[0]["capturable"] = True
except Exception as e:
print(e)
else:
warnings.warn(
"'optimizer' key is not present in the checkpoint asked to be loaded. Skipping."
)
if "lr_scheduler" in ckpt:
self.trainer.lr_scheduler.load_state_dict(ckpt["lr_scheduler"])
else:
warnings.warn(
"'lr_scheduler' key is not present in the checkpoint asked to be loaded. Skipping."
)
self.trainer.early_stopping.init_from_checkpoint(ckpt)
self.trainer.writer.write("Checkpoint {} loaded".format(file))
if "current_iteration" in ckpt:
self.trainer.current_iteration = ckpt["current_iteration"]
registry.register("current_iteration",
self.trainer.current_iteration)
if "current_epoch" in ckpt:
self.trainer.current_epoch = ckpt["current_epoch"]
registry.register("current_epoch", self.trainer.current_epoch)
def save(self, iteration, update_best=False):
if not is_main_process():
return
ckpt_filepath = os.path.join(self.models_foldername,
"model_%d.ckpt" % iteration)
best_ckpt_filepath = os.path.join(self.ckpt_foldername,
self.ckpt_prefix + "best.ckpt")
best_iteration = self.trainer.early_stopping.best_monitored_iteration
best_metric = self.trainer.early_stopping.best_monitored_value
current_iteration = self.trainer.current_iteration
current_epoch = self.trainer.current_epoch
model = self.trainer.model
data_parallel = registry.get("data_parallel") or registry.get(
"distributed")
if data_parallel is True:
model = model.module
ckpt = {
"model": model.state_dict(),
"optimizer": self.trainer.optimizer.state_dict(),
"lr_scheduler": self.trainer.lr_scheduler.state_dict(),
"current_iteration": current_iteration,
"current_epoch": current_epoch,
"best_iteration": best_iteration,
"best_metric_value": best_metric,
}
torch.save(ckpt, ckpt_filepath)
self.remove_redundant_ckpts()
if update_best:
torch.save(ckpt, best_ckpt_filepath)

122
lib/utils/optim_utils.py Normal file
View File

@@ -0,0 +1,122 @@
import torch
from torch.nn import LayerNorm, Linear, GELU
from torch.nn import MultiheadAttention, Sequential
import warnings
try:
from atorch.normalization import LayerNorm as FastLayerNorm
from atorch.modules.transformer.inject import replace_module
from atorch.modules.transformer.layers import MultiheadAttentionFA, BertAttentionFA
except (ImportError, ModuleNotFoundError) as e:
warnings.warn("Using replace_speedup_op but no atorch/apex installed:%s" % e)
try:
from transformers.models.bert.modeling_bert import BertAttention
replace_transformer_bert = True
except ImportError:
replace_transformer_bert = False
class DefaultStrategy:
replace_mha = True
replace_layernorm = True
replace_linear_gelu = False # TODO: numerical consistency
def replace_layer_norm(module: torch.nn.Module, cur_name: str):
for name, child in module.named_children():
child_name = cur_name + "." + name
if isinstance(child, LayerNorm):
new_module = FastLayerNorm(child.normalized_shape, eps=child.eps)
new_module.load_state_dict(child.state_dict())
setattr(module, name, new_module)
else:
replace_layer_norm(child, child_name)
def is_atorch_available(raise_error=True, log=None):
try:
import atorch # noqa: F401
return True
except ImportError as e:
if raise_error is True:
raise ImportError(e, log)
else:
return False
def _cast_if_autocast_enabled(*args):
if not torch.is_autocast_enabled():
return args
else:
return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())
def _fused_dense_gelu_dense(input, weight1, bias1, weight2, bias2):
batch, seq_length, hidden_size = input.size()
input = input.view(batch * seq_length, hidden_size)
args = _cast_if_autocast_enabled(input, weight1, bias1, weight2, bias2)
from apex.fused_dense import FusedDenseGeluDenseFunc # with cast
with torch.cuda.amp.autocast(enabled=False):
out = FusedDenseGeluDenseFunc.apply(*args)
out = out.view(batch, seq_length, -1)
return out
def linear_gelu_forward(input_, weight1, bias1, weight2, bias2):
return _fused_dense_gelu_dense(input_, weight1, bias1, weight2, bias2)
def replace_linear_gelu(module, cur_name: str):
"""
(layers): Sequential(
(0): Sequential(
(0): Linear(in_features=1536, out_features=6144, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
)
(1): Linear(in_features=6144, out_features=1536, bias=True)
(2): Dropout(p=0.0, inplace=False)
)
"""
for name, child in module.named_children():
child_name = cur_name + "." + name
if isinstance(child, Sequential):
if len(child) >= 2 and isinstance(
child[0], Sequential
) and isinstance(
child[1], Linear
) and len(child[0]
) >= 2 and isinstance(
child[0][0], Linear
) and isinstance(
child[0][1], GELU
): # Sequential+Linear
linear0 = child[0][0]
linear1 = child[1]
if getattr(child, "replace_linear_gelu", False):
continue
child.forward = lambda x: linear_gelu_forward(
x, linear0.weight, linear0.bias, linear1.weight, linear1.bias)
child.replace_linear_gelu = True
print("REPLACE linear+gelu:%s" % child_name)
# setattr(module, name, new_module)
else:
replace_linear_gelu(child, child_name)
def replace_speedup_op(model, strategy=DefaultStrategy):
if not is_atorch_available(raise_error=False):
raise ImportError("Install Atorch/apex before using speedup op")
if strategy.replace_mha:
model = replace_module(model, MultiheadAttention, MultiheadAttentionFA, need_scr_module=True)
if replace_transformer_bert:
model = replace_module(model, BertAttention, BertAttentionFA, need_scr_module=True)
root_name = model.__class__.__name__
if strategy.replace_layernorm:
replace_layer_norm(model, root_name) # inplace
if strategy.replace_linear_gelu:
replace_linear_gelu(model, root_name)
return model
# TODO:
# 1. SyncBatchNorm

243
lib/utils/utils.py Normal file
View File

@@ -0,0 +1,243 @@
import numpy as np
def cosine_scheduler(base_value,
final_value,
all_iters,
warmup_iters=0,
start_warmup_value=0):
warmup_schedule = np.array([])
if warmup_iters > 0:
warmup_schedule = np.linspace(start_warmup_value, base_value,
warmup_iters)
iters = np.arange(all_iters - warmup_iters)
schedule = final_value + 0.5 * (base_value - final_value) * (
1 + np.cos(np.pi * iters / len(iters)))
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == all_iters
return schedule
def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
if epoch >= freeze_last_layer:
return
for n, p in model.named_parameters():
if "last_layer" in n:
p.grad = None
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=float)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def cancel_gradients_backbone(iteration, model, freeze_backbone_steps):
if iteration >= freeze_backbone_steps:
return
for n, p in model.named_parameters():
if "backbon_hr" in n or 'backbon_s2' in n or 'head_s2' in n or 'fusion' in n or 'ctpe' in n:
p.grad = None
class EMA():
def __init__(self, model, decay):
self.model = model
self.decay = decay
self.shadow = {}
self.backup = {}
def register(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
new_average = (1.0 - self.decay
) * param.data + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone()
def apply_shadow(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
self.backup[name] = param.data
param.data = self.shadow[name]
def restore(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
class LayerDecayValueAssigner(object):
def __init__(self, layer_decay, num_layers, base_lr, net_type, arch='huge'):
assert net_type in ['swin', 'vit']
assert 0 < layer_decay <= 1
depths_dict = {
'tiny': [2, 2, 6, 2],
'small': [2, 2, 18, 2],
'base': [2, 2, 18, 2],
'large': [2, 2, 18, 2],
'huge': [2, 2, 18, 2],
'giant': [2, 2, 42, 4],
}
num_layers = num_layers if net_type == 'vit' else sum(depths_dict[arch])
self.layer_decay = layer_decay
self.values = list(layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))
self.depths = depths_dict[arch]
self.base_lr = base_lr
self.net_type = net_type
def get_num_layer_for_vit(self, var_name):
if var_name in ("cls_token", "mask_token", "pos_embed"):
return 0
elif var_name.startswith("patch_embed"):
return 0
elif var_name.startswith("layers"):
layer_id = int(var_name.split('.')[1])
return layer_id + 1
else:
return len(self.values) - 1
def get_num_layer_for_swin(self, var_name):
if var_name in ("mask_token", "pos_embed"):
return 0
elif var_name.startswith("patch_embed"):
return 0
elif var_name.startswith("stages"):
layer_id = int(var_name.split('.')[1])
if 'blocks' in var_name:
block_id = int(var_name.split('.')[3])
else:
block_id = self.depths[layer_id] - 1
layer_id = sum(self.depths[:layer_id]) + block_id
return layer_id + 1
else:
return len(self.values) - 1
def get_layer_id(self, var_name):
if self.net_type == 'swin':
return self.get_num_layer_for_swin(var_name)
if self.net_type == 'vit':
return self.get_num_layer_for_vit(var_name)
def fix_param(self, model, num_block=4):
if num_block < 1:
return 0
frozen_num = 0
if self.net_type == 'swin':
for name, param in model.named_parameters():
if name.startswith("patch_embed"):
param.requires_grad = False
frozen_num += 1
if name.startswith("stages") and self.get_layer_id(name) <= num_block:
param.requires_grad = False
frozen_num += 1
if self.net_type == 'vit':
for name, param in model.named_parameters():
if name.startswith("patch_embed"):
param.requires_grad = False
frozen_num += 1
if name.startswith("layers") and self.get_layer_id(name) <= num_block:
param.requires_grad = False
frozen_num += 1
return frozen_num
def fix_param_deeper(self, model, num_block=4):
if num_block < 1:
return 0
frozen_num = 0
if self.net_type == 'swin':
raise ValueError('Not Support')
if self.net_type == 'vit':
for name, param in model.named_parameters():
if name.startswith("patch_embed"):
param.requires_grad = False
frozen_num += 1
if name.startswith("layers") and self.get_layer_id(name) >= num_block:
param.requires_grad = False
frozen_num += 1
return frozen_num
def get_parameter_groups(self, model, weight_decay):
parameter_groups_with_wd, parameter_groups_without_wd = [], []
print_info_with_wd, print_info_without_wd = [], []
no_decay = [
"absolute_pos_embed", "relative_position_bias_table", "norm", "bias"
]
if self.layer_decay == 1:
parameter_groups_with_wd.append(
{"params": [], "weight_decay": weight_decay, "lr": self.base_lr}
)
print_info_with_wd.append(
{"params": [], "weight_decay": weight_decay, "lr": self.base_lr}
)
parameter_groups_without_wd.append(
{"params": [], "weight_decay": 0, "lr": self.base_lr}
)
print_info_without_wd.append(
{"params": [], "weight_decay": 0, "lr": self.base_lr}
)
else:
for scale in self.values:
parameter_groups_with_wd.append(
{"params": [], "weight_decay": weight_decay, "lr": scale * self.base_lr}
)
print_info_with_wd.append(
{"params": [], "weight_decay": weight_decay, "lr": scale * self.base_lr}
)
parameter_groups_without_wd.append(
{"params": [], "weight_decay": 0, "lr": scale * self.base_lr}
)
print_info_without_wd.append(
{"params": [], "weight_decay": 0, "lr": scale * self.base_lr}
)
for name, param in model.named_parameters():
if not param.requires_grad:
print(f'frozen param: {name}')
continue # frozen weights
layer_id = self.get_layer_id(name) if self.layer_decay < 1 else 0
if any(nd in name for nd in no_decay):
parameter_groups_without_wd[layer_id]['params'].append(param)
print_info_without_wd[layer_id]['params'].append(name)
else:
parameter_groups_with_wd[layer_id]['params'].append(param)
print_info_with_wd[layer_id]['params'].append(name)
parameter_groups_with_wd = [x for x in parameter_groups_with_wd if len(x['params']) > 0]
parameter_groups_without_wd = [x for x in parameter_groups_without_wd if len(x['params']) > 0]
print_info_with_wd = [x for x in print_info_with_wd if len(x['params']) > 0]
print_info_without_wd = [x for x in print_info_without_wd if len(x['params']) > 0]
if self.layer_decay < 1:
for wd, nwd in zip(print_info_with_wd, print_info_without_wd):
print(wd)
print(nwd)
parameter_groups = []
parameter_groups.extend(parameter_groups_with_wd)
parameter_groups.extend(parameter_groups_without_wd)
return parameter_groups