This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

View File

@@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .hooks import SegVisualizationHook
from .optimizers import (ForceDefaultOptimWrapperConstructor,
LayerDecayOptimizerConstructor,
LearningRateDecayOptimizerConstructor)
from .schedulers import PolyLRRatio
__all__ = [
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
'SegVisualizationHook', 'PolyLRRatio',
'ForceDefaultOptimWrapperConstructor'
]

View File

@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .visualization_hook import SegVisualizationHook
__all__ = ['SegVisualizationHook']

View File

@@ -0,0 +1,129 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import Optional, Sequence
import mmcv
from mmengine.fileio import get
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmengine.visualization import Visualizer
from mmseg.registry import HOOKS
from mmseg.structures import SegDataSample
@HOOKS.register_module()
class SegVisualizationHook(Hook):
"""Segmentation Visualization Hook. Used to visualize validation and
testing process prediction results.
In the testing phase:
1. If ``show`` is True, it means that only the prediction results are
visualized without storing data, so ``vis_backends`` needs to
be excluded.
Args:
draw (bool): whether to draw prediction results. If it is False,
it means that no drawing will be done. Defaults to False.
interval (int): The interval of visualization. Defaults to 50.
show (bool): Whether to display the drawn image. Default to False.
wait_time (float): The interval of show (s). Defaults to 0.
backend_args (dict, Optional): Arguments to instantiate a file backend.
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
for details. Defaults to None.
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
"""
def __init__(self,
draw: bool = False,
interval: int = 50,
show: bool = False,
wait_time: float = 0.,
backend_args: Optional[dict] = None):
self._visualizer: Visualizer = Visualizer.get_current_instance()
self.interval = interval
self.show = show
if self.show:
# No need to think about vis backends.
self._visualizer._vis_backends = {}
warnings.warn('The show is True, it means that only '
'the prediction results are visualized '
'without storing data, so vis_backends '
'needs to be excluded.')
self.wait_time = wait_time
self.backend_args = backend_args.copy() if backend_args else None
self.draw = draw
if not self.draw:
warnings.warn('The draw is False, it means that the '
'hook for visualization will not take '
'effect. The results will NOT be '
'visualized or stored.')
self._test_index = 0
def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[SegDataSample]) -> None:
"""Run after every ``self.interval`` validation iterations.
Args:
runner (:obj:`Runner`): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`SegDataSample`]]): A batch of data samples
that contain annotations and predictions.
"""
if self.draw is False:
return
# There is no guarantee that the same batch of images
# is visualized for each evaluation.
total_curr_iter = runner.iter + batch_idx
# Visualize only the first data
img_path = outputs[0].img_path
img_bytes = get(img_path, backend_args=self.backend_args)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
window_name = f'val_{osp.basename(img_path)}'
if total_curr_iter % self.interval == 0:
self._visualizer.add_datasample(
window_name,
img,
data_sample=outputs[0],
show=self.show,
wait_time=self.wait_time,
step=total_curr_iter)
def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[SegDataSample]) -> None:
"""Run after every testing iterations.
Args:
runner (:obj:`Runner`): The runner of the testing process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`SegDataSample`]): A batch of data samples
that contain annotations and predictions.
"""
if self.draw is False:
return
for data_sample in outputs:
self._test_index += 1
img_path = data_sample.img_path
window_name = f'test_{osp.basename(img_path)}'
img_path = data_sample.img_path
img_bytes = get(img_path, backend_args=self.backend_args)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
self._visualizer.add_datasample(
window_name,
img,
data_sample=data_sample,
show=self.show,
wait_time=self.wait_time,
step=self._test_index)

View File

@@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .force_default_constructor import ForceDefaultOptimWrapperConstructor
from .layer_decay_optimizer_constructor import (
LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor)
__all__ = [
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
'ForceDefaultOptimWrapperConstructor'
]

View File

@@ -0,0 +1,255 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import List, Optional, Union
import torch
import torch.nn as nn
from mmengine.logging import print_log
from mmengine.optim import DefaultOptimWrapperConstructor
from mmengine.utils.dl_utils import mmcv_full_available
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
from torch.nn import GroupNorm, LayerNorm
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class ForceDefaultOptimWrapperConstructor(DefaultOptimWrapperConstructor):
"""Default constructor with forced optimizer settings.
This constructor extends the default constructor to add an option for
forcing default optimizer settings. This is useful for ensuring that
certain parameters or layers strictly adhere to pre-defined default
settings, regardless of any custom settings specified.
By default, each parameter share the same optimizer settings, and we
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
It is a dict and may contain various fields like 'custom_keys',
'bias_lr_mult', etc., as well as the additional field
`force_default_settings` which allows for enforcing default settings on
optimizer parameters.
- ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
one of the keys in ``custom_keys`` is a substring of the name of one
parameter, then the setting of the parameter will be specified by
``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will
be ignored. It should be noted that the aforementioned ``key`` is the
longest key that is a substring of the name of the parameter. If there
are multiple matched keys with the same length, then the key with lower
alphabet order will be chosen.
``custom_keys[key]`` should be a dict and may contain fields ``lr_mult``
and ``decay_mult``. See Example 2 below.
- ``bias_lr_mult`` (float): It will be multiplied to the learning
rate for all bias parameters (except for those in normalization
layers and offset layers of DCN).
- ``bias_decay_mult`` (float): It will be multiplied to the weight
decay for all bias parameters (except for those in
normalization layers, depthwise conv layers, offset layers of DCN).
- ``norm_decay_mult`` (float): It will be multiplied to the weight
decay for all weight and bias parameters of normalization
layers.
- ``flat_decay_mult`` (float): It will be multiplied to the weight
decay for all one-dimensional parameters
- ``dwconv_decay_mult`` (float): It will be multiplied to the weight
decay for all weight and bias parameters of depthwise conv
layers.
- ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning
rate for parameters of offset layer in the deformable convs
of a model.
- ``bypass_duplicate`` (bool): If true, the duplicate parameters
would not be added into optimizer. Defaults to False.
- ``force_default_settings`` (bool): If true, this will override any
custom settings defined by ``custom_keys`` and enforce the use of
default settings for optimizer parameters like ``bias_lr_mult``.
This is particularly useful when you want to ensure that certain layers
or parameters adhere strictly to the pre-defined default settings.
Note:
1. If the option ``dcn_offset_lr_mult`` is used, the constructor will
override the effect of ``bias_lr_mult`` in the bias of offset layer.
So be careful when using both ``bias_lr_mult`` and
``dcn_offset_lr_mult``. If you wish to apply both of them to the offset
layer in deformable convs, set ``dcn_offset_lr_mult`` to the original
``dcn_offset_lr_mult`` * ``bias_lr_mult``.
2. If the option ``dcn_offset_lr_mult`` is used, the constructor will
apply it to all the DCN layers in the model. So be careful when the
model contains multiple DCN layers in places other than backbone.
3. When the option ``force_default_settings`` is true, it will override
any custom settings provided in ``custom_keys``. This ensures that the
default settings for the optimizer parameters are used.
Args:
optim_wrapper_cfg (dict): The config dict of the optimizer wrapper.
Required fields of ``optim_wrapper_cfg`` are
- ``type``: class name of the OptimizerWrapper
- ``optimizer``: The configuration of optimizer.
Optional fields of ``optim_wrapper_cfg`` are
- any arguments of the corresponding optimizer wrapper type,
e.g., accumulative_counts, clip_grad, etc.
Required fields of ``optimizer`` are
- `type`: class name of the optimizer.
Optional fields of ``optimizer`` are
- any arguments of the corresponding optimizer type, e.g.,
lr, weight_decay, momentum, etc.
paramwise_cfg (dict, optional): Parameter-wise options.
Example 1:
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
>>> optim_wrapper_cfg = dict(
>>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01,
>>> momentum=0.9, weight_decay=0.0001))
>>> paramwise_cfg = dict(norm_decay_mult=0.)
>>> optim_wrapper_builder = DefaultOptimWrapperConstructor(
>>> optim_wrapper_cfg, paramwise_cfg)
>>> optim_wrapper = optim_wrapper_builder(model)
Example 2:
>>> # assume model have attribute model.backbone and model.cls_head
>>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict(
>>> type='SGD', lr=0.01, weight_decay=0.95))
>>> paramwise_cfg = dict(custom_keys={
>>> 'backbone': dict(lr_mult=0.1, decay_mult=0.9)})
>>> optim_wrapper_builder = DefaultOptimWrapperConstructor(
>>> optim_wrapper_cfg, paramwise_cfg)
>>> optim_wrapper = optim_wrapper_builder(model)
>>> # Then the `lr` and `weight_decay` for model.backbone is
>>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
>>> # model.cls_head is (0.01, 0.95).
"""
def add_params(self,
params: List[dict],
module: nn.Module,
prefix: str = '',
is_dcn_module: Optional[Union[int, float]] = None) -> None:
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (list[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
prefix (str): The prefix of the module
is_dcn_module (int|float|None): If the current module is a
submodule of DCN, `is_dcn_module` will be passed to
control conv_offset layer's learning rate. Defaults to None.
"""
# get param-wise options
custom_keys = self.paramwise_cfg.get('custom_keys', {})
# first sort with alphabet order and then sort with reversed len of str
sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', None)
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None)
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None)
dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', None)
flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None)
bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', None)
force_default_settings = self.paramwise_cfg.get(
'force_default_settings', False)
# special rules for norm layers and depth-wise conv layers
is_norm = isinstance(module,
(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
is_dwconv = (
isinstance(module, torch.nn.Conv2d)
and module.in_channels == module.groups)
for name, param in module.named_parameters(recurse=False):
param_group = {'params': [param]}
if bypass_duplicate and self._is_in(param_group, params):
print_log(
f'{prefix} is duplicate. It is skipped since '
f'bypass_duplicate={bypass_duplicate}',
logger='current',
level=logging.WARNING)
continue
if not param.requires_grad:
params.append(param_group)
continue
# if the parameter match one of the custom keys, ignore other rules
is_custom = False
for key in sorted_keys:
if key in f'{prefix}.{name}':
is_custom = True
lr_mult = custom_keys[key].get('lr_mult', 1.)
param_group['lr'] = self.base_lr * lr_mult
if self.base_wd is not None:
decay_mult = custom_keys[key].get('decay_mult', 1.)
param_group['weight_decay'] = self.base_wd * decay_mult
# add custom settings to param_group
for k, v in custom_keys[key].items():
param_group[k] = v
break
if not is_custom or force_default_settings:
# bias_lr_mult affects all bias parameters
# except for norm.bias dcn.conv_offset.bias
if name == 'bias' and not (
is_norm or is_dcn_module) and bias_lr_mult is not None:
param_group['lr'] = self.base_lr * bias_lr_mult
if (prefix.find('conv_offset') != -1 and is_dcn_module
and dcn_offset_lr_mult is not None
and isinstance(module, torch.nn.Conv2d)):
# deal with both dcn_offset's bias & weight
param_group['lr'] = self.base_lr * dcn_offset_lr_mult
# apply weight decay policies
if self.base_wd is not None:
# norm decay
if is_norm and norm_decay_mult is not None:
param_group[
'weight_decay'] = self.base_wd * norm_decay_mult
# bias lr and decay
elif (name == 'bias' and not is_dcn_module
and bias_decay_mult is not None):
param_group[
'weight_decay'] = self.base_wd * bias_decay_mult
# depth-wise conv
elif is_dwconv and dwconv_decay_mult is not None:
param_group[
'weight_decay'] = self.base_wd * dwconv_decay_mult
# flatten parameters except dcn offset
elif (param.ndim == 1 and not is_dcn_module
and flat_decay_mult is not None):
param_group[
'weight_decay'] = self.base_wd * flat_decay_mult
params.append(param_group)
for key, value in param_group.items():
if key == 'params':
continue
full_name = f'{prefix}.{name}' if prefix else name
print_log(
f'paramwise_options -- {full_name}:{key}={value}',
logger='current')
if mmcv_full_available():
from mmcv.ops import DeformConv2d, ModulatedDeformConv2d
is_dcn_module = isinstance(module,
(DeformConv2d, ModulatedDeformConv2d))
else:
is_dcn_module = False
for child_name, child_mod in module.named_children():
child_prefix = f'{prefix}.{child_name}' if prefix else child_name
self.add_params(
params,
child_mod,
prefix=child_prefix,
is_dcn_module=is_dcn_module)

View File

@@ -0,0 +1,207 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import warnings
from mmengine.dist import get_dist_info
from mmengine.logging import print_log
from mmengine.optim import DefaultOptimWrapperConstructor
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
def get_layer_id_for_convnext(var_name, max_layer_id):
"""Get the layer id to set the different learning rates in ``layer_wise``
decay_type.
Args:
var_name (str): The key of the model.
max_layer_id (int): Maximum number of backbone layers.
Returns:
int: The id number corresponding to different learning rate in
``LearningRateDecayOptimizerConstructor``.
"""
if var_name in ('backbone.cls_token', 'backbone.mask_token',
'backbone.pos_embed'):
return 0
elif var_name.startswith('backbone.downsample_layers'):
stage_id = int(var_name.split('.')[2])
if stage_id == 0:
layer_id = 0
elif stage_id == 1:
layer_id = 2
elif stage_id == 2:
layer_id = 3
elif stage_id == 3:
layer_id = max_layer_id
return layer_id
elif var_name.startswith('backbone.stages'):
stage_id = int(var_name.split('.')[2])
block_id = int(var_name.split('.')[3])
if stage_id == 0:
layer_id = 1
elif stage_id == 1:
layer_id = 2
elif stage_id == 2:
layer_id = 3 + block_id // 3
elif stage_id == 3:
layer_id = max_layer_id
return layer_id
else:
return max_layer_id + 1
def get_stage_id_for_convnext(var_name, max_stage_id):
"""Get the stage id to set the different learning rates in ``stage_wise``
decay_type.
Args:
var_name (str): The key of the model.
max_stage_id (int): Maximum number of backbone layers.
Returns:
int: The id number corresponding to different learning rate in
``LearningRateDecayOptimizerConstructor``.
"""
if var_name in ('backbone.cls_token', 'backbone.mask_token',
'backbone.pos_embed'):
return 0
elif var_name.startswith('backbone.downsample_layers'):
return 0
elif var_name.startswith('backbone.stages'):
stage_id = int(var_name.split('.')[2])
return stage_id + 1
else:
return max_stage_id - 1
def get_layer_id_for_vit(var_name, max_layer_id):
"""Get the layer id to set the different learning rates.
Args:
var_name (str): The key of the model.
num_max_layer (int): Maximum number of backbone layers.
Returns:
int: Returns the layer id of the key.
"""
if var_name in ('backbone.cls_token', 'backbone.mask_token',
'backbone.pos_embed'):
return 0
elif var_name.startswith('backbone.patch_embed'):
return 0
elif var_name.startswith('backbone.layers'):
layer_id = int(var_name.split('.')[2])
return layer_id + 1
else:
return max_layer_id - 1
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
"""Different learning rates are set for different layers of backbone.
Note: Currently, this optimizer constructor is built for ConvNeXt,
BEiT and MAE.
"""
def add_params(self, params, module, **kwargs):
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (list[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
"""
parameter_groups = {}
print_log(f'self.paramwise_cfg is {self.paramwise_cfg}')
num_layers = self.paramwise_cfg.get('num_layers') + 2
decay_rate = self.paramwise_cfg.get('decay_rate')
decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
print_log('Build LearningRateDecayOptimizerConstructor '
f'{decay_type} {decay_rate} - {num_layers}')
weight_decay = self.base_wd
for name, param in module.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith('.bias') or name in (
'pos_embed', 'cls_token'):
group_name = 'no_decay'
this_weight_decay = 0.
else:
group_name = 'decay'
this_weight_decay = weight_decay
if 'layer_wise' in decay_type:
if 'ConvNeXt' in module.backbone.__class__.__name__:
layer_id = get_layer_id_for_convnext(
name, self.paramwise_cfg.get('num_layers'))
print_log(f'set param {name} as id {layer_id}')
elif 'BEiT' in module.backbone.__class__.__name__ or \
'MAE' in module.backbone.__class__.__name__:
layer_id = get_layer_id_for_vit(name, num_layers)
print_log(f'set param {name} as id {layer_id}')
else:
raise NotImplementedError()
elif decay_type == 'stage_wise':
if 'ConvNeXt' in module.backbone.__class__.__name__:
layer_id = get_stage_id_for_convnext(name, num_layers)
print_log(f'set param {name} as id {layer_id}')
else:
raise NotImplementedError()
group_name = f'layer_{layer_id}_{group_name}'
if group_name not in parameter_groups:
scale = decay_rate**(num_layers - layer_id - 1)
parameter_groups[group_name] = {
'weight_decay': this_weight_decay,
'params': [],
'param_names': [],
'lr_scale': scale,
'group_name': group_name,
'lr': scale * self.base_lr,
}
parameter_groups[group_name]['params'].append(param)
parameter_groups[group_name]['param_names'].append(name)
rank, _ = get_dist_info()
if rank == 0:
to_display = {}
for key in parameter_groups:
to_display[key] = {
'param_names': parameter_groups[key]['param_names'],
'lr_scale': parameter_groups[key]['lr_scale'],
'lr': parameter_groups[key]['lr'],
'weight_decay': parameter_groups[key]['weight_decay'],
}
print_log(f'Param groups = {json.dumps(to_display, indent=2)}')
params.extend(parameter_groups.values())
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
"""Different learning rates are set for different layers of backbone.
Note: Currently, this optimizer constructor is built for BEiT,
and it will be deprecated.
Please use ``LearningRateDecayOptimizerConstructor`` instead.
"""
def __init__(self, optim_wrapper_cfg, paramwise_cfg):
warnings.warn('DeprecationWarning: Original '
'LayerDecayOptimizerConstructor of BEiT '
'will be deprecated. Please use '
'LearningRateDecayOptimizerConstructor instead, '
'and set decay_type = layer_wise_vit in paramwise_cfg.')
paramwise_cfg.update({'decay_type': 'layer_wise_vit'})
warnings.warn('DeprecationWarning: Layer_decay_rate will '
'be deleted, please use decay_rate instead.')
paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate')
super().__init__(optim_wrapper_cfg, paramwise_cfg)

View File

@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .poly_ratio_scheduler import PolyLRRatio
__all__ = ['PolyLRRatio']

View File

@@ -0,0 +1,62 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from mmengine.optim.scheduler import PolyLR
from mmseg.registry import PARAM_SCHEDULERS
@PARAM_SCHEDULERS.register_module()
class PolyLRRatio(PolyLR):
"""Implements polynomial learning rate decay with ratio.
This scheduler adjusts the learning rate of each parameter group
following a polynomial decay equation. The decay can occur in
conjunction with external parameter adjustments made outside this
scheduler.
Args:
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
eta_min (float): Minimum learning rate at the end of scheduling.
Defaults to 0.
eta_min_ratio (float, optional): The ratio of the minimum parameter
value to the base parameter value. Either `eta_min` or
`eta_min_ratio` should be specified. Defaults to None.
power (float): The power of the polynomial. Defaults to 1.0.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
def __init__(self, eta_min_ratio: Optional[int] = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.eta_min_ratio = eta_min_ratio
def _get_value(self):
"""Compute value using chainable form of the scheduler."""
if self.last_step == 0:
return [
group[self.param_name] for group in self.optimizer.param_groups
]
param_groups_value = []
for base_value, param_group in zip(self.base_values,
self.optimizer.param_groups):
eta_min = self.eta_min if self.eta_min_ratio is None else \
base_value * self.eta_min_ratio
step_ratio = (1 - 1 /
(self.total_iters - self.last_step + 1))**self.power
step_value = (param_group[self.param_name] -
eta_min) * step_ratio + eta_min
param_groups_value.append(step_value)
return param_groups_value