init
This commit is contained in:
135
lib/utils/checkpoint.py
Normal file
135
lib/utils/checkpoint.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# Copyright (c) Ant Financial Service Group. and its affiliates.
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from antmmf.common import constants
|
||||
from antmmf.common.registry import registry
|
||||
from antmmf.common.checkpoint import Checkpoint
|
||||
from antmmf.utils.distributed_utils import is_main_process
|
||||
|
||||
class SegCheckpoint(Checkpoint):
|
||||
def __init__(self, trainer, load_only=False):
|
||||
super().__init__(trainer, load_only=False)
|
||||
|
||||
def load_model_weights(self, file, force=False):
|
||||
self.trainer.writer.write("Loading checkpoint")
|
||||
ckpt = self._torch_load(file)
|
||||
if registry.get(constants.STATE) is constants.STATE_ONLINE_SERVING:
|
||||
data_parallel = False
|
||||
else:
|
||||
data_parallel = registry.get("data_parallel") or registry.get(
|
||||
"distributed")
|
||||
|
||||
if "model" in ckpt:
|
||||
ckpt_model = ckpt["model"]
|
||||
else:
|
||||
ckpt_model = ckpt
|
||||
ckpt = {"model": ckpt}
|
||||
|
||||
new_dict = {}
|
||||
|
||||
# TODO: Move to separate function
|
||||
for attr in ckpt_model:
|
||||
if "fa_history" in attr:
|
||||
new_dict[attr.replace("fa_history",
|
||||
"fa_context")] = ckpt_model[attr]
|
||||
elif data_parallel is False and attr.startswith("module."):
|
||||
new_k = attr.replace("module.", "", 1)
|
||||
if '.Wqkv.' in new_k:
|
||||
new_k = new_k.replace('.Wqkv.', '.in_proj_')
|
||||
|
||||
new_dict[new_k] = ckpt_model[attr]
|
||||
elif data_parallel is not False and not attr.startswith("module."):
|
||||
new_dict["module." + attr] = ckpt_model[attr]
|
||||
elif data_parallel is False and not attr.startswith("module."):
|
||||
print('data_parallel is False and not attr!!!')
|
||||
new_k = attr
|
||||
if '.Wqkv.' in new_k:
|
||||
new_k = new_k.replace('.Wqkv.', '.in_proj_')
|
||||
new_dict[new_k] = ckpt_model[attr]
|
||||
else:
|
||||
new_dict[attr] = ckpt_model[attr]
|
||||
print(new_dict.keys())
|
||||
self._load_state_dict(new_dict)
|
||||
self._load_model_weights_with_mapping(new_dict, force=force)
|
||||
print(f'load weight: {file} done!')
|
||||
return ckpt
|
||||
|
||||
def _load(self, file, force=False, resume_state=False):
|
||||
ckpt = self.load_model_weights(file, force=force)
|
||||
|
||||
# skip loading training state
|
||||
if resume_state is False:
|
||||
return
|
||||
|
||||
if "optimizer" in ckpt:
|
||||
try:
|
||||
self.trainer.optimizer.load_state_dict(ckpt["optimizer"])
|
||||
# fix the bug of checkpoint in the pytorch with version higher than 1.11
|
||||
if "capturable" in self.trainer.optimizer.param_groups[0]:
|
||||
self.trainer.optimizer.param_groups[0]["capturable"] = True
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
else:
|
||||
warnings.warn(
|
||||
"'optimizer' key is not present in the checkpoint asked to be loaded. Skipping."
|
||||
)
|
||||
|
||||
if "lr_scheduler" in ckpt:
|
||||
self.trainer.lr_scheduler.load_state_dict(ckpt["lr_scheduler"])
|
||||
else:
|
||||
warnings.warn(
|
||||
"'lr_scheduler' key is not present in the checkpoint asked to be loaded. Skipping."
|
||||
)
|
||||
|
||||
self.trainer.early_stopping.init_from_checkpoint(ckpt)
|
||||
|
||||
self.trainer.writer.write("Checkpoint {} loaded".format(file))
|
||||
|
||||
if "current_iteration" in ckpt:
|
||||
self.trainer.current_iteration = ckpt["current_iteration"]
|
||||
registry.register("current_iteration",
|
||||
self.trainer.current_iteration)
|
||||
|
||||
if "current_epoch" in ckpt:
|
||||
self.trainer.current_epoch = ckpt["current_epoch"]
|
||||
registry.register("current_epoch", self.trainer.current_epoch)
|
||||
|
||||
def save(self, iteration, update_best=False):
|
||||
if not is_main_process():
|
||||
return
|
||||
|
||||
ckpt_filepath = os.path.join(self.models_foldername,
|
||||
"model_%d.ckpt" % iteration)
|
||||
best_ckpt_filepath = os.path.join(self.ckpt_foldername,
|
||||
self.ckpt_prefix + "best.ckpt")
|
||||
|
||||
best_iteration = self.trainer.early_stopping.best_monitored_iteration
|
||||
best_metric = self.trainer.early_stopping.best_monitored_value
|
||||
current_iteration = self.trainer.current_iteration
|
||||
current_epoch = self.trainer.current_epoch
|
||||
model = self.trainer.model
|
||||
data_parallel = registry.get("data_parallel") or registry.get(
|
||||
"distributed")
|
||||
|
||||
if data_parallel is True:
|
||||
model = model.module
|
||||
|
||||
ckpt = {
|
||||
"model": model.state_dict(),
|
||||
"optimizer": self.trainer.optimizer.state_dict(),
|
||||
"lr_scheduler": self.trainer.lr_scheduler.state_dict(),
|
||||
"current_iteration": current_iteration,
|
||||
"current_epoch": current_epoch,
|
||||
"best_iteration": best_iteration,
|
||||
"best_metric_value": best_metric,
|
||||
}
|
||||
|
||||
torch.save(ckpt, ckpt_filepath)
|
||||
self.remove_redundant_ckpts()
|
||||
|
||||
if update_best:
|
||||
torch.save(ckpt, best_ckpt_filepath)
|
||||
122
lib/utils/optim_utils.py
Normal file
122
lib/utils/optim_utils.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import torch
|
||||
from torch.nn import LayerNorm, Linear, GELU
|
||||
from torch.nn import MultiheadAttention, Sequential
|
||||
import warnings
|
||||
try:
|
||||
from atorch.normalization import LayerNorm as FastLayerNorm
|
||||
from atorch.modules.transformer.inject import replace_module
|
||||
from atorch.modules.transformer.layers import MultiheadAttentionFA, BertAttentionFA
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
warnings.warn("Using replace_speedup_op but no atorch/apex installed:%s" % e)
|
||||
try:
|
||||
from transformers.models.bert.modeling_bert import BertAttention
|
||||
replace_transformer_bert = True
|
||||
|
||||
except ImportError:
|
||||
replace_transformer_bert = False
|
||||
|
||||
|
||||
class DefaultStrategy:
|
||||
replace_mha = True
|
||||
replace_layernorm = True
|
||||
replace_linear_gelu = False # TODO: numerical consistency
|
||||
|
||||
|
||||
def replace_layer_norm(module: torch.nn.Module, cur_name: str):
|
||||
|
||||
for name, child in module.named_children():
|
||||
child_name = cur_name + "." + name
|
||||
if isinstance(child, LayerNorm):
|
||||
new_module = FastLayerNorm(child.normalized_shape, eps=child.eps)
|
||||
new_module.load_state_dict(child.state_dict())
|
||||
setattr(module, name, new_module)
|
||||
else:
|
||||
replace_layer_norm(child, child_name)
|
||||
|
||||
def is_atorch_available(raise_error=True, log=None):
|
||||
try:
|
||||
import atorch # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError as e:
|
||||
if raise_error is True:
|
||||
raise ImportError(e, log)
|
||||
else:
|
||||
return False
|
||||
|
||||
def _cast_if_autocast_enabled(*args):
|
||||
if not torch.is_autocast_enabled():
|
||||
return args
|
||||
else:
|
||||
return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())
|
||||
|
||||
|
||||
def _fused_dense_gelu_dense(input, weight1, bias1, weight2, bias2):
|
||||
batch, seq_length, hidden_size = input.size()
|
||||
input = input.view(batch * seq_length, hidden_size)
|
||||
args = _cast_if_autocast_enabled(input, weight1, bias1, weight2, bias2)
|
||||
from apex.fused_dense import FusedDenseGeluDenseFunc # with cast
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
out = FusedDenseGeluDenseFunc.apply(*args)
|
||||
out = out.view(batch, seq_length, -1)
|
||||
return out
|
||||
|
||||
|
||||
def linear_gelu_forward(input_, weight1, bias1, weight2, bias2):
|
||||
return _fused_dense_gelu_dense(input_, weight1, bias1, weight2, bias2)
|
||||
|
||||
|
||||
def replace_linear_gelu(module, cur_name: str):
|
||||
"""
|
||||
(layers): Sequential(
|
||||
(0): Sequential(
|
||||
(0): Linear(in_features=1536, out_features=6144, bias=True)
|
||||
(1): GELU()
|
||||
(2): Dropout(p=0.0, inplace=False)
|
||||
)
|
||||
(1): Linear(in_features=6144, out_features=1536, bias=True)
|
||||
(2): Dropout(p=0.0, inplace=False)
|
||||
)
|
||||
"""
|
||||
for name, child in module.named_children():
|
||||
child_name = cur_name + "." + name
|
||||
if isinstance(child, Sequential):
|
||||
if len(child) >= 2 and isinstance(
|
||||
child[0], Sequential
|
||||
) and isinstance(
|
||||
child[1], Linear
|
||||
) and len(child[0]
|
||||
) >= 2 and isinstance(
|
||||
child[0][0], Linear
|
||||
) and isinstance(
|
||||
child[0][1], GELU
|
||||
): # Sequential+Linear
|
||||
linear0 = child[0][0]
|
||||
linear1 = child[1]
|
||||
if getattr(child, "replace_linear_gelu", False):
|
||||
continue
|
||||
child.forward = lambda x: linear_gelu_forward(
|
||||
x, linear0.weight, linear0.bias, linear1.weight, linear1.bias)
|
||||
child.replace_linear_gelu = True
|
||||
print("REPLACE linear+gelu:%s" % child_name)
|
||||
# setattr(module, name, new_module)
|
||||
else:
|
||||
replace_linear_gelu(child, child_name)
|
||||
|
||||
|
||||
def replace_speedup_op(model, strategy=DefaultStrategy):
|
||||
if not is_atorch_available(raise_error=False):
|
||||
raise ImportError("Install Atorch/apex before using speedup op")
|
||||
if strategy.replace_mha:
|
||||
model = replace_module(model, MultiheadAttention, MultiheadAttentionFA, need_scr_module=True)
|
||||
if replace_transformer_bert:
|
||||
model = replace_module(model, BertAttention, BertAttentionFA, need_scr_module=True)
|
||||
root_name = model.__class__.__name__
|
||||
if strategy.replace_layernorm:
|
||||
replace_layer_norm(model, root_name) # inplace
|
||||
if strategy.replace_linear_gelu:
|
||||
replace_linear_gelu(model, root_name)
|
||||
return model
|
||||
|
||||
# TODO:
|
||||
# 1. SyncBatchNorm
|
||||
243
lib/utils/utils.py
Normal file
243
lib/utils/utils.py
Normal file
@@ -0,0 +1,243 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def cosine_scheduler(base_value,
|
||||
final_value,
|
||||
all_iters,
|
||||
warmup_iters=0,
|
||||
start_warmup_value=0):
|
||||
warmup_schedule = np.array([])
|
||||
if warmup_iters > 0:
|
||||
warmup_schedule = np.linspace(start_warmup_value, base_value,
|
||||
warmup_iters)
|
||||
|
||||
iters = np.arange(all_iters - warmup_iters)
|
||||
schedule = final_value + 0.5 * (base_value - final_value) * (
|
||||
1 + np.cos(np.pi * iters / len(iters)))
|
||||
|
||||
schedule = np.concatenate((warmup_schedule, schedule))
|
||||
assert len(schedule) == all_iters
|
||||
return schedule
|
||||
|
||||
|
||||
def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
|
||||
if epoch >= freeze_last_layer:
|
||||
return
|
||||
for n, p in model.named_parameters():
|
||||
if "last_layer" in n:
|
||||
p.grad = None
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=float)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
def cancel_gradients_backbone(iteration, model, freeze_backbone_steps):
|
||||
if iteration >= freeze_backbone_steps:
|
||||
return
|
||||
for n, p in model.named_parameters():
|
||||
if "backbon_hr" in n or 'backbon_s2' in n or 'head_s2' in n or 'fusion' in n or 'ctpe' in n:
|
||||
p.grad = None
|
||||
|
||||
|
||||
class EMA():
|
||||
|
||||
def __init__(self, model, decay):
|
||||
self.model = model
|
||||
self.decay = decay
|
||||
self.shadow = {}
|
||||
self.backup = {}
|
||||
|
||||
def register(self):
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.shadow[name] = param.data.clone()
|
||||
|
||||
def update(self):
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
assert name in self.shadow
|
||||
new_average = (1.0 - self.decay
|
||||
) * param.data + self.decay * self.shadow[name]
|
||||
self.shadow[name] = new_average.clone()
|
||||
|
||||
def apply_shadow(self):
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
assert name in self.shadow
|
||||
self.backup[name] = param.data
|
||||
param.data = self.shadow[name]
|
||||
|
||||
def restore(self):
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
assert name in self.backup
|
||||
param.data = self.backup[name]
|
||||
self.backup = {}
|
||||
|
||||
|
||||
class LayerDecayValueAssigner(object):
|
||||
|
||||
def __init__(self, layer_decay, num_layers, base_lr, net_type, arch='huge'):
|
||||
assert net_type in ['swin', 'vit']
|
||||
assert 0 < layer_decay <= 1
|
||||
depths_dict = {
|
||||
'tiny': [2, 2, 6, 2],
|
||||
'small': [2, 2, 18, 2],
|
||||
'base': [2, 2, 18, 2],
|
||||
'large': [2, 2, 18, 2],
|
||||
'huge': [2, 2, 18, 2],
|
||||
'giant': [2, 2, 42, 4],
|
||||
}
|
||||
num_layers = num_layers if net_type == 'vit' else sum(depths_dict[arch])
|
||||
self.layer_decay = layer_decay
|
||||
self.values = list(layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))
|
||||
self.depths = depths_dict[arch]
|
||||
self.base_lr = base_lr
|
||||
self.net_type = net_type
|
||||
|
||||
def get_num_layer_for_vit(self, var_name):
|
||||
if var_name in ("cls_token", "mask_token", "pos_embed"):
|
||||
return 0
|
||||
elif var_name.startswith("patch_embed"):
|
||||
return 0
|
||||
elif var_name.startswith("layers"):
|
||||
layer_id = int(var_name.split('.')[1])
|
||||
return layer_id + 1
|
||||
else:
|
||||
return len(self.values) - 1
|
||||
|
||||
def get_num_layer_for_swin(self, var_name):
|
||||
if var_name in ("mask_token", "pos_embed"):
|
||||
return 0
|
||||
elif var_name.startswith("patch_embed"):
|
||||
return 0
|
||||
elif var_name.startswith("stages"):
|
||||
layer_id = int(var_name.split('.')[1])
|
||||
if 'blocks' in var_name:
|
||||
block_id = int(var_name.split('.')[3])
|
||||
else:
|
||||
block_id = self.depths[layer_id] - 1
|
||||
layer_id = sum(self.depths[:layer_id]) + block_id
|
||||
return layer_id + 1
|
||||
else:
|
||||
return len(self.values) - 1
|
||||
|
||||
def get_layer_id(self, var_name):
|
||||
if self.net_type == 'swin':
|
||||
return self.get_num_layer_for_swin(var_name)
|
||||
if self.net_type == 'vit':
|
||||
return self.get_num_layer_for_vit(var_name)
|
||||
|
||||
def fix_param(self, model, num_block=4):
|
||||
if num_block < 1:
|
||||
return 0
|
||||
frozen_num = 0
|
||||
if self.net_type == 'swin':
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith("patch_embed"):
|
||||
param.requires_grad = False
|
||||
frozen_num += 1
|
||||
if name.startswith("stages") and self.get_layer_id(name) <= num_block:
|
||||
param.requires_grad = False
|
||||
frozen_num += 1
|
||||
if self.net_type == 'vit':
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith("patch_embed"):
|
||||
param.requires_grad = False
|
||||
frozen_num += 1
|
||||
if name.startswith("layers") and self.get_layer_id(name) <= num_block:
|
||||
param.requires_grad = False
|
||||
frozen_num += 1
|
||||
return frozen_num
|
||||
|
||||
def fix_param_deeper(self, model, num_block=4):
|
||||
if num_block < 1:
|
||||
return 0
|
||||
frozen_num = 0
|
||||
if self.net_type == 'swin':
|
||||
raise ValueError('Not Support')
|
||||
if self.net_type == 'vit':
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith("patch_embed"):
|
||||
param.requires_grad = False
|
||||
frozen_num += 1
|
||||
if name.startswith("layers") and self.get_layer_id(name) >= num_block:
|
||||
param.requires_grad = False
|
||||
frozen_num += 1
|
||||
return frozen_num
|
||||
|
||||
def get_parameter_groups(self, model, weight_decay):
|
||||
parameter_groups_with_wd, parameter_groups_without_wd = [], []
|
||||
print_info_with_wd, print_info_without_wd = [], []
|
||||
no_decay = [
|
||||
"absolute_pos_embed", "relative_position_bias_table", "norm", "bias"
|
||||
]
|
||||
if self.layer_decay == 1:
|
||||
parameter_groups_with_wd.append(
|
||||
{"params": [], "weight_decay": weight_decay, "lr": self.base_lr}
|
||||
)
|
||||
print_info_with_wd.append(
|
||||
{"params": [], "weight_decay": weight_decay, "lr": self.base_lr}
|
||||
)
|
||||
parameter_groups_without_wd.append(
|
||||
{"params": [], "weight_decay": 0, "lr": self.base_lr}
|
||||
)
|
||||
print_info_without_wd.append(
|
||||
{"params": [], "weight_decay": 0, "lr": self.base_lr}
|
||||
)
|
||||
else:
|
||||
for scale in self.values:
|
||||
parameter_groups_with_wd.append(
|
||||
{"params": [], "weight_decay": weight_decay, "lr": scale * self.base_lr}
|
||||
)
|
||||
print_info_with_wd.append(
|
||||
{"params": [], "weight_decay": weight_decay, "lr": scale * self.base_lr}
|
||||
)
|
||||
parameter_groups_without_wd.append(
|
||||
{"params": [], "weight_decay": 0, "lr": scale * self.base_lr}
|
||||
)
|
||||
print_info_without_wd.append(
|
||||
{"params": [], "weight_decay": 0, "lr": scale * self.base_lr}
|
||||
)
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
print(f'frozen param: {name}')
|
||||
continue # frozen weights
|
||||
layer_id = self.get_layer_id(name) if self.layer_decay < 1 else 0
|
||||
if any(nd in name for nd in no_decay):
|
||||
parameter_groups_without_wd[layer_id]['params'].append(param)
|
||||
print_info_without_wd[layer_id]['params'].append(name)
|
||||
else:
|
||||
parameter_groups_with_wd[layer_id]['params'].append(param)
|
||||
print_info_with_wd[layer_id]['params'].append(name)
|
||||
parameter_groups_with_wd = [x for x in parameter_groups_with_wd if len(x['params']) > 0]
|
||||
parameter_groups_without_wd = [x for x in parameter_groups_without_wd if len(x['params']) > 0]
|
||||
print_info_with_wd = [x for x in print_info_with_wd if len(x['params']) > 0]
|
||||
print_info_without_wd = [x for x in print_info_without_wd if len(x['params']) > 0]
|
||||
if self.layer_decay < 1:
|
||||
for wd, nwd in zip(print_info_with_wd, print_info_without_wd):
|
||||
print(wd)
|
||||
print(nwd)
|
||||
parameter_groups = []
|
||||
parameter_groups.extend(parameter_groups_with_wd)
|
||||
parameter_groups.extend(parameter_groups_without_wd)
|
||||
return parameter_groups
|
||||
|
||||
Reference in New Issue
Block a user