init
This commit is contained in:
4
lib/trainer/__init__.py
Normal file
4
lib/trainer/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .seg_trainer import SEGTrainer
|
||||
|
||||
__all__ = ['SEGTrainer']
|
||||
|
||||
399
lib/trainer/seg_trainer.py
Normal file
399
lib/trainer/seg_trainer.py
Normal file
@@ -0,0 +1,399 @@
|
||||
# Copyright (c) Ant Group. and its affiliates.
|
||||
import gc
|
||||
import math
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from antmmf.common.registry import registry
|
||||
from antmmf.common.report import Report
|
||||
from antmmf.common.meter import Meter
|
||||
from antmmf.modules.metrics import Metrics
|
||||
from antmmf.optimizer.combine_optimizers import CombinedOptimizer
|
||||
from antmmf.utils.distributed_utils import (broadcast_scalar, is_main_process)
|
||||
from antmmf.utils.early_stopping import EarlyStopping
|
||||
from antmmf.utils.general import clip_gradients, count_parameters, nullcontext
|
||||
from antmmf.utils.timer import Timer
|
||||
from antmmf.trainers.base_trainer import BaseTrainer
|
||||
|
||||
from lib.utils.utils import cancel_gradients_backbone, EMA
|
||||
from lib.utils.checkpoint import SegCheckpoint
|
||||
|
||||
try:
|
||||
import atorch
|
||||
from atorch import amp
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@registry.register_trainer("seg_trainer")
|
||||
class SEGTrainer(BaseTrainer):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.enable_torch_amp=True
|
||||
self.enable_atorch_amp=False
|
||||
|
||||
def load(self, has_check_point=True):
|
||||
super().load(has_check_point)
|
||||
torch.backends.cuda.matmul.allow_tf32 = self.config.training_parameters.get(
|
||||
"enable_tf32", False)
|
||||
if hasattr(
|
||||
self.config.training_parameters, "freeze_backbone"
|
||||
) and self.config.training_parameters.freeze_backbone is True:
|
||||
for n, p in self.model.named_parameters():
|
||||
if "backbone_hr." in n or 'backbone_s2.' in n or 'head_s2.' in n or 'backbone_s1.' in n or 'head_s1.' in n or 'fusion.' in n or 'ctpe' in n or 'glbank.' in n:
|
||||
p.requires_grad = False
|
||||
else:
|
||||
print(n, '-->', p.requires_grad)
|
||||
if hasattr(self.config.training_parameters,
|
||||
"ema") and self.config.training_parameters.ema is True:
|
||||
self.ema = EMA(self.model, 0.96)
|
||||
self.ema.register()
|
||||
|
||||
def load_extras(self, has_check_point=True):
|
||||
self.checkpoint = None if has_check_point is False else SegCheckpoint(
|
||||
self)
|
||||
self.meter = Meter()
|
||||
|
||||
self.training_parameters = self.config.training_parameters
|
||||
|
||||
monitored_metric = self.training_parameters.monitored_metric
|
||||
metric_minimize = self.training_parameters.metric_minimize
|
||||
should_early_stop = self.training_parameters.should_early_stop
|
||||
patience = self.training_parameters.patience
|
||||
|
||||
self.log_interval = self.training_parameters.log_interval
|
||||
self.snapshot_interval = self.training_parameters.snapshot_interval
|
||||
self.max_iterations = self.training_parameters.max_iterations
|
||||
self.should_clip_gradients = self.training_parameters.clip_gradients
|
||||
self.max_epochs = self.training_parameters.max_epochs
|
||||
self.gradient_accumulation_steps = int(
|
||||
self.training_parameters.gradient_accumulation_steps)
|
||||
assert self.gradient_accumulation_steps >= 1
|
||||
for t_type in self.task_loader.task_type:
|
||||
if t_type == "train":
|
||||
self.dataset_train_order = self.training_parameters.get(
|
||||
"dataset_train_order", self.train_task.datasets_name)
|
||||
if t_type == "val":
|
||||
self.dataset_val_order = self.training_parameters.get(
|
||||
"dataset_val_order", self.val_task.datasets_name)
|
||||
if t_type == "test":
|
||||
self.dataset_test_order = self.training_parameters.get(
|
||||
"dataset_test_order", self.test_task.datasets_name)
|
||||
if t_type == "interpret":
|
||||
self.dataset_interpret_order = self.training_parameters.get(
|
||||
"dataset_interpret_order",
|
||||
self.interpret_task.datasets_name)
|
||||
|
||||
self.early_stopping = EarlyStopping(
|
||||
self.model,
|
||||
self.checkpoint,
|
||||
monitored_metric,
|
||||
patience=patience,
|
||||
minimize=metric_minimize,
|
||||
should_stop=should_early_stop,
|
||||
)
|
||||
self.current_epoch = 1
|
||||
self.current_iteration = 0
|
||||
|
||||
self.not_debug = self.training_parameters.logger_level != "debug"
|
||||
|
||||
self.lr_scheduler = None
|
||||
self.setup_lr_scheduler()
|
||||
|
||||
if self.checkpoint is not None:
|
||||
self.checkpoint.load_state_dict()
|
||||
|
||||
if "overall_metrics" in self.training_parameters:
|
||||
self.overall_metric_evaluator = Metrics(
|
||||
self.config.training_parameters.get("overall_metrics", []))
|
||||
self.synchronized_loss = self.config.training_parameters.synchronized_loss
|
||||
|
||||
def train(self):
|
||||
self.writer.write("===== Model =====")
|
||||
self.writer.write(self.model)
|
||||
self.writer.write(
|
||||
"Model Params: Trainable {Trainable:.3f}M Total {Total:.3f}M".
|
||||
format(**count_parameters(self.model)))
|
||||
|
||||
if "train" not in self.run_type:
|
||||
self.inference()
|
||||
return
|
||||
|
||||
should_break = False
|
||||
|
||||
if self.max_epochs is None:
|
||||
self.max_epochs = math.inf
|
||||
else:
|
||||
self.max_iterations = min(self.max_iterations,
|
||||
self.max_epochs * self.epoch_iterations)
|
||||
|
||||
self.model.train()
|
||||
self.train_timer = Timer()
|
||||
|
||||
self.profile("Setup Time")
|
||||
|
||||
if self.enable_torch_amp:
|
||||
self.writer.write("Using Automatic mixed precision training")
|
||||
if hasattr(self.config, "amp_attributes") and hasattr(
|
||||
self.config.amp_attributes, "growth_interval"):
|
||||
growth_interval = self.config.amp_attributes.growth_interval
|
||||
else:
|
||||
growth_interval = 2000
|
||||
self.scaler = torch.cuda.amp.GradScaler(
|
||||
init_scale=self.config.amp_attributes.init_scale,
|
||||
enabled=False,
|
||||
growth_interval=growth_interval)
|
||||
self.writer.write("Using Init scale:%s" %
|
||||
self.config.amp_attributes.init_scale)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
self.writer.write("Starting training...")
|
||||
while self.current_iteration < self.max_iterations and not should_break:
|
||||
registry.register("current_epoch", self.current_epoch)
|
||||
self.task_loader.seed_sampler("train", self.current_epoch)
|
||||
|
||||
if self.current_epoch > self.max_epochs:
|
||||
break
|
||||
|
||||
for batch in tqdm(
|
||||
chain(*self.train_loader_list),
|
||||
total=self._len_of_loader_list(self.train_loader_list),
|
||||
disable=self.disable_tqdm or (not is_main_process()),
|
||||
):
|
||||
self.profile("Batch load time")
|
||||
report, _, _ = self._forward_pass(
|
||||
batch, enable_amp=self.enable_torch_amp)
|
||||
if report is None:
|
||||
continue
|
||||
|
||||
self._update_meter(report, self.meter)
|
||||
|
||||
loss = self._extract_loss(report)
|
||||
self._backward(loss)
|
||||
if hasattr(
|
||||
self.config.training_parameters,
|
||||
"ema") and self.config.training_parameters.ema is True:
|
||||
self.ema.update()
|
||||
should_break = self._logistics()
|
||||
|
||||
self._run_scheduler()
|
||||
|
||||
self.current_iteration += 1
|
||||
self.writer.write(self.current_iteration, "debug")
|
||||
registry.register("current_iteration", self.current_iteration)
|
||||
if self.current_iteration >= self.max_iterations:
|
||||
break
|
||||
if should_break:
|
||||
break
|
||||
|
||||
self.current_epoch += 1
|
||||
|
||||
self.finalize()
|
||||
|
||||
def _forward_pass(self, batch, enable_amp=False):
|
||||
if not batch: # Samplelist might be empty dict
|
||||
return None, None, None
|
||||
prepared_batch = self.task_loader.prepare_batch(batch)
|
||||
|
||||
self.profile("Batch prepare time")
|
||||
forward_context = torch.cuda.amp.autocast(
|
||||
enabled=True,
|
||||
dtype=torch.bfloat16) if enable_amp else nullcontext()
|
||||
|
||||
with forward_context:
|
||||
# Arguments should be a dict at this point
|
||||
model_output = self.model(prepared_batch)
|
||||
|
||||
if self.synchronized_loss:
|
||||
is_parallel = isinstance(
|
||||
self.model, nn.DataParallel) or isinstance(
|
||||
self.model, nn.parallel.DistributedDataParallel)
|
||||
if "losses" not in model_output:
|
||||
loss_func = getattr(
|
||||
self.model.module if is_parallel else self.model,
|
||||
"losses")
|
||||
model_output["losses"] = loss_func(
|
||||
prepared_batch,
|
||||
model_output,
|
||||
iteration=self.current_iteration)
|
||||
if "metrics" not in model_output:
|
||||
metric_func = getattr(
|
||||
self.model.module if is_parallel else self.model,
|
||||
"metrics")
|
||||
model_output["metrics"] = metric_func(
|
||||
prepared_batch, model_output)
|
||||
|
||||
report = Report(prepared_batch, model_output)
|
||||
self.profile("Forward time")
|
||||
|
||||
return report, model_output, prepared_batch
|
||||
|
||||
def _backward(self, loss):
|
||||
loss = loss / self.gradient_accumulation_steps
|
||||
|
||||
if self.enable_torch_amp:
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
# Unscales the gradients of optimizer's assigned params in-place, this should
|
||||
# be called first so that clip_gradients can take effect as usual.
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
elif self.enable_atorch_amp:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
self.profile("Backward time")
|
||||
|
||||
if self.current_iteration % self.gradient_accumulation_steps != 0:
|
||||
return
|
||||
|
||||
if self.should_clip_gradients:
|
||||
if self.enable_atorch_amp:
|
||||
clip_gradients(amp.master_params(self.optimizer),
|
||||
self.current_iteration, self.writer,
|
||||
self.config)
|
||||
else:
|
||||
clip_gradients(self.model, self.current_iteration, self.writer,
|
||||
self.config)
|
||||
|
||||
if hasattr(
|
||||
self.config.training_parameters, "freeze_backbone_steps"
|
||||
) and self.config.training_parameters.freeze_backbone_steps is not None:
|
||||
cancel_gradients_backbone(
|
||||
self.current_iteration, self.model,
|
||||
self.config.training_parameters.freeze_backbone_steps)
|
||||
|
||||
if self.enable_torch_amp:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
self.profile("Optimizer time")
|
||||
|
||||
def _logistics(self):
|
||||
should_print = self.current_iteration and self.current_iteration % self.log_interval == 0
|
||||
extra = {}
|
||||
prefix = ""
|
||||
|
||||
if should_print is True:
|
||||
if "cuda" in str(self.device):
|
||||
extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
|
||||
extra["max mem"] //= 1024
|
||||
|
||||
# display lr
|
||||
if isinstance(self.optimizer, CombinedOptimizer):
|
||||
extra["lr"] = self.optimizer.get_optimizers_lr_str()
|
||||
else:
|
||||
extra["lr"] = "|".join([
|
||||
"{:.8f}".format(x["lr"]).rstrip("0")
|
||||
for x in self.optimizer.param_groups
|
||||
])
|
||||
|
||||
extra.update({
|
||||
"time": self.train_timer.get_time_since_start(),
|
||||
"eta": self._calculate_time_left(),
|
||||
})
|
||||
|
||||
self.train_timer.reset()
|
||||
|
||||
self._summarize_meter(
|
||||
self.meter,
|
||||
prefix=prefix,
|
||||
extra=extra,
|
||||
should_print=should_print,
|
||||
)
|
||||
|
||||
should_break = self._try_full_validation()
|
||||
|
||||
return should_break
|
||||
|
||||
def _try_full_validation(self, force=False):
|
||||
should_break = False
|
||||
|
||||
if self.current_iteration and self.current_iteration % self.snapshot_interval == 0 or force:
|
||||
self.writer.write(
|
||||
"Evaluation time. Running on full validation set...")
|
||||
|
||||
validation_timer = Timer()
|
||||
dataset_name, meter = self.evaluate_set(self.val_loader_list)
|
||||
extra = {
|
||||
"validation time": validation_timer.get_time_since_start()
|
||||
}
|
||||
|
||||
overall_metric = self.overall_metric_evaluator.summarize()
|
||||
stop = self.early_stopping(self.current_iteration, overall_metric,
|
||||
meter)
|
||||
if hasattr(self.config.training_parameters,
|
||||
"ema") and self.config.training_parameters.ema is True:
|
||||
self.ema.restore()
|
||||
stop = bool(broadcast_scalar(stop, src=0, device=self.device))
|
||||
|
||||
extra.update(self.early_stopping.get_info())
|
||||
|
||||
prefix = "{}: full val".format(dataset_name)
|
||||
self._summarize_overall(overall_metric,
|
||||
meter,
|
||||
prefix=prefix,
|
||||
extra=extra)
|
||||
gc.collect()
|
||||
|
||||
if "cuda" in str(self.device):
|
||||
with torch.cuda.device(self.device):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if stop > 0: # `stop` is now `int`, NCCL does not support `boolean` type's broadcasting
|
||||
self.writer.write("Early stopping activated")
|
||||
should_break = True
|
||||
|
||||
return should_break
|
||||
|
||||
def evaluate_set(self, loader_list):
|
||||
from antmmf.structures import SampleList
|
||||
|
||||
meter = Meter()
|
||||
torch.cuda.empty_cache()
|
||||
with torch.no_grad():
|
||||
self.model.eval()
|
||||
if hasattr(self.config.training_parameters,
|
||||
"ema") and self.config.training_parameters.ema is True:
|
||||
self.ema.apply_shadow()
|
||||
if self.config.training_parameters.get('fp16', False):
|
||||
self.model.half()
|
||||
self.overall_metric_evaluator.reset()
|
||||
for idx, batch in tqdm(
|
||||
enumerate(chain(*loader_list)),
|
||||
total=self._len_of_loader_list(loader_list),
|
||||
disable=not is_main_process() or self.disable_tqdm,
|
||||
):
|
||||
# report, model_output, prepared_batch = self._forward_pass(
|
||||
# batch, enable_amp=self.enable_torch_amp)
|
||||
if idx >= self.config.training_parameters.get('num_eval', 1e7):
|
||||
break
|
||||
if self.config.training_parameters.get('fp16', False):
|
||||
input_dict = SampleList()
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.cuda.FloatTensor):
|
||||
input_dict[k] = v.half()
|
||||
else:
|
||||
input_dict[k] = v
|
||||
report, model_output, prepared_batch = self._forward_pass(
|
||||
input_dict, enable_amp=self.enable_torch_amp)
|
||||
else:
|
||||
report, model_output, prepared_batch = self._forward_pass(
|
||||
batch, enable_amp=self.enable_torch_amp)
|
||||
self._update_meter(report, meter)
|
||||
self.overall_metric_evaluator.collect(prepared_batch,
|
||||
model_output)
|
||||
for _, metric_object in self.overall_metric_evaluator.metrics.items(
|
||||
):
|
||||
metric_object.all_reduce()
|
||||
self.model.train()
|
||||
|
||||
return report.dataset_name, meter
|
||||
Reference in New Issue
Block a user