This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

122
lib/utils/optim_utils.py Normal file
View File

@@ -0,0 +1,122 @@
import torch
from torch.nn import LayerNorm, Linear, GELU
from torch.nn import MultiheadAttention, Sequential
import warnings
try:
from atorch.normalization import LayerNorm as FastLayerNorm
from atorch.modules.transformer.inject import replace_module
from atorch.modules.transformer.layers import MultiheadAttentionFA, BertAttentionFA
except (ImportError, ModuleNotFoundError) as e:
warnings.warn("Using replace_speedup_op but no atorch/apex installed:%s" % e)
try:
from transformers.models.bert.modeling_bert import BertAttention
replace_transformer_bert = True
except ImportError:
replace_transformer_bert = False
class DefaultStrategy:
replace_mha = True
replace_layernorm = True
replace_linear_gelu = False # TODO: numerical consistency
def replace_layer_norm(module: torch.nn.Module, cur_name: str):
for name, child in module.named_children():
child_name = cur_name + "." + name
if isinstance(child, LayerNorm):
new_module = FastLayerNorm(child.normalized_shape, eps=child.eps)
new_module.load_state_dict(child.state_dict())
setattr(module, name, new_module)
else:
replace_layer_norm(child, child_name)
def is_atorch_available(raise_error=True, log=None):
try:
import atorch # noqa: F401
return True
except ImportError as e:
if raise_error is True:
raise ImportError(e, log)
else:
return False
def _cast_if_autocast_enabled(*args):
if not torch.is_autocast_enabled():
return args
else:
return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())
def _fused_dense_gelu_dense(input, weight1, bias1, weight2, bias2):
batch, seq_length, hidden_size = input.size()
input = input.view(batch * seq_length, hidden_size)
args = _cast_if_autocast_enabled(input, weight1, bias1, weight2, bias2)
from apex.fused_dense import FusedDenseGeluDenseFunc # with cast
with torch.cuda.amp.autocast(enabled=False):
out = FusedDenseGeluDenseFunc.apply(*args)
out = out.view(batch, seq_length, -1)
return out
def linear_gelu_forward(input_, weight1, bias1, weight2, bias2):
return _fused_dense_gelu_dense(input_, weight1, bias1, weight2, bias2)
def replace_linear_gelu(module, cur_name: str):
"""
(layers): Sequential(
(0): Sequential(
(0): Linear(in_features=1536, out_features=6144, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
)
(1): Linear(in_features=6144, out_features=1536, bias=True)
(2): Dropout(p=0.0, inplace=False)
)
"""
for name, child in module.named_children():
child_name = cur_name + "." + name
if isinstance(child, Sequential):
if len(child) >= 2 and isinstance(
child[0], Sequential
) and isinstance(
child[1], Linear
) and len(child[0]
) >= 2 and isinstance(
child[0][0], Linear
) and isinstance(
child[0][1], GELU
): # Sequential+Linear
linear0 = child[0][0]
linear1 = child[1]
if getattr(child, "replace_linear_gelu", False):
continue
child.forward = lambda x: linear_gelu_forward(
x, linear0.weight, linear0.bias, linear1.weight, linear1.bias)
child.replace_linear_gelu = True
print("REPLACE linear+gelu:%s" % child_name)
# setattr(module, name, new_module)
else:
replace_linear_gelu(child, child_name)
def replace_speedup_op(model, strategy=DefaultStrategy):
if not is_atorch_available(raise_error=False):
raise ImportError("Install Atorch/apex before using speedup op")
if strategy.replace_mha:
model = replace_module(model, MultiheadAttention, MultiheadAttentionFA, need_scr_module=True)
if replace_transformer_bert:
model = replace_module(model, BertAttention, BertAttentionFA, need_scr_module=True)
root_name = model.__class__.__name__
if strategy.replace_layernorm:
replace_layer_norm(model, root_name) # inplace
if strategy.replace_linear_gelu:
replace_linear_gelu(model, root_name)
return model
# TODO:
# 1. SyncBatchNorm