This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

399
lib/trainer/seg_trainer.py Normal file
View File

@@ -0,0 +1,399 @@
# Copyright (c) Ant Group. and its affiliates.
import gc
import math
from itertools import chain
import torch
from torch import nn
from tqdm import tqdm
from antmmf.common.registry import registry
from antmmf.common.report import Report
from antmmf.common.meter import Meter
from antmmf.modules.metrics import Metrics
from antmmf.optimizer.combine_optimizers import CombinedOptimizer
from antmmf.utils.distributed_utils import (broadcast_scalar, is_main_process)
from antmmf.utils.early_stopping import EarlyStopping
from antmmf.utils.general import clip_gradients, count_parameters, nullcontext
from antmmf.utils.timer import Timer
from antmmf.trainers.base_trainer import BaseTrainer
from lib.utils.utils import cancel_gradients_backbone, EMA
from lib.utils.checkpoint import SegCheckpoint
try:
import atorch
from atorch import amp
except ImportError:
pass
@registry.register_trainer("seg_trainer")
class SEGTrainer(BaseTrainer):
def __init__(self, config):
super().__init__(config)
self.enable_torch_amp=True
self.enable_atorch_amp=False
def load(self, has_check_point=True):
super().load(has_check_point)
torch.backends.cuda.matmul.allow_tf32 = self.config.training_parameters.get(
"enable_tf32", False)
if hasattr(
self.config.training_parameters, "freeze_backbone"
) and self.config.training_parameters.freeze_backbone is True:
for n, p in self.model.named_parameters():
if "backbone_hr." in n or 'backbone_s2.' in n or 'head_s2.' in n or 'backbone_s1.' in n or 'head_s1.' in n or 'fusion.' in n or 'ctpe' in n or 'glbank.' in n:
p.requires_grad = False
else:
print(n, '-->', p.requires_grad)
if hasattr(self.config.training_parameters,
"ema") and self.config.training_parameters.ema is True:
self.ema = EMA(self.model, 0.96)
self.ema.register()
def load_extras(self, has_check_point=True):
self.checkpoint = None if has_check_point is False else SegCheckpoint(
self)
self.meter = Meter()
self.training_parameters = self.config.training_parameters
monitored_metric = self.training_parameters.monitored_metric
metric_minimize = self.training_parameters.metric_minimize
should_early_stop = self.training_parameters.should_early_stop
patience = self.training_parameters.patience
self.log_interval = self.training_parameters.log_interval
self.snapshot_interval = self.training_parameters.snapshot_interval
self.max_iterations = self.training_parameters.max_iterations
self.should_clip_gradients = self.training_parameters.clip_gradients
self.max_epochs = self.training_parameters.max_epochs
self.gradient_accumulation_steps = int(
self.training_parameters.gradient_accumulation_steps)
assert self.gradient_accumulation_steps >= 1
for t_type in self.task_loader.task_type:
if t_type == "train":
self.dataset_train_order = self.training_parameters.get(
"dataset_train_order", self.train_task.datasets_name)
if t_type == "val":
self.dataset_val_order = self.training_parameters.get(
"dataset_val_order", self.val_task.datasets_name)
if t_type == "test":
self.dataset_test_order = self.training_parameters.get(
"dataset_test_order", self.test_task.datasets_name)
if t_type == "interpret":
self.dataset_interpret_order = self.training_parameters.get(
"dataset_interpret_order",
self.interpret_task.datasets_name)
self.early_stopping = EarlyStopping(
self.model,
self.checkpoint,
monitored_metric,
patience=patience,
minimize=metric_minimize,
should_stop=should_early_stop,
)
self.current_epoch = 1
self.current_iteration = 0
self.not_debug = self.training_parameters.logger_level != "debug"
self.lr_scheduler = None
self.setup_lr_scheduler()
if self.checkpoint is not None:
self.checkpoint.load_state_dict()
if "overall_metrics" in self.training_parameters:
self.overall_metric_evaluator = Metrics(
self.config.training_parameters.get("overall_metrics", []))
self.synchronized_loss = self.config.training_parameters.synchronized_loss
def train(self):
self.writer.write("===== Model =====")
self.writer.write(self.model)
self.writer.write(
"Model Params: Trainable {Trainable:.3f}M Total {Total:.3f}M".
format(**count_parameters(self.model)))
if "train" not in self.run_type:
self.inference()
return
should_break = False
if self.max_epochs is None:
self.max_epochs = math.inf
else:
self.max_iterations = min(self.max_iterations,
self.max_epochs * self.epoch_iterations)
self.model.train()
self.train_timer = Timer()
self.profile("Setup Time")
if self.enable_torch_amp:
self.writer.write("Using Automatic mixed precision training")
if hasattr(self.config, "amp_attributes") and hasattr(
self.config.amp_attributes, "growth_interval"):
growth_interval = self.config.amp_attributes.growth_interval
else:
growth_interval = 2000
self.scaler = torch.cuda.amp.GradScaler(
init_scale=self.config.amp_attributes.init_scale,
enabled=False,
growth_interval=growth_interval)
self.writer.write("Using Init scale:%s" %
self.config.amp_attributes.init_scale)
self.optimizer.zero_grad()
self.writer.write("Starting training...")
while self.current_iteration < self.max_iterations and not should_break:
registry.register("current_epoch", self.current_epoch)
self.task_loader.seed_sampler("train", self.current_epoch)
if self.current_epoch > self.max_epochs:
break
for batch in tqdm(
chain(*self.train_loader_list),
total=self._len_of_loader_list(self.train_loader_list),
disable=self.disable_tqdm or (not is_main_process()),
):
self.profile("Batch load time")
report, _, _ = self._forward_pass(
batch, enable_amp=self.enable_torch_amp)
if report is None:
continue
self._update_meter(report, self.meter)
loss = self._extract_loss(report)
self._backward(loss)
if hasattr(
self.config.training_parameters,
"ema") and self.config.training_parameters.ema is True:
self.ema.update()
should_break = self._logistics()
self._run_scheduler()
self.current_iteration += 1
self.writer.write(self.current_iteration, "debug")
registry.register("current_iteration", self.current_iteration)
if self.current_iteration >= self.max_iterations:
break
if should_break:
break
self.current_epoch += 1
self.finalize()
def _forward_pass(self, batch, enable_amp=False):
if not batch: # Samplelist might be empty dict
return None, None, None
prepared_batch = self.task_loader.prepare_batch(batch)
self.profile("Batch prepare time")
forward_context = torch.cuda.amp.autocast(
enabled=True,
dtype=torch.bfloat16) if enable_amp else nullcontext()
with forward_context:
# Arguments should be a dict at this point
model_output = self.model(prepared_batch)
if self.synchronized_loss:
is_parallel = isinstance(
self.model, nn.DataParallel) or isinstance(
self.model, nn.parallel.DistributedDataParallel)
if "losses" not in model_output:
loss_func = getattr(
self.model.module if is_parallel else self.model,
"losses")
model_output["losses"] = loss_func(
prepared_batch,
model_output,
iteration=self.current_iteration)
if "metrics" not in model_output:
metric_func = getattr(
self.model.module if is_parallel else self.model,
"metrics")
model_output["metrics"] = metric_func(
prepared_batch, model_output)
report = Report(prepared_batch, model_output)
self.profile("Forward time")
return report, model_output, prepared_batch
def _backward(self, loss):
loss = loss / self.gradient_accumulation_steps
if self.enable_torch_amp:
self.scaler.scale(loss).backward()
# Unscales the gradients of optimizer's assigned params in-place, this should
# be called first so that clip_gradients can take effect as usual.
self.scaler.unscale_(self.optimizer)
elif self.enable_atorch_amp:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
self.profile("Backward time")
if self.current_iteration % self.gradient_accumulation_steps != 0:
return
if self.should_clip_gradients:
if self.enable_atorch_amp:
clip_gradients(amp.master_params(self.optimizer),
self.current_iteration, self.writer,
self.config)
else:
clip_gradients(self.model, self.current_iteration, self.writer,
self.config)
if hasattr(
self.config.training_parameters, "freeze_backbone_steps"
) and self.config.training_parameters.freeze_backbone_steps is not None:
cancel_gradients_backbone(
self.current_iteration, self.model,
self.config.training_parameters.freeze_backbone_steps)
if self.enable_torch_amp:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
self.profile("Optimizer time")
def _logistics(self):
should_print = self.current_iteration and self.current_iteration % self.log_interval == 0
extra = {}
prefix = ""
if should_print is True:
if "cuda" in str(self.device):
extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
extra["max mem"] //= 1024
# display lr
if isinstance(self.optimizer, CombinedOptimizer):
extra["lr"] = self.optimizer.get_optimizers_lr_str()
else:
extra["lr"] = "|".join([
"{:.8f}".format(x["lr"]).rstrip("0")
for x in self.optimizer.param_groups
])
extra.update({
"time": self.train_timer.get_time_since_start(),
"eta": self._calculate_time_left(),
})
self.train_timer.reset()
self._summarize_meter(
self.meter,
prefix=prefix,
extra=extra,
should_print=should_print,
)
should_break = self._try_full_validation()
return should_break
def _try_full_validation(self, force=False):
should_break = False
if self.current_iteration and self.current_iteration % self.snapshot_interval == 0 or force:
self.writer.write(
"Evaluation time. Running on full validation set...")
validation_timer = Timer()
dataset_name, meter = self.evaluate_set(self.val_loader_list)
extra = {
"validation time": validation_timer.get_time_since_start()
}
overall_metric = self.overall_metric_evaluator.summarize()
stop = self.early_stopping(self.current_iteration, overall_metric,
meter)
if hasattr(self.config.training_parameters,
"ema") and self.config.training_parameters.ema is True:
self.ema.restore()
stop = bool(broadcast_scalar(stop, src=0, device=self.device))
extra.update(self.early_stopping.get_info())
prefix = "{}: full val".format(dataset_name)
self._summarize_overall(overall_metric,
meter,
prefix=prefix,
extra=extra)
gc.collect()
if "cuda" in str(self.device):
with torch.cuda.device(self.device):
torch.cuda.empty_cache()
if stop > 0: # `stop` is now `int`, NCCL does not support `boolean` type's broadcasting
self.writer.write("Early stopping activated")
should_break = True
return should_break
def evaluate_set(self, loader_list):
from antmmf.structures import SampleList
meter = Meter()
torch.cuda.empty_cache()
with torch.no_grad():
self.model.eval()
if hasattr(self.config.training_parameters,
"ema") and self.config.training_parameters.ema is True:
self.ema.apply_shadow()
if self.config.training_parameters.get('fp16', False):
self.model.half()
self.overall_metric_evaluator.reset()
for idx, batch in tqdm(
enumerate(chain(*loader_list)),
total=self._len_of_loader_list(loader_list),
disable=not is_main_process() or self.disable_tqdm,
):
# report, model_output, prepared_batch = self._forward_pass(
# batch, enable_amp=self.enable_torch_amp)
if idx >= self.config.training_parameters.get('num_eval', 1e7):
break
if self.config.training_parameters.get('fp16', False):
input_dict = SampleList()
for k, v in batch.items():
if isinstance(v, torch.cuda.FloatTensor):
input_dict[k] = v.half()
else:
input_dict[k] = v
report, model_output, prepared_batch = self._forward_pass(
input_dict, enable_amp=self.enable_torch_amp)
else:
report, model_output, prepared_batch = self._forward_pass(
batch, enable_amp=self.enable_torch_amp)
self._update_meter(report, meter)
self.overall_metric_evaluator.collect(prepared_batch,
model_output)
for _, metric_object in self.overall_metric_evaluator.metrics.items(
):
metric_object.all_reduce()
self.model.train()
return report.dataset_name, meter