399 lines
16 KiB
Python
399 lines
16 KiB
Python
# Copyright (c) Ant Group. and its affiliates.
|
|
import gc
|
|
import math
|
|
from itertools import chain
|
|
|
|
import torch
|
|
from torch import nn
|
|
from tqdm import tqdm
|
|
|
|
from antmmf.common.registry import registry
|
|
from antmmf.common.report import Report
|
|
from antmmf.common.meter import Meter
|
|
from antmmf.modules.metrics import Metrics
|
|
from antmmf.optimizer.combine_optimizers import CombinedOptimizer
|
|
from antmmf.utils.distributed_utils import (broadcast_scalar, is_main_process)
|
|
from antmmf.utils.early_stopping import EarlyStopping
|
|
from antmmf.utils.general import clip_gradients, count_parameters, nullcontext
|
|
from antmmf.utils.timer import Timer
|
|
from antmmf.trainers.base_trainer import BaseTrainer
|
|
|
|
from lib.utils.utils import cancel_gradients_backbone, EMA
|
|
from lib.utils.checkpoint import SegCheckpoint
|
|
|
|
try:
|
|
import atorch
|
|
from atorch import amp
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
@registry.register_trainer("seg_trainer")
|
|
class SEGTrainer(BaseTrainer):
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.enable_torch_amp=True
|
|
self.enable_atorch_amp=False
|
|
|
|
def load(self, has_check_point=True):
|
|
super().load(has_check_point)
|
|
torch.backends.cuda.matmul.allow_tf32 = self.config.training_parameters.get(
|
|
"enable_tf32", False)
|
|
if hasattr(
|
|
self.config.training_parameters, "freeze_backbone"
|
|
) and self.config.training_parameters.freeze_backbone is True:
|
|
for n, p in self.model.named_parameters():
|
|
if "backbone_hr." in n or 'backbone_s2.' in n or 'head_s2.' in n or 'backbone_s1.' in n or 'head_s1.' in n or 'fusion.' in n or 'ctpe' in n or 'glbank.' in n:
|
|
p.requires_grad = False
|
|
else:
|
|
print(n, '-->', p.requires_grad)
|
|
if hasattr(self.config.training_parameters,
|
|
"ema") and self.config.training_parameters.ema is True:
|
|
self.ema = EMA(self.model, 0.96)
|
|
self.ema.register()
|
|
|
|
def load_extras(self, has_check_point=True):
|
|
self.checkpoint = None if has_check_point is False else SegCheckpoint(
|
|
self)
|
|
self.meter = Meter()
|
|
|
|
self.training_parameters = self.config.training_parameters
|
|
|
|
monitored_metric = self.training_parameters.monitored_metric
|
|
metric_minimize = self.training_parameters.metric_minimize
|
|
should_early_stop = self.training_parameters.should_early_stop
|
|
patience = self.training_parameters.patience
|
|
|
|
self.log_interval = self.training_parameters.log_interval
|
|
self.snapshot_interval = self.training_parameters.snapshot_interval
|
|
self.max_iterations = self.training_parameters.max_iterations
|
|
self.should_clip_gradients = self.training_parameters.clip_gradients
|
|
self.max_epochs = self.training_parameters.max_epochs
|
|
self.gradient_accumulation_steps = int(
|
|
self.training_parameters.gradient_accumulation_steps)
|
|
assert self.gradient_accumulation_steps >= 1
|
|
for t_type in self.task_loader.task_type:
|
|
if t_type == "train":
|
|
self.dataset_train_order = self.training_parameters.get(
|
|
"dataset_train_order", self.train_task.datasets_name)
|
|
if t_type == "val":
|
|
self.dataset_val_order = self.training_parameters.get(
|
|
"dataset_val_order", self.val_task.datasets_name)
|
|
if t_type == "test":
|
|
self.dataset_test_order = self.training_parameters.get(
|
|
"dataset_test_order", self.test_task.datasets_name)
|
|
if t_type == "interpret":
|
|
self.dataset_interpret_order = self.training_parameters.get(
|
|
"dataset_interpret_order",
|
|
self.interpret_task.datasets_name)
|
|
|
|
self.early_stopping = EarlyStopping(
|
|
self.model,
|
|
self.checkpoint,
|
|
monitored_metric,
|
|
patience=patience,
|
|
minimize=metric_minimize,
|
|
should_stop=should_early_stop,
|
|
)
|
|
self.current_epoch = 1
|
|
self.current_iteration = 0
|
|
|
|
self.not_debug = self.training_parameters.logger_level != "debug"
|
|
|
|
self.lr_scheduler = None
|
|
self.setup_lr_scheduler()
|
|
|
|
if self.checkpoint is not None:
|
|
self.checkpoint.load_state_dict()
|
|
|
|
if "overall_metrics" in self.training_parameters:
|
|
self.overall_metric_evaluator = Metrics(
|
|
self.config.training_parameters.get("overall_metrics", []))
|
|
self.synchronized_loss = self.config.training_parameters.synchronized_loss
|
|
|
|
def train(self):
|
|
self.writer.write("===== Model =====")
|
|
self.writer.write(self.model)
|
|
self.writer.write(
|
|
"Model Params: Trainable {Trainable:.3f}M Total {Total:.3f}M".
|
|
format(**count_parameters(self.model)))
|
|
|
|
if "train" not in self.run_type:
|
|
self.inference()
|
|
return
|
|
|
|
should_break = False
|
|
|
|
if self.max_epochs is None:
|
|
self.max_epochs = math.inf
|
|
else:
|
|
self.max_iterations = min(self.max_iterations,
|
|
self.max_epochs * self.epoch_iterations)
|
|
|
|
self.model.train()
|
|
self.train_timer = Timer()
|
|
|
|
self.profile("Setup Time")
|
|
|
|
if self.enable_torch_amp:
|
|
self.writer.write("Using Automatic mixed precision training")
|
|
if hasattr(self.config, "amp_attributes") and hasattr(
|
|
self.config.amp_attributes, "growth_interval"):
|
|
growth_interval = self.config.amp_attributes.growth_interval
|
|
else:
|
|
growth_interval = 2000
|
|
self.scaler = torch.cuda.amp.GradScaler(
|
|
init_scale=self.config.amp_attributes.init_scale,
|
|
enabled=False,
|
|
growth_interval=growth_interval)
|
|
self.writer.write("Using Init scale:%s" %
|
|
self.config.amp_attributes.init_scale)
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
self.writer.write("Starting training...")
|
|
while self.current_iteration < self.max_iterations and not should_break:
|
|
registry.register("current_epoch", self.current_epoch)
|
|
self.task_loader.seed_sampler("train", self.current_epoch)
|
|
|
|
if self.current_epoch > self.max_epochs:
|
|
break
|
|
|
|
for batch in tqdm(
|
|
chain(*self.train_loader_list),
|
|
total=self._len_of_loader_list(self.train_loader_list),
|
|
disable=self.disable_tqdm or (not is_main_process()),
|
|
):
|
|
self.profile("Batch load time")
|
|
report, _, _ = self._forward_pass(
|
|
batch, enable_amp=self.enable_torch_amp)
|
|
if report is None:
|
|
continue
|
|
|
|
self._update_meter(report, self.meter)
|
|
|
|
loss = self._extract_loss(report)
|
|
self._backward(loss)
|
|
if hasattr(
|
|
self.config.training_parameters,
|
|
"ema") and self.config.training_parameters.ema is True:
|
|
self.ema.update()
|
|
should_break = self._logistics()
|
|
|
|
self._run_scheduler()
|
|
|
|
self.current_iteration += 1
|
|
self.writer.write(self.current_iteration, "debug")
|
|
registry.register("current_iteration", self.current_iteration)
|
|
if self.current_iteration >= self.max_iterations:
|
|
break
|
|
if should_break:
|
|
break
|
|
|
|
self.current_epoch += 1
|
|
|
|
self.finalize()
|
|
|
|
def _forward_pass(self, batch, enable_amp=False):
|
|
if not batch: # Samplelist might be empty dict
|
|
return None, None, None
|
|
prepared_batch = self.task_loader.prepare_batch(batch)
|
|
|
|
self.profile("Batch prepare time")
|
|
forward_context = torch.cuda.amp.autocast(
|
|
enabled=True,
|
|
dtype=torch.bfloat16) if enable_amp else nullcontext()
|
|
|
|
with forward_context:
|
|
# Arguments should be a dict at this point
|
|
model_output = self.model(prepared_batch)
|
|
|
|
if self.synchronized_loss:
|
|
is_parallel = isinstance(
|
|
self.model, nn.DataParallel) or isinstance(
|
|
self.model, nn.parallel.DistributedDataParallel)
|
|
if "losses" not in model_output:
|
|
loss_func = getattr(
|
|
self.model.module if is_parallel else self.model,
|
|
"losses")
|
|
model_output["losses"] = loss_func(
|
|
prepared_batch,
|
|
model_output,
|
|
iteration=self.current_iteration)
|
|
if "metrics" not in model_output:
|
|
metric_func = getattr(
|
|
self.model.module if is_parallel else self.model,
|
|
"metrics")
|
|
model_output["metrics"] = metric_func(
|
|
prepared_batch, model_output)
|
|
|
|
report = Report(prepared_batch, model_output)
|
|
self.profile("Forward time")
|
|
|
|
return report, model_output, prepared_batch
|
|
|
|
def _backward(self, loss):
|
|
loss = loss / self.gradient_accumulation_steps
|
|
|
|
if self.enable_torch_amp:
|
|
self.scaler.scale(loss).backward()
|
|
|
|
# Unscales the gradients of optimizer's assigned params in-place, this should
|
|
# be called first so that clip_gradients can take effect as usual.
|
|
self.scaler.unscale_(self.optimizer)
|
|
elif self.enable_atorch_amp:
|
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
|
scaled_loss.backward()
|
|
else:
|
|
loss.backward()
|
|
self.profile("Backward time")
|
|
|
|
if self.current_iteration % self.gradient_accumulation_steps != 0:
|
|
return
|
|
|
|
if self.should_clip_gradients:
|
|
if self.enable_atorch_amp:
|
|
clip_gradients(amp.master_params(self.optimizer),
|
|
self.current_iteration, self.writer,
|
|
self.config)
|
|
else:
|
|
clip_gradients(self.model, self.current_iteration, self.writer,
|
|
self.config)
|
|
|
|
if hasattr(
|
|
self.config.training_parameters, "freeze_backbone_steps"
|
|
) and self.config.training_parameters.freeze_backbone_steps is not None:
|
|
cancel_gradients_backbone(
|
|
self.current_iteration, self.model,
|
|
self.config.training_parameters.freeze_backbone_steps)
|
|
|
|
if self.enable_torch_amp:
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
else:
|
|
self.optimizer.step()
|
|
|
|
self.optimizer.zero_grad()
|
|
self.profile("Optimizer time")
|
|
|
|
def _logistics(self):
|
|
should_print = self.current_iteration and self.current_iteration % self.log_interval == 0
|
|
extra = {}
|
|
prefix = ""
|
|
|
|
if should_print is True:
|
|
if "cuda" in str(self.device):
|
|
extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
|
|
extra["max mem"] //= 1024
|
|
|
|
# display lr
|
|
if isinstance(self.optimizer, CombinedOptimizer):
|
|
extra["lr"] = self.optimizer.get_optimizers_lr_str()
|
|
else:
|
|
extra["lr"] = "|".join([
|
|
"{:.8f}".format(x["lr"]).rstrip("0")
|
|
for x in self.optimizer.param_groups
|
|
])
|
|
|
|
extra.update({
|
|
"time": self.train_timer.get_time_since_start(),
|
|
"eta": self._calculate_time_left(),
|
|
})
|
|
|
|
self.train_timer.reset()
|
|
|
|
self._summarize_meter(
|
|
self.meter,
|
|
prefix=prefix,
|
|
extra=extra,
|
|
should_print=should_print,
|
|
)
|
|
|
|
should_break = self._try_full_validation()
|
|
|
|
return should_break
|
|
|
|
def _try_full_validation(self, force=False):
|
|
should_break = False
|
|
|
|
if self.current_iteration and self.current_iteration % self.snapshot_interval == 0 or force:
|
|
self.writer.write(
|
|
"Evaluation time. Running on full validation set...")
|
|
|
|
validation_timer = Timer()
|
|
dataset_name, meter = self.evaluate_set(self.val_loader_list)
|
|
extra = {
|
|
"validation time": validation_timer.get_time_since_start()
|
|
}
|
|
|
|
overall_metric = self.overall_metric_evaluator.summarize()
|
|
stop = self.early_stopping(self.current_iteration, overall_metric,
|
|
meter)
|
|
if hasattr(self.config.training_parameters,
|
|
"ema") and self.config.training_parameters.ema is True:
|
|
self.ema.restore()
|
|
stop = bool(broadcast_scalar(stop, src=0, device=self.device))
|
|
|
|
extra.update(self.early_stopping.get_info())
|
|
|
|
prefix = "{}: full val".format(dataset_name)
|
|
self._summarize_overall(overall_metric,
|
|
meter,
|
|
prefix=prefix,
|
|
extra=extra)
|
|
gc.collect()
|
|
|
|
if "cuda" in str(self.device):
|
|
with torch.cuda.device(self.device):
|
|
torch.cuda.empty_cache()
|
|
|
|
if stop > 0: # `stop` is now `int`, NCCL does not support `boolean` type's broadcasting
|
|
self.writer.write("Early stopping activated")
|
|
should_break = True
|
|
|
|
return should_break
|
|
|
|
def evaluate_set(self, loader_list):
|
|
from antmmf.structures import SampleList
|
|
|
|
meter = Meter()
|
|
torch.cuda.empty_cache()
|
|
with torch.no_grad():
|
|
self.model.eval()
|
|
if hasattr(self.config.training_parameters,
|
|
"ema") and self.config.training_parameters.ema is True:
|
|
self.ema.apply_shadow()
|
|
if self.config.training_parameters.get('fp16', False):
|
|
self.model.half()
|
|
self.overall_metric_evaluator.reset()
|
|
for idx, batch in tqdm(
|
|
enumerate(chain(*loader_list)),
|
|
total=self._len_of_loader_list(loader_list),
|
|
disable=not is_main_process() or self.disable_tqdm,
|
|
):
|
|
# report, model_output, prepared_batch = self._forward_pass(
|
|
# batch, enable_amp=self.enable_torch_amp)
|
|
if idx >= self.config.training_parameters.get('num_eval', 1e7):
|
|
break
|
|
if self.config.training_parameters.get('fp16', False):
|
|
input_dict = SampleList()
|
|
for k, v in batch.items():
|
|
if isinstance(v, torch.cuda.FloatTensor):
|
|
input_dict[k] = v.half()
|
|
else:
|
|
input_dict[k] = v
|
|
report, model_output, prepared_batch = self._forward_pass(
|
|
input_dict, enable_amp=self.enable_torch_amp)
|
|
else:
|
|
report, model_output, prepared_batch = self._forward_pass(
|
|
batch, enable_amp=self.enable_torch_amp)
|
|
self._update_meter(report, meter)
|
|
self.overall_metric_evaluator.collect(prepared_batch,
|
|
model_output)
|
|
for _, metric_object in self.overall_metric_evaluator.metrics.items(
|
|
):
|
|
metric_object.all_reduce()
|
|
self.model.train()
|
|
|
|
return report.dataset_name, meter |