123 lines
4.4 KiB
Python
123 lines
4.4 KiB
Python
import torch
|
|
from torch.nn import LayerNorm, Linear, GELU
|
|
from torch.nn import MultiheadAttention, Sequential
|
|
import warnings
|
|
try:
|
|
from atorch.normalization import LayerNorm as FastLayerNorm
|
|
from atorch.modules.transformer.inject import replace_module
|
|
from atorch.modules.transformer.layers import MultiheadAttentionFA, BertAttentionFA
|
|
except (ImportError, ModuleNotFoundError) as e:
|
|
warnings.warn("Using replace_speedup_op but no atorch/apex installed:%s" % e)
|
|
try:
|
|
from transformers.models.bert.modeling_bert import BertAttention
|
|
replace_transformer_bert = True
|
|
|
|
except ImportError:
|
|
replace_transformer_bert = False
|
|
|
|
|
|
class DefaultStrategy:
|
|
replace_mha = True
|
|
replace_layernorm = True
|
|
replace_linear_gelu = False # TODO: numerical consistency
|
|
|
|
|
|
def replace_layer_norm(module: torch.nn.Module, cur_name: str):
|
|
|
|
for name, child in module.named_children():
|
|
child_name = cur_name + "." + name
|
|
if isinstance(child, LayerNorm):
|
|
new_module = FastLayerNorm(child.normalized_shape, eps=child.eps)
|
|
new_module.load_state_dict(child.state_dict())
|
|
setattr(module, name, new_module)
|
|
else:
|
|
replace_layer_norm(child, child_name)
|
|
|
|
def is_atorch_available(raise_error=True, log=None):
|
|
try:
|
|
import atorch # noqa: F401
|
|
|
|
return True
|
|
except ImportError as e:
|
|
if raise_error is True:
|
|
raise ImportError(e, log)
|
|
else:
|
|
return False
|
|
|
|
def _cast_if_autocast_enabled(*args):
|
|
if not torch.is_autocast_enabled():
|
|
return args
|
|
else:
|
|
return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())
|
|
|
|
|
|
def _fused_dense_gelu_dense(input, weight1, bias1, weight2, bias2):
|
|
batch, seq_length, hidden_size = input.size()
|
|
input = input.view(batch * seq_length, hidden_size)
|
|
args = _cast_if_autocast_enabled(input, weight1, bias1, weight2, bias2)
|
|
from apex.fused_dense import FusedDenseGeluDenseFunc # with cast
|
|
with torch.cuda.amp.autocast(enabled=False):
|
|
out = FusedDenseGeluDenseFunc.apply(*args)
|
|
out = out.view(batch, seq_length, -1)
|
|
return out
|
|
|
|
|
|
def linear_gelu_forward(input_, weight1, bias1, weight2, bias2):
|
|
return _fused_dense_gelu_dense(input_, weight1, bias1, weight2, bias2)
|
|
|
|
|
|
def replace_linear_gelu(module, cur_name: str):
|
|
"""
|
|
(layers): Sequential(
|
|
(0): Sequential(
|
|
(0): Linear(in_features=1536, out_features=6144, bias=True)
|
|
(1): GELU()
|
|
(2): Dropout(p=0.0, inplace=False)
|
|
)
|
|
(1): Linear(in_features=6144, out_features=1536, bias=True)
|
|
(2): Dropout(p=0.0, inplace=False)
|
|
)
|
|
"""
|
|
for name, child in module.named_children():
|
|
child_name = cur_name + "." + name
|
|
if isinstance(child, Sequential):
|
|
if len(child) >= 2 and isinstance(
|
|
child[0], Sequential
|
|
) and isinstance(
|
|
child[1], Linear
|
|
) and len(child[0]
|
|
) >= 2 and isinstance(
|
|
child[0][0], Linear
|
|
) and isinstance(
|
|
child[0][1], GELU
|
|
): # Sequential+Linear
|
|
linear0 = child[0][0]
|
|
linear1 = child[1]
|
|
if getattr(child, "replace_linear_gelu", False):
|
|
continue
|
|
child.forward = lambda x: linear_gelu_forward(
|
|
x, linear0.weight, linear0.bias, linear1.weight, linear1.bias)
|
|
child.replace_linear_gelu = True
|
|
print("REPLACE linear+gelu:%s" % child_name)
|
|
# setattr(module, name, new_module)
|
|
else:
|
|
replace_linear_gelu(child, child_name)
|
|
|
|
|
|
def replace_speedup_op(model, strategy=DefaultStrategy):
|
|
if not is_atorch_available(raise_error=False):
|
|
raise ImportError("Install Atorch/apex before using speedup op")
|
|
if strategy.replace_mha:
|
|
model = replace_module(model, MultiheadAttention, MultiheadAttentionFA, need_scr_module=True)
|
|
if replace_transformer_bert:
|
|
model = replace_module(model, BertAttention, BertAttentionFA, need_scr_module=True)
|
|
root_name = model.__class__.__name__
|
|
if strategy.replace_layernorm:
|
|
replace_layer_norm(model, root_name) # inplace
|
|
if strategy.replace_linear_gelu:
|
|
replace_linear_gelu(model, root_name)
|
|
return model
|
|
|
|
# TODO:
|
|
# 1. SyncBatchNorm
|