init
This commit is contained in:
303
finetune/configs/atlantic.py
Normal file
303
finetune/configs/atlantic.py
Normal file
@@ -0,0 +1,303 @@
|
||||
crop_size = (
|
||||
256,
|
||||
256,
|
||||
)
|
||||
data_preprocessor = dict(
|
||||
bgr_to_rgb=True,
|
||||
mean=[
|
||||
537.9411629981602,
|
||||
615.7886221108977,
|
||||
343.4481583821405,
|
||||
3010.641650390625,
|
||||
],
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=crop_size,
|
||||
std=[
|
||||
367.4598430230881,
|
||||
254.2473100510193,
|
||||
187.5437562223154,
|
||||
921.0792775874182,
|
||||
],
|
||||
type='SegDataPreProcessor')
|
||||
data_root = 'rs_datasets/'
|
||||
dataset_type = 'AtlanticDataset'
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
by_epoch=False, interval=1000, save_best='mIoU',max_keep_ckpts=1,
|
||||
type='CheckpointHook'),
|
||||
logger=dict(interval=50, log_metric_by_epoch=False, type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
visualization=dict(type='SegVisualizationHook'))
|
||||
default_scope = 'mmseg'
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=True,
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
|
||||
img_ratios = [
|
||||
0.5,
|
||||
0.75,
|
||||
1.0,
|
||||
1.25,
|
||||
1.5,
|
||||
1.75,
|
||||
]
|
||||
launcher = 'pytorch'
|
||||
load_from = None
|
||||
log_level = 'INFO'
|
||||
log_processor = dict(by_epoch=False)
|
||||
mean = [
|
||||
537.9411629981602,
|
||||
615.7886221108977,
|
||||
343.4481583821405,
|
||||
3010.641650390625,
|
||||
]
|
||||
model = dict(
|
||||
auxiliary_head=dict(
|
||||
align_corners=False,
|
||||
channels=256,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=512,
|
||||
in_index=3,
|
||||
loss_decode=dict(
|
||||
loss_weight=0.4, type='CrossEntropyLoss', use_sigmoid=False),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=2,
|
||||
num_convs=1,
|
||||
type='FCNHead'),
|
||||
backbone=dict(
|
||||
act_cfg=dict(type='GELU'),
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
drop_rate=0.1,
|
||||
embed_dims=1024,
|
||||
img_size=crop_size,
|
||||
in_channels=4,
|
||||
interpolate_mode='bilinear',
|
||||
mlp_ratio=4,
|
||||
norm_cfg=dict(eps=1e-06, requires_grad=True, type='LN'),
|
||||
norm_eval=False,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
out_indices=(
|
||||
5,
|
||||
11,
|
||||
17,
|
||||
23,
|
||||
),
|
||||
patch_size=4,
|
||||
qkv_bias=True,
|
||||
type='VisionTransformer',
|
||||
with_cls_token=False),
|
||||
data_preprocessor=dict(
|
||||
bgr_to_rgb=True,
|
||||
mean=[
|
||||
537.9411629981602,
|
||||
615.7886221108977,
|
||||
343.4481583821405,
|
||||
3010.641650390625,
|
||||
],
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=crop_size,
|
||||
std=[
|
||||
367.4598430230881,
|
||||
254.2473100510193,
|
||||
187.5437562223154,
|
||||
921.0792775874182,
|
||||
],
|
||||
type='SegDataPreProcessor'),
|
||||
decode_head=dict(
|
||||
align_corners=False,
|
||||
channels=512,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=[
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
],
|
||||
in_index=[
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
],
|
||||
loss_decode=dict(
|
||||
loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=False),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=2,
|
||||
pool_scales=(
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
6,
|
||||
),
|
||||
type='UPerHead'),
|
||||
neck=dict(
|
||||
in_channels=[
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
],
|
||||
out_channels=512,
|
||||
scales=[
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
],
|
||||
type='MultiLevelNeck'),
|
||||
pretrained="pretrain/skysensepp_release_s2.pth",
|
||||
test_cfg=dict(mode='whole'),
|
||||
train_cfg=dict(),
|
||||
type='EncoderDecoder')
|
||||
norm_cfg = dict(requires_grad=True, type='SyncBN')
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
betas=(
|
||||
0.9,
|
||||
0.999,
|
||||
), lr=6e-05, type='AdamW', weight_decay=0.01),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys=dict(
|
||||
cls_token=dict(decay_mult=0.0),
|
||||
norm=dict(decay_mult=0.0),
|
||||
pos_embed=dict(decay_mult=0.0))),
|
||||
type='OptimWrapper')
|
||||
optimizer = dict(lr=0.01, momentum=0.9, type='SGD', weight_decay=0.0005)
|
||||
param_scheduler = [
|
||||
dict(
|
||||
begin=0, by_epoch=False, end=1000, start_factor=1e-06,
|
||||
type='LinearLR'),
|
||||
dict(
|
||||
begin=1000,
|
||||
by_epoch=False,
|
||||
end=10000,
|
||||
eta_min=0.0,
|
||||
power=1.0,
|
||||
type='PolyLR'),
|
||||
]
|
||||
randomness = dict(seed=20240315)
|
||||
resume = False
|
||||
std = [
|
||||
367.4598430230881,
|
||||
254.2473100510193,
|
||||
187.5437562223154,
|
||||
921.0792775874182,
|
||||
]
|
||||
test_cfg = dict(type='TestLoop')
|
||||
test_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'deforestation_atlantic/deforestation_atlantic_test.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadSingleRSImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='AtlanticDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
test_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore',
|
||||
], type='IoUMetric')
|
||||
test_pipeline = [
|
||||
dict(type='LoadSingleRSImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs'),
|
||||
]
|
||||
train_cfg = dict(max_iters=10000, type='IterBasedTrainLoop', val_interval=500)
|
||||
train_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'deforestation_atlantic/deforestation_atlantic_train.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadSingleRSImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(cat_max_ratio=0.75, crop_size=crop_size, type='RandomCrop'),
|
||||
dict(prob=0.5, type='RandomFlip'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='AtlanticDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=True, type='InfiniteSampler'))
|
||||
train_pipeline = [
|
||||
dict(type='LoadSingleRSImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(cat_max_ratio=0.75, crop_size=crop_size, type='RandomCrop'),
|
||||
dict(prob=0.5, type='RandomFlip'),
|
||||
dict(type='PackSegInputs'),
|
||||
]
|
||||
tta_model = dict(type='SegTTAModel')
|
||||
tta_pipeline = [
|
||||
dict(backend_args=None, type='LoadImageFromFile'),
|
||||
dict(
|
||||
transforms=[
|
||||
[
|
||||
dict(keep_ratio=True, scale_factor=0.5, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=0.75, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.0, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.25, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.5, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.75, type='Resize'),
|
||||
],
|
||||
[
|
||||
dict(direction='horizontal', prob=0.0, type='RandomFlip'),
|
||||
dict(direction='horizontal', prob=1.0, type='RandomFlip'),
|
||||
],
|
||||
[
|
||||
dict(type='LoadAnnotations'),
|
||||
],
|
||||
[
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
],
|
||||
type='TestTimeAug'),
|
||||
]
|
||||
val_cfg = dict(type='ValLoop')
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'deforestation_atlantic/deforestation_atlantic_val.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadSingleRSImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='AtlanticDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
val_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore',
|
||||
], type='IoUMetric')
|
||||
vis_backends = [
|
||||
dict(type='LocalVisBackend'),
|
||||
]
|
||||
visualizer = dict(
|
||||
name='visualizer',
|
||||
type='SegLocalVisualizer',
|
||||
vis_backends=[
|
||||
dict(type='LocalVisBackend'),
|
||||
])
|
||||
work_dir = 'save/atlantic_skysensepp'
|
||||
284
finetune/configs/c2smsflood.py
Normal file
284
finetune/configs/c2smsflood.py
Normal file
@@ -0,0 +1,284 @@
|
||||
dataset_type = 'C2SFloodDataset'
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
by_epoch=False,
|
||||
interval=2000,
|
||||
max_keep_ckpts=2,
|
||||
save_best='mIoU',
|
||||
type='CheckpointHook'),
|
||||
logger=dict(interval=50, log_metric_by_epoch=False, type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
visualization=dict(type='SegVisualizationHook'))
|
||||
default_scope = 'mmseg'
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=True,
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
|
||||
launcher = 'pytorch'
|
||||
load_from = None
|
||||
log_level = 'INFO'
|
||||
log_processor = dict(by_epoch=False)
|
||||
model = dict(
|
||||
auxiliary_head=dict(
|
||||
align_corners=False,
|
||||
channels=256,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=512,
|
||||
in_index=3,
|
||||
loss_decode=dict(
|
||||
loss_weight=0.4, type='CrossEntropyLoss', use_sigmoid=False),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=4,
|
||||
num_convs=1,
|
||||
type='FCNHead'),
|
||||
backbone=dict(
|
||||
act_cfg=dict(type='GELU'),
|
||||
attn_drop_rate=0.0,
|
||||
downscale_indices=(
|
||||
5,
|
||||
11,
|
||||
),
|
||||
drop_path_rate=0.0,
|
||||
drop_rate=0.1,
|
||||
embed_dims=1024,
|
||||
img_size=(
|
||||
256,
|
||||
256,
|
||||
),
|
||||
in_channels=10,
|
||||
interpolate_mode='bilinear',
|
||||
mlp_ratio=4,
|
||||
norm_cfg=dict(eps=1e-06, requires_grad=True, type='LN'),
|
||||
norm_eval=False,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
out_indices=(
|
||||
5,
|
||||
11,
|
||||
17,
|
||||
23,
|
||||
),
|
||||
patch_size=4,
|
||||
qkv_bias=True,
|
||||
type='VisionTransformer',
|
||||
with_cls_token=False),
|
||||
data_preprocessor=dict(
|
||||
bgr_to_rgb=True,
|
||||
mean=[
|
||||
1381.26,
|
||||
1302.05,
|
||||
1179.27,
|
||||
1393.56,
|
||||
2164.76,
|
||||
2561.75,
|
||||
2377.94,
|
||||
2791.13,
|
||||
1998.09,
|
||||
1196.08,
|
||||
],
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=(
|
||||
256,
|
||||
256,
|
||||
),
|
||||
std=[
|
||||
653.25,
|
||||
659.61,
|
||||
779.58,
|
||||
720.45,
|
||||
871.09,
|
||||
1035.57,
|
||||
965.36,
|
||||
1141.71,
|
||||
1019.73,
|
||||
825.01,
|
||||
],
|
||||
type='SegDataPreProcessor'),
|
||||
decode_head=dict(
|
||||
align_corners=False,
|
||||
channels=512,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=[
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
],
|
||||
in_index=[
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
],
|
||||
loss_decode=dict(
|
||||
loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=False),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=4,
|
||||
pool_scales=(
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
6,
|
||||
),
|
||||
type='UPerHead'),
|
||||
neck=dict(
|
||||
in_channels=[
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
],
|
||||
out_channels=512,
|
||||
scales=[
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
],
|
||||
type='MultiLevelNeck'),
|
||||
pretrained=
|
||||
'pretrain/skysensepp_mmcvt_s2.pth',
|
||||
test_cfg=dict(mode='whole'),
|
||||
train_cfg=dict(),
|
||||
type='EncoderDecoder')
|
||||
norm_cfg = None
|
||||
optim_wrapper = dict(
|
||||
constructor='LearningRateDecayOptimizerConstructor',
|
||||
optimizer=dict(
|
||||
betas=(
|
||||
0.9,
|
||||
0.999,
|
||||
), lr=0.0005, type='AdamW', weight_decay=0.015),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys=dict(
|
||||
cls_token=dict(decay_mult=0.0),
|
||||
norm=dict(decay_mult=0.0),
|
||||
pos_embed=dict(decay_mult=0.0)),
|
||||
decay_rate=0.84,
|
||||
decay_type='layer_wise',
|
||||
num_layers=24),
|
||||
type='OptimWrapper')
|
||||
param_scheduler = [
|
||||
dict(
|
||||
begin=0, by_epoch=False, end=1000, start_factor=1e-06,
|
||||
type='LinearLR'),
|
||||
dict(
|
||||
begin=1000,
|
||||
by_epoch=False,
|
||||
end=10000,
|
||||
eta_min=0.0,
|
||||
power=1.0,
|
||||
type='PolyLR'),
|
||||
]
|
||||
randomness = dict(seed=1)
|
||||
resume = False
|
||||
test_cfg = dict(type='TestLoop')
|
||||
test_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/c2smsfloods/c2s_s2_val_mm.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromNpz'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='C2SFloodDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
test_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore'
|
||||
], type='IoUMetric')
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromNpz'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
]
|
||||
train_cfg = dict(max_iters=10000, type='IterBasedTrainLoop', val_interval=500)
|
||||
train_dataloader = dict(
|
||||
batch_size=2,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/c2smsfloods/c2s_s2_train_mm.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromNpz'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotationsNpz'),
|
||||
dict(
|
||||
cat_max_ratio=0.75, crop_size=(
|
||||
256,
|
||||
256,
|
||||
), type='RandomCrop'),
|
||||
dict(prob=0.5, type='RandomFlip'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='C2SFloodDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=True, type='InfiniteSampler'))
|
||||
tta_model = dict(type='SegTTAModel')
|
||||
tta_pipeline = [
|
||||
dict(backend_args=None, type='LoadImageFromFile'),
|
||||
dict(
|
||||
transforms=[
|
||||
[
|
||||
dict(keep_ratio=True, scale_factor=1.0, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.2, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.4, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.6, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.8, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=2.0, type='Resize'),
|
||||
],
|
||||
[
|
||||
dict(direction='horizontal', prob=0.0, type='RandomFlip'),
|
||||
dict(direction='horizontal', prob=1.0, type='RandomFlip'),
|
||||
],
|
||||
[
|
||||
dict(type='LoadAnnotations'),
|
||||
],
|
||||
[
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
],
|
||||
type='TestTimeAug'),
|
||||
]
|
||||
val_cfg = dict(type='ValLoop')
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/c2smsfloods/c2s_s2_val_mm.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromNpz'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='C2SFloodDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
val_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
], type='IoUMetric')
|
||||
vis_backends = [
|
||||
dict(type='LocalVisBackend'),
|
||||
]
|
||||
visualizer = dict(
|
||||
name='visualizer',
|
||||
type='SegLocalVisualizer',
|
||||
vis_backends=[
|
||||
dict(type='LocalVisBackend'),
|
||||
])
|
||||
227
finetune/configs/cabuar.py
Normal file
227
finetune/configs/cabuar.py
Normal file
@@ -0,0 +1,227 @@
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
by_epoch=False,
|
||||
interval=1000,
|
||||
max_keep_ckpts=1,
|
||||
save_best='mIoU',
|
||||
type='CheckpointHook'),
|
||||
logger=dict(interval=50, log_metric_by_epoch=False, type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
visualization=dict(type='SegVisualizationHook'))
|
||||
default_scope = 'mmseg'
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=True,
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
|
||||
launcher = 'pytorch'
|
||||
load_from = None
|
||||
log_level = 'INFO'
|
||||
log_processor = dict(by_epoch=False)
|
||||
model = dict(
|
||||
auxiliary_head=dict(
|
||||
align_corners=False,
|
||||
channels=256,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=512,
|
||||
in_index=3,
|
||||
loss_decode=dict(
|
||||
loss_weight=0.4, type='CrossEntropyLoss', use_sigmoid=False),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=2,
|
||||
num_convs=1,
|
||||
type='FCNHead'),
|
||||
backbone=dict(
|
||||
act_cfg=dict(type='GELU'),
|
||||
attn_drop_rate=0.0,
|
||||
downscale_indices=(5, 11,),
|
||||
drop_path_rate=0.0,
|
||||
drop_rate=0.1,
|
||||
embed_dims=1024,
|
||||
img_size=(256, 256,),
|
||||
in_channels=10,
|
||||
interpolate_mode='bilinear',
|
||||
mlp_ratio=4,
|
||||
norm_cfg=dict(eps=1e-06, requires_grad=True, type='LN'),
|
||||
norm_eval=False,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
out_indices=(5, 11, 17, 23,),
|
||||
patch_size=4,
|
||||
qkv_bias=True,
|
||||
type='VisionTransformer',
|
||||
with_cls_token=False),
|
||||
data_preprocessor=dict(
|
||||
bgr_to_rgb=True,
|
||||
mean=[4.50021132, 6.09891466, 7.50766315, 9.54643074, 12.82568112, 14.29062133, 15.24644993, 15.73945708, 16.60374872, 12.31011599,],
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=(256, 256,),
|
||||
std=[2.6094148, 2.49566825, 1.37103968, 2.6094148, 2.49566825, 1.37103968, 2.6094148, 2.49566825, 1.37103968, 1.37103968, ],
|
||||
type='SegDataPreProcessor'),
|
||||
decode_head=dict(
|
||||
align_corners=False,
|
||||
channels=512,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=[512, 512, 512, 512,],
|
||||
in_index=[0, 1, 2, 3,],
|
||||
loss_decode=dict(
|
||||
loss_weight=1.0, type='CrossEntropyLoss', balance=True, max_scale=4.0, use_sigmoid=False),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=2,
|
||||
pool_scales=(1, 2, 3, 6,),
|
||||
type='UPerHead'),
|
||||
neck=dict(
|
||||
in_channels=[1024, 1024, 1024, 1024,],
|
||||
out_channels=512,
|
||||
scales=[1, 1, 1, 1,],
|
||||
type='MultiLevelNeck'),
|
||||
pretrained='pretrain/skysensepp_mmcvt_s2.pth',
|
||||
test_cfg=dict(mode='whole'),
|
||||
train_cfg=dict(),
|
||||
type='EncoderDecoder')
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(betas=(0.9, 0.999,), lr=None, type='AdamW', weight_decay=None,),
|
||||
constructor='LearningRateDecayOptimizerConstructor',
|
||||
paramwise_cfg=dict(
|
||||
num_layers=24,
|
||||
decay_rate=None,
|
||||
decay_type='layer_wise',
|
||||
custom_keys=dict(
|
||||
cls_token=dict(decay_mult=0.0),
|
||||
norm=dict(decay_mult=0.0),
|
||||
pos_embed=dict(decay_mult=0.0))),
|
||||
type='OptimWrapper')
|
||||
param_scheduler = [
|
||||
dict(
|
||||
begin=0, by_epoch=False, end=1000, start_factor=1e-06,
|
||||
type='LinearLR'),
|
||||
dict(
|
||||
begin=1000,
|
||||
by_epoch=False,
|
||||
end=10000,
|
||||
eta_min=0.0,
|
||||
power=1.0,
|
||||
type='PolyLR'),
|
||||
]
|
||||
randomness = dict(seed=20240315)
|
||||
resume = False
|
||||
test_cfg = dict(type='TestLoop')
|
||||
test_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/cabuar/cabura_val_fold_4.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(data_key='image', type='LoadImageFromNpz'),
|
||||
dict(
|
||||
data_key='image',
|
||||
reduce_zero_label=False,
|
||||
type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='CABURADataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
test_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore',
|
||||
], type='IoUMetric')
|
||||
test_pipeline = [
|
||||
dict(data_key='image', type='LoadImageFromNpz'),
|
||||
dict(data_key='image', reduce_zero_label=False, type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
]
|
||||
train_cfg = dict(max_iters=10000, type='IterBasedTrainLoop', val_interval=500)
|
||||
train_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/cabuar/cabura_train_fold0_3.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(data_key='image', type='LoadImageFromNpz'),
|
||||
dict(
|
||||
data_key='image',
|
||||
reduce_zero_label=False,
|
||||
type='LoadAnnotationsNpz'),
|
||||
dict(
|
||||
cat_max_ratio=0.75, crop_size=(
|
||||
256,
|
||||
256,
|
||||
), type='RandomCrop'),
|
||||
dict(prob=0.5, type='RandomFlip'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='CABURADataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=True, type='InfiniteSampler'))
|
||||
tta_model = dict(type='SegTTAModel')
|
||||
tta_pipeline = [
|
||||
dict(backend_args=None, type='LoadImageFromFile'),
|
||||
dict(
|
||||
transforms=[
|
||||
[
|
||||
dict(keep_ratio=True, scale_factor=1.0, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.2, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.4, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.6, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.8, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=2.0, type='Resize'),
|
||||
],
|
||||
[
|
||||
dict(direction='horizontal', prob=0.0, type='RandomFlip'),
|
||||
dict(direction='horizontal', prob=1.0, type='RandomFlip'),
|
||||
],
|
||||
[
|
||||
dict(type='LoadAnnotations'),
|
||||
],
|
||||
[
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
],
|
||||
type='TestTimeAug'),
|
||||
]
|
||||
val_cfg = dict(type='ValLoop')
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/cabuar/cabura_val_fold_4.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(data_key='image', type='LoadImageFromNpz'),
|
||||
dict(
|
||||
data_key='image',
|
||||
reduce_zero_label=False,
|
||||
type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='CABURADataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
val_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore',
|
||||
], type='IoUMetric')
|
||||
vis_backends = [
|
||||
dict(type='LocalVisBackend'),
|
||||
]
|
||||
visualizer = dict(
|
||||
name='visualizer',
|
||||
type='SegLocalVisualizer',
|
||||
vis_backends=[
|
||||
dict(type='LocalVisBackend'),
|
||||
])
|
||||
work_dir = 'work_dirs/ft_cabura_test'
|
||||
414
finetune/configs/germany.py
Normal file
414
finetune/configs/germany.py
Normal file
@@ -0,0 +1,414 @@
|
||||
crop_size = (
|
||||
24,
|
||||
24,
|
||||
)
|
||||
data_preprocessor = dict(
|
||||
bgr_to_rgb=False,
|
||||
mean=[
|
||||
2482.0061841829206,
|
||||
2456.642580060208,
|
||||
2667.8229979675334,
|
||||
2744.9377076257624,
|
||||
3620.1499158373827,
|
||||
4063.9126981046647,
|
||||
3922.2406108776354,
|
||||
4264.908986788407,
|
||||
2453.0070206816135,
|
||||
1774.0019119673998,
|
||||
],
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=(
|
||||
24,
|
||||
24,
|
||||
),
|
||||
std=[
|
||||
2392.1256366526068,
|
||||
2100.1364646122875,
|
||||
2262.6154840764625,
|
||||
2353.899770400333,
|
||||
2089.598452203458,
|
||||
2057.1247114077073,
|
||||
2013.2108514271458,
|
||||
2041.0248949410561,
|
||||
1380.4643757742374,
|
||||
1243.547946113518,
|
||||
],
|
||||
ts_size=30,
|
||||
type='RSTsSegDataPreProcessor')
|
||||
data_root = 'rs_datasets/'
|
||||
dataset_type = 'GermanyCropDataset'
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
by_epoch=False,
|
||||
interval=2000,
|
||||
max_keep_ckpts=1,
|
||||
save_best='mIoU',
|
||||
type='CheckpointHook'),
|
||||
logger=dict(interval=20, log_metric_by_epoch=False, type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
visualization=dict(type='SegVisualizationHook'))
|
||||
default_scope = 'mmseg'
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=True,
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
|
||||
find_unused_parameters = True
|
||||
img_ratios = [
|
||||
0.5,
|
||||
0.75,
|
||||
1.0,
|
||||
1.25,
|
||||
1.5,
|
||||
1.75,
|
||||
]
|
||||
launcher = 'pytorch'
|
||||
load_from = None
|
||||
log_level = 'INFO'
|
||||
log_processor = dict(by_epoch=False)
|
||||
mean = [
|
||||
2482.0061841829206,
|
||||
2456.642580060208,
|
||||
2667.8229979675334,
|
||||
2744.9377076257624,
|
||||
3620.1499158373827,
|
||||
4063.9126981046647,
|
||||
3922.2406108776354,
|
||||
4264.908986788407,
|
||||
2453.0070206816135,
|
||||
1774.0019119673998,
|
||||
]
|
||||
model = dict(
|
||||
auxiliary_head=dict(
|
||||
align_corners=False,
|
||||
channels=256,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=1024,
|
||||
in_index=3,
|
||||
loss_decode=dict(
|
||||
loss_weight=0.4, type='CrossEntropyLoss', use_sigmoid=False),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=18,
|
||||
num_convs=1,
|
||||
type='FCNHead'),
|
||||
backbone=dict(
|
||||
act_cfg=dict(type='GELU'),
|
||||
attn_drop_rate=0.0,
|
||||
downscale_indices=[
|
||||
-1,
|
||||
],
|
||||
drop_path_rate=0.0,
|
||||
drop_rate=0.1,
|
||||
embed_dims=1024,
|
||||
img_size=(
|
||||
24,
|
||||
24,
|
||||
),
|
||||
in_channels=10,
|
||||
init_cfg=dict(
|
||||
checkpoint=
|
||||
'pretrain/skysensepp_mmcvt_s2.pth',
|
||||
type='Pretrained'),
|
||||
interpolate_mode='bilinear',
|
||||
mlp_ratio=4,
|
||||
norm_cfg=dict(eps=1e-06, requires_grad=True, type='LN'),
|
||||
norm_eval=False,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
out_indices=(
|
||||
5,
|
||||
11,
|
||||
17,
|
||||
23,
|
||||
),
|
||||
patch_size=4,
|
||||
qkv_bias=True,
|
||||
type='VisionTransformer',
|
||||
with_cls_token=False),
|
||||
data_preprocessor=dict(
|
||||
bgr_to_rgb=False,
|
||||
mean=[
|
||||
2482.0061841829206,
|
||||
2456.642580060208,
|
||||
2667.8229979675334,
|
||||
2744.9377076257624,
|
||||
3620.1499158373827,
|
||||
4063.9126981046647,
|
||||
3922.2406108776354,
|
||||
4264.908986788407,
|
||||
2453.0070206816135,
|
||||
1774.0019119673998,
|
||||
],
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=(
|
||||
24,
|
||||
24,
|
||||
),
|
||||
std=[
|
||||
2392.1256366526068,
|
||||
2100.1364646122875,
|
||||
2262.6154840764625,
|
||||
2353.899770400333,
|
||||
2089.598452203458,
|
||||
2057.1247114077073,
|
||||
2013.2108514271458,
|
||||
2041.0248949410561,
|
||||
1380.4643757742374,
|
||||
1243.547946113518,
|
||||
],
|
||||
ts_size=30,
|
||||
type='RSTsSegDataPreProcessor'),
|
||||
decode_head=dict(
|
||||
align_corners=False,
|
||||
channels=512,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=[
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
],
|
||||
in_index=[
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
],
|
||||
loss_decode=dict(
|
||||
loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=False),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=18,
|
||||
pool_scales=(
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
6,
|
||||
),
|
||||
type='UPerHead'),
|
||||
neck=dict(
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.3,
|
||||
drop_rate=0.0,
|
||||
embed_dims=1024,
|
||||
in_channels=[
|
||||
768,
|
||||
768,
|
||||
768,
|
||||
768,
|
||||
],
|
||||
in_channels_ml=[
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
],
|
||||
init_cfg=dict(
|
||||
checkpoint=
|
||||
'pretrain/skysensepp_mmcvt_fusion.pth',
|
||||
type='Pretrained'),
|
||||
input_dims=1024,
|
||||
mlp_ratio=4,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
out_channels=768,
|
||||
out_channels_ml=1024,
|
||||
output_cls_token=True,
|
||||
qkv_bias=True,
|
||||
scales=[
|
||||
4,
|
||||
2,
|
||||
1,
|
||||
0.5,
|
||||
],
|
||||
scales_ml=[
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
],
|
||||
ts_size=30,
|
||||
type='FusionMultiLevelNeck',
|
||||
with_cls_token=True),
|
||||
test_cfg=dict(mode='whole'),
|
||||
train_cfg=dict(),
|
||||
type='EncoderDecoder')
|
||||
norm_cfg = dict(requires_grad=True, type='SyncBN')
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
betas=(
|
||||
0.9,
|
||||
0.999,
|
||||
), lr=0.0001, type='AdamW', weight_decay=0.01),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys=dict(
|
||||
cls_token=dict(decay_mult=0.0),
|
||||
norm=dict(decay_mult=0.0),
|
||||
pos_embed=dict(decay_mult=0.0))),
|
||||
type='OptimWrapper')
|
||||
optimizer = dict(lr=0.01, momentum=0.9, type='SGD', weight_decay=0.0005)
|
||||
param_scheduler = [
|
||||
dict(
|
||||
begin=0, by_epoch=False, end=2000, start_factor=1e-06,
|
||||
type='LinearLR'),
|
||||
dict(
|
||||
begin=2000,
|
||||
by_epoch=False,
|
||||
end=20000,
|
||||
eta_min=0.0,
|
||||
power=1.0,
|
||||
type='PolyLR'),
|
||||
]
|
||||
randomness = dict(seed=20240311)
|
||||
resume = False
|
||||
static_graph = True
|
||||
std = [
|
||||
2392.1256366526068,
|
||||
2100.1364646122875,
|
||||
2262.6154840764625,
|
||||
2353.899770400333,
|
||||
2089.598452203458,
|
||||
2057.1247114077073,
|
||||
2013.2108514271458,
|
||||
2041.0248949410561,
|
||||
1380.4643757742374,
|
||||
1243.547946113518,
|
||||
]
|
||||
test_cfg = dict(type='TestLoop')
|
||||
test_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/germany_crop/germany_crop_val.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(data_key='image', ts_size=30, type='LoadTsImageFromNpz'),
|
||||
dict(
|
||||
data_key='image',
|
||||
reduce_zero_label=True,
|
||||
type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='GermanyCropDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
test_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore',
|
||||
], type='IoUMetric')
|
||||
test_pipeline = [
|
||||
dict(data_key='image', ts_size=30, type='LoadTsImageFromNpz'),
|
||||
dict(data_key='image', reduce_zero_label=True, type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
]
|
||||
train_cfg = dict(
|
||||
dynamic_intervals=[
|
||||
(
|
||||
0,
|
||||
1000,
|
||||
),
|
||||
(
|
||||
4000,
|
||||
2000,
|
||||
),
|
||||
(
|
||||
8000,
|
||||
4000,
|
||||
),
|
||||
],
|
||||
max_iters=20000,
|
||||
type='IterBasedTrainLoop',
|
||||
val_interval=2000)
|
||||
train_dataloader = dict(
|
||||
batch_size=2,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/germany_crop/germany_crop_train.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(data_key='image', ts_size=30, type='LoadTsImageFromNpz'),
|
||||
dict(
|
||||
data_key='image',
|
||||
reduce_zero_label=True,
|
||||
type='LoadAnnotationsNpz'),
|
||||
dict(prob=0.5, type='RandomFlip'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='GermanyCropDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=True, type='InfiniteSampler'))
|
||||
train_pipeline = [
|
||||
dict(data_key='image', ts_size=30, type='LoadTsImageFromNpz'),
|
||||
dict(data_key='image', reduce_zero_label=True, type='LoadAnnotationsNpz'),
|
||||
dict(prob=0.5, type='RandomFlip'),
|
||||
dict(type='PackSegInputs'),
|
||||
]
|
||||
tta_model = dict(type='SegTTAModel')
|
||||
tta_pipeline = [
|
||||
dict(backend_args=None, type='LoadImageFromFile'),
|
||||
dict(
|
||||
transforms=[
|
||||
[
|
||||
dict(keep_ratio=True, scale_factor=0.5, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=0.75, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.0, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.25, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.5, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.75, type='Resize'),
|
||||
],
|
||||
[
|
||||
dict(direction='horizontal', prob=0.0, type='RandomFlip'),
|
||||
dict(direction='horizontal', prob=1.0, type='RandomFlip'),
|
||||
],
|
||||
[
|
||||
dict(type='LoadAnnotations'),
|
||||
],
|
||||
[
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
],
|
||||
type='TestTimeAug'),
|
||||
]
|
||||
val_cfg = dict(type='ValLoop')
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/germany_crop/germany_crop_val.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(data_key='image', ts_size=30, type='LoadTsImageFromNpz'),
|
||||
dict(
|
||||
data_key='image',
|
||||
reduce_zero_label=True,
|
||||
type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='GermanyCropDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
val_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore',
|
||||
], type='IoUMetric')
|
||||
vis_backends = [
|
||||
dict(type='LocalVisBackend'),
|
||||
]
|
||||
visualizer = dict(
|
||||
name='visualizer',
|
||||
type='SegLocalVisualizer',
|
||||
vis_backends=[
|
||||
dict(type='LocalVisBackend'),
|
||||
])
|
||||
work_dir = 'save_germany'
|
||||
297
finetune/configs/sos.py
Normal file
297
finetune/configs/sos.py
Normal file
@@ -0,0 +1,297 @@
|
||||
crop_size = (
|
||||
256,
|
||||
256,
|
||||
)
|
||||
data_preprocessor = dict(
|
||||
bgr_to_rgb=True,
|
||||
mean=[
|
||||
83.37,
|
||||
83.37,
|
||||
83.37,
|
||||
],
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=(
|
||||
256,
|
||||
256,
|
||||
),
|
||||
std=[
|
||||
40.45,
|
||||
40.45,
|
||||
40.45,
|
||||
],
|
||||
type='SegDataPreProcessor')
|
||||
data_root = 'rs_datasets/'
|
||||
dataset_type = 'SOSDataset'
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
by_epoch=False,
|
||||
interval=1000,
|
||||
max_keep_ckpts=2,
|
||||
save_best='mIoU',
|
||||
type='CheckpointHook'),
|
||||
logger=dict(interval=50, log_metric_by_epoch=False, type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
visualization=dict(type='SegVisualizationHook'))
|
||||
default_scope = 'mmseg'
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=True,
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
|
||||
img_ratios = [
|
||||
0.5,
|
||||
0.75,
|
||||
1.0,
|
||||
1.25,
|
||||
1.5,
|
||||
1.75,
|
||||
]
|
||||
launcher = 'pytorch'
|
||||
load_from = None
|
||||
log_level = 'INFO'
|
||||
log_processor = dict(by_epoch=False)
|
||||
mean = [
|
||||
83.37,
|
||||
83.37,
|
||||
83.37,
|
||||
]
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
act_cfg=dict(type='GELU'),
|
||||
attn_drop_rate=0.0,
|
||||
downscale_indices=[
|
||||
-1,
|
||||
],
|
||||
drop_path_rate=0.0,
|
||||
drop_rate=0.1,
|
||||
embed_dims=1024,
|
||||
img_size=(
|
||||
256,
|
||||
256,
|
||||
),
|
||||
in_channels=2,
|
||||
interpolate_mode='bilinear',
|
||||
mlp_ratio=4,
|
||||
norm_cfg=dict(eps=1e-06, requires_grad=True, type='LN'),
|
||||
norm_eval=False,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
out_indices=(
|
||||
5,
|
||||
11,
|
||||
17,
|
||||
23,
|
||||
),
|
||||
patch_size=4,
|
||||
qkv_bias=True,
|
||||
type='VisionTransformer',
|
||||
with_cls_token=False),
|
||||
data_preprocessor=dict(
|
||||
bgr_to_rgb=True,
|
||||
mean=[
|
||||
83.37,
|
||||
83.37,
|
||||
83.37,
|
||||
],
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=(
|
||||
256,
|
||||
256,
|
||||
),
|
||||
std=[
|
||||
40.45,
|
||||
40.45,
|
||||
40.45,
|
||||
],
|
||||
type='SegDataPreProcessor'),
|
||||
decode_head=dict(
|
||||
align_corners=False,
|
||||
channels=512,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=[
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
],
|
||||
in_index=[
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
],
|
||||
loss_decode=dict(
|
||||
loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=True),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=2,
|
||||
pool_scales=(
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
6,
|
||||
),
|
||||
type='UPerHead'),
|
||||
neck=dict(
|
||||
in_channels=[
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
],
|
||||
out_channels=1024,
|
||||
scales=[
|
||||
4,
|
||||
2,
|
||||
1,
|
||||
0.5,
|
||||
],
|
||||
type='MultiLevelNeck'),
|
||||
pretrained=
|
||||
'pretrain/skysensepp_mmcvt_s1.pth',
|
||||
test_cfg=dict(mode='whole'),
|
||||
train_cfg=dict(),
|
||||
type='EncoderDecoder')
|
||||
norm_cfg = dict(requires_grad=True, type='SyncBN')
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
betas=(
|
||||
0.9,
|
||||
0.999,
|
||||
), lr=6e-05, type='AdamW', weight_decay=0.01),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys=dict(
|
||||
cls_token=dict(decay_mult=0.0),
|
||||
norm=dict(decay_mult=0.0),
|
||||
pos_embed=dict(decay_mult=0.0))),
|
||||
type='OptimWrapper')
|
||||
param_scheduler = [
|
||||
dict(
|
||||
begin=0, by_epoch=False, end=1000, start_factor=1e-06,
|
||||
type='LinearLR'),
|
||||
dict(
|
||||
begin=1000,
|
||||
by_epoch=False,
|
||||
end=10000,
|
||||
eta_min=0.0,
|
||||
power=1.0,
|
||||
type='PolyLR'),
|
||||
]
|
||||
randomness = dict(seed=0)
|
||||
resume = False
|
||||
std = [
|
||||
40.45,
|
||||
40.45,
|
||||
40.45,
|
||||
]
|
||||
test_cfg = dict(type='TestLoop')
|
||||
test_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/sos/sos_test_sentinel.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='SOSDataset'),
|
||||
num_workers=2,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
test_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore',
|
||||
], type='IoUMetric')
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs'),
|
||||
]
|
||||
train_cfg = dict(max_iters=0, type='IterBasedTrainLoop', val_interval=500)
|
||||
train_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/sos/sos_train_sentinel.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(prob=0.5, type='RandomFlip'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='SOSDataset'),
|
||||
num_workers=2,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=True, type='InfiniteSampler'))
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(prob=0.5, type='RandomFlip'),
|
||||
dict(type='PackSegInputs'),
|
||||
]
|
||||
tta_model = dict(type='SegTTAModel')
|
||||
tta_pipeline = [
|
||||
dict(backend_args=None, type='LoadImageFromFile'),
|
||||
dict(
|
||||
transforms=[
|
||||
[
|
||||
dict(keep_ratio=True, scale_factor=0.5, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=0.75, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.0, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.25, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.5, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.75, type='Resize'),
|
||||
],
|
||||
[
|
||||
dict(direction='horizontal', prob=0.0, type='RandomFlip'),
|
||||
dict(direction='horizontal', prob=1.0, type='RandomFlip'),
|
||||
],
|
||||
[
|
||||
dict(type='LoadAnnotations'),
|
||||
],
|
||||
[
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
],
|
||||
type='TestTimeAug'),
|
||||
]
|
||||
val_cfg = dict(type='ValLoop')
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/sos/sos_test_sentinel.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='SOSDataset'),
|
||||
num_workers=2,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
val_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore',
|
||||
], type='IoUMetric')
|
||||
vis_backends = [
|
||||
dict(type='LocalVisBackend'),
|
||||
]
|
||||
visualizer = dict(
|
||||
name='visualizer',
|
||||
type='SegLocalVisualizer',
|
||||
vis_backends=[
|
||||
dict(type='LocalVisBackend'),
|
||||
])
|
||||
work_dir = 'save_sos/'
|
||||
74
finetune/mmseg/__init__.py
Normal file
74
finetune/mmseg/__init__.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
from packaging.version import parse
|
||||
|
||||
from .version import __version__, version_info
|
||||
|
||||
MMCV_MIN = '2.0.0rc4'
|
||||
MMCV_MAX = '2.2.0'
|
||||
MMENGINE_MIN = '0.5.0'
|
||||
MMENGINE_MAX = '1.0.0'
|
||||
|
||||
|
||||
def digit_version(version_str: str, length: int = 4):
|
||||
"""Convert a version string into a tuple of integers.
|
||||
|
||||
This method is usually used for comparing two versions. For pre-release
|
||||
versions: alpha < beta < rc.
|
||||
|
||||
Args:
|
||||
version_str (str): The version string.
|
||||
length (int): The maximum number of version levels. Default: 4.
|
||||
|
||||
Returns:
|
||||
tuple[int]: The version info in digits (integers).
|
||||
"""
|
||||
version = parse(version_str)
|
||||
assert version.release, f'failed to parse version {version_str}'
|
||||
release = list(version.release)
|
||||
release = release[:length]
|
||||
if len(release) < length:
|
||||
release = release + [0] * (length - len(release))
|
||||
if version.is_prerelease:
|
||||
mapping = {'a': -3, 'b': -2, 'rc': -1}
|
||||
val = -4
|
||||
# version.pre can be None
|
||||
if version.pre:
|
||||
if version.pre[0] not in mapping:
|
||||
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
|
||||
'version checking may go wrong')
|
||||
else:
|
||||
val = mapping[version.pre[0]]
|
||||
release.extend([val, version.pre[-1]])
|
||||
else:
|
||||
release.extend([val, 0])
|
||||
|
||||
elif version.is_postrelease:
|
||||
release.extend([1, version.post])
|
||||
else:
|
||||
release.extend([0, 0])
|
||||
return tuple(release)
|
||||
|
||||
|
||||
mmcv_min_version = digit_version(MMCV_MIN)
|
||||
mmcv_max_version = digit_version(MMCV_MAX)
|
||||
mmcv_version = digit_version(mmcv.__version__)
|
||||
|
||||
|
||||
assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \
|
||||
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
||||
f'Please install mmcv>=2.0.0rc4.'
|
||||
|
||||
mmengine_min_version = digit_version(MMENGINE_MIN)
|
||||
mmengine_max_version = digit_version(MMENGINE_MAX)
|
||||
mmengine_version = digit_version(mmengine.__version__)
|
||||
|
||||
assert (mmengine_min_version <= mmengine_version < mmengine_max_version), \
|
||||
f'MMEngine=={mmengine.__version__} is used but incompatible. ' \
|
||||
f'Please install mmengine>={mmengine_min_version}, '\
|
||||
f'<{mmengine_max_version}.'
|
||||
|
||||
__all__ = ['__version__', 'version_info', 'digit_version']
|
||||
9
finetune/mmseg/apis/__init__.py
Normal file
9
finetune/mmseg/apis/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .inference import inference_model, init_model, show_result_pyplot
|
||||
from .mmseg_inferencer import MMSegInferencer
|
||||
from .remote_sense_inferencer import RSImage, RSInferencer
|
||||
|
||||
__all__ = [
|
||||
'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer',
|
||||
'RSInferencer', 'RSImage'
|
||||
]
|
||||
189
finetune/mmseg/apis/inference.py
Normal file
189
finetune/mmseg/apis/inference.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner import load_checkpoint
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
from mmseg.models import BaseSegmentor
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
from .utils import ImageType, _preprare_data
|
||||
|
||||
|
||||
def init_model(config: Union[str, Path, Config],
|
||||
checkpoint: Optional[str] = None,
|
||||
device: str = 'cuda:0',
|
||||
cfg_options: Optional[dict] = None):
|
||||
"""Initialize a segmentor from config file.
|
||||
|
||||
Args:
|
||||
config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
|
||||
:obj:`Path`, or the config object.
|
||||
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
||||
will not load any weights.
|
||||
device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
|
||||
Use 'cpu' for loading model on CPU.
|
||||
cfg_options (dict, optional): Options to override some settings in
|
||||
the used config.
|
||||
Returns:
|
||||
nn.Module: The constructed segmentor.
|
||||
"""
|
||||
if isinstance(config, (str, Path)):
|
||||
config = Config.fromfile(config)
|
||||
elif not isinstance(config, Config):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
'but got {}'.format(type(config)))
|
||||
if cfg_options is not None:
|
||||
config.merge_from_dict(cfg_options)
|
||||
if config.model.type == 'EncoderDecoder':
|
||||
if 'init_cfg' in config.model.backbone:
|
||||
config.model.backbone.init_cfg = None
|
||||
elif config.model.type == 'MultimodalEncoderDecoder':
|
||||
for k, v in config.model.items():
|
||||
if isinstance(v, dict) and 'init_cfg' in v:
|
||||
config.model[k].init_cfg = None
|
||||
config.model.pretrained = None
|
||||
config.model.train_cfg = None
|
||||
init_default_scope(config.get('default_scope', 'mmseg'))
|
||||
|
||||
model = MODELS.build(config.model)
|
||||
if checkpoint is not None:
|
||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
|
||||
# save the dataset_meta in the model for convenience
|
||||
if 'dataset_meta' in checkpoint.get('meta', {}):
|
||||
# mmseg 1.x
|
||||
model.dataset_meta = dataset_meta
|
||||
elif 'CLASSES' in checkpoint.get('meta', {}):
|
||||
# < mmseg 1.x
|
||||
classes = checkpoint['meta']['CLASSES']
|
||||
palette = checkpoint['meta']['PALETTE']
|
||||
model.dataset_meta = {'classes': classes, 'palette': palette}
|
||||
else:
|
||||
warnings.simplefilter('once')
|
||||
warnings.warn(
|
||||
'dataset_meta or class names are not saved in the '
|
||||
'checkpoint\'s meta data, classes and palette will be'
|
||||
'set according to num_classes ')
|
||||
num_classes = model.decode_head.num_classes
|
||||
dataset_name = None
|
||||
for name in dataset_aliases.keys():
|
||||
if len(get_classes(name)) == num_classes:
|
||||
dataset_name = name
|
||||
break
|
||||
if dataset_name is None:
|
||||
warnings.warn(
|
||||
'No suitable dataset found, use Cityscapes by default')
|
||||
dataset_name = 'cityscapes'
|
||||
model.dataset_meta = {
|
||||
'classes': get_classes(dataset_name),
|
||||
'palette': get_palette(dataset_name)
|
||||
}
|
||||
model.cfg = config # save the config in the model for convenience
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def inference_model(model: BaseSegmentor,
|
||||
img: ImageType) -> Union[SegDataSample, SampleList]:
|
||||
"""Inference image(s) with the segmentor.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The loaded segmentor.
|
||||
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
|
||||
images.
|
||||
|
||||
Returns:
|
||||
:obj:`SegDataSample` or list[:obj:`SegDataSample`]:
|
||||
If imgs is a list or tuple, the same length list type results
|
||||
will be returned, otherwise return the segmentation results directly.
|
||||
"""
|
||||
# prepare data
|
||||
data, is_batch = _preprare_data(img, model)
|
||||
|
||||
# forward the model
|
||||
with torch.no_grad():
|
||||
results = model.test_step(data)
|
||||
|
||||
return results if is_batch else results[0]
|
||||
|
||||
|
||||
def show_result_pyplot(model: BaseSegmentor,
|
||||
img: Union[str, np.ndarray],
|
||||
result: SegDataSample,
|
||||
opacity: float = 0.5,
|
||||
title: str = '',
|
||||
draw_gt: bool = True,
|
||||
draw_pred: bool = True,
|
||||
wait_time: float = 0,
|
||||
show: bool = True,
|
||||
with_labels: Optional[bool] = True,
|
||||
save_dir=None,
|
||||
out_file=None):
|
||||
"""Visualize the segmentation results on the image.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The loaded segmentor.
|
||||
img (str or np.ndarray): Image filename or loaded image.
|
||||
result (SegDataSample): The prediction SegDataSample result.
|
||||
opacity(float): Opacity of painted segmentation map.
|
||||
Default 0.5. Must be in (0, 1] range.
|
||||
title (str): The title of pyplot figure.
|
||||
Default is ''.
|
||||
draw_gt (bool): Whether to draw GT SegDataSample. Default to True.
|
||||
draw_pred (bool): Whether to draw Prediction SegDataSample.
|
||||
Defaults to True.
|
||||
wait_time (float): The interval of show (s). 0 is the special value
|
||||
that means "forever". Defaults to 0.
|
||||
show (bool): Whether to display the drawn image.
|
||||
Default to True.
|
||||
with_labels(bool, optional): Add semantic labels in visualization
|
||||
result, Default to True.
|
||||
save_dir (str, optional): Save file dir for all storage backends.
|
||||
If it is None, the backend storage will not save any data.
|
||||
out_file (str, optional): Path to output file. Default to None.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
np.ndarray: the drawn image which channel is RGB.
|
||||
"""
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
if isinstance(img, str):
|
||||
image = mmcv.imread(img, channel_order='rgb')
|
||||
else:
|
||||
image = img
|
||||
if save_dir is not None:
|
||||
mkdir_or_exist(save_dir)
|
||||
# init visualizer
|
||||
visualizer = SegLocalVisualizer(
|
||||
vis_backends=[dict(type='LocalVisBackend')],
|
||||
save_dir=save_dir,
|
||||
alpha=opacity)
|
||||
visualizer.dataset_meta = dict(
|
||||
classes=model.dataset_meta['classes'],
|
||||
palette=model.dataset_meta['palette'])
|
||||
visualizer.add_datasample(
|
||||
name=title,
|
||||
image=image,
|
||||
data_sample=result,
|
||||
draw_gt=draw_gt,
|
||||
draw_pred=draw_pred,
|
||||
wait_time=wait_time,
|
||||
out_file=out_file,
|
||||
show=show,
|
||||
with_labels=with_labels)
|
||||
vis_img = visualizer.get_image()
|
||||
|
||||
return vis_img
|
||||
382
finetune/mmseg/apis/mmseg_inferencer.py
Normal file
382
finetune/mmseg/apis/mmseg_inferencer.py
Normal file
@@ -0,0 +1,382 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.transforms import Compose
|
||||
from mmengine.infer.infer import BaseInferencer, ModelType
|
||||
from mmengine.model import revert_sync_batchnorm
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner.checkpoint import _load_checkpoint_to_model
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import ConfigType, SampleList, get_classes, get_palette
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
|
||||
InputType = Union[str, np.ndarray]
|
||||
InputsType = Union[InputType, Sequence[InputType]]
|
||||
PredType = Union[SegDataSample, SampleList]
|
||||
|
||||
|
||||
class MMSegInferencer(BaseInferencer):
|
||||
"""Semantic segmentation inferencer, provides inference and visualization
|
||||
interfaces. Note: MMEngine >= 0.5.0 is required.
|
||||
|
||||
Args:
|
||||
model (str, optional): Path to the config file or the model name
|
||||
defined in metafile. Take the `mmseg metafile <https://github.com/open-mmlab/mmsegmentation/blob/main/configs/fcn/metafile.yaml>`_
|
||||
as an example the `model` could be
|
||||
"fcn_r50-d8_4xb2-40k_cityscapes-512x1024", and the weights of model
|
||||
will be download automatically. If use config file, like
|
||||
"configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py", the
|
||||
`weights` should be defined.
|
||||
weights (str, optional): Path to the checkpoint. If it is not specified
|
||||
and model is a model name of metafile, the weights will be loaded
|
||||
from metafile. Defaults to None.
|
||||
classes (list, optional): Input classes for result rendering, as the
|
||||
prediction of segmentation model is a segment map with label
|
||||
indices, `classes` is a list which includes items responding to the
|
||||
label indices. If classes is not defined, visualizer will take
|
||||
`cityscapes` classes by default. Defaults to None.
|
||||
palette (list, optional): Input palette for result rendering, which is
|
||||
a list of color palette responding to the classes. If palette is
|
||||
not defined, visualizer will take `cityscapes` palette by default.
|
||||
Defaults to None.
|
||||
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/utils/class_names.py#L302-L317>`_
|
||||
visulizer will use the meta information of the dataset i.e. classes
|
||||
and palette, but the `classes` and `palette` have higher priority.
|
||||
Defaults to None.
|
||||
device (str, optional): Device to run inference. If None, the available
|
||||
device will be automatically used. Defaults to None.
|
||||
scope (str, optional): The scope of the model. Defaults to 'mmseg'.
|
||||
""" # noqa
|
||||
|
||||
preprocess_kwargs: set = set()
|
||||
forward_kwargs: set = {'mode', 'out_dir'}
|
||||
visualize_kwargs: set = {
|
||||
'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis',
|
||||
'with_labels'
|
||||
}
|
||||
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
|
||||
|
||||
def __init__(self,
|
||||
model: Union[ModelType, str],
|
||||
weights: Optional[str] = None,
|
||||
classes: Optional[Union[str, List]] = None,
|
||||
palette: Optional[Union[str, List]] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
scope: Optional[str] = 'mmseg') -> None:
|
||||
# A global counter tracking the number of images processes, for
|
||||
# naming of the output images
|
||||
self.num_visualized_imgs = 0
|
||||
self.num_pred_imgs = 0
|
||||
init_default_scope(scope if scope else 'mmseg')
|
||||
super().__init__(
|
||||
model=model, weights=weights, device=device, scope=scope)
|
||||
|
||||
if device == 'cpu' or not torch.cuda.is_available():
|
||||
self.model = revert_sync_batchnorm(self.model)
|
||||
|
||||
assert isinstance(self.visualizer, SegLocalVisualizer)
|
||||
self.visualizer.set_dataset_meta(classes, palette, dataset_name)
|
||||
|
||||
def _load_weights_to_model(self, model: nn.Module,
|
||||
checkpoint: Optional[dict],
|
||||
cfg: Optional[ConfigType]) -> None:
|
||||
"""Loading model weights and meta information from cfg and checkpoint.
|
||||
|
||||
Subclasses could override this method to load extra meta information
|
||||
from ``checkpoint`` and ``cfg`` to model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model to load weights and meta information.
|
||||
checkpoint (dict, optional): The loaded checkpoint.
|
||||
cfg (Config or ConfigDict, optional): The loaded config.
|
||||
"""
|
||||
|
||||
if checkpoint is not None:
|
||||
_load_checkpoint_to_model(model, checkpoint)
|
||||
checkpoint_meta = checkpoint.get('meta', {})
|
||||
# save the dataset_meta in the model for convenience
|
||||
if 'dataset_meta' in checkpoint_meta:
|
||||
# mmsegmentation 1.x
|
||||
model.dataset_meta = {
|
||||
'classes': checkpoint_meta['dataset_meta'].get('classes'),
|
||||
'palette': checkpoint_meta['dataset_meta'].get('palette')
|
||||
}
|
||||
elif 'CLASSES' in checkpoint_meta:
|
||||
# mmsegmentation 0.x
|
||||
classes = checkpoint_meta['CLASSES']
|
||||
palette = checkpoint_meta.get('PALETTE', None)
|
||||
model.dataset_meta = {'classes': classes, 'palette': palette}
|
||||
else:
|
||||
warnings.warn(
|
||||
'dataset_meta or class names are not saved in the '
|
||||
'checkpoint\'s meta data, use classes of Cityscapes by '
|
||||
'default.')
|
||||
model.dataset_meta = {
|
||||
'classes': get_classes('cityscapes'),
|
||||
'palette': get_palette('cityscapes')
|
||||
}
|
||||
else:
|
||||
warnings.warn('Checkpoint is not loaded, and the inference '
|
||||
'result is calculated by the randomly initialized '
|
||||
'model!')
|
||||
warnings.warn(
|
||||
'weights is None, use cityscapes classes by default.')
|
||||
model.dataset_meta = {
|
||||
'classes': get_classes('cityscapes'),
|
||||
'palette': get_palette('cityscapes')
|
||||
}
|
||||
|
||||
def __call__(self,
|
||||
inputs: InputsType,
|
||||
return_datasamples: bool = False,
|
||||
batch_size: int = 1,
|
||||
return_vis: bool = False,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
out_dir: str = '',
|
||||
img_out_dir: str = 'vis',
|
||||
pred_out_dir: str = 'pred',
|
||||
**kwargs) -> dict:
|
||||
"""Call the inferencer.
|
||||
|
||||
Args:
|
||||
inputs (Union[list, str, np.ndarray]): Inputs for the inferencer.
|
||||
return_datasamples (bool): Whether to return results as
|
||||
:obj:`SegDataSample`. Defaults to False.
|
||||
batch_size (int): Batch size. Defaults to 1.
|
||||
show (bool): Whether to display the rendering color segmentation
|
||||
mask in a popup window. Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
out_dir (str): Output directory of inference results. Defaults
|
||||
to ''.
|
||||
img_out_dir (str): Subdirectory of `out_dir`, used to save
|
||||
rendering color segmentation mask, so `out_dir` must be defined
|
||||
if you would like to save predicted mask. Defaults to 'vis'.
|
||||
pred_out_dir (str): Subdirectory of `out_dir`, used to save
|
||||
predicted mask file, so `out_dir` must be defined if you would
|
||||
like to save predicted mask. Defaults to 'pred'.
|
||||
|
||||
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
|
||||
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
|
||||
Each key in kwargs should be in the corresponding set of
|
||||
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
|
||||
and ``postprocess_kwargs``.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Inference and visualization results.
|
||||
"""
|
||||
|
||||
if out_dir != '':
|
||||
pred_out_dir = osp.join(out_dir, pred_out_dir)
|
||||
img_out_dir = osp.join(out_dir, img_out_dir)
|
||||
else:
|
||||
pred_out_dir = ''
|
||||
img_out_dir = ''
|
||||
|
||||
return super().__call__(
|
||||
inputs=inputs,
|
||||
return_datasamples=return_datasamples,
|
||||
batch_size=batch_size,
|
||||
show=show,
|
||||
wait_time=wait_time,
|
||||
img_out_dir=img_out_dir,
|
||||
pred_out_dir=pred_out_dir,
|
||||
return_vis=return_vis,
|
||||
**kwargs)
|
||||
|
||||
def visualize(self,
|
||||
inputs: list,
|
||||
preds: List[dict],
|
||||
return_vis: bool = False,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
img_out_dir: str = '',
|
||||
opacity: float = 0.8,
|
||||
with_labels: Optional[bool] = True) -> List[np.ndarray]:
|
||||
"""Visualize predictions.
|
||||
|
||||
Args:
|
||||
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
|
||||
preds (Any): Predictions of the model.
|
||||
show (bool): Whether to display the image in a popup window.
|
||||
Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
img_out_dir (str): Output directory of rendering prediction i.e.
|
||||
color segmentation mask. Defaults: ''
|
||||
opacity (int, float): The transparency of segmentation mask.
|
||||
Defaults to 0.8.
|
||||
|
||||
Returns:
|
||||
List[np.ndarray]: Visualization results.
|
||||
"""
|
||||
if not show and img_out_dir == '' and not return_vis:
|
||||
return None
|
||||
if self.visualizer is None:
|
||||
raise ValueError('Visualization needs the "visualizer" term'
|
||||
'defined in the config, but got None.')
|
||||
|
||||
self.visualizer.set_dataset_meta(**self.model.dataset_meta)
|
||||
self.visualizer.alpha = opacity
|
||||
|
||||
results = []
|
||||
|
||||
for single_input, pred in zip(inputs, preds):
|
||||
if isinstance(single_input, str):
|
||||
img_bytes = mmengine.fileio.get(single_input)
|
||||
img = mmcv.imfrombytes(img_bytes)
|
||||
img = img[:, :, ::-1]
|
||||
img_name = osp.basename(single_input)
|
||||
elif isinstance(single_input, np.ndarray):
|
||||
img = single_input.copy()
|
||||
img_num = str(self.num_visualized_imgs).zfill(8) + '_vis'
|
||||
img_name = f'{img_num}.jpg'
|
||||
else:
|
||||
raise ValueError('Unsupported input type:'
|
||||
f'{type(single_input)}')
|
||||
|
||||
out_file = osp.join(img_out_dir, img_name) if img_out_dir != ''\
|
||||
else None
|
||||
|
||||
self.visualizer.add_datasample(
|
||||
img_name,
|
||||
img,
|
||||
pred,
|
||||
show=show,
|
||||
wait_time=wait_time,
|
||||
draw_gt=False,
|
||||
draw_pred=True,
|
||||
out_file=out_file,
|
||||
with_labels=with_labels)
|
||||
if return_vis:
|
||||
results.append(self.visualizer.get_image())
|
||||
self.num_visualized_imgs += 1
|
||||
|
||||
return results if return_vis else None
|
||||
|
||||
def postprocess(self,
|
||||
preds: PredType,
|
||||
visualization: List[np.ndarray],
|
||||
return_datasample: bool = False,
|
||||
pred_out_dir: str = '') -> dict:
|
||||
"""Process the predictions and visualization results from ``forward``
|
||||
and ``visualize``.
|
||||
|
||||
This method should be responsible for the following tasks:
|
||||
|
||||
1. Pack the predictions and visualization results and return them.
|
||||
2. Save the predictions, if it needed.
|
||||
|
||||
Args:
|
||||
preds (List[Dict]): Predictions of the model.
|
||||
visualization (List[np.ndarray]): The list of rendering color
|
||||
segmentation mask.
|
||||
return_datasample (bool): Whether to return results as datasamples.
|
||||
Defaults to False.
|
||||
pred_out_dir: File to save the inference results w/o
|
||||
visualization. If left as empty, no file will be saved.
|
||||
Defaults to ''.
|
||||
|
||||
Returns:
|
||||
dict: Inference and visualization results with key ``predictions``
|
||||
and ``visualization``
|
||||
|
||||
- ``visualization (Any)``: Returned by :meth:`visualize`
|
||||
- ``predictions`` (List[np.ndarray], np.ndarray): Returned by
|
||||
:meth:`forward` and processed in :meth:`postprocess`.
|
||||
If ``return_datasample=False``, it will be the segmentation mask
|
||||
with label indice.
|
||||
"""
|
||||
if return_datasample:
|
||||
if len(preds) == 1:
|
||||
return preds[0]
|
||||
else:
|
||||
return preds
|
||||
|
||||
results_dict = {}
|
||||
|
||||
results_dict['predictions'] = []
|
||||
results_dict['visualization'] = []
|
||||
|
||||
for i, pred in enumerate(preds):
|
||||
pred_data = dict()
|
||||
if 'pred_sem_seg' in pred.keys():
|
||||
pred_data['sem_seg'] = pred.pred_sem_seg.numpy().data[0]
|
||||
elif 'pred_depth_map' in pred.keys():
|
||||
pred_data['depth_map'] = pred.pred_depth_map.numpy().data[0]
|
||||
|
||||
if visualization is not None:
|
||||
vis = visualization[i]
|
||||
results_dict['visualization'].append(vis)
|
||||
if pred_out_dir != '':
|
||||
mmengine.mkdir_or_exist(pred_out_dir)
|
||||
for key, data in pred_data.items():
|
||||
post_fix = '_pred.png' if key == 'sem_seg' else '_pred.npy'
|
||||
img_name = str(self.num_pred_imgs).zfill(8) + post_fix
|
||||
img_path = osp.join(pred_out_dir, img_name)
|
||||
if key == 'sem_seg':
|
||||
output = Image.fromarray(data.astype(np.uint8))
|
||||
output.save(img_path)
|
||||
else:
|
||||
np.save(img_path, data)
|
||||
pred_data = next(iter(pred_data.values()))
|
||||
results_dict['predictions'].append(pred_data)
|
||||
self.num_pred_imgs += 1
|
||||
|
||||
if len(results_dict['predictions']) == 1:
|
||||
results_dict['predictions'] = results_dict['predictions'][0]
|
||||
if visualization is not None:
|
||||
results_dict['visualization'] = \
|
||||
results_dict['visualization'][0]
|
||||
return results_dict
|
||||
|
||||
def _init_pipeline(self, cfg: ConfigType) -> Compose:
|
||||
"""Initialize the test pipeline.
|
||||
|
||||
Return a pipeline to handle various input data, such as ``str``,
|
||||
``np.ndarray``. It is an abstract method in BaseInferencer, and should
|
||||
be implemented in subclasses.
|
||||
|
||||
The returned pipeline will be used to process a single data.
|
||||
It will be used in :meth:`preprocess` like this:
|
||||
|
||||
.. code-block:: python
|
||||
def preprocess(self, inputs, batch_size, **kwargs):
|
||||
...
|
||||
dataset = map(self.pipeline, dataset)
|
||||
...
|
||||
"""
|
||||
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
# Loading annotations is also not applicable
|
||||
for transform in ('LoadAnnotations', 'LoadDepthAnnotation'):
|
||||
idx = self._get_transform_idx(pipeline_cfg, transform)
|
||||
if idx != -1:
|
||||
del pipeline_cfg[idx]
|
||||
|
||||
load_img_idx = self._get_transform_idx(pipeline_cfg,
|
||||
'LoadImageFromFile')
|
||||
if load_img_idx == -1:
|
||||
raise ValueError(
|
||||
'LoadImageFromFile is not found in the test pipeline')
|
||||
pipeline_cfg[load_img_idx]['type'] = 'InferencerLoader'
|
||||
return Compose(pipeline_cfg)
|
||||
|
||||
def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int:
|
||||
"""Returns the index of the transform in a pipeline.
|
||||
|
||||
If the transform is not found, returns -1.
|
||||
"""
|
||||
for i, transform in enumerate(pipeline_cfg):
|
||||
if transform['type'] == name:
|
||||
return i
|
||||
return -1
|
||||
279
finetune/mmseg/apis/remote_sense_inferencer.py
Normal file
279
finetune/mmseg/apis/remote_sense_inferencer.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import threading
|
||||
from queue import Queue
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner import load_checkpoint
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
except ImportError:
|
||||
gdal = None
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import _preprare_data
|
||||
|
||||
|
||||
class RSImage:
|
||||
"""Remote sensing image class.
|
||||
|
||||
Args:
|
||||
img (str or gdal.Dataset): Image file path or gdal.Dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, image):
|
||||
self.dataset = gdal.Open(image, gdal.GA_ReadOnly) if isinstance(
|
||||
image, str) else image
|
||||
assert isinstance(self.dataset, gdal.Dataset), \
|
||||
f'{image} is not a image'
|
||||
self.width = self.dataset.RasterXSize
|
||||
self.height = self.dataset.RasterYSize
|
||||
self.channel = self.dataset.RasterCount
|
||||
self.trans = self.dataset.GetGeoTransform()
|
||||
self.proj = self.dataset.GetProjection()
|
||||
self.band_list = []
|
||||
self.band_list.extend(
|
||||
self.dataset.GetRasterBand(c + 1) for c in range(self.channel))
|
||||
self.grids = []
|
||||
|
||||
def read(self, grid: Optional[List] = None) -> np.ndarray:
|
||||
"""Read image data. If grid is None, read the whole image.
|
||||
|
||||
Args:
|
||||
grid (Optional[List], optional): Grid to read. Defaults to None.
|
||||
Returns:
|
||||
np.ndarray: Image data.
|
||||
"""
|
||||
if grid is None:
|
||||
return np.einsum('ijk->jki', self.dataset.ReadAsArray())
|
||||
assert len(
|
||||
grid) >= 4, 'grid must be a list containing at least 4 elements'
|
||||
data = self.dataset.ReadAsArray(*grid[:4])
|
||||
if data.ndim == 2:
|
||||
data = data[np.newaxis, ...]
|
||||
return np.einsum('ijk->jki', data)
|
||||
|
||||
def write(self, data: Optional[np.ndarray], grid: Optional[List] = None):
|
||||
"""Write image data.
|
||||
|
||||
Args:
|
||||
grid (Optional[List], optional): Grid to write. Defaults to None.
|
||||
data (Optional[np.ndarray], optional): Data to write.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: Either grid or data must be provided.
|
||||
"""
|
||||
if grid is not None:
|
||||
assert len(grid) == 8, 'grid must be a list of 8 elements'
|
||||
for band in self.band_list:
|
||||
band.WriteArray(
|
||||
data[grid[5]:grid[5] + grid[7], grid[4]:grid[4] + grid[6]],
|
||||
grid[0] + grid[4], grid[1] + grid[5])
|
||||
elif data is not None:
|
||||
for i in range(self.channel):
|
||||
self.band_list[i].WriteArray(data[..., i])
|
||||
else:
|
||||
raise ValueError('Either grid or data must be provided.')
|
||||
|
||||
def create_seg_map(self, output_path: Optional[str] = None):
|
||||
if output_path is None:
|
||||
output_path = 'output_label.tif'
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
seg_map = driver.Create(output_path, self.width, self.height, 1,
|
||||
gdal.GDT_Byte)
|
||||
seg_map.SetGeoTransform(self.trans)
|
||||
seg_map.SetProjection(self.proj)
|
||||
seg_map_img = RSImage(seg_map)
|
||||
seg_map_img.path = output_path
|
||||
return seg_map_img
|
||||
|
||||
def create_grids(self,
|
||||
window_size: Tuple[int, int],
|
||||
stride: Tuple[int, int] = (0, 0)):
|
||||
"""Create grids for image inference.
|
||||
|
||||
Args:
|
||||
window_size (Tuple[int, int]): the size of the sliding window.
|
||||
stride (Tuple[int, int], optional): the stride of the sliding
|
||||
window. Defaults to (0, 0).
|
||||
|
||||
Raises:
|
||||
AssertionError: window_size must be a tuple of 2 elements.
|
||||
AssertionError: stride must be a tuple of 2 elements.
|
||||
"""
|
||||
assert len(
|
||||
window_size) == 2, 'window_size must be a tuple of 2 elements'
|
||||
assert len(stride) == 2, 'stride must be a tuple of 2 elements'
|
||||
win_w, win_h = window_size
|
||||
stride_x, stride_y = stride
|
||||
|
||||
stride_x = win_w if stride_x == 0 else stride_x
|
||||
stride_y = win_h if stride_y == 0 else stride_y
|
||||
|
||||
x_half_overlap = (win_w - stride_x + 1) // 2
|
||||
y_half_overlap = (win_h - stride_y + 1) // 2
|
||||
|
||||
for y in range(0, self.height, stride_y):
|
||||
y_end = y + win_h >= self.height
|
||||
y_offset = self.height - win_h if y_end else y
|
||||
y_size = win_h
|
||||
y_crop_off = 0 if y_offset == 0 else y_half_overlap
|
||||
y_crop_size = y_size if y_end else win_h - y_crop_off
|
||||
|
||||
for x in range(0, self.width, stride_x):
|
||||
x_end = x + win_w >= self.width
|
||||
x_offset = self.width - win_w if x_end else x
|
||||
x_size = win_w
|
||||
x_crop_off = 0 if x_offset == 0 else x_half_overlap
|
||||
x_crop_size = x_size if x_end else win_w - x_crop_off
|
||||
|
||||
self.grids.append([
|
||||
x_offset, y_offset, x_size, y_size, x_crop_off, y_crop_off,
|
||||
x_crop_size, y_crop_size
|
||||
])
|
||||
|
||||
|
||||
class RSInferencer:
|
||||
"""Remote sensing inference class.
|
||||
|
||||
Args:
|
||||
model (BaseModel): The loaded model.
|
||||
batch_size (int, optional): Batch size. Defaults to 1.
|
||||
thread (int, optional): Number of threads. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, model: BaseModel, batch_size: int = 1, thread: int = 1):
|
||||
self.model = model
|
||||
self.batch_size = batch_size
|
||||
self.END_FLAG = object()
|
||||
self.read_buffer = Queue(self.batch_size)
|
||||
self.write_buffer = Queue(self.batch_size)
|
||||
self.thread = thread
|
||||
|
||||
@classmethod
|
||||
def from_config_path(cls,
|
||||
config_path: str,
|
||||
checkpoint_path: str,
|
||||
batch_size: int = 1,
|
||||
thread: int = 1,
|
||||
device: Optional[str] = 'cpu'):
|
||||
"""Initialize a segmentor from config file.
|
||||
|
||||
Args:
|
||||
config_path (str): Config file path.
|
||||
checkpoint_path (str): Checkpoint path.
|
||||
batch_size (int, optional): Batch size. Defaults to 1.
|
||||
"""
|
||||
init_default_scope('mmseg')
|
||||
cfg = Config.fromfile(config_path)
|
||||
model = MODELS.build(cfg.model)
|
||||
model.cfg = cfg
|
||||
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return cls(model, batch_size, thread)
|
||||
|
||||
@classmethod
|
||||
def from_model(cls,
|
||||
model: BaseModel,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
batch_size: int = 1,
|
||||
thread: int = 1,
|
||||
device: Optional[str] = 'cpu'):
|
||||
"""Initialize a segmentor from model.
|
||||
|
||||
Args:
|
||||
model (BaseModel): The loaded model.
|
||||
checkpoint_path (Optional[str]): Checkpoint path.
|
||||
batch_size (int, optional): Batch size. Defaults to 1.
|
||||
"""
|
||||
if checkpoint_path is not None:
|
||||
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
||||
model.to(device)
|
||||
return cls(model, batch_size, thread)
|
||||
|
||||
def read(self,
|
||||
image: RSImage,
|
||||
window_size: Tuple[int, int],
|
||||
strides: Tuple[int, int] = (0, 0)):
|
||||
"""Load image data to read buffer.
|
||||
|
||||
Args:
|
||||
image (RSImage): The image to read.
|
||||
window_size (Tuple[int, int]): The size of the sliding window.
|
||||
strides (Tuple[int, int], optional): The stride of the sliding
|
||||
window. Defaults to (0, 0).
|
||||
"""
|
||||
image.create_grids(window_size, strides)
|
||||
for grid in image.grids:
|
||||
self.read_buffer.put([grid, image.read(grid=grid)])
|
||||
self.read_buffer.put(self.END_FLAG)
|
||||
|
||||
def inference(self):
|
||||
"""Inference image data from read buffer and put the result to write
|
||||
buffer."""
|
||||
while True:
|
||||
item = self.read_buffer.get()
|
||||
if item == self.END_FLAG:
|
||||
self.read_buffer.put(self.END_FLAG)
|
||||
self.write_buffer.put(item)
|
||||
break
|
||||
data, _ = _preprare_data(item[1], self.model)
|
||||
with torch.no_grad():
|
||||
result = self.model.test_step(data)
|
||||
item[1] = result[0].pred_sem_seg.cpu().data.numpy()[0]
|
||||
self.write_buffer.put(item)
|
||||
self.read_buffer.task_done()
|
||||
|
||||
def write(self, image: RSImage, output_path: Optional[str] = None):
|
||||
"""Write image data from write buffer.
|
||||
|
||||
Args:
|
||||
image (RSImage): The image to write.
|
||||
output_path (Optional[str], optional): The path to save the
|
||||
segmentation map. Defaults to None.
|
||||
"""
|
||||
seg_map = image.create_seg_map(output_path)
|
||||
while True:
|
||||
item = self.write_buffer.get()
|
||||
if item == self.END_FLAG:
|
||||
break
|
||||
seg_map.write(data=item[1], grid=item[0])
|
||||
self.write_buffer.task_done()
|
||||
|
||||
def run(self,
|
||||
image: RSImage,
|
||||
window_size: Tuple[int, int],
|
||||
strides: Tuple[int, int] = (0, 0),
|
||||
output_path: Optional[str] = None):
|
||||
"""Run inference with multi-threading.
|
||||
|
||||
Args:
|
||||
image (RSImage): The image to inference.
|
||||
window_size (Tuple[int, int]): The size of the sliding window.
|
||||
strides (Tuple[int, int], optional): The stride of the sliding
|
||||
window. Defaults to (0, 0).
|
||||
output_path (Optional[str], optional): The path to save the
|
||||
segmentation map. Defaults to None.
|
||||
"""
|
||||
read_thread = threading.Thread(
|
||||
target=self.read, args=(image, window_size, strides))
|
||||
read_thread.start()
|
||||
inference_threads = []
|
||||
for _ in range(self.thread):
|
||||
inference_thread = threading.Thread(target=self.inference)
|
||||
inference_thread.start()
|
||||
inference_threads.append(inference_thread)
|
||||
write_thread = threading.Thread(
|
||||
target=self.write, args=(image, output_path))
|
||||
write_thread.start()
|
||||
read_thread.join()
|
||||
for inference_thread in inference_threads:
|
||||
inference_thread.join()
|
||||
write_thread.join()
|
||||
41
finetune/mmseg/apis/utils.py
Normal file
41
finetune/mmseg/apis/utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections import defaultdict
|
||||
from typing import Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine.model import BaseModel
|
||||
|
||||
ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
|
||||
|
||||
|
||||
def _preprare_data(imgs: ImageType, model: BaseModel):
|
||||
|
||||
cfg = model.cfg
|
||||
for t in cfg.test_pipeline:
|
||||
if t.get('type') == 'LoadAnnotations':
|
||||
cfg.test_pipeline.remove(t)
|
||||
|
||||
is_batch = True
|
||||
if not isinstance(imgs, (list, tuple)):
|
||||
imgs = [imgs]
|
||||
is_batch = False
|
||||
|
||||
if isinstance(imgs[0], np.ndarray):
|
||||
cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray'
|
||||
|
||||
# TODO: Consider using the singleton pattern to avoid building
|
||||
# a pipeline for each inference
|
||||
pipeline = Compose(cfg.test_pipeline)
|
||||
|
||||
data = defaultdict(list)
|
||||
for img in imgs:
|
||||
if isinstance(img, np.ndarray):
|
||||
data_ = dict(img=img)
|
||||
else:
|
||||
data_ = dict(img_path=img)
|
||||
data_ = pipeline(data_)
|
||||
data['inputs'].append(data_['inputs'])
|
||||
data['data_samples'].append(data_['data_samples'])
|
||||
|
||||
return data, is_batch
|
||||
35
finetune/mmseg/datasets/__init__.py
Normal file
35
finetune/mmseg/datasets/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# yapf: disable
|
||||
from .atlantic import AtlanticDataset
|
||||
from .c2sfloods import C2SFloodDataset
|
||||
from .cabuar import CABURADataset
|
||||
from .germany import GermanyCropDataset
|
||||
from .sos import SOSDataset
|
||||
# yapf: disable
|
||||
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
|
||||
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
|
||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||
BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
|
||||
LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||
LoadImageFromNDArray, LoadMultipleRSImageFromFile,
|
||||
LoadSingleRSImageFromFile, PackSegInputs,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomMosaic, RandomRotate, RandomRotFlip, Rerange,
|
||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||
SegRescale)
|
||||
|
||||
# yapf: enable
|
||||
__all__ = [
|
||||
'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip',
|
||||
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
|
||||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
|
||||
'Albu', 'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile',
|
||||
'ConcatCDInput', 'AtlanticDataset', 'C2SFloodDataset',
|
||||
'CABURADataset', 'GermanyCropDataset', 'SOSDataset'
|
||||
]
|
||||
48
finetune/mmseg/datasets/atlantic.py
Normal file
48
finetune/mmseg/datasets/atlantic.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import List
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
@DATASETS.register_module()
|
||||
class AtlanticDataset(BaseSegDataset):
|
||||
METAINFO = dict(
|
||||
classes=('Background', 'Deforestation area'),
|
||||
palette=[[0,0,0], [255,255,255]]
|
||||
)
|
||||
def __init__(self,
|
||||
img_suffix='.tif',
|
||||
seg_map_suffix='.tif',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = json.load(open(self.ann_file))
|
||||
for line in lines:
|
||||
# img_name = line.strip()
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, line['s2_path']))
|
||||
if 'target_path' in line.keys():
|
||||
data_info['seg_map_path'] = osp.join(self.data_root, line['target_path'])
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
552
finetune/mmseg/datasets/basesegdataset.py
Normal file
552
finetune/mmseg/datasets/basesegdataset.py
Normal file
@@ -0,0 +1,552 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import mmengine
|
||||
import mmengine.fileio as fileio
|
||||
import numpy as np
|
||||
from mmengine.dataset import BaseDataset, Compose
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class BaseSegDataset(BaseDataset):
|
||||
"""Custom dataset for semantic segmentation. An example of file structure
|
||||
is as followed.
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
├── data
|
||||
│ ├── my_dataset
|
||||
│ │ ├── img_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{img_suffix}
|
||||
│ │ │ │ ├── yyy{img_suffix}
|
||||
│ │ │ │ ├── zzz{img_suffix}
|
||||
│ │ │ ├── val
|
||||
│ │ ├── ann_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{seg_map_suffix}
|
||||
│ │ │ │ ├── yyy{seg_map_suffix}
|
||||
│ │ │ │ ├── zzz{seg_map_suffix}
|
||||
│ │ │ ├── val
|
||||
|
||||
The img/gt_semantic_seg pair of BaseSegDataset should be of the same
|
||||
except suffix. A valid img/gt_semantic_seg filename pair should be like
|
||||
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
|
||||
in the suffix). If split is given, then ``xxx`` is specified in txt file.
|
||||
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
|
||||
Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.
|
||||
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path. Defaults to ''.
|
||||
metainfo (dict, optional): Meta information for dataset, such as
|
||||
specify classes to load. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (dict, optional): Prefix for training data. Defaults to
|
||||
dict(img_path=None, seg_map_path=None).
|
||||
img_suffix (str): Suffix of images. Default: '.jpg'
|
||||
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
|
||||
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
||||
indices (int or Sequence[int], optional): Support using first few
|
||||
data in annotation file to facilitate training/testing on a smaller
|
||||
dataset. Defaults to None which means using all ``data_infos``.
|
||||
serialize_data (bool, optional): Whether to hold memory using
|
||||
serialized objects, when enabled, data loader workers can use
|
||||
shared RAM from master process instead of making a copy. Defaults
|
||||
to True.
|
||||
pipeline (list, optional): Processing pipeline. Defaults to [].
|
||||
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
||||
Defaults to False.
|
||||
lazy_init (bool, optional): Whether to load annotation during
|
||||
instantiation. In some cases, such as visualization, only the meta
|
||||
information of the dataset is needed, which is not necessary to
|
||||
load annotation file. ``Basedataset`` can skip load annotations to
|
||||
save time by set ``lazy_init=True``. Defaults to False.
|
||||
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
||||
None img. The maximum extra number of cycles to get a valid
|
||||
image. Defaults to 1000.
|
||||
ignore_index (int): The label index to be ignored. Default: 255
|
||||
reduce_zero_label (bool): Whether to mark label zero as ignored.
|
||||
Default to False.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
METAINFO: dict = dict()
|
||||
|
||||
def __init__(self,
|
||||
ann_file: str = '',
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
data_prefix: dict = dict(img_path='', seg_map_path=''),
|
||||
filter_cfg: Optional[dict] = None,
|
||||
indices: Optional[Union[int, Sequence[int]]] = None,
|
||||
serialize_data: bool = True,
|
||||
pipeline: List[Union[dict, Callable]] = [],
|
||||
test_mode: bool = False,
|
||||
lazy_init: bool = False,
|
||||
max_refetch: int = 1000,
|
||||
ignore_index: int = 255,
|
||||
reduce_zero_label: bool = False,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
|
||||
self.img_suffix = img_suffix
|
||||
self.seg_map_suffix = seg_map_suffix
|
||||
self.ignore_index = ignore_index
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
self.data_root = data_root
|
||||
self.data_prefix = copy.copy(data_prefix)
|
||||
self.ann_file = ann_file
|
||||
self.filter_cfg = copy.deepcopy(filter_cfg)
|
||||
self._indices = indices
|
||||
self.serialize_data = serialize_data
|
||||
self.test_mode = test_mode
|
||||
self.max_refetch = max_refetch
|
||||
self.data_list: List[dict] = []
|
||||
self.data_bytes: np.ndarray
|
||||
|
||||
# Set meta information.
|
||||
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
|
||||
|
||||
# Get label map for custom classes
|
||||
new_classes = self._metainfo.get('classes', None)
|
||||
self.label_map = self.get_label_map(new_classes)
|
||||
self._metainfo.update(
|
||||
dict(
|
||||
label_map=self.label_map,
|
||||
reduce_zero_label=self.reduce_zero_label))
|
||||
|
||||
# Update palette based on label map or generate palette
|
||||
# if it is not defined
|
||||
updated_palette = self._update_palette()
|
||||
self._metainfo.update(dict(palette=updated_palette))
|
||||
|
||||
# Join paths.
|
||||
if self.data_root is not None:
|
||||
self._join_prefix()
|
||||
|
||||
# Build pipeline.
|
||||
self.pipeline = Compose(pipeline)
|
||||
# Full initialize the dataset.
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
if test_mode:
|
||||
assert self._metainfo.get('classes') is not None, \
|
||||
'dataset metainfo `classes` should be specified when testing'
|
||||
|
||||
@classmethod
|
||||
def get_label_map(cls,
|
||||
new_classes: Optional[Sequence] = None
|
||||
) -> Union[Dict, None]:
|
||||
"""Require label mapping.
|
||||
|
||||
The ``label_map`` is a dictionary, its keys are the old label ids and
|
||||
its values are the new label ids, and is used for changing pixel
|
||||
labels in load_annotations. If and only if old classes in cls.METAINFO
|
||||
is not equal to new classes in self._metainfo and nether of them is not
|
||||
None, `label_map` is not None.
|
||||
|
||||
Args:
|
||||
new_classes (list, tuple, optional): The new classes name from
|
||||
metainfo. Default to None.
|
||||
|
||||
|
||||
Returns:
|
||||
dict, optional: The mapping from old classes in cls.METAINFO to
|
||||
new classes in self._metainfo
|
||||
"""
|
||||
old_classes = cls.METAINFO.get('classes', None)
|
||||
if (new_classes is not None and old_classes is not None
|
||||
and list(new_classes) != list(old_classes)):
|
||||
|
||||
label_map = {}
|
||||
if not set(new_classes).issubset(cls.METAINFO['classes']):
|
||||
raise ValueError(
|
||||
f'new classes {new_classes} is not a '
|
||||
f'subset of classes {old_classes} in METAINFO.')
|
||||
for i, c in enumerate(old_classes):
|
||||
if c not in new_classes:
|
||||
label_map[i] = 255
|
||||
else:
|
||||
label_map[i] = new_classes.index(c)
|
||||
return label_map
|
||||
else:
|
||||
return None
|
||||
|
||||
def _update_palette(self) -> list:
|
||||
"""Update palette after loading metainfo.
|
||||
|
||||
If length of palette is equal to classes, just return the palette.
|
||||
If palette is not defined, it will randomly generate a palette.
|
||||
If classes is updated by customer, it will return the subset of
|
||||
palette.
|
||||
|
||||
Returns:
|
||||
Sequence: Palette for current dataset.
|
||||
"""
|
||||
palette = self._metainfo.get('palette', [])
|
||||
classes = self._metainfo.get('classes', [])
|
||||
# palette does match classes
|
||||
if len(palette) == len(classes):
|
||||
return palette
|
||||
|
||||
if len(palette) == 0:
|
||||
# Get random state before set seed, and restore
|
||||
# random state later.
|
||||
# It will prevent loss of randomness, as the palette
|
||||
# may be different in each iteration if not specified.
|
||||
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
||||
state = np.random.get_state()
|
||||
np.random.seed(42)
|
||||
# random palette
|
||||
new_palette = np.random.randint(
|
||||
0, 255, size=(len(classes), 3)).tolist()
|
||||
np.random.set_state(state)
|
||||
elif len(palette) >= len(classes) and self.label_map is not None:
|
||||
new_palette = []
|
||||
# return subset of palette
|
||||
for old_id, new_id in sorted(
|
||||
self.label_map.items(), key=lambda x: x[1]):
|
||||
if new_id != 255:
|
||||
new_palette.append(palette[old_id])
|
||||
new_palette = type(palette)(new_palette)
|
||||
else:
|
||||
raise ValueError('palette does not match classes '
|
||||
f'as metainfo is {self._metainfo}.')
|
||||
return new_palette
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
img_dir = self.data_prefix.get('img_path', None)
|
||||
ann_dir = self.data_prefix.get('seg_map_path', None)
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = mmengine.list_from_file(
|
||||
self.ann_file, backend_args=self.backend_args)
|
||||
for line in lines:
|
||||
img_name = line.strip()
|
||||
data_info = dict(
|
||||
img_path=osp.join(img_dir, img_name + self.img_suffix))
|
||||
if ann_dir is not None:
|
||||
seg_map = img_name + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
else:
|
||||
_suffix_len = len(self.img_suffix)
|
||||
for img in fileio.list_dir_or_file(
|
||||
dir_path=img_dir,
|
||||
list_dir=False,
|
||||
suffix=self.img_suffix,
|
||||
recursive=True,
|
||||
backend_args=self.backend_args):
|
||||
data_info = dict(img_path=osp.join(img_dir, img))
|
||||
if ann_dir is not None:
|
||||
seg_map = img[:-_suffix_len] + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class BaseCDDataset(BaseDataset):
|
||||
"""Custom dataset for change detection. An example of file structure is as
|
||||
followed.
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
├── data
|
||||
│ ├── my_dataset
|
||||
│ │ ├── img_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{img_suffix}
|
||||
│ │ │ │ ├── yyy{img_suffix}
|
||||
│ │ │ │ ├── zzz{img_suffix}
|
||||
│ │ │ ├── val
|
||||
│ │ ├── img_dir2
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{img_suffix}
|
||||
│ │ │ │ ├── yyy{img_suffix}
|
||||
│ │ │ │ ├── zzz{img_suffix}
|
||||
│ │ │ ├── val
|
||||
│ │ ├── ann_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{seg_map_suffix}
|
||||
│ │ │ │ ├── yyy{seg_map_suffix}
|
||||
│ │ │ │ ├── zzz{seg_map_suffix}
|
||||
│ │ │ ├── val
|
||||
|
||||
The image names in img_dir and img_dir2 should be consistent.
|
||||
The img/gt_semantic_seg pair of BaseSegDataset should be of the same
|
||||
except suffix. A valid img/gt_semantic_seg filename pair should be like
|
||||
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
|
||||
in the suffix). If split is given, then ``xxx`` is specified in txt file.
|
||||
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
|
||||
Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.
|
||||
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path. Defaults to ''.
|
||||
metainfo (dict, optional): Meta information for dataset, such as
|
||||
specify classes to load. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (dict, optional): Prefix for training data. Defaults to
|
||||
dict(img_path=None, img_path2=None, seg_map_path=None).
|
||||
img_suffix (str): Suffix of images. Default: '.jpg'
|
||||
img_suffix2 (str): Suffix of images. Default: '.jpg'
|
||||
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
|
||||
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
||||
indices (int or Sequence[int], optional): Support using first few
|
||||
data in annotation file to facilitate training/testing on a smaller
|
||||
dataset. Defaults to None which means using all ``data_infos``.
|
||||
serialize_data (bool, optional): Whether to hold memory using
|
||||
serialized objects, when enabled, data loader workers can use
|
||||
shared RAM from master process instead of making a copy. Defaults
|
||||
to True.
|
||||
pipeline (list, optional): Processing pipeline. Defaults to [].
|
||||
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
||||
Defaults to False.
|
||||
lazy_init (bool, optional): Whether to load annotation during
|
||||
instantiation. In some cases, such as visualization, only the meta
|
||||
information of the dataset is needed, which is not necessary to
|
||||
load annotation file. ``Basedataset`` can skip load annotations to
|
||||
save time by set ``lazy_init=True``. Defaults to False.
|
||||
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
||||
None img. The maximum extra number of cycles to get a valid
|
||||
image. Defaults to 1000.
|
||||
ignore_index (int): The label index to be ignored. Default: 255
|
||||
reduce_zero_label (bool): Whether to mark label zero as ignored.
|
||||
Default to False.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
METAINFO: dict = dict()
|
||||
|
||||
def __init__(self,
|
||||
ann_file: str = '',
|
||||
img_suffix='.jpg',
|
||||
img_suffix2='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
data_prefix: dict = dict(
|
||||
img_path='', img_path2='', seg_map_path=''),
|
||||
filter_cfg: Optional[dict] = None,
|
||||
indices: Optional[Union[int, Sequence[int]]] = None,
|
||||
serialize_data: bool = True,
|
||||
pipeline: List[Union[dict, Callable]] = [],
|
||||
test_mode: bool = False,
|
||||
lazy_init: bool = False,
|
||||
max_refetch: int = 1000,
|
||||
ignore_index: int = 255,
|
||||
reduce_zero_label: bool = False,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
|
||||
self.img_suffix = img_suffix
|
||||
self.img_suffix2 = img_suffix2
|
||||
self.seg_map_suffix = seg_map_suffix
|
||||
self.ignore_index = ignore_index
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
self.data_root = data_root
|
||||
self.data_prefix = copy.copy(data_prefix)
|
||||
self.ann_file = ann_file
|
||||
self.filter_cfg = copy.deepcopy(filter_cfg)
|
||||
self._indices = indices
|
||||
self.serialize_data = serialize_data
|
||||
self.test_mode = test_mode
|
||||
self.max_refetch = max_refetch
|
||||
self.data_list: List[dict] = []
|
||||
self.data_bytes: np.ndarray
|
||||
|
||||
# Set meta information.
|
||||
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
|
||||
|
||||
# Get label map for custom classes
|
||||
new_classes = self._metainfo.get('classes', None)
|
||||
self.label_map = self.get_label_map(new_classes)
|
||||
self._metainfo.update(
|
||||
dict(
|
||||
label_map=self.label_map,
|
||||
reduce_zero_label=self.reduce_zero_label))
|
||||
|
||||
# Update palette based on label map or generate palette
|
||||
# if it is not defined
|
||||
updated_palette = self._update_palette()
|
||||
self._metainfo.update(dict(palette=updated_palette))
|
||||
|
||||
# Join paths.
|
||||
if self.data_root is not None:
|
||||
self._join_prefix()
|
||||
|
||||
# Build pipeline.
|
||||
self.pipeline = Compose(pipeline)
|
||||
# Full initialize the dataset.
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
if test_mode:
|
||||
assert self._metainfo.get('classes') is not None, \
|
||||
'dataset metainfo `classes` should be specified when testing'
|
||||
|
||||
@classmethod
|
||||
def get_label_map(cls,
|
||||
new_classes: Optional[Sequence] = None
|
||||
) -> Union[Dict, None]:
|
||||
"""Require label mapping.
|
||||
|
||||
The ``label_map`` is a dictionary, its keys are the old label ids and
|
||||
its values are the new label ids, and is used for changing pixel
|
||||
labels in load_annotations. If and only if old classes in cls.METAINFO
|
||||
is not equal to new classes in self._metainfo and nether of them is not
|
||||
None, `label_map` is not None.
|
||||
|
||||
Args:
|
||||
new_classes (list, tuple, optional): The new classes name from
|
||||
metainfo. Default to None.
|
||||
|
||||
|
||||
Returns:
|
||||
dict, optional: The mapping from old classes in cls.METAINFO to
|
||||
new classes in self._metainfo
|
||||
"""
|
||||
old_classes = cls.METAINFO.get('classes', None)
|
||||
if (new_classes is not None and old_classes is not None
|
||||
and list(new_classes) != list(old_classes)):
|
||||
|
||||
label_map = {}
|
||||
if not set(new_classes).issubset(cls.METAINFO['classes']):
|
||||
raise ValueError(
|
||||
f'new classes {new_classes} is not a '
|
||||
f'subset of classes {old_classes} in METAINFO.')
|
||||
for i, c in enumerate(old_classes):
|
||||
if c not in new_classes:
|
||||
label_map[i] = 255
|
||||
else:
|
||||
label_map[i] = new_classes.index(c)
|
||||
return label_map
|
||||
else:
|
||||
return None
|
||||
|
||||
def _update_palette(self) -> list:
|
||||
"""Update palette after loading metainfo.
|
||||
|
||||
If length of palette is equal to classes, just return the palette.
|
||||
If palette is not defined, it will randomly generate a palette.
|
||||
If classes is updated by customer, it will return the subset of
|
||||
palette.
|
||||
|
||||
Returns:
|
||||
Sequence: Palette for current dataset.
|
||||
"""
|
||||
palette = self._metainfo.get('palette', [])
|
||||
classes = self._metainfo.get('classes', [])
|
||||
# palette does match classes
|
||||
if len(palette) == len(classes):
|
||||
return palette
|
||||
|
||||
if len(palette) == 0:
|
||||
# Get random state before set seed, and restore
|
||||
# random state later.
|
||||
# It will prevent loss of randomness, as the palette
|
||||
# may be different in each iteration if not specified.
|
||||
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
||||
state = np.random.get_state()
|
||||
np.random.seed(42)
|
||||
# random palette
|
||||
new_palette = np.random.randint(
|
||||
0, 255, size=(len(classes), 3)).tolist()
|
||||
np.random.set_state(state)
|
||||
elif len(palette) >= len(classes) and self.label_map is not None:
|
||||
new_palette = []
|
||||
# return subset of palette
|
||||
for old_id, new_id in sorted(
|
||||
self.label_map.items(), key=lambda x: x[1]):
|
||||
if new_id != 255:
|
||||
new_palette.append(palette[old_id])
|
||||
new_palette = type(palette)(new_palette)
|
||||
else:
|
||||
raise ValueError('palette does not match classes '
|
||||
f'as metainfo is {self._metainfo}.')
|
||||
return new_palette
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
img_dir = self.data_prefix.get('img_path', None)
|
||||
img_dir2 = self.data_prefix.get('img_path2', None)
|
||||
ann_dir = self.data_prefix.get('seg_map_path', None)
|
||||
if osp.isfile(self.ann_file):
|
||||
lines = mmengine.list_from_file(
|
||||
self.ann_file, backend_args=self.backend_args)
|
||||
for line in lines:
|
||||
img_name = line.strip()
|
||||
if '.' in osp.basename(img_name):
|
||||
img_name, img_ext = osp.splitext(img_name)
|
||||
self.img_suffix = img_ext
|
||||
self.img_suffix2 = img_ext
|
||||
data_info = dict(
|
||||
img_path=osp.join(img_dir, img_name + self.img_suffix),
|
||||
img_path2=osp.join(img_dir2, img_name + self.img_suffix2))
|
||||
|
||||
if ann_dir is not None:
|
||||
seg_map = img_name + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
else:
|
||||
for img in fileio.list_dir_or_file(
|
||||
dir_path=img_dir,
|
||||
list_dir=False,
|
||||
suffix=self.img_suffix,
|
||||
recursive=True,
|
||||
backend_args=self.backend_args):
|
||||
if '.' in osp.basename(img):
|
||||
img, img_ext = osp.splitext(img)
|
||||
self.img_suffix = img_ext
|
||||
self.img_suffix2 = img_ext
|
||||
data_info = dict(
|
||||
img_path=osp.join(img_dir, img + self.img_suffix),
|
||||
img_path2=osp.join(img_dir2, img + self.img_suffix2))
|
||||
if ann_dir is not None:
|
||||
seg_map = img + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
65
finetune/mmseg/datasets/c2sfloods.py
Normal file
65
finetune/mmseg/datasets/c2sfloods.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
# LEGEND = [
|
||||
# 255 255 255; % Background
|
||||
# 0 0 0; % Roads
|
||||
# 100 100 100; % Buildings
|
||||
# 0 125 0; % Trees
|
||||
# 0 255 0; % Grass
|
||||
# 150 80 0; % Bare Soil
|
||||
# 0 0 150; % Water
|
||||
# 255 255 0; % Railways
|
||||
# 150 150 255]; % Swimming Pools
|
||||
|
||||
@DATASETS.register_module()
|
||||
class C2SFloodDataset(BaseSegDataset):
|
||||
"""Zurich dataset.
|
||||
|
||||
In segmentation map annotation for LoveDA, 0 is the ignore index.
|
||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('Background', 'Water', 'Cloud', 'Cloud shadow'),
|
||||
palette=[[0,0,0], [255,255,255], [255,0,0], [0,255,0]]
|
||||
)
|
||||
def __init__(self,
|
||||
img_suffix='.npz',
|
||||
seg_map_suffix='.npz',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = json.load(open(self.ann_file))
|
||||
for line in lines:
|
||||
# img_name = line.strip()
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, line['s2_path']))
|
||||
if 'target_path' in line.keys():
|
||||
data_info['seg_map_path'] = osp.join(self.data_root, line['target_path'])
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
54
finetune/mmseg/datasets/cabuar.py
Normal file
54
finetune/mmseg/datasets/cabuar.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
@DATASETS.register_module()
|
||||
class CABURADataset(BaseSegDataset):
|
||||
"""Zurich dataset.
|
||||
|
||||
In segmentation map annotation for LoveDA, 0 is the ignore index.
|
||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('Background', 'Burned area'),
|
||||
palette=[[0,0,0], [255,255,255]]
|
||||
)
|
||||
def __init__(self,
|
||||
img_suffix='.npz',
|
||||
seg_map_suffix='.npz',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = json.load(open(self.ann_file))
|
||||
for line in lines:
|
||||
# img_name = line.strip()
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, line['post_fire']))
|
||||
if 'mask' in line.keys():
|
||||
data_info['seg_map_path'] = osp.join(self.data_root, line['mask'])
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
72
finetune/mmseg/datasets/germany.py
Normal file
72
finetune/mmseg/datasets/germany.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
from mmengine.logging import print_log
|
||||
import pandas as pd
|
||||
# LEGEND = [
|
||||
# 255 255 255; % Background
|
||||
# 0 0 0; % Roads
|
||||
# 100 100 100; % Buildings
|
||||
# 0 125 0; % Trees
|
||||
# 0 255 0; % Grass
|
||||
# 150 80 0; % Bare Soil
|
||||
# 0 0 150; % Water
|
||||
# 255 255 0; % Railways
|
||||
# 150 150 255]; % Swimming Pools
|
||||
|
||||
@DATASETS.register_module()
|
||||
class GermanyCropDataset(BaseSegDataset):
|
||||
"""Zurich dataset.
|
||||
|
||||
In segmentation map annotation for LoveDA, 0 is the ignore index.
|
||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
# {0: "unknown", 1: "sugar_beet", 2: "summer_oat", 3: "meadow", 5: "rape", 8: "hop",
|
||||
# 9: "winter_spelt", 12: "winter_triticale", 13: "beans", 15: "peas", 16: "potatoes",
|
||||
# 17: "soybeans", 19: "asparagus", 22: "winter_wheat", 23: "winter_barley", 24: "winter_rye",
|
||||
# 25: "summer_barley", 26: "maize"}
|
||||
METAINFO = dict(
|
||||
classes=('sugar_beet', 'summer_oat', 'meadow', 'rape', 'hop', 'winter_spelt', 'winter_triticale', 'beans', 'peas',\
|
||||
'potatoes', 'soybeans', 'asparagus', 'winter_wheat', 'winter_barley', 'winter_rye', 'summer_barley', 'maize'),
|
||||
palette=[(255, 255, 255), (255, 255, 170), (255, 255, 85), (255, 170, 255), (255, 170, 170), (255, 170, 85), \
|
||||
(255, 85, 255), (255, 85, 170), (255, 85, 85), (170, 255, 255), (170, 255, 170), (170, 255, 85), (170, 170, 255), \
|
||||
(170, 170, 170), (170, 170, 85), (170, 85, 255), (170, 85, 170)])
|
||||
def __init__(self,
|
||||
img_suffix='.pickle',
|
||||
seg_map_suffix='.pickle',
|
||||
reduce_zero_label=True,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = json.load(open(self.ann_file))
|
||||
print_log(f'dataset count: {len(lines)}')
|
||||
for line in lines:
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, line['s2_path']))
|
||||
if 'target_path' in line.keys():
|
||||
data_info['seg_map_path'] = osp.join(self.data_root, line['target_path'])
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
66
finetune/mmseg/datasets/sos.py
Normal file
66
finetune/mmseg/datasets/sos.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
# LEGEND = [
|
||||
# 255 255 255; % Background
|
||||
# 0 0 0; % Roads
|
||||
# 100 100 100; % Buildings
|
||||
# 0 125 0; % Trees
|
||||
# 0 255 0; % Grass
|
||||
# 150 80 0; % Bare Soil
|
||||
# 0 0 150; % Water
|
||||
# 255 255 0; % Railways
|
||||
# 150 150 255]; % Swimming Pools
|
||||
|
||||
@DATASETS.register_module()
|
||||
class SOSDataset(BaseSegDataset):
|
||||
"""Zurich dataset.
|
||||
|
||||
In segmentation map annotation for LoveDA, 0 is the ignore index.
|
||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('Background', 'Oil Spill Area'),
|
||||
palette=[[0,0,0], [255,255,255]]
|
||||
)
|
||||
def __init__(self,
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = json.load(open(self.ann_file))
|
||||
for line in lines:
|
||||
# img_name = line.strip()
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, line['s1_path']))
|
||||
if 'target_path' in line.keys():
|
||||
data_info['seg_map_path'] = osp.join(self.data_root, line['target_path'])
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
# print(data_list)
|
||||
return data_list
|
||||
32
finetune/mmseg/datasets/transforms/__init__.py
Normal file
32
finetune/mmseg/datasets/transforms/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .formatting import PackSegInputs
|
||||
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||
LoadDepthAnnotation, LoadImageFromNDArray,
|
||||
LoadMultipleRSImageFromFile, LoadSingleRSImageFromFile)
|
||||
# yapf: disable
|
||||
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
|
||||
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
|
||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||
BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomDepthMix, RandomFlip, RandomMosaic,
|
||||
RandomRotate, RandomRotFlip, Rerange, Resize,
|
||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||
SegRescale)
|
||||
from .loading_npz import (LoadAnnotationsNpz, LoadImageFromNpz, LoadTsImageFromNpz, LoadAnnotationsOil, LoadImageOil, LoadImageSingleChannel)
|
||||
|
||||
# yapf: enable
|
||||
__all__ = [
|
||||
'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale',
|
||||
'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange',
|
||||
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
|
||||
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||
'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad',
|
||||
'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput',
|
||||
'LoadMultipleRSImageFromFile', 'LoadDepthAnnotation', 'RandomDepthMix',
|
||||
'RandomFlip', 'Resize', 'LoadAnnotationsNpz', 'LoadImageFromNpz', 'LoadTsImageFromNpz',
|
||||
'LoadAnnotationsOil', 'LoadImageOil', 'LoadImageSingleChannel'
|
||||
]
|
||||
112
finetune/mmseg/datasets/transforms/formatting.py
Normal file
112
finetune/mmseg/datasets/transforms/formatting.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from mmcv.transforms import to_tensor
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class PackSegInputs(BaseTransform):
|
||||
"""Pack the inputs data for the semantic segmentation.
|
||||
|
||||
The ``img_meta`` item is always populated. The contents of the
|
||||
``img_meta`` dictionary depends on ``meta_keys``. By default this includes:
|
||||
|
||||
- ``img_path``: filename of the image
|
||||
|
||||
- ``ori_shape``: original shape of the image as a tuple (h, w, c)
|
||||
|
||||
- ``img_shape``: shape of the image input to the network as a tuple \
|
||||
(h, w, c). Note that images may be zero padded on the \
|
||||
bottom/right if the batch tensor is larger than this shape.
|
||||
|
||||
- ``pad_shape``: shape of padded images
|
||||
|
||||
- ``scale_factor``: a float indicating the preprocessing scale
|
||||
|
||||
- ``flip``: a boolean indicating if image flip transform was used
|
||||
|
||||
- ``flip_direction``: the flipping direction
|
||||
|
||||
Args:
|
||||
meta_keys (Sequence[str], optional): Meta keys to be packed from
|
||||
``SegDataSample`` and collected in ``data[img_metas]``.
|
||||
Default: ``('img_path', 'ori_shape',
|
||||
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction')``
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
meta_keys=('img_path', 'seg_map_path', 'ori_shape',
|
||||
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction', 'reduce_zero_label')):
|
||||
self.meta_keys = meta_keys
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Method to pack the input data.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from the data pipeline.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
|
||||
- 'inputs' (obj:`torch.Tensor`): The forward data of models.
|
||||
- 'data_sample' (obj:`SegDataSample`): The annotation info of the
|
||||
sample.
|
||||
"""
|
||||
packed_results = dict()
|
||||
if 'img' in results:
|
||||
img = results['img']
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
if not img.flags.c_contiguous:
|
||||
img = to_tensor(np.ascontiguousarray(img.transpose(2, 0, 1)))
|
||||
else:
|
||||
img = img.transpose(2, 0, 1)
|
||||
img = to_tensor(img).contiguous()
|
||||
packed_results['inputs'] = img
|
||||
|
||||
data_sample = SegDataSample()
|
||||
if 'gt_seg_map' in results:
|
||||
if len(results['gt_seg_map'].shape) == 2:
|
||||
data = to_tensor(results['gt_seg_map'][None,
|
||||
...].astype(np.int64))
|
||||
else:
|
||||
warnings.warn('Please pay attention your ground truth '
|
||||
'segmentation map, usually the segmentation '
|
||||
'map is 2D, but got '
|
||||
f'{results["gt_seg_map"].shape}')
|
||||
data = to_tensor(results['gt_seg_map'].astype(np.int64))
|
||||
gt_sem_seg_data = dict(data=data)
|
||||
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
|
||||
if 'gt_edge_map' in results:
|
||||
gt_edge_data = dict(
|
||||
data=to_tensor(results['gt_edge_map'][None,
|
||||
...].astype(np.int64)))
|
||||
data_sample.set_data(dict(gt_edge_map=PixelData(**gt_edge_data)))
|
||||
|
||||
if 'gt_depth_map' in results:
|
||||
gt_depth_data = dict(
|
||||
data=to_tensor(results['gt_depth_map'][None, ...]))
|
||||
data_sample.set_data(dict(gt_depth_map=PixelData(**gt_depth_data)))
|
||||
|
||||
img_meta = {}
|
||||
for key in self.meta_keys:
|
||||
if key in results:
|
||||
img_meta[key] = results[key]
|
||||
data_sample.set_metainfo(img_meta)
|
||||
packed_results['data_samples'] = data_sample
|
||||
|
||||
return packed_results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(meta_keys={self.meta_keys})'
|
||||
return repr_str
|
||||
771
finetune/mmseg/datasets/transforms/loading.py
Normal file
771
finetune/mmseg/datasets/transforms/loading.py
Normal file
@@ -0,0 +1,771 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine.fileio as fileio
|
||||
import numpy as np
|
||||
from mmcv.transforms import BaseTransform
|
||||
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
|
||||
from mmcv.transforms import LoadImageFromFile
|
||||
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from mmseg.utils import datafrombytes
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
except ImportError:
|
||||
gdal = None
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadAnnotations(MMCV_LoadAnnotations):
|
||||
"""Load annotations for semantic segmentation provided by dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# Filename of semantic segmentation ground truth file.
|
||||
'seg_map_path': 'a/b/c'
|
||||
}
|
||||
|
||||
After this module, the annotation has been changed to the format below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# in str
|
||||
'seg_fields': List
|
||||
# In uint8 type.
|
||||
'gt_seg_map': np.ndarray (H, W)
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_map_path (str): Path of semantic segmentation ground truth file.
|
||||
|
||||
Added Keys:
|
||||
|
||||
- seg_fields (List)
|
||||
- gt_seg_map (np.uint8)
|
||||
|
||||
Args:
|
||||
reduce_zero_label (bool, optional): Whether reduce all label value
|
||||
by 1. Usually used for datasets where 0 is background label.
|
||||
Defaults to None.
|
||||
imdecode_backend (str): The image decoding backend type. The backend
|
||||
argument for :func:``mmcv.imfrombytes``.
|
||||
See :fun:``mmcv.imfrombytes`` for details.
|
||||
Defaults to 'pillow'.
|
||||
backend_args (dict): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduce_zero_label=None,
|
||||
backend_args=None,
|
||||
imdecode_backend='pillow',
|
||||
) -> None:
|
||||
super().__init__(
|
||||
with_bbox=False,
|
||||
with_label=False,
|
||||
with_seg=True,
|
||||
with_keypoints=False,
|
||||
imdecode_backend=imdecode_backend,
|
||||
backend_args=backend_args)
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
if self.reduce_zero_label is not None:
|
||||
warnings.warn('`reduce_zero_label` will be deprecated, '
|
||||
'if you would like to ignore the zero label, please '
|
||||
'set `reduce_zero_label=True` when dataset '
|
||||
'initialized')
|
||||
self.imdecode_backend = imdecode_backend
|
||||
|
||||
def _load_seg_map(self, results: dict) -> None:
|
||||
"""Private function to load semantic segmentation annotations.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded semantic segmentation annotations.
|
||||
"""
|
||||
|
||||
img_bytes = fileio.get(
|
||||
results['seg_map_path'], backend_args=self.backend_args)
|
||||
gt_semantic_seg = mmcv.imfrombytes(
|
||||
img_bytes, flag='unchanged',
|
||||
backend=self.imdecode_backend).squeeze().astype(np.uint8)
|
||||
|
||||
# reduce zero_label
|
||||
if self.reduce_zero_label is None:
|
||||
self.reduce_zero_label = results['reduce_zero_label']
|
||||
assert self.reduce_zero_label == results['reduce_zero_label'], \
|
||||
'Initialize dataset with `reduce_zero_label` as ' \
|
||||
f'{results["reduce_zero_label"]} but when load annotation ' \
|
||||
f'the `reduce_zero_label` is {self.reduce_zero_label}'
|
||||
if self.reduce_zero_label:
|
||||
# avoid using underflow conversion
|
||||
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
||||
gt_semantic_seg = gt_semantic_seg - 1
|
||||
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
||||
# modify if custom classes
|
||||
if results.get('label_map', None) is not None:
|
||||
# Add deep copy to solve bug of repeatedly
|
||||
# replace `gt_semantic_seg`, which is reported in
|
||||
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
|
||||
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
||||
for old_id, new_id in results['label_map'].items():
|
||||
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
||||
results['gt_seg_map'] = gt_semantic_seg
|
||||
results['seg_fields'].append('gt_seg_map')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(reduce_zero_label={self.reduce_zero_label}, '
|
||||
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
|
||||
repr_str += f'backend_args={self.backend_args})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromNDArray(LoadImageFromFile):
|
||||
"""Load an image from ``results['img']``.
|
||||
|
||||
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
|
||||
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
|
||||
from webcam.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_path
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Transform function to add image meta information.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict with Webcam read image in
|
||||
``results['img']``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
img = results['img']
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img_path'] = None
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadBiomedicalImageFromFile(BaseTransform):
|
||||
"""Load an biomedical mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default,
|
||||
N is the number of modalities, and data type is float32
|
||||
if set to_float32 = True, or float64 if decode_backend is 'nifti' and
|
||||
to_float32 is False.
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy'and 'nifti', and there is a convention that when backend is
|
||||
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||
backend is 'nifti'. Defaults to 'nifti'.
|
||||
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||
Defaults to False.
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an float64 array.
|
||||
Defaults to True.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
decode_backend: str = 'nifti',
|
||||
to_xyz: bool = False,
|
||||
to_float32: bool = True,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
self.decode_backend = decode_backend
|
||||
self.to_xyz = to_xyz
|
||||
self.to_float32 = to_float32
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
|
||||
data_bytes = fileio.get(filename, self.backend_args)
|
||||
img = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
if len(img.shape) == 3:
|
||||
img = img[None, ...]
|
||||
|
||||
if self.decode_backend == 'nifti':
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
if self.to_xyz:
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[1:]
|
||||
results['ori_shape'] = img.shape[1:]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_xyz={self.to_xyz}, '
|
||||
f'to_float32={self.to_float32}, '
|
||||
f'backend_args={self.backend_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadBiomedicalAnnotation(BaseTransform):
|
||||
"""Load ``seg_map`` annotation provided by biomedical dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'gt_seg_map': np.ndarray (X, Y, Z) or (Z, Y, X)
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_map_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- gt_seg_map (np.ndarray): Biomedical seg map with shape (Z, Y, X) by
|
||||
default, and data type is float32 if set to_float32 = True, or
|
||||
float64 if decode_backend is 'nifti' and to_float32 is False.
|
||||
|
||||
Args:
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy'and 'nifti', and there is a convention that when backend is
|
||||
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||
backend is 'nifti'. Defaults to 'nifti'.
|
||||
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||
Defaults to False.
|
||||
to_float32 (bool): Whether to convert the loaded seg map to a float32
|
||||
numpy array. If set to False, the loaded image is an float64 array.
|
||||
Defaults to True.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See :class:`mmengine.fileio` for details.
|
||||
Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
decode_backend: str = 'nifti',
|
||||
to_xyz: bool = False,
|
||||
to_float32: bool = True,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
super().__init__()
|
||||
self.decode_backend = decode_backend
|
||||
self.to_xyz = to_xyz
|
||||
self.to_float32 = to_float32
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
data_bytes = fileio.get(results['seg_map_path'], self.backend_args)
|
||||
gt_seg_map = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
|
||||
if self.to_float32:
|
||||
gt_seg_map = gt_seg_map.astype(np.float32)
|
||||
|
||||
if self.decode_backend == 'nifti':
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
|
||||
if self.to_xyz:
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
|
||||
results['gt_seg_map'] = gt_seg_map
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_xyz={self.to_xyz}, '
|
||||
f'to_float32={self.to_float32}, '
|
||||
f'backend_args={self.backend_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadBiomedicalData(BaseTransform):
|
||||
"""Load an biomedical image and annotation from file.
|
||||
|
||||
The loading data format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'img': np.ndarray data[:-1, X, Y, Z]
|
||||
'seg_map': np.ndarray data[-1, X, Y, Z]
|
||||
}
|
||||
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default,
|
||||
N is the number of modalities.
|
||||
- gt_seg_map (np.ndarray, optional): Biomedical seg map with shape
|
||||
(Z, Y, X) by default.
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
with_seg (bool): Whether to parse and load the semantic segmentation
|
||||
annotation. Defaults to False.
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy'and 'nifti', and there is a convention that when backend is
|
||||
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||
backend is 'nifti'. Defaults to 'nifti'.
|
||||
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||
Defaults to False.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
with_seg=False,
|
||||
decode_backend: str = 'numpy',
|
||||
to_xyz: bool = False,
|
||||
backend_args: Optional[dict] = None) -> None: # noqa
|
||||
self.with_seg = with_seg
|
||||
self.decode_backend = decode_backend
|
||||
self.to_xyz = to_xyz
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
data_bytes = fileio.get(results['img_path'], self.backend_args)
|
||||
data = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
# img is 4D data (N, X, Y, Z), N is the number of protocol
|
||||
img = data[:-1, :]
|
||||
|
||||
if self.decode_backend == 'nifti':
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
if self.to_xyz:
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[1:]
|
||||
results['ori_shape'] = img.shape[1:]
|
||||
|
||||
if self.with_seg:
|
||||
gt_seg_map = data[-1, :]
|
||||
if self.decode_backend == 'nifti':
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
|
||||
if self.to_xyz:
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
results['gt_seg_map'] = gt_seg_map
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'with_seg={self.with_seg}, '
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_xyz={self.to_xyz}, '
|
||||
f'backend_args={self.backend_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class InferencerLoader(BaseTransform):
|
||||
"""Load an image from ``results['img']``.
|
||||
|
||||
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
|
||||
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
|
||||
from webcam.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_path
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.from_file = TRANSFORMS.build(
|
||||
dict(type='LoadImageFromFile', **kwargs))
|
||||
self.from_ndarray = TRANSFORMS.build(
|
||||
dict(type='LoadImageFromNDArray', **kwargs))
|
||||
|
||||
def transform(self, single_input: Union[str, np.ndarray, dict]) -> dict:
|
||||
"""Transform function to add image meta information.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict with Webcam read image in
|
||||
``results['img']``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
if isinstance(single_input, str):
|
||||
inputs = dict(img_path=single_input)
|
||||
elif isinstance(single_input, np.ndarray):
|
||||
inputs = dict(img=single_input)
|
||||
elif isinstance(single_input, dict):
|
||||
inputs = single_input
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if 'img' in inputs:
|
||||
return self.from_ndarray(inputs)
|
||||
return self.from_file(inputs)
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadSingleRSImageFromFile(BaseTransform):
|
||||
"""Load a Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, to_float32: bool = True):
|
||||
self.to_float32 = to_float32
|
||||
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
ds = gdal.Open(filename)
|
||||
if ds is None:
|
||||
raise Exception(f'Unable to open file: {filename}')
|
||||
img = np.einsum('ijk->jki', ds.ReadAsArray())
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadMultipleRSImageFromFile(BaseTransform):
|
||||
"""Load two Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
- img_path2
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img2
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, to_float32: bool = True):
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
self.to_float32 = to_float32
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
filename2 = results['img_path2']
|
||||
|
||||
ds = gdal.Open(filename)
|
||||
ds2 = gdal.Open(filename2)
|
||||
|
||||
if ds is None:
|
||||
raise Exception(f'Unable to open file: {filename}')
|
||||
if ds2 is None:
|
||||
raise Exception(f'Unable to open file: {filename2}')
|
||||
|
||||
img = np.einsum('ijk->jki', ds.ReadAsArray())
|
||||
img2 = np.einsum('ijk->jki', ds2.ReadAsArray())
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
img2 = img2.astype(np.float32)
|
||||
|
||||
if img.shape != img2.shape:
|
||||
raise Exception(f'Image shapes do not match:'
|
||||
f' {img.shape} vs {img2.shape}')
|
||||
|
||||
results['img'] = img
|
||||
results['img2'] = img2
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadDepthAnnotation(BaseTransform):
|
||||
"""Load ``depth_map`` annotation provided by depth estimation dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'gt_depth_map': np.ndarray [Y, X]
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_depth_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- gt_depth_map (np.ndarray): Depth map with shape (Y, X) by
|
||||
default, and data type is float32 if set to_float32 = True.
|
||||
- depth_rescale_factor (float): The rescale factor of depth map, which
|
||||
can be used to recover the original value of depth map.
|
||||
|
||||
Args:
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy', 'nifti', and 'cv2'. Defaults to 'cv2'.
|
||||
to_float32 (bool): Whether to convert the loaded depth map to a float32
|
||||
numpy array. If set to False, the loaded image is an uint16 array.
|
||||
Defaults to True.
|
||||
depth_rescale_factor (float): Factor to rescale the depth value to
|
||||
limit the range. Defaults to 1.0.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See :class:`mmengine.fileio` for details.
|
||||
Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
decode_backend: str = 'cv2',
|
||||
to_float32: bool = True,
|
||||
depth_rescale_factor: float = 1.0,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
super().__init__()
|
||||
self.decode_backend = decode_backend
|
||||
self.to_float32 = to_float32
|
||||
self.depth_rescale_factor = depth_rescale_factor
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load depth map.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded depth map.
|
||||
"""
|
||||
data_bytes = fileio.get(results['depth_map_path'], self.backend_args)
|
||||
gt_depth_map = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
|
||||
if self.to_float32:
|
||||
gt_depth_map = gt_depth_map.astype(np.float32)
|
||||
|
||||
gt_depth_map *= self.depth_rescale_factor
|
||||
results['gt_depth_map'] = gt_depth_map
|
||||
results['seg_fields'].append('gt_depth_map')
|
||||
results['depth_rescale_factor'] = self.depth_rescale_factor
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_float32={self.to_float32}, '
|
||||
f'backend_args={self.backend_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromNpyFile(LoadImageFromFile):
|
||||
"""Load an image from ``results['img_path']``.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def transform(self, results: dict) -> Optional[dict]:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from
|
||||
:class:`mmengine.dataset.BaseDataset`.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
|
||||
try:
|
||||
if Path(filename).suffix in ['.npy', '.npz']:
|
||||
img = np.load(filename)
|
||||
else:
|
||||
if self.file_client_args is not None:
|
||||
file_client = fileio.FileClient.infer_client(
|
||||
self.file_client_args, filename)
|
||||
img_bytes = file_client.get(filename)
|
||||
else:
|
||||
img_bytes = fileio.get(
|
||||
filename, backend_args=self.backend_args)
|
||||
img = mmcv.imfrombytes(
|
||||
img_bytes,
|
||||
flag=self.color_type,
|
||||
backend=self.imdecode_backend)
|
||||
except Exception as e:
|
||||
if self.ignore_empty:
|
||||
return None
|
||||
else:
|
||||
raise e
|
||||
|
||||
# in some cases, images are not read successfully, the img would be
|
||||
# `None`, refer to https://github.com/open-mmlab/mmpretrain/issues/1427
|
||||
assert img is not None, f'failed to load image: {filename}'
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
493
finetune/mmseg/datasets/transforms/loading_npz.py
Normal file
493
finetune/mmseg/datasets/transforms/loading_npz.py
Normal file
@@ -0,0 +1,493 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Dict, Optional, Union
|
||||
import io
|
||||
|
||||
import mmcv
|
||||
import mmengine.fileio as fileio
|
||||
import numpy as np
|
||||
from mmcv.transforms import BaseTransform
|
||||
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
|
||||
from mmcv.transforms import LoadImageFromFile
|
||||
import imageio
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from mmseg.utils import datafrombytes
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
except ImportError:
|
||||
gdal = None
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadAnnotationsNpz(MMCV_LoadAnnotations):
|
||||
"""Load annotations for semantic segmentation provided by dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# Filename of semantic segmentation ground truth file.
|
||||
'seg_map_path': 'a/b/c'
|
||||
}
|
||||
|
||||
After this module, the annotation has been changed to the format below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# in str
|
||||
'seg_fields': List
|
||||
# In uint8 type.
|
||||
'gt_seg_map': np.ndarray (H, W)
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_map_path (str): Path of semantic segmentation ground truth file.
|
||||
|
||||
Added Keys:
|
||||
|
||||
- seg_fields (List)
|
||||
- gt_seg_map (np.uint8)
|
||||
|
||||
Args:
|
||||
reduce_zero_label (bool, optional): Whether reduce all label value
|
||||
by 1. Usually used for datasets where 0 is background label.
|
||||
Defaults to None.
|
||||
imdecode_backend (str): The image decoding backend type. The backend
|
||||
argument for :func:``mmcv.imfrombytes``.
|
||||
See :fun:``mmcv.imfrombytes`` for details.
|
||||
Defaults to 'pillow'.
|
||||
backend_args (dict): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduce_zero_label=None,
|
||||
backend_args=None,
|
||||
data_key='arr_0',
|
||||
imdecode_backend='pillow',
|
||||
) -> None:
|
||||
super().__init__(
|
||||
with_bbox=False,
|
||||
with_label=False,
|
||||
with_seg=True,
|
||||
with_keypoints=False,
|
||||
imdecode_backend=imdecode_backend,
|
||||
backend_args=backend_args)
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
if self.reduce_zero_label is not None:
|
||||
warnings.warn('`reduce_zero_label` will be deprecated, '
|
||||
'if you would like to ignore the zero label, please '
|
||||
'set `reduce_zero_label=True` when dataset '
|
||||
'initialized')
|
||||
self.imdecode_backend = imdecode_backend
|
||||
self.data_key = data_key
|
||||
|
||||
def _load_seg_map(self, results: dict) -> None:
|
||||
"""Private function to load semantic segmentation annotations.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded semantic segmentation annotations.
|
||||
"""
|
||||
|
||||
# img_bytes = fileio.get(
|
||||
# results['seg_map_path'], backend_args=self.backend_args)
|
||||
# gt_semantic_seg = mmcv.imfrombytes(
|
||||
# img_bytes, flag='unchanged',
|
||||
# backend=self.imdecode_backend).squeeze().astype(np.uint8)
|
||||
gt_semantic_seg = np.load(results['seg_map_path'])[self.data_key].squeeze().astype(np.uint8)
|
||||
# reduce zero_label
|
||||
if self.reduce_zero_label is None:
|
||||
self.reduce_zero_label = results['reduce_zero_label']
|
||||
assert self.reduce_zero_label == results['reduce_zero_label'], \
|
||||
'Initialize dataset with `reduce_zero_label` as ' \
|
||||
f'{results["reduce_zero_label"]} but when load annotation ' \
|
||||
f'the `reduce_zero_label` is {self.reduce_zero_label}'
|
||||
if self.reduce_zero_label:
|
||||
# avoid using underflow conversion
|
||||
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
||||
gt_semantic_seg = gt_semantic_seg - 1
|
||||
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
||||
# modify if custom classes
|
||||
if results.get('label_map', None) is not None:
|
||||
# Add deep copy to solve bug of repeatedly
|
||||
# replace `gt_semantic_seg`, which is reported in
|
||||
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
|
||||
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
||||
for old_id, new_id in results['label_map'].items():
|
||||
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
||||
results['gt_seg_map'] = gt_semantic_seg
|
||||
results['seg_fields'].append('gt_seg_map')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(reduce_zero_label={self.reduce_zero_label}, '
|
||||
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
|
||||
repr_str += f'backend_args={self.backend_args})'
|
||||
return repr_str
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageSingleChannel(BaseTransform):
|
||||
"""Load a Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, to_float32: bool = True):
|
||||
self.to_float32 = to_float32
|
||||
# self.data_key = data_key
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
img = imageio.imread(filename) # h, w, c
|
||||
img = img[:, :, 0]
|
||||
# if ds is None:
|
||||
# raise Exception(f'Unable to open file: {filename}')
|
||||
# print(img)s
|
||||
# img = np.einsum('ijk->jki', img)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadAnnotationsOil(MMCV_LoadAnnotations):
|
||||
"""Load annotations for semantic segmentation provided by dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# Filename of semantic segmentation ground truth file.
|
||||
'seg_map_path': 'a/b/c'
|
||||
}
|
||||
|
||||
After this module, the annotation has been changed to the format below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# in str
|
||||
'seg_fields': List
|
||||
# In uint8 type.
|
||||
'gt_seg_map': np.ndarray (H, W)
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_map_path (str): Path of semantic segmentation ground truth file.
|
||||
|
||||
Added Keys:
|
||||
|
||||
- seg_fields (List)
|
||||
- gt_seg_map (np.uint8)
|
||||
|
||||
Args:
|
||||
reduce_zero_label (bool, optional): Whether reduce all label value
|
||||
by 1. Usually used for datasets where 0 is background label.
|
||||
Defaults to None.
|
||||
imdecode_backend (str): The image decoding backend type. The backend
|
||||
argument for :func:``mmcv.imfrombytes``.
|
||||
See :fun:``mmcv.imfrombytes`` for details.
|
||||
Defaults to 'pillow'.
|
||||
backend_args (dict): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduce_zero_label=None,
|
||||
backend_args=None,
|
||||
data_key='arr_0',
|
||||
imdecode_backend='pillow',
|
||||
) -> None:
|
||||
super().__init__(
|
||||
with_bbox=False,
|
||||
with_label=False,
|
||||
with_seg=True,
|
||||
with_keypoints=False,
|
||||
imdecode_backend=imdecode_backend,
|
||||
backend_args=backend_args)
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
if self.reduce_zero_label is not None:
|
||||
warnings.warn('`reduce_zero_label` will be deprecated, '
|
||||
'if you would like to ignore the zero label, please '
|
||||
'set `reduce_zero_label=True` when dataset '
|
||||
'initialized')
|
||||
self.imdecode_backend = imdecode_backend
|
||||
self.data_key = data_key
|
||||
|
||||
def _load_seg_map(self, results: dict) -> None:
|
||||
"""Private function to load semantic segmentation annotations.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded semantic segmentation annotations.
|
||||
"""
|
||||
|
||||
# img_bytes = fileio.get(
|
||||
# results['seg_map_path'], backend_args=self.backend_args)
|
||||
# gt_semantic_seg = mmcv.imfrombytes(
|
||||
# img_bytes, flag='unchanged',
|
||||
# backend=self.imdecode_backend).squeeze().astype(np.uint8)
|
||||
seg_map = gdal.Open(results['seg_map_path']).ReadAsArray()
|
||||
gt_semantic_seg = np.zeros_like(seg_map).astype(np.uint8)
|
||||
gt_semantic_seg[seg_map==3.] = 1
|
||||
# reduce zero_label
|
||||
if self.reduce_zero_label is None:
|
||||
self.reduce_zero_label = results['reduce_zero_label']
|
||||
assert self.reduce_zero_label == results['reduce_zero_label'], \
|
||||
'Initialize dataset with `reduce_zero_label` as ' \
|
||||
f'{results["reduce_zero_label"]} but when load annotation ' \
|
||||
f'the `reduce_zero_label` is {self.reduce_zero_label}'
|
||||
if self.reduce_zero_label:
|
||||
# avoid using underflow conversion
|
||||
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
||||
gt_semantic_seg = gt_semantic_seg - 1
|
||||
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
||||
# modify if custom classes
|
||||
if results.get('label_map', None) is not None:
|
||||
# Add deep copy to solve bug of repeatedly
|
||||
# replace `gt_semantic_seg`, which is reported in
|
||||
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
|
||||
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
||||
for old_id, new_id in results['label_map'].items():
|
||||
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
||||
results['gt_seg_map'] = gt_semantic_seg
|
||||
results['seg_fields'].append('gt_seg_map')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(reduce_zero_label={self.reduce_zero_label}, '
|
||||
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
|
||||
repr_str += f'backend_args={self.backend_args})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromNpz(BaseTransform):
|
||||
"""Load a Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, data_key='arr_0', to_float32: bool = True):
|
||||
self.to_float32 = to_float32
|
||||
self.data_key = data_key
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
img = np.load(filename)[self.data_key]
|
||||
# if ds is None:
|
||||
# raise Exception(f'Unable to open file: {filename}')
|
||||
# print(img)s
|
||||
img = np.einsum('ijk->jki', img)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageOil(BaseTransform):
|
||||
"""Load a Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, data_key='arr_0', to_float32: bool = True):
|
||||
self.to_float32 = to_float32
|
||||
self.data_key = data_key
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
img = gdal.Open(filename).ReadAsArray()
|
||||
img = img[:,:,None]
|
||||
# if ds is None:
|
||||
# raise Exception(f'Unable to open file: {filename}')
|
||||
# print(img)s
|
||||
# img = np.einsum('ijk->jki', img)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadTsImageFromNpz(BaseTransform):
|
||||
"""Load a Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, data_key='arr_0', to_float32: bool = True, ts_size: int=10):
|
||||
self.to_float32 = to_float32
|
||||
self.data_key = data_key
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
self.ts_size = ts_size
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
img = np.load(filename)[self.data_key]
|
||||
ts, c, h, w = img.shape
|
||||
if ts >= self.ts_size:
|
||||
selected_indices = np.random.choice(ts, size=self.ts_size, replace=False)
|
||||
else:
|
||||
selected_indices = np.random.choice(ts, size=self.ts_size, replace=True)
|
||||
selected_indices.sort()
|
||||
img = img[selected_indices, :, :, :]
|
||||
# print(f'after input shape: {img.shape}')
|
||||
img = img.transpose(2, 3, 0, 1).reshape(h, w, self.ts_size*c) # h, w, ts, c -> h, w, ts*c
|
||||
# if ds is None:
|
||||
# raise Exception(f'Unable to open file: {filename}')
|
||||
# print(img)s
|
||||
# img = np.einsum('ijk->jki', img)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
|
||||
2537
finetune/mmseg/datasets/transforms/transforms.py
Normal file
2537
finetune/mmseg/datasets/transforms/transforms.py
Normal file
File diff suppressed because it is too large
Load Diff
12
finetune/mmseg/engine/__init__.py
Normal file
12
finetune/mmseg/engine/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .hooks import SegVisualizationHook
|
||||
from .optimizers import (ForceDefaultOptimWrapperConstructor,
|
||||
LayerDecayOptimizerConstructor,
|
||||
LearningRateDecayOptimizerConstructor)
|
||||
from .schedulers import PolyLRRatio
|
||||
|
||||
__all__ = [
|
||||
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
|
||||
'SegVisualizationHook', 'PolyLRRatio',
|
||||
'ForceDefaultOptimWrapperConstructor'
|
||||
]
|
||||
4
finetune/mmseg/engine/hooks/__init__.py
Normal file
4
finetune/mmseg/engine/hooks/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .visualization_hook import SegVisualizationHook
|
||||
|
||||
__all__ = ['SegVisualizationHook']
|
||||
129
finetune/mmseg/engine/hooks/visualization_hook.py
Normal file
129
finetune/mmseg/engine/hooks/visualization_hook.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import mmcv
|
||||
from mmengine.fileio import get
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.runner import Runner
|
||||
from mmengine.visualization import Visualizer
|
||||
|
||||
from mmseg.registry import HOOKS
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class SegVisualizationHook(Hook):
|
||||
"""Segmentation Visualization Hook. Used to visualize validation and
|
||||
testing process prediction results.
|
||||
|
||||
In the testing phase:
|
||||
|
||||
1. If ``show`` is True, it means that only the prediction results are
|
||||
visualized without storing data, so ``vis_backends`` needs to
|
||||
be excluded.
|
||||
|
||||
Args:
|
||||
draw (bool): whether to draw prediction results. If it is False,
|
||||
it means that no drawing will be done. Defaults to False.
|
||||
interval (int): The interval of visualization. Defaults to 50.
|
||||
show (bool): Whether to display the drawn image. Default to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
draw: bool = False,
|
||||
interval: int = 50,
|
||||
show: bool = False,
|
||||
wait_time: float = 0.,
|
||||
backend_args: Optional[dict] = None):
|
||||
self._visualizer: Visualizer = Visualizer.get_current_instance()
|
||||
self.interval = interval
|
||||
self.show = show
|
||||
if self.show:
|
||||
# No need to think about vis backends.
|
||||
self._visualizer._vis_backends = {}
|
||||
warnings.warn('The show is True, it means that only '
|
||||
'the prediction results are visualized '
|
||||
'without storing data, so vis_backends '
|
||||
'needs to be excluded.')
|
||||
|
||||
self.wait_time = wait_time
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
self.draw = draw
|
||||
if not self.draw:
|
||||
warnings.warn('The draw is False, it means that the '
|
||||
'hook for visualization will not take '
|
||||
'effect. The results will NOT be '
|
||||
'visualized or stored.')
|
||||
self._test_index = 0
|
||||
|
||||
def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
|
||||
outputs: Sequence[SegDataSample]) -> None:
|
||||
"""Run after every ``self.interval`` validation iterations.
|
||||
|
||||
Args:
|
||||
runner (:obj:`Runner`): The runner of the validation process.
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (dict): Data from dataloader.
|
||||
outputs (Sequence[:obj:`SegDataSample`]]): A batch of data samples
|
||||
that contain annotations and predictions.
|
||||
"""
|
||||
if self.draw is False:
|
||||
return
|
||||
|
||||
# There is no guarantee that the same batch of images
|
||||
# is visualized for each evaluation.
|
||||
total_curr_iter = runner.iter + batch_idx
|
||||
|
||||
# Visualize only the first data
|
||||
img_path = outputs[0].img_path
|
||||
img_bytes = get(img_path, backend_args=self.backend_args)
|
||||
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
|
||||
window_name = f'val_{osp.basename(img_path)}'
|
||||
|
||||
if total_curr_iter % self.interval == 0:
|
||||
self._visualizer.add_datasample(
|
||||
window_name,
|
||||
img,
|
||||
data_sample=outputs[0],
|
||||
show=self.show,
|
||||
wait_time=self.wait_time,
|
||||
step=total_curr_iter)
|
||||
|
||||
def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
|
||||
outputs: Sequence[SegDataSample]) -> None:
|
||||
"""Run after every testing iterations.
|
||||
|
||||
Args:
|
||||
runner (:obj:`Runner`): The runner of the testing process.
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (dict): Data from dataloader.
|
||||
outputs (Sequence[:obj:`SegDataSample`]): A batch of data samples
|
||||
that contain annotations and predictions.
|
||||
"""
|
||||
if self.draw is False:
|
||||
return
|
||||
|
||||
for data_sample in outputs:
|
||||
self._test_index += 1
|
||||
|
||||
img_path = data_sample.img_path
|
||||
window_name = f'test_{osp.basename(img_path)}'
|
||||
|
||||
img_path = data_sample.img_path
|
||||
img_bytes = get(img_path, backend_args=self.backend_args)
|
||||
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
|
||||
|
||||
self._visualizer.add_datasample(
|
||||
window_name,
|
||||
img,
|
||||
data_sample=data_sample,
|
||||
show=self.show,
|
||||
wait_time=self.wait_time,
|
||||
step=self._test_index)
|
||||
9
finetune/mmseg/engine/optimizers/__init__.py
Normal file
9
finetune/mmseg/engine/optimizers/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .force_default_constructor import ForceDefaultOptimWrapperConstructor
|
||||
from .layer_decay_optimizer_constructor import (
|
||||
LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor)
|
||||
|
||||
__all__ = [
|
||||
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
|
||||
'ForceDefaultOptimWrapperConstructor'
|
||||
]
|
||||
255
finetune/mmseg/engine/optimizers/force_default_constructor.py
Normal file
255
finetune/mmseg/engine/optimizers/force_default_constructor.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.optim import DefaultOptimWrapperConstructor
|
||||
from mmengine.utils.dl_utils import mmcv_full_available
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
|
||||
from torch.nn import GroupNorm, LayerNorm
|
||||
|
||||
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
|
||||
|
||||
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class ForceDefaultOptimWrapperConstructor(DefaultOptimWrapperConstructor):
|
||||
"""Default constructor with forced optimizer settings.
|
||||
|
||||
This constructor extends the default constructor to add an option for
|
||||
forcing default optimizer settings. This is useful for ensuring that
|
||||
certain parameters or layers strictly adhere to pre-defined default
|
||||
settings, regardless of any custom settings specified.
|
||||
|
||||
By default, each parameter share the same optimizer settings, and we
|
||||
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
|
||||
It is a dict and may contain various fields like 'custom_keys',
|
||||
'bias_lr_mult', etc., as well as the additional field
|
||||
`force_default_settings` which allows for enforcing default settings on
|
||||
optimizer parameters.
|
||||
|
||||
- ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
|
||||
one of the keys in ``custom_keys`` is a substring of the name of one
|
||||
parameter, then the setting of the parameter will be specified by
|
||||
``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will
|
||||
be ignored. It should be noted that the aforementioned ``key`` is the
|
||||
longest key that is a substring of the name of the parameter. If there
|
||||
are multiple matched keys with the same length, then the key with lower
|
||||
alphabet order will be chosen.
|
||||
``custom_keys[key]`` should be a dict and may contain fields ``lr_mult``
|
||||
and ``decay_mult``. See Example 2 below.
|
||||
- ``bias_lr_mult`` (float): It will be multiplied to the learning
|
||||
rate for all bias parameters (except for those in normalization
|
||||
layers and offset layers of DCN).
|
||||
- ``bias_decay_mult`` (float): It will be multiplied to the weight
|
||||
decay for all bias parameters (except for those in
|
||||
normalization layers, depthwise conv layers, offset layers of DCN).
|
||||
- ``norm_decay_mult`` (float): It will be multiplied to the weight
|
||||
decay for all weight and bias parameters of normalization
|
||||
layers.
|
||||
- ``flat_decay_mult`` (float): It will be multiplied to the weight
|
||||
decay for all one-dimensional parameters
|
||||
- ``dwconv_decay_mult`` (float): It will be multiplied to the weight
|
||||
decay for all weight and bias parameters of depthwise conv
|
||||
layers.
|
||||
- ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning
|
||||
rate for parameters of offset layer in the deformable convs
|
||||
of a model.
|
||||
- ``bypass_duplicate`` (bool): If true, the duplicate parameters
|
||||
would not be added into optimizer. Defaults to False.
|
||||
- ``force_default_settings`` (bool): If true, this will override any
|
||||
custom settings defined by ``custom_keys`` and enforce the use of
|
||||
default settings for optimizer parameters like ``bias_lr_mult``.
|
||||
This is particularly useful when you want to ensure that certain layers
|
||||
or parameters adhere strictly to the pre-defined default settings.
|
||||
|
||||
Note:
|
||||
|
||||
1. If the option ``dcn_offset_lr_mult`` is used, the constructor will
|
||||
override the effect of ``bias_lr_mult`` in the bias of offset layer.
|
||||
So be careful when using both ``bias_lr_mult`` and
|
||||
``dcn_offset_lr_mult``. If you wish to apply both of them to the offset
|
||||
layer in deformable convs, set ``dcn_offset_lr_mult`` to the original
|
||||
``dcn_offset_lr_mult`` * ``bias_lr_mult``.
|
||||
|
||||
2. If the option ``dcn_offset_lr_mult`` is used, the constructor will
|
||||
apply it to all the DCN layers in the model. So be careful when the
|
||||
model contains multiple DCN layers in places other than backbone.
|
||||
|
||||
3. When the option ``force_default_settings`` is true, it will override
|
||||
any custom settings provided in ``custom_keys``. This ensures that the
|
||||
default settings for the optimizer parameters are used.
|
||||
|
||||
Args:
|
||||
optim_wrapper_cfg (dict): The config dict of the optimizer wrapper.
|
||||
|
||||
Required fields of ``optim_wrapper_cfg`` are
|
||||
|
||||
- ``type``: class name of the OptimizerWrapper
|
||||
- ``optimizer``: The configuration of optimizer.
|
||||
|
||||
Optional fields of ``optim_wrapper_cfg`` are
|
||||
|
||||
- any arguments of the corresponding optimizer wrapper type,
|
||||
e.g., accumulative_counts, clip_grad, etc.
|
||||
|
||||
Required fields of ``optimizer`` are
|
||||
|
||||
- `type`: class name of the optimizer.
|
||||
|
||||
Optional fields of ``optimizer`` are
|
||||
|
||||
- any arguments of the corresponding optimizer type, e.g.,
|
||||
lr, weight_decay, momentum, etc.
|
||||
|
||||
paramwise_cfg (dict, optional): Parameter-wise options.
|
||||
|
||||
Example 1:
|
||||
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
|
||||
>>> optim_wrapper_cfg = dict(
|
||||
>>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01,
|
||||
>>> momentum=0.9, weight_decay=0.0001))
|
||||
>>> paramwise_cfg = dict(norm_decay_mult=0.)
|
||||
>>> optim_wrapper_builder = DefaultOptimWrapperConstructor(
|
||||
>>> optim_wrapper_cfg, paramwise_cfg)
|
||||
>>> optim_wrapper = optim_wrapper_builder(model)
|
||||
|
||||
Example 2:
|
||||
>>> # assume model have attribute model.backbone and model.cls_head
|
||||
>>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict(
|
||||
>>> type='SGD', lr=0.01, weight_decay=0.95))
|
||||
>>> paramwise_cfg = dict(custom_keys={
|
||||
>>> 'backbone': dict(lr_mult=0.1, decay_mult=0.9)})
|
||||
>>> optim_wrapper_builder = DefaultOptimWrapperConstructor(
|
||||
>>> optim_wrapper_cfg, paramwise_cfg)
|
||||
>>> optim_wrapper = optim_wrapper_builder(model)
|
||||
>>> # Then the `lr` and `weight_decay` for model.backbone is
|
||||
>>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
|
||||
>>> # model.cls_head is (0.01, 0.95).
|
||||
"""
|
||||
|
||||
def add_params(self,
|
||||
params: List[dict],
|
||||
module: nn.Module,
|
||||
prefix: str = '',
|
||||
is_dcn_module: Optional[Union[int, float]] = None) -> None:
|
||||
"""Add all parameters of module to the params list.
|
||||
|
||||
The parameters of the given module will be added to the list of param
|
||||
groups, with specific rules defined by paramwise_cfg.
|
||||
|
||||
Args:
|
||||
params (list[dict]): A list of param groups, it will be modified
|
||||
in place.
|
||||
module (nn.Module): The module to be added.
|
||||
prefix (str): The prefix of the module
|
||||
is_dcn_module (int|float|None): If the current module is a
|
||||
submodule of DCN, `is_dcn_module` will be passed to
|
||||
control conv_offset layer's learning rate. Defaults to None.
|
||||
"""
|
||||
# get param-wise options
|
||||
custom_keys = self.paramwise_cfg.get('custom_keys', {})
|
||||
# first sort with alphabet order and then sort with reversed len of str
|
||||
sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
|
||||
|
||||
bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', None)
|
||||
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None)
|
||||
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None)
|
||||
dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', None)
|
||||
flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None)
|
||||
bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
|
||||
dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', None)
|
||||
force_default_settings = self.paramwise_cfg.get(
|
||||
'force_default_settings', False)
|
||||
|
||||
# special rules for norm layers and depth-wise conv layers
|
||||
is_norm = isinstance(module,
|
||||
(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
|
||||
is_dwconv = (
|
||||
isinstance(module, torch.nn.Conv2d)
|
||||
and module.in_channels == module.groups)
|
||||
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
param_group = {'params': [param]}
|
||||
if bypass_duplicate and self._is_in(param_group, params):
|
||||
print_log(
|
||||
f'{prefix} is duplicate. It is skipped since '
|
||||
f'bypass_duplicate={bypass_duplicate}',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
continue
|
||||
if not param.requires_grad:
|
||||
params.append(param_group)
|
||||
continue
|
||||
|
||||
# if the parameter match one of the custom keys, ignore other rules
|
||||
is_custom = False
|
||||
for key in sorted_keys:
|
||||
if key in f'{prefix}.{name}':
|
||||
is_custom = True
|
||||
lr_mult = custom_keys[key].get('lr_mult', 1.)
|
||||
param_group['lr'] = self.base_lr * lr_mult
|
||||
if self.base_wd is not None:
|
||||
decay_mult = custom_keys[key].get('decay_mult', 1.)
|
||||
param_group['weight_decay'] = self.base_wd * decay_mult
|
||||
# add custom settings to param_group
|
||||
for k, v in custom_keys[key].items():
|
||||
param_group[k] = v
|
||||
break
|
||||
|
||||
if not is_custom or force_default_settings:
|
||||
# bias_lr_mult affects all bias parameters
|
||||
# except for norm.bias dcn.conv_offset.bias
|
||||
if name == 'bias' and not (
|
||||
is_norm or is_dcn_module) and bias_lr_mult is not None:
|
||||
param_group['lr'] = self.base_lr * bias_lr_mult
|
||||
|
||||
if (prefix.find('conv_offset') != -1 and is_dcn_module
|
||||
and dcn_offset_lr_mult is not None
|
||||
and isinstance(module, torch.nn.Conv2d)):
|
||||
# deal with both dcn_offset's bias & weight
|
||||
param_group['lr'] = self.base_lr * dcn_offset_lr_mult
|
||||
|
||||
# apply weight decay policies
|
||||
if self.base_wd is not None:
|
||||
# norm decay
|
||||
if is_norm and norm_decay_mult is not None:
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * norm_decay_mult
|
||||
# bias lr and decay
|
||||
elif (name == 'bias' and not is_dcn_module
|
||||
and bias_decay_mult is not None):
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * bias_decay_mult
|
||||
# depth-wise conv
|
||||
elif is_dwconv and dwconv_decay_mult is not None:
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * dwconv_decay_mult
|
||||
# flatten parameters except dcn offset
|
||||
elif (param.ndim == 1 and not is_dcn_module
|
||||
and flat_decay_mult is not None):
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * flat_decay_mult
|
||||
params.append(param_group)
|
||||
for key, value in param_group.items():
|
||||
if key == 'params':
|
||||
continue
|
||||
full_name = f'{prefix}.{name}' if prefix else name
|
||||
print_log(
|
||||
f'paramwise_options -- {full_name}:{key}={value}',
|
||||
logger='current')
|
||||
|
||||
if mmcv_full_available():
|
||||
from mmcv.ops import DeformConv2d, ModulatedDeformConv2d
|
||||
is_dcn_module = isinstance(module,
|
||||
(DeformConv2d, ModulatedDeformConv2d))
|
||||
else:
|
||||
is_dcn_module = False
|
||||
for child_name, child_mod in module.named_children():
|
||||
child_prefix = f'{prefix}.{child_name}' if prefix else child_name
|
||||
self.add_params(
|
||||
params,
|
||||
child_mod,
|
||||
prefix=child_prefix,
|
||||
is_dcn_module=is_dcn_module)
|
||||
@@ -0,0 +1,207 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import warnings
|
||||
|
||||
from mmengine.dist import get_dist_info
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.optim import DefaultOptimWrapperConstructor
|
||||
|
||||
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
|
||||
|
||||
|
||||
def get_layer_id_for_convnext(var_name, max_layer_id):
|
||||
"""Get the layer id to set the different learning rates in ``layer_wise``
|
||||
decay_type.
|
||||
|
||||
Args:
|
||||
var_name (str): The key of the model.
|
||||
max_layer_id (int): Maximum number of backbone layers.
|
||||
|
||||
Returns:
|
||||
int: The id number corresponding to different learning rate in
|
||||
``LearningRateDecayOptimizerConstructor``.
|
||||
"""
|
||||
|
||||
if var_name in ('backbone.cls_token', 'backbone.mask_token',
|
||||
'backbone.pos_embed'):
|
||||
return 0
|
||||
elif var_name.startswith('backbone.downsample_layers'):
|
||||
stage_id = int(var_name.split('.')[2])
|
||||
if stage_id == 0:
|
||||
layer_id = 0
|
||||
elif stage_id == 1:
|
||||
layer_id = 2
|
||||
elif stage_id == 2:
|
||||
layer_id = 3
|
||||
elif stage_id == 3:
|
||||
layer_id = max_layer_id
|
||||
return layer_id
|
||||
elif var_name.startswith('backbone.stages'):
|
||||
stage_id = int(var_name.split('.')[2])
|
||||
block_id = int(var_name.split('.')[3])
|
||||
if stage_id == 0:
|
||||
layer_id = 1
|
||||
elif stage_id == 1:
|
||||
layer_id = 2
|
||||
elif stage_id == 2:
|
||||
layer_id = 3 + block_id // 3
|
||||
elif stage_id == 3:
|
||||
layer_id = max_layer_id
|
||||
return layer_id
|
||||
else:
|
||||
return max_layer_id + 1
|
||||
|
||||
|
||||
def get_stage_id_for_convnext(var_name, max_stage_id):
|
||||
"""Get the stage id to set the different learning rates in ``stage_wise``
|
||||
decay_type.
|
||||
|
||||
Args:
|
||||
var_name (str): The key of the model.
|
||||
max_stage_id (int): Maximum number of backbone layers.
|
||||
|
||||
Returns:
|
||||
int: The id number corresponding to different learning rate in
|
||||
``LearningRateDecayOptimizerConstructor``.
|
||||
"""
|
||||
|
||||
if var_name in ('backbone.cls_token', 'backbone.mask_token',
|
||||
'backbone.pos_embed'):
|
||||
return 0
|
||||
elif var_name.startswith('backbone.downsample_layers'):
|
||||
return 0
|
||||
elif var_name.startswith('backbone.stages'):
|
||||
stage_id = int(var_name.split('.')[2])
|
||||
return stage_id + 1
|
||||
else:
|
||||
return max_stage_id - 1
|
||||
|
||||
|
||||
def get_layer_id_for_vit(var_name, max_layer_id):
|
||||
"""Get the layer id to set the different learning rates.
|
||||
|
||||
Args:
|
||||
var_name (str): The key of the model.
|
||||
num_max_layer (int): Maximum number of backbone layers.
|
||||
|
||||
Returns:
|
||||
int: Returns the layer id of the key.
|
||||
"""
|
||||
|
||||
if var_name in ('backbone.cls_token', 'backbone.mask_token',
|
||||
'backbone.pos_embed'):
|
||||
return 0
|
||||
elif var_name.startswith('backbone.patch_embed'):
|
||||
return 0
|
||||
elif var_name.startswith('backbone.layers'):
|
||||
layer_id = int(var_name.split('.')[2])
|
||||
return layer_id + 1
|
||||
else:
|
||||
return max_layer_id - 1
|
||||
|
||||
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
|
||||
"""Different learning rates are set for different layers of backbone.
|
||||
|
||||
Note: Currently, this optimizer constructor is built for ConvNeXt,
|
||||
BEiT and MAE.
|
||||
"""
|
||||
|
||||
def add_params(self, params, module, **kwargs):
|
||||
"""Add all parameters of module to the params list.
|
||||
|
||||
The parameters of the given module will be added to the list of param
|
||||
groups, with specific rules defined by paramwise_cfg.
|
||||
|
||||
Args:
|
||||
params (list[dict]): A list of param groups, it will be modified
|
||||
in place.
|
||||
module (nn.Module): The module to be added.
|
||||
"""
|
||||
|
||||
parameter_groups = {}
|
||||
print_log(f'self.paramwise_cfg is {self.paramwise_cfg}')
|
||||
num_layers = self.paramwise_cfg.get('num_layers') + 2
|
||||
decay_rate = self.paramwise_cfg.get('decay_rate')
|
||||
decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
|
||||
print_log('Build LearningRateDecayOptimizerConstructor '
|
||||
f'{decay_type} {decay_rate} - {num_layers}')
|
||||
weight_decay = self.base_wd
|
||||
for name, param in module.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue # frozen weights
|
||||
if len(param.shape) == 1 or name.endswith('.bias') or name in (
|
||||
'pos_embed', 'cls_token'):
|
||||
group_name = 'no_decay'
|
||||
this_weight_decay = 0.
|
||||
else:
|
||||
group_name = 'decay'
|
||||
this_weight_decay = weight_decay
|
||||
if 'layer_wise' in decay_type:
|
||||
if 'ConvNeXt' in module.backbone.__class__.__name__:
|
||||
layer_id = get_layer_id_for_convnext(
|
||||
name, self.paramwise_cfg.get('num_layers'))
|
||||
print_log(f'set param {name} as id {layer_id}')
|
||||
elif 'BEiT' in module.backbone.__class__.__name__ or \
|
||||
'MAE' in module.backbone.__class__.__name__:
|
||||
layer_id = get_layer_id_for_vit(name, num_layers)
|
||||
print_log(f'set param {name} as id {layer_id}')
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
elif decay_type == 'stage_wise':
|
||||
if 'ConvNeXt' in module.backbone.__class__.__name__:
|
||||
layer_id = get_stage_id_for_convnext(name, num_layers)
|
||||
print_log(f'set param {name} as id {layer_id}')
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
group_name = f'layer_{layer_id}_{group_name}'
|
||||
|
||||
if group_name not in parameter_groups:
|
||||
scale = decay_rate**(num_layers - layer_id - 1)
|
||||
|
||||
parameter_groups[group_name] = {
|
||||
'weight_decay': this_weight_decay,
|
||||
'params': [],
|
||||
'param_names': [],
|
||||
'lr_scale': scale,
|
||||
'group_name': group_name,
|
||||
'lr': scale * self.base_lr,
|
||||
}
|
||||
|
||||
parameter_groups[group_name]['params'].append(param)
|
||||
parameter_groups[group_name]['param_names'].append(name)
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
to_display = {}
|
||||
for key in parameter_groups:
|
||||
to_display[key] = {
|
||||
'param_names': parameter_groups[key]['param_names'],
|
||||
'lr_scale': parameter_groups[key]['lr_scale'],
|
||||
'lr': parameter_groups[key]['lr'],
|
||||
'weight_decay': parameter_groups[key]['weight_decay'],
|
||||
}
|
||||
print_log(f'Param groups = {json.dumps(to_display, indent=2)}')
|
||||
params.extend(parameter_groups.values())
|
||||
|
||||
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
|
||||
"""Different learning rates are set for different layers of backbone.
|
||||
|
||||
Note: Currently, this optimizer constructor is built for BEiT,
|
||||
and it will be deprecated.
|
||||
Please use ``LearningRateDecayOptimizerConstructor`` instead.
|
||||
"""
|
||||
|
||||
def __init__(self, optim_wrapper_cfg, paramwise_cfg):
|
||||
warnings.warn('DeprecationWarning: Original '
|
||||
'LayerDecayOptimizerConstructor of BEiT '
|
||||
'will be deprecated. Please use '
|
||||
'LearningRateDecayOptimizerConstructor instead, '
|
||||
'and set decay_type = layer_wise_vit in paramwise_cfg.')
|
||||
paramwise_cfg.update({'decay_type': 'layer_wise_vit'})
|
||||
warnings.warn('DeprecationWarning: Layer_decay_rate will '
|
||||
'be deleted, please use decay_rate instead.')
|
||||
paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate')
|
||||
super().__init__(optim_wrapper_cfg, paramwise_cfg)
|
||||
4
finetune/mmseg/engine/schedulers/__init__.py
Normal file
4
finetune/mmseg/engine/schedulers/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .poly_ratio_scheduler import PolyLRRatio
|
||||
|
||||
__all__ = ['PolyLRRatio']
|
||||
62
finetune/mmseg/engine/schedulers/poly_ratio_scheduler.py
Normal file
62
finetune/mmseg/engine/schedulers/poly_ratio_scheduler.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional
|
||||
|
||||
from mmengine.optim.scheduler import PolyLR
|
||||
|
||||
from mmseg.registry import PARAM_SCHEDULERS
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class PolyLRRatio(PolyLR):
|
||||
"""Implements polynomial learning rate decay with ratio.
|
||||
|
||||
This scheduler adjusts the learning rate of each parameter group
|
||||
following a polynomial decay equation. The decay can occur in
|
||||
conjunction with external parameter adjustments made outside this
|
||||
scheduler.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
|
||||
eta_min (float): Minimum learning rate at the end of scheduling.
|
||||
Defaults to 0.
|
||||
eta_min_ratio (float, optional): The ratio of the minimum parameter
|
||||
value to the base parameter value. Either `eta_min` or
|
||||
`eta_min_ratio` should be specified. Defaults to None.
|
||||
power (float): The power of the polynomial. Defaults to 1.0.
|
||||
begin (int): Step at which to start updating the parameters.
|
||||
Defaults to 0.
|
||||
end (int): Step at which to stop updating the parameters.
|
||||
Defaults to INF.
|
||||
last_step (int): The index of last step. Used for resume without
|
||||
state dict. Defaults to -1.
|
||||
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||
epochs. Defaults to True.
|
||||
verbose (bool): Whether to print the value for each update.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, eta_min_ratio: Optional[int] = None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.eta_min_ratio = eta_min_ratio
|
||||
|
||||
def _get_value(self):
|
||||
"""Compute value using chainable form of the scheduler."""
|
||||
|
||||
if self.last_step == 0:
|
||||
return [
|
||||
group[self.param_name] for group in self.optimizer.param_groups
|
||||
]
|
||||
|
||||
param_groups_value = []
|
||||
for base_value, param_group in zip(self.base_values,
|
||||
self.optimizer.param_groups):
|
||||
eta_min = self.eta_min if self.eta_min_ratio is None else \
|
||||
base_value * self.eta_min_ratio
|
||||
step_ratio = (1 - 1 /
|
||||
(self.total_iters - self.last_step + 1))**self.power
|
||||
step_value = (param_group[self.param_name] -
|
||||
eta_min) * step_ratio + eta_min
|
||||
param_groups_value.append(step_value)
|
||||
|
||||
return param_groups_value
|
||||
4
finetune/mmseg/evaluation/__init__.py
Normal file
4
finetune/mmseg/evaluation/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .metrics import CityscapesMetric, DepthMetric, IoUMetric
|
||||
|
||||
__all__ = ['IoUMetric', 'CityscapesMetric', 'DepthMetric']
|
||||
6
finetune/mmseg/evaluation/metrics/__init__.py
Normal file
6
finetune/mmseg/evaluation/metrics/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .citys_metric import CityscapesMetric
|
||||
from .depth_metric import DepthMetric
|
||||
from .iou_metric import IoUMetric
|
||||
|
||||
__all__ = ['IoUMetric', 'CityscapesMetric', 'DepthMetric']
|
||||
158
finetune/mmseg/evaluation/metrics/citys_metric.py
Normal file
158
finetune/mmseg/evaluation/metrics/citys_metric.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
try:
|
||||
|
||||
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
|
||||
import cityscapesscripts.helpers.labels as CSLabels
|
||||
except ImportError:
|
||||
CSLabels = None
|
||||
CSEval = None
|
||||
|
||||
import numpy as np
|
||||
from mmengine.dist import is_main_process, master_only
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class CityscapesMetric(BaseMetric):
|
||||
"""Cityscapes evaluation metric.
|
||||
|
||||
Args:
|
||||
output_dir (str): The directory for output prediction
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
Default: 255.
|
||||
format_only (bool): Only format result for results commit without
|
||||
perform evaluation. It is useful when you want to format the result
|
||||
to a specific format and submit it to the test server.
|
||||
Defaults to False.
|
||||
keep_results (bool): Whether to keep the results. When ``format_only``
|
||||
is True, ``keep_results`` must be True. Defaults to False.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
output_dir: str,
|
||||
ignore_index: int = 255,
|
||||
format_only: bool = False,
|
||||
keep_results: bool = False,
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
if CSEval is None:
|
||||
raise ImportError('Please run "pip install cityscapesscripts" to '
|
||||
'install cityscapesscripts first.')
|
||||
self.output_dir = output_dir
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
self.format_only = format_only
|
||||
if format_only:
|
||||
assert keep_results, (
|
||||
'When format_only is True, the results must be keep, please '
|
||||
f'set keep_results as True, but got {keep_results}')
|
||||
self.keep_results = keep_results
|
||||
self.prefix = prefix
|
||||
if is_main_process():
|
||||
mkdir_or_exist(self.output_dir)
|
||||
|
||||
@master_only
|
||||
def __del__(self) -> None:
|
||||
"""Clean up."""
|
||||
if not self.keep_results:
|
||||
shutil.rmtree(self.output_dir)
|
||||
|
||||
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
|
||||
"""Process one batch of data and data_samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (dict): A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
mkdir_or_exist(self.output_dir)
|
||||
|
||||
for data_sample in data_samples:
|
||||
pred_label = data_sample['pred_sem_seg']['data'][0].cpu().numpy()
|
||||
# when evaluating with official cityscapesscripts,
|
||||
# labelIds should be used
|
||||
pred_label = self._convert_to_label_id(pred_label)
|
||||
basename = osp.splitext(osp.basename(data_sample['img_path']))[0]
|
||||
png_filename = osp.abspath(
|
||||
osp.join(self.output_dir, f'{basename}.png'))
|
||||
output = Image.fromarray(pred_label.astype(np.uint8)).convert('P')
|
||||
output.save(png_filename)
|
||||
if self.format_only:
|
||||
# format_only always for test dataset without ground truth
|
||||
gt_filename = ''
|
||||
else:
|
||||
# when evaluating with official cityscapesscripts,
|
||||
# **_gtFine_labelIds.png is used
|
||||
gt_filename = data_sample['seg_map_path'].replace(
|
||||
'labelTrainIds.png', 'labelIds.png')
|
||||
self.results.append((png_filename, gt_filename))
|
||||
|
||||
def compute_metrics(self, results: list) -> Dict[str, float]:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list): Testing results of the dataset.
|
||||
|
||||
Returns:
|
||||
dict[str: float]: Cityscapes evaluation results.
|
||||
"""
|
||||
logger: MMLogger = MMLogger.get_current_instance()
|
||||
if self.format_only:
|
||||
logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
|
||||
return OrderedDict()
|
||||
|
||||
msg = 'Evaluating in Cityscapes style'
|
||||
if logger is None:
|
||||
msg = '\n' + msg
|
||||
print_log(msg, logger=logger)
|
||||
|
||||
eval_results = dict()
|
||||
print_log(
|
||||
f'Evaluating results under {self.output_dir} ...', logger=logger)
|
||||
|
||||
CSEval.args.evalInstLevelScore = True
|
||||
CSEval.args.predictionPath = osp.abspath(self.output_dir)
|
||||
CSEval.args.evalPixelAccuracy = True
|
||||
CSEval.args.JSONOutput = False
|
||||
|
||||
pred_list, gt_list = zip(*results)
|
||||
metric = dict()
|
||||
eval_results.update(
|
||||
CSEval.evaluateImgLists(pred_list, gt_list, CSEval.args))
|
||||
metric['averageScoreCategories'] = eval_results[
|
||||
'averageScoreCategories']
|
||||
metric['averageScoreInstCategories'] = eval_results[
|
||||
'averageScoreInstCategories']
|
||||
return metric
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_label_id(result):
|
||||
"""Convert trainId to id for cityscapes."""
|
||||
if isinstance(result, str):
|
||||
result = np.load(result)
|
||||
result_copy = result.copy()
|
||||
for trainId, label in CSLabels.trainId2label.items():
|
||||
result_copy[result == trainId] = label.id
|
||||
|
||||
return result_copy
|
||||
212
finetune/mmseg/evaluation/metrics/depth_metric.py
Normal file
212
finetune/mmseg/evaluation/metrics/depth_metric.py
Normal file
@@ -0,0 +1,212 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from collections import OrderedDict, defaultdict
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.dist import is_main_process
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
from prettytable import PrettyTable
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class DepthMetric(BaseMetric):
|
||||
"""Depth estimation evaluation metric.
|
||||
|
||||
Args:
|
||||
depth_metrics (List[str], optional): List of metrics to compute. If
|
||||
not specified, defaults to all metrics in self.METRICS.
|
||||
min_depth_eval (float): Minimum depth value for evaluation.
|
||||
Defaults to 0.0.
|
||||
max_depth_eval (float): Maximum depth value for evaluation.
|
||||
Defaults to infinity.
|
||||
crop_type (str, optional): Specifies the type of cropping to be used
|
||||
during evaluation. This option can affect how the evaluation mask
|
||||
is generated. Currently, 'nyu_crop' is supported, but other
|
||||
types can be added in future. Defaults to None if no cropping
|
||||
should be applied.
|
||||
depth_scale_factor (float): Factor to scale the depth values.
|
||||
Defaults to 1.0.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
output_dir (str): The directory for output prediction. Defaults to
|
||||
None.
|
||||
format_only (bool): Only format result for results commit without
|
||||
perform evaluation. It is useful when you want to save the result
|
||||
to a specific format and submit it to the test server.
|
||||
Defaults to False.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
"""
|
||||
METRICS = ('d1', 'd2', 'd3', 'abs_rel', 'sq_rel', 'rmse', 'rmse_log',
|
||||
'log10', 'silog')
|
||||
|
||||
def __init__(self,
|
||||
depth_metrics: Optional[List[str]] = None,
|
||||
min_depth_eval: float = 0.0,
|
||||
max_depth_eval: float = float('inf'),
|
||||
crop_type: Optional[str] = None,
|
||||
depth_scale_factor: float = 1.0,
|
||||
collect_device: str = 'cpu',
|
||||
output_dir: Optional[str] = None,
|
||||
format_only: bool = False,
|
||||
prefix: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
|
||||
if depth_metrics is None:
|
||||
self.metrics = self.METRICS
|
||||
elif isinstance(depth_metrics, [tuple, list]):
|
||||
for metric in depth_metrics:
|
||||
assert metric in self.METRICS, f'the metric {metric} is not ' \
|
||||
f'supported. Please use metrics in {self.METRICS}'
|
||||
self.metrics = depth_metrics
|
||||
|
||||
# Validate crop_type, if provided
|
||||
assert crop_type in [
|
||||
None, 'nyu_crop'
|
||||
], (f'Invalid value for crop_type: {crop_type}. Supported values are '
|
||||
'None or \'nyu_crop\'.')
|
||||
self.crop_type = crop_type
|
||||
self.min_depth_eval = min_depth_eval
|
||||
self.max_depth_eval = max_depth_eval
|
||||
self.output_dir = output_dir
|
||||
if self.output_dir and is_main_process():
|
||||
mkdir_or_exist(self.output_dir)
|
||||
self.format_only = format_only
|
||||
self.depth_scale_factor = depth_scale_factor
|
||||
|
||||
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
|
||||
"""Process one batch of data and data_samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to compute the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (dict): A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for data_sample in data_samples:
|
||||
pred_label = data_sample['pred_depth_map']['data'].squeeze()
|
||||
# format_only always for test dataset without ground truth
|
||||
if not self.format_only:
|
||||
gt_depth = data_sample['gt_depth_map']['data'].squeeze().to(
|
||||
pred_label)
|
||||
|
||||
eval_mask = self._get_eval_mask(gt_depth)
|
||||
self.results.append(
|
||||
(gt_depth[eval_mask], pred_label[eval_mask]))
|
||||
# format_result
|
||||
if self.output_dir is not None:
|
||||
basename = osp.splitext(osp.basename(
|
||||
data_sample['img_path']))[0]
|
||||
png_filename = osp.abspath(
|
||||
osp.join(self.output_dir, f'{basename}.png'))
|
||||
output_mask = pred_label.cpu().numpy(
|
||||
) * self.depth_scale_factor
|
||||
|
||||
cv2.imwrite(png_filename, output_mask.astype(np.uint16),
|
||||
[cv2.IMWRITE_PNG_COMPRESSION, 0])
|
||||
|
||||
def _get_eval_mask(self, gt_depth: Tensor):
|
||||
"""Generates an evaluation mask based on ground truth depth and
|
||||
cropping.
|
||||
|
||||
Args:
|
||||
gt_depth (Tensor): Ground truth depth map.
|
||||
|
||||
Returns:
|
||||
Tensor: Boolean mask where evaluation should be performed.
|
||||
"""
|
||||
valid_mask = torch.logical_and(gt_depth > self.min_depth_eval,
|
||||
gt_depth < self.max_depth_eval)
|
||||
|
||||
if self.crop_type == 'nyu_crop':
|
||||
# this implementation is adapted from
|
||||
# https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox/blob/main/depth/datasets/nyu.py # noqa
|
||||
crop_mask = torch.zeros_like(valid_mask)
|
||||
crop_mask[45:471, 41:601] = 1
|
||||
else:
|
||||
crop_mask = torch.ones_like(valid_mask)
|
||||
|
||||
eval_mask = torch.logical_and(valid_mask, crop_mask)
|
||||
return eval_mask
|
||||
|
||||
@staticmethod
|
||||
def _calc_all_metrics(gt_depth, pred_depth):
|
||||
"""Computes final evaluation metrics based on accumulated results."""
|
||||
assert gt_depth.shape == pred_depth.shape
|
||||
|
||||
thresh = torch.max((gt_depth / pred_depth), (pred_depth / gt_depth))
|
||||
diff = pred_depth - gt_depth
|
||||
diff_log = torch.log(pred_depth) - torch.log(gt_depth)
|
||||
|
||||
d1 = torch.sum(thresh < 1.25).float() / len(thresh)
|
||||
d2 = torch.sum(thresh < 1.25**2).float() / len(thresh)
|
||||
d3 = torch.sum(thresh < 1.25**3).float() / len(thresh)
|
||||
|
||||
abs_rel = torch.mean(torch.abs(diff) / gt_depth)
|
||||
sq_rel = torch.mean(torch.pow(diff, 2) / gt_depth)
|
||||
|
||||
rmse = torch.sqrt(torch.mean(torch.pow(diff, 2)))
|
||||
rmse_log = torch.sqrt(torch.mean(torch.pow(diff_log, 2)))
|
||||
|
||||
log10 = torch.mean(
|
||||
torch.abs(torch.log10(pred_depth) - torch.log10(gt_depth)))
|
||||
silog = torch.sqrt(
|
||||
torch.pow(diff_log, 2).mean() -
|
||||
0.5 * torch.pow(diff_log.mean(), 2))
|
||||
|
||||
return {
|
||||
'd1': d1.item(),
|
||||
'd2': d2.item(),
|
||||
'd3': d3.item(),
|
||||
'abs_rel': abs_rel.item(),
|
||||
'sq_rel': sq_rel.item(),
|
||||
'rmse': rmse.item(),
|
||||
'rmse_log': rmse_log.item(),
|
||||
'log10': log10.item(),
|
||||
'silog': silog.item()
|
||||
}
|
||||
|
||||
def compute_metrics(self, results: list) -> Dict[str, float]:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: The computed metrics. The keys are the names of
|
||||
the metrics, and the values are corresponding results. The keys
|
||||
are identical with self.metrics.
|
||||
"""
|
||||
logger: MMLogger = MMLogger.get_current_instance()
|
||||
if self.format_only:
|
||||
logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
|
||||
return OrderedDict()
|
||||
|
||||
metrics = defaultdict(list)
|
||||
for gt_depth, pred_depth in results:
|
||||
for key, value in self._calc_all_metrics(gt_depth,
|
||||
pred_depth).items():
|
||||
metrics[key].append(value)
|
||||
metrics = {k: sum(metrics[k]) / len(metrics[k]) for k in self.metrics}
|
||||
|
||||
table_data = PrettyTable()
|
||||
for key, val in metrics.items():
|
||||
table_data.add_column(key, [round(val, 5)])
|
||||
|
||||
print_log('results:', logger)
|
||||
print_log('\n' + table_data.get_string(), logger=logger)
|
||||
|
||||
return metrics
|
||||
286
finetune/mmseg/evaluation/metrics/iou_metric.py
Normal file
286
finetune/mmseg/evaluation/metrics/iou_metric.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.dist import is_main_process
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
from PIL import Image
|
||||
from prettytable import PrettyTable
|
||||
|
||||
from mmseg.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class IoUMetric(BaseMetric):
|
||||
"""IoU evaluation metric.
|
||||
|
||||
Args:
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
Default: 255.
|
||||
iou_metrics (list[str] | str): Metrics to be calculated, the options
|
||||
includes 'mIoU', 'mDice' and 'mFscore'.
|
||||
nan_to_num (int, optional): If specified, NaN values will be replaced
|
||||
by the numbers defined by the user. Default: None.
|
||||
beta (int): Determines the weight of recall in the combined score.
|
||||
Default: 1.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
output_dir (str): The directory for output prediction. Defaults to
|
||||
None.
|
||||
format_only (bool): Only format result for results commit without
|
||||
perform evaluation. It is useful when you want to save the result
|
||||
to a specific format and submit it to the test server.
|
||||
Defaults to False.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ignore_index: int = 255,
|
||||
iou_metrics: List[str] = ['mIoU'],
|
||||
nan_to_num: Optional[int] = None,
|
||||
beta: int = 1,
|
||||
collect_device: str = 'cpu',
|
||||
output_dir: Optional[str] = None,
|
||||
format_only: bool = False,
|
||||
prefix: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
|
||||
self.ignore_index = ignore_index
|
||||
self.metrics = iou_metrics
|
||||
self.nan_to_num = nan_to_num
|
||||
self.beta = beta
|
||||
self.output_dir = output_dir
|
||||
if self.output_dir and is_main_process():
|
||||
mkdir_or_exist(self.output_dir)
|
||||
self.format_only = format_only
|
||||
|
||||
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
|
||||
"""Process one batch of data and data_samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to compute the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (dict): A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
num_classes = len(self.dataset_meta['classes'])
|
||||
for data_sample in data_samples:
|
||||
pred_label = data_sample['pred_sem_seg']['data'].squeeze()
|
||||
# format_only always for test dataset without ground truth
|
||||
if not self.format_only:
|
||||
label = data_sample['gt_sem_seg']['data'].squeeze().to(
|
||||
pred_label)
|
||||
self.results.append(
|
||||
self.intersect_and_union(pred_label, label, num_classes,
|
||||
self.ignore_index))
|
||||
# format_result
|
||||
if self.output_dir is not None:
|
||||
basename = osp.splitext(osp.basename(
|
||||
data_sample['img_path']))[0]
|
||||
png_filename = osp.abspath(
|
||||
osp.join(self.output_dir, f'{basename}.png'))
|
||||
output_mask = pred_label.cpu().numpy()
|
||||
# The index range of official ADE20k dataset is from 0 to 150.
|
||||
# But the index range of output is from 0 to 149.
|
||||
# That is because we set reduce_zero_label=True.
|
||||
if data_sample.get('reduce_zero_label', False):
|
||||
output_mask = output_mask + 1
|
||||
output = Image.fromarray(output_mask.astype(np.uint8))
|
||||
output.save(png_filename)
|
||||
|
||||
def compute_metrics(self, results: list) -> Dict[str, float]:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: The computed metrics. The keys are the names of
|
||||
the metrics, and the values are corresponding results. The key
|
||||
mainly includes aAcc, mIoU, mAcc, mDice, mFscore, mPrecision,
|
||||
mRecall.
|
||||
"""
|
||||
logger: MMLogger = MMLogger.get_current_instance()
|
||||
if self.format_only:
|
||||
logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
|
||||
return OrderedDict()
|
||||
# convert list of tuples to tuple of lists, e.g.
|
||||
# [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to
|
||||
# ([A_1, ..., A_n], ..., [D_1, ..., D_n])
|
||||
results = tuple(zip(*results))
|
||||
assert len(results) == 4
|
||||
|
||||
total_area_intersect = sum(results[0])
|
||||
total_area_union = sum(results[1])
|
||||
total_area_pred_label = sum(results[2])
|
||||
total_area_label = sum(results[3])
|
||||
ret_metrics = self.total_area_to_metrics(
|
||||
total_area_intersect, total_area_union, total_area_pred_label,
|
||||
total_area_label, self.metrics, self.nan_to_num, self.beta)
|
||||
|
||||
class_names = self.dataset_meta['classes']
|
||||
|
||||
# summary table
|
||||
ret_metrics_summary = OrderedDict({
|
||||
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
|
||||
for ret_metric, ret_metric_value in ret_metrics.items()
|
||||
})
|
||||
metrics = dict()
|
||||
for key, val in ret_metrics_summary.items():
|
||||
if key == 'aAcc':
|
||||
metrics[key] = val
|
||||
else:
|
||||
metrics['m' + key] = val
|
||||
|
||||
# each class table
|
||||
ret_metrics.pop('aAcc', None)
|
||||
ret_metrics_class = OrderedDict({
|
||||
ret_metric: np.round(ret_metric_value * 100, 2)
|
||||
for ret_metric, ret_metric_value in ret_metrics.items()
|
||||
})
|
||||
ret_metrics_class.update({'Class': class_names})
|
||||
ret_metrics_class.move_to_end('Class', last=False)
|
||||
class_table_data = PrettyTable()
|
||||
for key, val in ret_metrics_class.items():
|
||||
class_table_data.add_column(key, val)
|
||||
|
||||
print_log('per class results:', logger)
|
||||
print_log('\n' + class_table_data.get_string(), logger=logger)
|
||||
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
def intersect_and_union(pred_label: torch.tensor, label: torch.tensor,
|
||||
num_classes: int, ignore_index: int):
|
||||
"""Calculate Intersection and Union.
|
||||
|
||||
Args:
|
||||
pred_label (torch.tensor): Prediction segmentation map
|
||||
or predict result filename. The shape is (H, W).
|
||||
label (torch.tensor): Ground truth segmentation map
|
||||
or label filename. The shape is (H, W).
|
||||
num_classes (int): Number of categories.
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The intersection of prediction and ground truth
|
||||
histogram on all classes.
|
||||
torch.Tensor: The union of prediction and ground truth histogram on
|
||||
all classes.
|
||||
torch.Tensor: The prediction histogram on all classes.
|
||||
torch.Tensor: The ground truth histogram on all classes.
|
||||
"""
|
||||
|
||||
mask = (label != ignore_index)
|
||||
pred_label = pred_label[mask]
|
||||
label = label[mask]
|
||||
|
||||
intersect = pred_label[pred_label == label]
|
||||
area_intersect = torch.histc(
|
||||
intersect.float(), bins=(num_classes), min=0,
|
||||
max=num_classes - 1).cpu()
|
||||
area_pred_label = torch.histc(
|
||||
pred_label.float(), bins=(num_classes), min=0,
|
||||
max=num_classes - 1).cpu()
|
||||
area_label = torch.histc(
|
||||
label.float(), bins=(num_classes), min=0,
|
||||
max=num_classes - 1).cpu()
|
||||
area_union = area_pred_label + area_label - area_intersect
|
||||
return area_intersect, area_union, area_pred_label, area_label
|
||||
|
||||
@staticmethod
|
||||
def total_area_to_metrics(total_area_intersect: np.ndarray,
|
||||
total_area_union: np.ndarray,
|
||||
total_area_pred_label: np.ndarray,
|
||||
total_area_label: np.ndarray,
|
||||
metrics: List[str] = ['mIoU'],
|
||||
nan_to_num: Optional[int] = None,
|
||||
beta: int = 1):
|
||||
"""Calculate evaluation metrics
|
||||
Args:
|
||||
total_area_intersect (np.ndarray): The intersection of prediction
|
||||
and ground truth histogram on all classes.
|
||||
total_area_union (np.ndarray): The union of prediction and ground
|
||||
truth histogram on all classes.
|
||||
total_area_pred_label (np.ndarray): The prediction histogram on
|
||||
all classes.
|
||||
total_area_label (np.ndarray): The ground truth histogram on
|
||||
all classes.
|
||||
metrics (List[str] | str): Metrics to be evaluated, 'mIoU' and
|
||||
'mDice'.
|
||||
nan_to_num (int, optional): If specified, NaN values will be
|
||||
replaced by the numbers defined by the user. Default: None.
|
||||
beta (int): Determines the weight of recall in the combined score.
|
||||
Default: 1.
|
||||
Returns:
|
||||
Dict[str, np.ndarray]: per category evaluation metrics,
|
||||
shape (num_classes, ).
|
||||
"""
|
||||
|
||||
def f_score(precision, recall, beta=1):
|
||||
"""calculate the f-score value.
|
||||
|
||||
Args:
|
||||
precision (float | torch.Tensor): The precision value.
|
||||
recall (float | torch.Tensor): The recall value.
|
||||
beta (int): Determines the weight of recall in the combined
|
||||
score. Default: 1.
|
||||
|
||||
Returns:
|
||||
[torch.tensor]: The f-score value.
|
||||
"""
|
||||
score = (1 + beta**2) * (precision * recall) / (
|
||||
(beta**2 * precision) + recall)
|
||||
return score
|
||||
|
||||
if isinstance(metrics, str):
|
||||
metrics = [metrics]
|
||||
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
|
||||
if not set(metrics).issubset(set(allowed_metrics)):
|
||||
raise KeyError(f'metrics {metrics} is not supported')
|
||||
|
||||
all_acc = total_area_intersect.sum() / total_area_label.sum()
|
||||
ret_metrics = OrderedDict({'aAcc': all_acc})
|
||||
for metric in metrics:
|
||||
if metric == 'mIoU':
|
||||
iou = total_area_intersect / total_area_union
|
||||
acc = total_area_intersect / total_area_label
|
||||
ret_metrics['IoU'] = iou
|
||||
ret_metrics['Acc'] = acc
|
||||
elif metric == 'mDice':
|
||||
dice = 2 * total_area_intersect / (
|
||||
total_area_pred_label + total_area_label)
|
||||
acc = total_area_intersect / total_area_label
|
||||
ret_metrics['Dice'] = dice
|
||||
ret_metrics['Acc'] = acc
|
||||
elif metric == 'mFscore':
|
||||
precision = total_area_intersect / total_area_pred_label
|
||||
recall = total_area_intersect / total_area_label
|
||||
f_value = torch.tensor([
|
||||
f_score(x[0], x[1], beta) for x in zip(precision, recall)
|
||||
])
|
||||
ret_metrics['Fscore'] = f_value
|
||||
ret_metrics['Precision'] = precision
|
||||
ret_metrics['Recall'] = recall
|
||||
|
||||
ret_metrics = {
|
||||
metric: value.numpy()
|
||||
for metric, value in ret_metrics.items()
|
||||
}
|
||||
if nan_to_num is not None:
|
||||
ret_metrics = OrderedDict({
|
||||
metric: np.nan_to_num(metric_value, nan=nan_to_num)
|
||||
for metric, metric_value in ret_metrics.items()
|
||||
})
|
||||
return ret_metrics
|
||||
16
finetune/mmseg/models/__init__.py
Normal file
16
finetune/mmseg/models/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .assigners import * # noqa: F401,F403
|
||||
from .backbones import * # noqa: F401,F403
|
||||
from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
|
||||
build_head, build_loss, build_segmentor)
|
||||
from .data_preprocessor import SegDataPreProcessor
|
||||
from .decode_heads import * # noqa: F401,F403
|
||||
from .losses import * # noqa: F401,F403
|
||||
from .necks import * # noqa: F401,F403
|
||||
from .segmentors import * # noqa: F401,F403
|
||||
from .text_encoder import * # noqa: F401,F403
|
||||
|
||||
__all__ = [
|
||||
'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
|
||||
'build_head', 'build_loss', 'build_segmentor', 'SegDataPreProcessor'
|
||||
]
|
||||
12
finetune/mmseg/models/assigners/__init__.py
Normal file
12
finetune/mmseg/models/assigners/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_assigner import BaseAssigner
|
||||
from .hungarian_assigner import HungarianAssigner
|
||||
from .match_cost import ClassificationCost, CrossEntropyLossCost, DiceCost
|
||||
|
||||
__all__ = [
|
||||
'BaseAssigner',
|
||||
'HungarianAssigner',
|
||||
'ClassificationCost',
|
||||
'CrossEntropyLossCost',
|
||||
'DiceCost',
|
||||
]
|
||||
18
finetune/mmseg/models/assigners/base_assigner.py
Normal file
18
finetune/mmseg/models/assigners/base_assigner.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from mmengine.structures import InstanceData
|
||||
|
||||
|
||||
class BaseAssigner(metaclass=ABCMeta):
|
||||
"""Base assigner that assigns masks to ground truth class labels."""
|
||||
|
||||
@abstractmethod
|
||||
def assign(self,
|
||||
pred_instances: InstanceData,
|
||||
gt_instances: InstanceData,
|
||||
gt_instances_ignore: Optional[InstanceData] = None,
|
||||
**kwargs):
|
||||
"""Assign masks to either a ground truth class label or a negative
|
||||
label."""
|
||||
86
finetune/mmseg/models/assigners/hungarian_assigner.py
Normal file
86
finetune/mmseg/models/assigners/hungarian_assigner.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from mmengine import ConfigDict
|
||||
from mmengine.structures import InstanceData
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from mmseg.registry import TASK_UTILS
|
||||
from .base_assigner import BaseAssigner
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class HungarianAssigner(BaseAssigner):
|
||||
"""Computes one-to-one matching between prediction masks and ground truth.
|
||||
|
||||
This class uses bipartite matching-based assignment to computes an
|
||||
assignment between the prediction masks and the ground truth. The
|
||||
assignment result is based on the weighted sum of match costs. The
|
||||
Hungarian algorithm is used to calculate the best matching with the
|
||||
minimum cost. The prediction masks that are not matched are classified
|
||||
as background.
|
||||
|
||||
Args:
|
||||
match_costs (ConfigDict|List[ConfigDict]): Match cost configs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, match_costs: Union[List[Union[dict, ConfigDict]], dict,
|
||||
ConfigDict]
|
||||
) -> None:
|
||||
|
||||
if isinstance(match_costs, dict):
|
||||
match_costs = [match_costs]
|
||||
elif isinstance(match_costs, list):
|
||||
assert len(match_costs) > 0, \
|
||||
'match_costs must not be a empty list.'
|
||||
|
||||
self.match_costs = [
|
||||
TASK_UTILS.build(match_cost) for match_cost in match_costs
|
||||
]
|
||||
|
||||
def assign(self, pred_instances: InstanceData, gt_instances: InstanceData,
|
||||
**kwargs):
|
||||
"""Computes one-to-one matching based on the weighted costs.
|
||||
|
||||
This method assign each query prediction to a ground truth or
|
||||
background. The assignment first calculates the cost for each
|
||||
category assigned to each query mask, and then uses the
|
||||
Hungarian algorithm to calculate the minimum cost as the best
|
||||
match.
|
||||
|
||||
Args:
|
||||
pred_instances (InstanceData): Instances of model
|
||||
predictions. It includes "masks", with shape
|
||||
(n, h, w) or (n, l), and "cls", with shape (n, num_classes+1)
|
||||
gt_instances (InstanceData): Ground truth of instance
|
||||
annotations. It includes "labels", with shape (k, ),
|
||||
and "masks", with shape (k, h, w) or (k, l).
|
||||
|
||||
Returns:
|
||||
matched_quiery_inds (Tensor): The indexes of matched quieres.
|
||||
matched_label_inds (Tensor): The indexes of matched labels.
|
||||
"""
|
||||
# compute weighted cost
|
||||
cost_list = []
|
||||
with autocast(enabled=False):
|
||||
for match_cost in self.match_costs:
|
||||
cost = match_cost(
|
||||
pred_instances=pred_instances, gt_instances=gt_instances)
|
||||
cost_list.append(cost)
|
||||
cost = torch.stack(cost_list).sum(dim=0)
|
||||
|
||||
device = cost.device
|
||||
# do Hungarian matching on CPU using linear_sum_assignment
|
||||
cost = cost.detach().cpu()
|
||||
if linear_sum_assignment is None:
|
||||
raise ImportError('Please run "pip install scipy" '
|
||||
'to install scipy first.')
|
||||
|
||||
matched_quiery_inds, matched_label_inds = linear_sum_assignment(cost)
|
||||
matched_quiery_inds = torch.from_numpy(matched_quiery_inds).to(device)
|
||||
matched_label_inds = torch.from_numpy(matched_label_inds).to(device)
|
||||
|
||||
return matched_quiery_inds, matched_label_inds
|
||||
231
finetune/mmseg/models/assigners/match_cost.py
Normal file
231
finetune/mmseg/models/assigners/match_cost.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import abstractmethod
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import TASK_UTILS
|
||||
|
||||
|
||||
class BaseMatchCost:
|
||||
"""Base match cost class.
|
||||
|
||||
Args:
|
||||
weight (Union[float, int]): Cost weight. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, weight: Union[float, int] = 1.) -> None:
|
||||
self.weight = weight
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, pred_instances: InstanceData,
|
||||
gt_instances: InstanceData, **kwargs) -> Tensor:
|
||||
"""Compute match cost.
|
||||
|
||||
Args:
|
||||
pred_instances (InstanceData): Instances of model predictions.
|
||||
It often includes "labels" and "scores".
|
||||
gt_instances (InstanceData): Ground truth of instance
|
||||
annotations. It usually includes "labels".
|
||||
|
||||
Returns:
|
||||
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class ClassificationCost(BaseMatchCost):
|
||||
"""ClsSoftmaxCost.
|
||||
|
||||
Args:
|
||||
weight (Union[float, int]): Cost weight. Defaults to 1.
|
||||
|
||||
Examples:
|
||||
>>> from mmseg.models.assigners import ClassificationCost
|
||||
>>> import torch
|
||||
>>> self = ClassificationCost()
|
||||
>>> cls_pred = torch.rand(4, 3)
|
||||
>>> gt_labels = torch.tensor([0, 1, 2])
|
||||
>>> factor = torch.tensor([10, 8, 10, 8])
|
||||
>>> self(cls_pred, gt_labels)
|
||||
tensor([[-0.3430, -0.3525, -0.3045],
|
||||
[-0.3077, -0.2931, -0.3992],
|
||||
[-0.3664, -0.3455, -0.2881],
|
||||
[-0.3343, -0.2701, -0.3956]])
|
||||
"""
|
||||
|
||||
def __init__(self, weight: Union[float, int] = 1) -> None:
|
||||
super().__init__(weight=weight)
|
||||
|
||||
def __call__(self, pred_instances: InstanceData,
|
||||
gt_instances: InstanceData, **kwargs) -> Tensor:
|
||||
"""Compute match cost.
|
||||
|
||||
Args:
|
||||
pred_instances (InstanceData): "scores" inside is
|
||||
predicted classification logits, of shape
|
||||
(num_queries, num_class).
|
||||
gt_instances (InstanceData): "labels" inside should have
|
||||
shape (num_gt, ).
|
||||
|
||||
Returns:
|
||||
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
||||
"""
|
||||
assert hasattr(pred_instances, 'scores'), \
|
||||
"pred_instances must contain 'scores'"
|
||||
assert hasattr(gt_instances, 'labels'), \
|
||||
"gt_instances must contain 'labels'"
|
||||
pred_scores = pred_instances.scores
|
||||
gt_labels = gt_instances.labels
|
||||
|
||||
pred_scores = pred_scores.softmax(-1)
|
||||
cls_cost = -pred_scores[:, gt_labels]
|
||||
|
||||
return cls_cost * self.weight
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class DiceCost(BaseMatchCost):
|
||||
"""Cost of mask assignments based on dice losses.
|
||||
|
||||
Args:
|
||||
pred_act (bool): Whether to apply sigmoid to mask_pred.
|
||||
Defaults to False.
|
||||
eps (float): Defaults to 1e-3.
|
||||
naive_dice (bool): If True, use the naive dice loss
|
||||
in which the power of the number in the denominator is
|
||||
the first power. If False, use the second power that
|
||||
is adopted by K-Net and SOLO. Defaults to True.
|
||||
weight (Union[float, int]): Cost weight. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pred_act: bool = False,
|
||||
eps: float = 1e-3,
|
||||
naive_dice: bool = True,
|
||||
weight: Union[float, int] = 1.) -> None:
|
||||
super().__init__(weight=weight)
|
||||
self.pred_act = pred_act
|
||||
self.eps = eps
|
||||
self.naive_dice = naive_dice
|
||||
|
||||
def _binary_mask_dice_loss(self, mask_preds: Tensor,
|
||||
gt_masks: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
mask_preds (Tensor): Mask prediction in shape (num_queries, *).
|
||||
gt_masks (Tensor): Ground truth in shape (num_gt, *)
|
||||
store 0 or 1, 0 for negative class and 1 for
|
||||
positive class.
|
||||
|
||||
Returns:
|
||||
Tensor: Dice cost matrix in shape (num_queries, num_gt).
|
||||
"""
|
||||
mask_preds = mask_preds.flatten(1)
|
||||
gt_masks = gt_masks.flatten(1).float()
|
||||
numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks)
|
||||
if self.naive_dice:
|
||||
denominator = mask_preds.sum(-1)[:, None] + \
|
||||
gt_masks.sum(-1)[None, :]
|
||||
else:
|
||||
denominator = mask_preds.pow(2).sum(1)[:, None] + \
|
||||
gt_masks.pow(2).sum(1)[None, :]
|
||||
loss = 1 - (numerator + self.eps) / (denominator + self.eps)
|
||||
return loss
|
||||
|
||||
def __call__(self, pred_instances: InstanceData,
|
||||
gt_instances: InstanceData, **kwargs) -> Tensor:
|
||||
"""Compute match cost.
|
||||
|
||||
Args:
|
||||
pred_instances (InstanceData): Predicted instances which
|
||||
must contain "masks".
|
||||
gt_instances (InstanceData): Ground truth which must contain
|
||||
"mask".
|
||||
|
||||
Returns:
|
||||
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
||||
"""
|
||||
assert hasattr(pred_instances, 'masks'), \
|
||||
"pred_instances must contain 'masks'"
|
||||
assert hasattr(gt_instances, 'masks'), \
|
||||
"gt_instances must contain 'masks'"
|
||||
pred_masks = pred_instances.masks
|
||||
gt_masks = gt_instances.masks
|
||||
|
||||
if self.pred_act:
|
||||
pred_masks = pred_masks.sigmoid()
|
||||
dice_cost = self._binary_mask_dice_loss(pred_masks, gt_masks)
|
||||
return dice_cost * self.weight
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class CrossEntropyLossCost(BaseMatchCost):
|
||||
"""CrossEntropyLossCost.
|
||||
|
||||
Args:
|
||||
use_sigmoid (bool): Whether the prediction uses sigmoid
|
||||
of softmax. Defaults to True.
|
||||
weight (Union[float, int]): Cost weight. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
use_sigmoid: bool = True,
|
||||
weight: Union[float, int] = 1.) -> None:
|
||||
super().__init__(weight=weight)
|
||||
self.use_sigmoid = use_sigmoid
|
||||
|
||||
def _binary_cross_entropy(self, cls_pred: Tensor,
|
||||
gt_labels: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or
|
||||
(num_queries, *).
|
||||
gt_labels (Tensor): The learning label of prediction with
|
||||
shape (num_gt, *).
|
||||
|
||||
Returns:
|
||||
Tensor: Cross entropy cost matrix in shape (num_queries, num_gt).
|
||||
"""
|
||||
cls_pred = cls_pred.flatten(1).float()
|
||||
gt_labels = gt_labels.flatten(1).float()
|
||||
n = cls_pred.shape[1]
|
||||
pos = F.binary_cross_entropy_with_logits(
|
||||
cls_pred, torch.ones_like(cls_pred), reduction='none')
|
||||
neg = F.binary_cross_entropy_with_logits(
|
||||
cls_pred, torch.zeros_like(cls_pred), reduction='none')
|
||||
cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \
|
||||
torch.einsum('nc,mc->nm', neg, 1 - gt_labels)
|
||||
cls_cost = cls_cost / n
|
||||
|
||||
return cls_cost
|
||||
|
||||
def __call__(self, pred_instances: InstanceData,
|
||||
gt_instances: InstanceData, **kwargs) -> Tensor:
|
||||
"""Compute match cost.
|
||||
|
||||
Args:
|
||||
pred_instances (:obj:`InstanceData`): Predicted instances which
|
||||
must contain ``masks``.
|
||||
gt_instances (:obj:`InstanceData`): Ground truth which must contain
|
||||
``masks``.
|
||||
|
||||
Returns:
|
||||
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
||||
"""
|
||||
assert hasattr(pred_instances, 'masks'), \
|
||||
"pred_instances must contain 'masks'"
|
||||
assert hasattr(gt_instances, 'masks'), \
|
||||
"gt_instances must contain 'masks'"
|
||||
pred_masks = pred_instances.masks
|
||||
gt_masks = gt_instances.masks
|
||||
if self.use_sigmoid:
|
||||
cls_cost = self._binary_cross_entropy(pred_masks, gt_masks)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return cls_cost * self.weight
|
||||
35
finetune/mmseg/models/backbones/__init__.py
Normal file
35
finetune/mmseg/models/backbones/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .beit import BEiT
|
||||
from .bisenetv1 import BiSeNetV1
|
||||
from .bisenetv2 import BiSeNetV2
|
||||
from .cgnet import CGNet
|
||||
from .ddrnet import DDRNet
|
||||
from .erfnet import ERFNet
|
||||
from .fast_scnn import FastSCNN
|
||||
from .hrnet import HRNet
|
||||
from .icnet import ICNet
|
||||
from .mae import MAE
|
||||
from .mit import MixVisionTransformer
|
||||
from .mobilenet_v2 import MobileNetV2
|
||||
from .mobilenet_v3 import MobileNetV3
|
||||
from .mscan import MSCAN
|
||||
from .pidnet import PIDNet
|
||||
from .resnest import ResNeSt
|
||||
from .resnet import ResNet, ResNetV1c, ResNetV1d
|
||||
from .resnext import ResNeXt
|
||||
from .stdc import STDCContextPathNet, STDCNet
|
||||
from .swin import SwinTransformer
|
||||
from .timm_backbone import TIMMBackbone
|
||||
from .twins import PCPVT, SVT
|
||||
from .unet import UNet
|
||||
from .vit import VisionTransformer
|
||||
from .vpd import VPD
|
||||
|
||||
__all__ = [
|
||||
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
|
||||
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
|
||||
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
|
||||
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
|
||||
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN',
|
||||
'DDRNet', 'VPD'
|
||||
]
|
||||
554
finetune/mmseg/models/backbones/beit.py
Normal file
554
finetune/mmseg/models/backbones/beit.py
Normal file
@@ -0,0 +1,554 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmengine.runner.checkpoint import _load_checkpoint
|
||||
from scipy import interpolate
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import PatchEmbed
|
||||
from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer
|
||||
|
||||
|
||||
class BEiTAttention(BaseModule):
|
||||
"""Window based multi-head self-attention (W-MSA) module with relative
|
||||
position bias.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
bias (bool): The option to add leanable bias for q, k, v. If bias is
|
||||
True, it will add leanable bias. If bias is 'qv_bias', it will only
|
||||
add leanable bias for q, v. If bias is False, it will not add bias
|
||||
for q, k, v. Default to 'qv_bias'.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
attn_drop_rate (float): Dropout ratio of attention weight.
|
||||
Default: 0.0
|
||||
proj_drop_rate (float): Dropout ratio of output. Default: 0.
|
||||
init_cfg (dict | None, optional): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
window_size,
|
||||
bias='qv_bias',
|
||||
qk_scale=None,
|
||||
attn_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
init_cfg=None,
|
||||
**kwargs):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
head_embed_dims = embed_dims // num_heads
|
||||
self.bias = bias
|
||||
self.scale = qk_scale or head_embed_dims**-0.5
|
||||
|
||||
qkv_bias = bias
|
||||
if bias == 'qv_bias':
|
||||
self._init_qv_bias()
|
||||
qkv_bias = False
|
||||
|
||||
self.window_size = window_size
|
||||
self._init_rel_pos_embedding()
|
||||
|
||||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop_rate)
|
||||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.proj_drop = nn.Dropout(proj_drop_rate)
|
||||
|
||||
def _init_qv_bias(self):
|
||||
self.q_bias = nn.Parameter(torch.zeros(self.embed_dims))
|
||||
self.v_bias = nn.Parameter(torch.zeros(self.embed_dims))
|
||||
|
||||
def _init_rel_pos_embedding(self):
|
||||
Wh, Ww = self.window_size
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3
|
||||
# relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH)
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, self.num_heads))
|
||||
|
||||
# get pair-wise relative position index for
|
||||
# each token inside the window
|
||||
coords_h = torch.arange(Wh)
|
||||
coords_w = torch.arange(Ww)
|
||||
# coords shape is (2, Wh, Ww)
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
||||
# coords_flatten shape is (2, Wh*Ww)
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :])
|
||||
# relative_coords shape is (Wh*Ww, Wh*Ww, 2)
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
# shift to start from 0
|
||||
relative_coords[:, :, 0] += Wh - 1
|
||||
relative_coords[:, :, 1] += Ww - 1
|
||||
relative_coords[:, :, 0] *= 2 * Ww - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype)
|
||||
# relative_position_index shape is (Wh*Ww, Wh*Ww)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1)
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer('relative_position_index',
|
||||
relative_position_index)
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (tensor): input features with shape of (num_windows*B, N, C).
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
|
||||
if self.bias == 'qv_bias':
|
||||
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
|
||||
qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias))
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
else:
|
||||
qkv = self.qkv(x)
|
||||
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
if self.relative_position_bias_table is not None:
|
||||
Wh = self.window_size[0]
|
||||
Ww = self.window_size[1]
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
Wh * Ww + 1, Wh * Ww + 1, -1)
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class BEiTTransformerEncoderLayer(VisionTransformerEncoderLayer):
|
||||
"""Implements one encoder layer in Vision Transformer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default: 0.0.
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
bias (bool): The option to add leanable bias for q, k, v. If bias is
|
||||
True, it will add leanable bias. If bias is 'qv_bias', it will only
|
||||
add leanable bias for q, v. If bias is False, it will not add bias
|
||||
for q, k, v. Default to 'qv_bias'.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
window_size (tuple[int], optional): The height and width of the window.
|
||||
Default: None.
|
||||
init_values (float, optional): Initialize the values of BEiTAttention
|
||||
and FFN with learnable scaling. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
bias='qv_bias',
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
window_size=None,
|
||||
attn_cfg=dict(),
|
||||
ffn_cfg=dict(add_identity=False),
|
||||
init_values=None):
|
||||
attn_cfg.update(dict(window_size=window_size, qk_scale=None))
|
||||
|
||||
super().__init__(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=feedforward_channels,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=0.,
|
||||
drop_rate=0.,
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
attn_cfg=attn_cfg,
|
||||
ffn_cfg=ffn_cfg)
|
||||
|
||||
# NOTE: drop path for stochastic depth, we shall see if
|
||||
# this is better than dropout here
|
||||
dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
self.drop_path = build_dropout(
|
||||
dropout_layer) if dropout_layer else nn.Identity()
|
||||
self.gamma_1 = nn.Parameter(
|
||||
init_values * torch.ones(embed_dims), requires_grad=True)
|
||||
self.gamma_2 = nn.Parameter(
|
||||
init_values * torch.ones(embed_dims), requires_grad=True)
|
||||
|
||||
def build_attn(self, attn_cfg):
|
||||
self.attn = BEiTAttention(**attn_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BEiT(BaseModule):
|
||||
"""BERT Pre-Training of Image Transformers.
|
||||
|
||||
Args:
|
||||
img_size (int | tuple): Input image size. Default: 224.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): Embedding dimension. Default: 768.
|
||||
num_layers (int): Depth of transformer. Default: 12.
|
||||
num_heads (int): Number of attention heads. Default: 12.
|
||||
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
out_indices (list | tuple | int): Output from which stages.
|
||||
Default: -1.
|
||||
qv_bias (bool): Enable bias for qv if True. Default: True.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
|
||||
Default: False.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Default: False.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
pretrained (str, optional): Model pretrained path. Default: None.
|
||||
init_values (float): Initialize the values of BEiTAttention and FFN
|
||||
with learnable scaling.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dims=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
out_indices=-1,
|
||||
qv_bias=True,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
patch_norm=False,
|
||||
final_norm=False,
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
pretrained=None,
|
||||
init_values=0.1,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
if isinstance(img_size, int):
|
||||
img_size = to_2tuple(img_size)
|
||||
elif isinstance(img_size, tuple):
|
||||
if len(img_size) == 1:
|
||||
img_size = to_2tuple(img_size[0])
|
||||
assert len(img_size) == 2, \
|
||||
f'The size of image should have length 1 or 2, ' \
|
||||
f'but got {len(img_size)}'
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.norm_eval = norm_eval
|
||||
self.pretrained = pretrained
|
||||
self.num_layers = num_layers
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.attn_drop_rate = attn_drop_rate
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.num_fcs = num_fcs
|
||||
self.qv_bias = qv_bias
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.patch_norm = patch_norm
|
||||
self.init_values = init_values
|
||||
self.window_size = (img_size[0] // patch_size,
|
||||
img_size[1] // patch_size)
|
||||
self.patch_shape = self.window_size
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
|
||||
self._build_patch_embedding()
|
||||
self._build_layers()
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
if out_indices == -1:
|
||||
out_indices = num_layers - 1
|
||||
self.out_indices = [out_indices]
|
||||
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
|
||||
self.out_indices = out_indices
|
||||
else:
|
||||
raise TypeError('out_indices must be type of int, list or tuple')
|
||||
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
def _build_patch_embedding(self):
|
||||
"""Build patch embedding layer."""
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_channels=self.in_channels,
|
||||
embed_dims=self.embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
padding=0,
|
||||
norm_cfg=self.norm_cfg if self.patch_norm else None,
|
||||
init_cfg=None)
|
||||
|
||||
def _build_layers(self):
|
||||
"""Build transformer encoding layers."""
|
||||
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, self.drop_path_rate, self.num_layers)
|
||||
]
|
||||
self.layers = ModuleList()
|
||||
for i in range(self.num_layers):
|
||||
self.layers.append(
|
||||
BEiTTransformerEncoderLayer(
|
||||
embed_dims=self.embed_dims,
|
||||
num_heads=self.num_heads,
|
||||
feedforward_channels=self.mlp_ratio * self.embed_dims,
|
||||
attn_drop_rate=self.attn_drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=self.num_fcs,
|
||||
bias='qv_bias' if self.qv_bias else False,
|
||||
act_cfg=self.act_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
window_size=self.window_size,
|
||||
init_values=self.init_values))
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def _geometric_sequence_interpolation(self, src_size, dst_size, sequence,
|
||||
num):
|
||||
"""Get new sequence via geometric sequence interpolation.
|
||||
|
||||
Args:
|
||||
src_size (int): Pos_embedding size in pre-trained model.
|
||||
dst_size (int): Pos_embedding size in the current model.
|
||||
sequence (tensor): The relative position bias of the pretrain
|
||||
model after removing the extra tokens.
|
||||
num (int): Number of attention heads.
|
||||
Returns:
|
||||
new_sequence (tensor): Geometric sequence interpolate the
|
||||
pre-trained relative position bias to the size of
|
||||
the current model.
|
||||
"""
|
||||
|
||||
def geometric_progression(a, r, n):
|
||||
return a * (1.0 - r**n) / (1.0 - r)
|
||||
|
||||
# Here is a binary function.
|
||||
left, right = 1.01, 1.5
|
||||
while right - left > 1e-6:
|
||||
q = (left + right) / 2.0
|
||||
gp = geometric_progression(1, q, src_size // 2)
|
||||
if gp > dst_size // 2:
|
||||
right = q
|
||||
else:
|
||||
left = q
|
||||
# The position of each interpolated point is determined
|
||||
# by the ratio obtained by dichotomy.
|
||||
dis = []
|
||||
cur = 1
|
||||
for i in range(src_size // 2):
|
||||
dis.append(cur)
|
||||
cur += q**(i + 1)
|
||||
r_ids = [-_ for _ in reversed(dis)]
|
||||
x = r_ids + [0] + dis
|
||||
y = r_ids + [0] + dis
|
||||
t = dst_size // 2.0
|
||||
dx = np.arange(-t, t + 0.1, 1.0)
|
||||
dy = np.arange(-t, t + 0.1, 1.0)
|
||||
# Interpolation functions are being executed and called.
|
||||
new_sequence = []
|
||||
for i in range(num):
|
||||
z = sequence[:, i].view(src_size, src_size).float().numpy()
|
||||
f = interpolate.interp2d(x, y, z, kind='cubic')
|
||||
new_sequence.append(
|
||||
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(sequence))
|
||||
new_sequence = torch.cat(new_sequence, dim=-1)
|
||||
return new_sequence
|
||||
|
||||
def resize_rel_pos_embed(self, checkpoint):
|
||||
"""Resize relative pos_embed weights.
|
||||
|
||||
This function is modified from
|
||||
https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/checkpoint.py. # noqa: E501
|
||||
Copyright (c) Microsoft Corporation
|
||||
Licensed under the MIT License
|
||||
Args:
|
||||
checkpoint (dict): Key and value of the pretrain model.
|
||||
Returns:
|
||||
state_dict (dict): Interpolate the relative pos_embed weights
|
||||
in the pre-train model to the current model size.
|
||||
"""
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
all_keys = list(state_dict.keys())
|
||||
for key in all_keys:
|
||||
if 'relative_position_index' in key:
|
||||
state_dict.pop(key)
|
||||
# In order to keep the center of pos_bias as consistent as
|
||||
# possible after interpolation, and vice versa in the edge
|
||||
# area, the geometric sequence interpolation method is adopted.
|
||||
if 'relative_position_bias_table' in key:
|
||||
rel_pos_bias = state_dict[key]
|
||||
src_num_pos, num_attn_heads = rel_pos_bias.size()
|
||||
dst_num_pos, _ = self.state_dict()[key].size()
|
||||
dst_patch_shape = self.patch_shape
|
||||
if dst_patch_shape[0] != dst_patch_shape[1]:
|
||||
raise NotImplementedError()
|
||||
# Count the number of extra tokens.
|
||||
num_extra_tokens = dst_num_pos - (
|
||||
dst_patch_shape[0] * 2 - 1) * (
|
||||
dst_patch_shape[1] * 2 - 1)
|
||||
src_size = int((src_num_pos - num_extra_tokens)**0.5)
|
||||
dst_size = int((dst_num_pos - num_extra_tokens)**0.5)
|
||||
if src_size != dst_size:
|
||||
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
||||
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
||||
new_rel_pos_bias = self._geometric_sequence_interpolation(
|
||||
src_size, dst_size, rel_pos_bias, num_attn_heads)
|
||||
new_rel_pos_bias = torch.cat(
|
||||
(new_rel_pos_bias, extra_tokens), dim=0)
|
||||
state_dict[key] = new_rel_pos_bias
|
||||
|
||||
return state_dict
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
def _init_weights(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
self.apply(_init_weights)
|
||||
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
checkpoint = _load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
state_dict = self.resize_rel_pos_embed(checkpoint)
|
||||
self.load_state_dict(state_dict, False)
|
||||
elif self.init_cfg is not None:
|
||||
super().init_weights()
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
# Copyright 2019 Ross Wightman
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def forward(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
|
||||
x, hw_shape = self.patch_embed(inputs)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i == len(self.layers) - 1:
|
||||
if self.final_norm:
|
||||
x = self.norm1(x)
|
||||
if i in self.out_indices:
|
||||
# Remove class token and reshape token for decoder head
|
||||
out = x[:, 1:]
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.LayerNorm):
|
||||
m.eval()
|
||||
332
finetune/mmseg/models/backbones/bisenetv1.py
Normal file
332
finetune/mmseg/models/backbones/bisenetv1.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class SpatialPath(BaseModule):
|
||||
"""Spatial Path to preserve the spatial size of the original input image
|
||||
and encode affluent spatial information.
|
||||
|
||||
Args:
|
||||
in_channels(int): The number of channels of input
|
||||
image. Default: 3.
|
||||
num_channels (Tuple[int]): The number of channels of
|
||||
each layers in Spatial Path.
|
||||
Default: (64, 64, 64, 128).
|
||||
Returns:
|
||||
x (torch.Tensor): Feature map for Feature Fusion Module.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
num_channels=(64, 64, 64, 128),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(num_channels) == 4, 'Length of input channels \
|
||||
of Spatial Path must be 4!'
|
||||
|
||||
self.layers = []
|
||||
for i in range(len(num_channels)):
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.layers.append(layer_name)
|
||||
if i == 0:
|
||||
self.add_module(
|
||||
layer_name,
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=num_channels[i],
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
elif i == len(num_channels) - 1:
|
||||
self.add_module(
|
||||
layer_name,
|
||||
ConvModule(
|
||||
in_channels=num_channels[i - 1],
|
||||
out_channels=num_channels[i],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
else:
|
||||
self.add_module(
|
||||
layer_name,
|
||||
ConvModule(
|
||||
in_channels=num_channels[i - 1],
|
||||
out_channels=num_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer_stage = getattr(self, layer_name)
|
||||
x = layer_stage(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionRefinementModule(BaseModule):
|
||||
"""Attention Refinement Module (ARM) to refine the features of each stage.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
Returns:
|
||||
x_out (torch.Tensor): Feature map of Attention Refinement Module.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channel,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv_layer = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.atten_conv_layer = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
ConvModule(
|
||||
in_channels=out_channel,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_layer(x)
|
||||
x_atten = self.atten_conv_layer(x)
|
||||
x_out = x * x_atten
|
||||
return x_out
|
||||
|
||||
|
||||
class ContextPath(BaseModule):
|
||||
"""Context Path to provide sufficient receptive field.
|
||||
|
||||
Args:
|
||||
backbone_cfg:(dict): Config of backbone of
|
||||
Context Path.
|
||||
context_channels (Tuple[int]): The number of channel numbers
|
||||
of various modules in Context Path.
|
||||
Default: (128, 256, 512).
|
||||
align_corners (bool, optional): The align_corners argument of
|
||||
resize operation. Default: False.
|
||||
Returns:
|
||||
x_16_up, x_32_up (torch.Tensor, torch.Tensor): Two feature maps
|
||||
undergoing upsampling from 1/16 and 1/32 downsampling
|
||||
feature maps. These two feature maps are used for Feature
|
||||
Fusion Module and Auxiliary Head.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_cfg,
|
||||
context_channels=(128, 256, 512),
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(context_channels) == 3, 'Length of input channels \
|
||||
of Context Path must be 3!'
|
||||
|
||||
self.backbone = MODELS.build(backbone_cfg)
|
||||
|
||||
self.align_corners = align_corners
|
||||
self.arm16 = AttentionRefinementModule(context_channels[1],
|
||||
context_channels[0])
|
||||
self.arm32 = AttentionRefinementModule(context_channels[2],
|
||||
context_channels[0])
|
||||
self.conv_head32 = ConvModule(
|
||||
in_channels=context_channels[0],
|
||||
out_channels=context_channels[0],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.conv_head16 = ConvModule(
|
||||
in_channels=context_channels[0],
|
||||
out_channels=context_channels[0],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.gap_conv = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
ConvModule(
|
||||
in_channels=context_channels[2],
|
||||
out_channels=context_channels[0],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def forward(self, x):
|
||||
x_4, x_8, x_16, x_32 = self.backbone(x)
|
||||
x_gap = self.gap_conv(x_32)
|
||||
|
||||
x_32_arm = self.arm32(x_32)
|
||||
x_32_sum = x_32_arm + x_gap
|
||||
x_32_up = resize(input=x_32_sum, size=x_16.shape[2:], mode='nearest')
|
||||
x_32_up = self.conv_head32(x_32_up)
|
||||
|
||||
x_16_arm = self.arm16(x_16)
|
||||
x_16_sum = x_16_arm + x_32_up
|
||||
x_16_up = resize(input=x_16_sum, size=x_8.shape[2:], mode='nearest')
|
||||
x_16_up = self.conv_head16(x_16_up)
|
||||
|
||||
return x_16_up, x_32_up
|
||||
|
||||
|
||||
class FeatureFusionModule(BaseModule):
|
||||
"""Feature Fusion Module to fuse low level output feature of Spatial Path
|
||||
and high level output feature of Context Path.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
Returns:
|
||||
x_out (torch.Tensor): Feature map of Feature Fusion Module.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.gap = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.conv_atten = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg), nn.Sigmoid())
|
||||
|
||||
def forward(self, x_sp, x_cp):
|
||||
x_concat = torch.cat([x_sp, x_cp], dim=1)
|
||||
x_fuse = self.conv1(x_concat)
|
||||
x_atten = self.gap(x_fuse)
|
||||
# Note: No BN and more 1x1 conv in paper.
|
||||
x_atten = self.conv_atten(x_atten)
|
||||
x_atten = x_fuse * x_atten
|
||||
x_out = x_atten + x_fuse
|
||||
return x_out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BiSeNetV1(BaseModule):
|
||||
"""BiSeNetV1 backbone.
|
||||
|
||||
This backbone is the implementation of `BiSeNet: Bilateral
|
||||
Segmentation Network for Real-time Semantic
|
||||
Segmentation <https://arxiv.org/abs/1808.00897>`_.
|
||||
|
||||
Args:
|
||||
backbone_cfg:(dict): Config of backbone of
|
||||
Context Path.
|
||||
in_channels (int): The number of channels of input
|
||||
image. Default: 3.
|
||||
spatial_channels (Tuple[int]): Size of channel numbers of
|
||||
various layers in Spatial Path.
|
||||
Default: (64, 64, 64, 128).
|
||||
context_channels (Tuple[int]): Size of channel numbers of
|
||||
various modules in Context Path.
|
||||
Default: (128, 256, 512).
|
||||
out_indices (Tuple[int] | int, optional): Output from which stages.
|
||||
Default: (0, 1, 2).
|
||||
align_corners (bool, optional): The align_corners argument of
|
||||
resize operation in Bilateral Guided Aggregation Layer.
|
||||
Default: False.
|
||||
out_channels(int): The number of channels of output.
|
||||
It must be the same with `in_channels` of decode_head.
|
||||
Default: 256.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_cfg,
|
||||
in_channels=3,
|
||||
spatial_channels=(64, 64, 64, 128),
|
||||
context_channels=(128, 256, 512),
|
||||
out_indices=(0, 1, 2),
|
||||
align_corners=False,
|
||||
out_channels=256,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(spatial_channels) == 4, 'Length of input channels \
|
||||
of Spatial Path must be 4!'
|
||||
|
||||
assert len(context_channels) == 3, 'Length of input channels \
|
||||
of Context Path must be 3!'
|
||||
|
||||
self.out_indices = out_indices
|
||||
self.align_corners = align_corners
|
||||
self.context_path = ContextPath(backbone_cfg, context_channels,
|
||||
self.align_corners)
|
||||
self.spatial_path = SpatialPath(in_channels, spatial_channels)
|
||||
self.ffm = FeatureFusionModule(context_channels[1], out_channels)
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
def forward(self, x):
|
||||
# stole refactoring code from Coin Cheung, thanks
|
||||
x_context8, x_context16 = self.context_path(x)
|
||||
x_spatial = self.spatial_path(x)
|
||||
x_fuse = self.ffm(x_spatial, x_context8)
|
||||
|
||||
outs = [x_fuse, x_context8, x_context16]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
||||
622
finetune/mmseg/models/backbones/bisenetv2.py
Normal file
622
finetune/mmseg/models/backbones/bisenetv2.py
Normal file
@@ -0,0 +1,622 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
|
||||
build_activation_layer, build_norm_layer)
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class DetailBranch(BaseModule):
|
||||
"""Detail Branch with wide channels and shallow layers to capture low-level
|
||||
details and generate high-resolution feature representation.
|
||||
|
||||
Args:
|
||||
detail_channels (Tuple[int]): Size of channel numbers of each stage
|
||||
in Detail Branch, in paper it has 3 stages.
|
||||
Default: (64, 64, 128).
|
||||
in_channels (int): Number of channels of input image. Default: 3.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Feature map of Detail Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
detail_channels=(64, 64, 128),
|
||||
in_channels=3,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
detail_branch = []
|
||||
for i in range(len(detail_channels)):
|
||||
if i == 0:
|
||||
detail_branch.append(
|
||||
nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)))
|
||||
else:
|
||||
detail_branch.append(
|
||||
nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i - 1],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)))
|
||||
self.detail_branch = nn.ModuleList(detail_branch)
|
||||
|
||||
def forward(self, x):
|
||||
for stage in self.detail_branch:
|
||||
x = stage(x)
|
||||
return x
|
||||
|
||||
|
||||
class StemBlock(BaseModule):
|
||||
"""Stem Block at the beginning of Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
Default: 3.
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 16.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): First feature map in Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.conv_first = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.convs = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels // 2,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=out_channels // 2,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.pool = nn.MaxPool2d(
|
||||
kernel_size=3, stride=2, padding=1, ceil_mode=False)
|
||||
self.fuse_last = ConvModule(
|
||||
in_channels=out_channels * 2,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_first(x)
|
||||
x_left = self.convs(x)
|
||||
x_right = self.pool(x)
|
||||
x = self.fuse_last(torch.cat([x_left, x_right], dim=1))
|
||||
return x
|
||||
|
||||
|
||||
class GELayer(BaseModule):
|
||||
"""Gather-and-Expansion Layer.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
exp_ratio (int): Expansion ratio for middle channels.
|
||||
Default: 6.
|
||||
stride (int): Stride of GELayer. Default: 1
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Intermediate feature map in
|
||||
Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
exp_ratio=6,
|
||||
stride=1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
mid_channel = in_channels * exp_ratio
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
if stride == 1:
|
||||
self.dwconv = nn.Sequential(
|
||||
# ReLU in ConvModule not shown in paper
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.shortcut = None
|
||||
else:
|
||||
self.dwconv = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
# ReLU in ConvModule not shown in paper
|
||||
ConvModule(
|
||||
in_channels=mid_channel,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=mid_channel,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
)
|
||||
self.shortcut = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=norm_cfg,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=mid_channel,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None,
|
||||
))
|
||||
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.conv1(x)
|
||||
x = self.dwconv(x)
|
||||
x = self.conv2(x)
|
||||
if self.shortcut is not None:
|
||||
shortcut = self.shortcut(identity)
|
||||
x = x + shortcut
|
||||
else:
|
||||
x = x + identity
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class CEBlock(BaseModule):
|
||||
"""Context Embedding Block for large receptive filed in Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
Default: 3.
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 16.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Last feature map in Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.gap = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
build_norm_layer(norm_cfg, self.in_channels)[1])
|
||||
self.conv_gap = ConvModule(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
# Note: in paper here is naive conv2d, no bn-relu
|
||||
self.conv_last = ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.gap(x)
|
||||
x = self.conv_gap(x)
|
||||
x = identity + x
|
||||
x = self.conv_last(x)
|
||||
return x
|
||||
|
||||
|
||||
class SemanticBranch(BaseModule):
|
||||
"""Semantic Branch which is lightweight with narrow channels and deep
|
||||
layers to obtain high-level semantic context.
|
||||
|
||||
Args:
|
||||
semantic_channels(Tuple[int]): Size of channel numbers of
|
||||
various stages in Semantic Branch.
|
||||
Default: (16, 32, 64, 128).
|
||||
in_channels (int): Number of channels of input image. Default: 3.
|
||||
exp_ratio (int): Expansion ratio for middle channels.
|
||||
Default: 6.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
semantic_outs (List[torch.Tensor]): List of several feature maps
|
||||
for auxiliary heads (Booster) and Bilateral
|
||||
Guided Aggregation Layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
semantic_channels=(16, 32, 64, 128),
|
||||
in_channels=3,
|
||||
exp_ratio=6,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.semantic_channels = semantic_channels
|
||||
self.semantic_stages = []
|
||||
for i in range(len(semantic_channels)):
|
||||
stage_name = f'stage{i + 1}'
|
||||
self.semantic_stages.append(stage_name)
|
||||
if i == 0:
|
||||
self.add_module(
|
||||
stage_name,
|
||||
StemBlock(self.in_channels, semantic_channels[i]))
|
||||
elif i == (len(semantic_channels) - 1):
|
||||
self.add_module(
|
||||
stage_name,
|
||||
nn.Sequential(
|
||||
GELayer(semantic_channels[i - 1], semantic_channels[i],
|
||||
exp_ratio, 2),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1)))
|
||||
else:
|
||||
self.add_module(
|
||||
stage_name,
|
||||
nn.Sequential(
|
||||
GELayer(semantic_channels[i - 1], semantic_channels[i],
|
||||
exp_ratio, 2),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1)))
|
||||
|
||||
self.add_module(f'stage{len(semantic_channels)}_CEBlock',
|
||||
CEBlock(semantic_channels[-1], semantic_channels[-1]))
|
||||
self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock')
|
||||
|
||||
def forward(self, x):
|
||||
semantic_outs = []
|
||||
for stage_name in self.semantic_stages:
|
||||
semantic_stage = getattr(self, stage_name)
|
||||
x = semantic_stage(x)
|
||||
semantic_outs.append(x)
|
||||
return semantic_outs
|
||||
|
||||
|
||||
class BGALayer(BaseModule):
|
||||
"""Bilateral Guided Aggregation Layer to fuse the complementary information
|
||||
from both Detail Branch and Semantic Branch.
|
||||
|
||||
Args:
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 128.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
output (torch.Tensor): Output feature map for Segment heads.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_channels=128,
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.out_channels = out_channels
|
||||
self.align_corners = align_corners
|
||||
self.detail_dwconv = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=None,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
self.detail_down = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False))
|
||||
self.semantic_conv = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None))
|
||||
self.semantic_dwconv = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=None,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
self.conv = ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
inplace=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
)
|
||||
|
||||
def forward(self, x_d, x_s):
|
||||
detail_dwconv = self.detail_dwconv(x_d)
|
||||
detail_down = self.detail_down(x_d)
|
||||
semantic_conv = self.semantic_conv(x_s)
|
||||
semantic_dwconv = self.semantic_dwconv(x_s)
|
||||
semantic_conv = resize(
|
||||
input=semantic_conv,
|
||||
size=detail_dwconv.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv)
|
||||
fuse_2 = detail_down * torch.sigmoid(semantic_dwconv)
|
||||
fuse_2 = resize(
|
||||
input=fuse_2,
|
||||
size=fuse_1.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
output = self.conv(fuse_1 + fuse_2)
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BiSeNetV2(BaseModule):
|
||||
"""BiSeNetV2: Bilateral Network with Guided Aggregation for
|
||||
Real-time Semantic Segmentation.
|
||||
|
||||
This backbone is the implementation of
|
||||
`BiSeNetV2 <https://arxiv.org/abs/2004.02147>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channel of input image. Default: 3.
|
||||
detail_channels (Tuple[int], optional): Channels of each stage
|
||||
in Detail Branch. Default: (64, 64, 128).
|
||||
semantic_channels (Tuple[int], optional): Channels of each stage
|
||||
in Semantic Branch. Default: (16, 32, 64, 128).
|
||||
See Table 1 and Figure 3 of paper for more details.
|
||||
semantic_expansion_ratio (int, optional): The expansion factor
|
||||
expanding channel number of middle channels in Semantic Branch.
|
||||
Default: 6.
|
||||
bga_channels (int, optional): Number of middle channels in
|
||||
Bilateral Guided Aggregation Layer. Default: 128.
|
||||
out_indices (Tuple[int] | int, optional): Output from which stages.
|
||||
Default: (0, 1, 2, 3, 4).
|
||||
align_corners (bool, optional): The align_corners argument of
|
||||
resize operation in Bilateral Guided Aggregation Layer.
|
||||
Default: False.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
detail_channels=(64, 64, 128),
|
||||
semantic_channels=(16, 32, 64, 128),
|
||||
semantic_expansion_ratio=6,
|
||||
bga_channels=128,
|
||||
out_indices=(0, 1, 2, 3, 4),
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
if init_cfg is None:
|
||||
init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_indices = out_indices
|
||||
self.detail_channels = detail_channels
|
||||
self.semantic_channels = semantic_channels
|
||||
self.semantic_expansion_ratio = semantic_expansion_ratio
|
||||
self.bga_channels = bga_channels
|
||||
self.align_corners = align_corners
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.detail = DetailBranch(self.detail_channels, self.in_channels)
|
||||
self.semantic = SemanticBranch(self.semantic_channels,
|
||||
self.in_channels,
|
||||
self.semantic_expansion_ratio)
|
||||
self.bga = BGALayer(self.bga_channels, self.align_corners)
|
||||
|
||||
def forward(self, x):
|
||||
# stole refactoring code from Coin Cheung, thanks
|
||||
x_detail = self.detail(x)
|
||||
x_semantic_lst = self.semantic(x)
|
||||
x_head = self.bga(x_detail, x_semantic_lst[-1])
|
||||
outs = [x_head] + x_semantic_lst[:-1]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
||||
372
finetune/mmseg/models/backbones/cgnet.py
Normal file
372
finetune/mmseg/models/backbones/cgnet.py
Normal file
@@ -0,0 +1,372 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class GlobalContextExtractor(nn.Module):
|
||||
"""Global Context Extractor for CGNet.
|
||||
|
||||
This class is employed to refine the joint feature of both local feature
|
||||
and surrounding context.
|
||||
|
||||
Args:
|
||||
channel (int): Number of input feature channels.
|
||||
reduction (int): Reductions for global context extractor. Default: 16.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, channel, reduction=16, with_cp=False):
|
||||
super().__init__()
|
||||
self.channel = channel
|
||||
self.reduction = reduction
|
||||
assert reduction >= 1 and channel >= reduction
|
||||
self.with_cp = with_cp
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
num_batch, num_channel = x.size()[:2]
|
||||
y = self.avg_pool(x).view(num_batch, num_channel)
|
||||
y = self.fc(y).view(num_batch, num_channel, 1, 1)
|
||||
return x * y
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ContextGuidedBlock(nn.Module):
|
||||
"""Context Guided Block for CGNet.
|
||||
|
||||
This class consists of four components: local feature extractor,
|
||||
surrounding feature extractor, joint feature extractor and global
|
||||
context extractor.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input feature channels.
|
||||
out_channels (int): Number of output feature channels.
|
||||
dilation (int): Dilation rate for surrounding context extractor.
|
||||
Default: 2.
|
||||
reduction (int): Reduction for global context extractor. Default: 16.
|
||||
skip_connect (bool): Add input to output or not. Default: True.
|
||||
downsample (bool): Downsample the input to 1/2 or not. Default: False.
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='PReLU').
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dilation=2,
|
||||
reduction=16,
|
||||
skip_connect=True,
|
||||
downsample=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='PReLU'),
|
||||
with_cp=False):
|
||||
super().__init__()
|
||||
self.with_cp = with_cp
|
||||
self.downsample = downsample
|
||||
|
||||
channels = out_channels if downsample else out_channels // 2
|
||||
if 'type' in act_cfg and act_cfg['type'] == 'PReLU':
|
||||
act_cfg['num_parameters'] = channels
|
||||
kernel_size = 3 if downsample else 1
|
||||
stride = 2 if downsample else 1
|
||||
padding = (kernel_size - 1) // 2
|
||||
|
||||
self.conv1x1 = ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.f_loc = build_conv_layer(
|
||||
conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
groups=channels,
|
||||
bias=False)
|
||||
self.f_sur = build_conv_layer(
|
||||
conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=dilation,
|
||||
groups=channels,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
|
||||
self.bn = build_norm_layer(norm_cfg, 2 * channels)[1]
|
||||
self.activate = nn.PReLU(2 * channels)
|
||||
|
||||
if downsample:
|
||||
self.bottleneck = build_conv_layer(
|
||||
conv_cfg,
|
||||
2 * channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
|
||||
self.skip_connect = skip_connect and not downsample
|
||||
self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
out = self.conv1x1(x)
|
||||
loc = self.f_loc(out)
|
||||
sur = self.f_sur(out)
|
||||
|
||||
joi_feat = torch.cat([loc, sur], 1) # the joint feature
|
||||
joi_feat = self.bn(joi_feat)
|
||||
joi_feat = self.activate(joi_feat)
|
||||
if self.downsample:
|
||||
joi_feat = self.bottleneck(joi_feat) # channel = out_channels
|
||||
# f_glo is employed to refine the joint feature
|
||||
out = self.f_glo(joi_feat)
|
||||
|
||||
if self.skip_connect:
|
||||
return x + out
|
||||
else:
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class InputInjection(nn.Module):
|
||||
"""Downsampling module for CGNet."""
|
||||
|
||||
def __init__(self, num_downsampling):
|
||||
super().__init__()
|
||||
self.pool = nn.ModuleList()
|
||||
for i in range(num_downsampling):
|
||||
self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
for pool in self.pool:
|
||||
x = pool(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CGNet(BaseModule):
|
||||
"""CGNet backbone.
|
||||
|
||||
This backbone is the implementation of `A Light-weight Context Guided
|
||||
Network for Semantic Segmentation <https://arxiv.org/abs/1811.08201>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Normally 3.
|
||||
num_channels (tuple[int]): Numbers of feature channels at each stages.
|
||||
Default: (32, 64, 128).
|
||||
num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2.
|
||||
Default: (3, 21).
|
||||
dilations (tuple[int]): Dilation rate for surrounding context
|
||||
extractors at stage 1 and stage 2. Default: (2, 4).
|
||||
reductions (tuple[int]): Reductions for global context extractors at
|
||||
stage 1 and stage 2. Default: (8, 16).
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='PReLU').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
num_channels=(32, 64, 128),
|
||||
num_blocks=(3, 21),
|
||||
dilations=(2, 4),
|
||||
reductions=(8, 16),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='PReLU'),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg)
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer=['Conv2d', 'Linear']),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm']),
|
||||
dict(type='Constant', val=0, layer='PReLU')
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.num_channels = num_channels
|
||||
assert isinstance(self.num_channels, tuple) and len(
|
||||
self.num_channels) == 3
|
||||
self.num_blocks = num_blocks
|
||||
assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2
|
||||
self.dilations = dilations
|
||||
assert isinstance(self.dilations, tuple) and len(self.dilations) == 2
|
||||
self.reductions = reductions
|
||||
assert isinstance(self.reductions, tuple) and len(self.reductions) == 2
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU':
|
||||
self.act_cfg['num_parameters'] = num_channels[0]
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
cur_channels = in_channels
|
||||
self.stem = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.stem.append(
|
||||
ConvModule(
|
||||
cur_channels,
|
||||
num_channels[0],
|
||||
3,
|
||||
2 if i == 0 else 1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
cur_channels = num_channels[0]
|
||||
|
||||
self.inject_2x = InputInjection(1) # down-sample for Input, factor=2
|
||||
self.inject_4x = InputInjection(2) # down-sample for Input, factor=4
|
||||
|
||||
cur_channels += in_channels
|
||||
self.norm_prelu_0 = nn.Sequential(
|
||||
build_norm_layer(norm_cfg, cur_channels)[1],
|
||||
nn.PReLU(cur_channels))
|
||||
|
||||
# stage 1
|
||||
self.level1 = nn.ModuleList()
|
||||
for i in range(num_blocks[0]):
|
||||
self.level1.append(
|
||||
ContextGuidedBlock(
|
||||
cur_channels if i == 0 else num_channels[1],
|
||||
num_channels[1],
|
||||
dilations[0],
|
||||
reductions[0],
|
||||
downsample=(i == 0),
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
with_cp=with_cp)) # CG block
|
||||
|
||||
cur_channels = 2 * num_channels[1] + in_channels
|
||||
self.norm_prelu_1 = nn.Sequential(
|
||||
build_norm_layer(norm_cfg, cur_channels)[1],
|
||||
nn.PReLU(cur_channels))
|
||||
|
||||
# stage 2
|
||||
self.level2 = nn.ModuleList()
|
||||
for i in range(num_blocks[1]):
|
||||
self.level2.append(
|
||||
ContextGuidedBlock(
|
||||
cur_channels if i == 0 else num_channels[2],
|
||||
num_channels[2],
|
||||
dilations[1],
|
||||
reductions[1],
|
||||
downsample=(i == 0),
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
with_cp=with_cp)) # CG block
|
||||
|
||||
cur_channels = 2 * num_channels[2]
|
||||
self.norm_prelu_2 = nn.Sequential(
|
||||
build_norm_layer(norm_cfg, cur_channels)[1],
|
||||
nn.PReLU(cur_channels))
|
||||
|
||||
def forward(self, x):
|
||||
output = []
|
||||
|
||||
# stage 0
|
||||
inp_2x = self.inject_2x(x)
|
||||
inp_4x = self.inject_4x(x)
|
||||
for layer in self.stem:
|
||||
x = layer(x)
|
||||
x = self.norm_prelu_0(torch.cat([x, inp_2x], 1))
|
||||
output.append(x)
|
||||
|
||||
# stage 1
|
||||
for i, layer in enumerate(self.level1):
|
||||
x = layer(x)
|
||||
if i == 0:
|
||||
down1 = x
|
||||
x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1))
|
||||
output.append(x)
|
||||
|
||||
# stage 2
|
||||
for i, layer in enumerate(self.level2):
|
||||
x = layer(x)
|
||||
if i == 0:
|
||||
down2 = x
|
||||
x = self.norm_prelu_2(torch.cat([down2, x], 1))
|
||||
output.append(x)
|
||||
|
||||
return output
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode will keeping the normalization
|
||||
layer freezed."""
|
||||
super().train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
222
finetune/mmseg/models/backbones/ddrnet.py
Normal file
222
finetune/mmseg/models/backbones/ddrnet.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.models.utils import DAPPM, BasicBlock, Bottleneck, resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DDRNet(BaseModule):
|
||||
"""DDRNet backbone.
|
||||
|
||||
This backbone is the implementation of `Deep Dual-resolution Networks for
|
||||
Real-time and Accurate Semantic Segmentation of Road Scenes
|
||||
<http://arxiv.org/abs/2101.06085>`_.
|
||||
Modified from https://github.com/ydhongHIT/DDRNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
channels: (int): The base channels of DDRNet. Default: 32.
|
||||
ppm_channels (int): The channels of PPM module. Default: 128.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
norm_cfg (dict): Config dict to build norm layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int = 3,
|
||||
channels: int = 32,
|
||||
ppm_channels: int = 128,
|
||||
align_corners: bool = False,
|
||||
norm_cfg: OptConfigType = dict(type='BN', requires_grad=True),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.ppm_channels = ppm_channels
|
||||
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.align_corners = align_corners
|
||||
|
||||
# stage 0-2
|
||||
self.stem = self._make_stem_layer(in_channels, channels, num_blocks=2)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
# low resolution(context) branch
|
||||
self.context_branch_layers = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.context_branch_layers.append(
|
||||
self._make_layer(
|
||||
block=BasicBlock if i < 2 else Bottleneck,
|
||||
inplanes=channels * 2**(i + 1),
|
||||
planes=channels * 8 if i > 0 else channels * 4,
|
||||
num_blocks=2 if i < 2 else 1,
|
||||
stride=2))
|
||||
|
||||
# bilateral fusion
|
||||
self.compression_1 = ConvModule(
|
||||
channels * 4,
|
||||
channels * 2,
|
||||
kernel_size=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
self.down_1 = ConvModule(
|
||||
channels * 2,
|
||||
channels * 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
self.compression_2 = ConvModule(
|
||||
channels * 8,
|
||||
channels * 2,
|
||||
kernel_size=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
self.down_2 = nn.Sequential(
|
||||
ConvModule(
|
||||
channels * 2,
|
||||
channels * 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
channels * 4,
|
||||
channels * 8,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None))
|
||||
|
||||
# high resolution(spatial) branch
|
||||
self.spatial_branch_layers = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.spatial_branch_layers.append(
|
||||
self._make_layer(
|
||||
block=BasicBlock if i < 2 else Bottleneck,
|
||||
inplanes=channels * 2,
|
||||
planes=channels * 2,
|
||||
num_blocks=2 if i < 2 else 1,
|
||||
))
|
||||
|
||||
self.spp = DAPPM(
|
||||
channels * 16, ppm_channels, channels * 4, num_scales=5)
|
||||
|
||||
def _make_stem_layer(self, in_channels, channels, num_blocks):
|
||||
layers = [
|
||||
ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
]
|
||||
|
||||
layers.extend([
|
||||
self._make_layer(BasicBlock, channels, channels, num_blocks),
|
||||
nn.ReLU(),
|
||||
self._make_layer(
|
||||
BasicBlock, channels, channels * 2, num_blocks, stride=2),
|
||||
nn.ReLU(),
|
||||
])
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_layer(self, block, inplanes, planes, num_blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
|
||||
|
||||
layers = [
|
||||
block(
|
||||
in_channels=inplanes,
|
||||
channels=planes,
|
||||
stride=stride,
|
||||
downsample=downsample)
|
||||
]
|
||||
inplanes = planes * block.expansion
|
||||
for i in range(1, num_blocks):
|
||||
layers.append(
|
||||
block(
|
||||
in_channels=inplanes,
|
||||
channels=planes,
|
||||
stride=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg_out=None if i == num_blocks - 1 else self.act_cfg))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
out_size = (x.shape[-2] // 8, x.shape[-1] // 8)
|
||||
|
||||
# stage 0-2
|
||||
x = self.stem(x)
|
||||
|
||||
# stage3
|
||||
x_c = self.context_branch_layers[0](x)
|
||||
x_s = self.spatial_branch_layers[0](x)
|
||||
comp_c = self.compression_1(self.relu(x_c))
|
||||
x_c += self.down_1(self.relu(x_s))
|
||||
x_s += resize(
|
||||
comp_c,
|
||||
size=out_size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.training:
|
||||
temp_context = x_s.clone()
|
||||
|
||||
# stage4
|
||||
x_c = self.context_branch_layers[1](self.relu(x_c))
|
||||
x_s = self.spatial_branch_layers[1](self.relu(x_s))
|
||||
comp_c = self.compression_2(self.relu(x_c))
|
||||
x_c += self.down_2(self.relu(x_s))
|
||||
x_s += resize(
|
||||
comp_c,
|
||||
size=out_size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
# stage5
|
||||
x_s = self.spatial_branch_layers[2](self.relu(x_s))
|
||||
x_c = self.context_branch_layers[2](self.relu(x_c))
|
||||
x_c = self.spp(x_c)
|
||||
x_c = resize(
|
||||
x_c,
|
||||
size=out_size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
return (temp_context, x_s + x_c) if self.training else x_s + x_c
|
||||
329
finetune/mmseg/models/backbones/erfnet.py
Normal file
329
finetune/mmseg/models/backbones/erfnet.py
Normal file
@@ -0,0 +1,329 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class DownsamplerBlock(BaseModule):
|
||||
"""Downsampler block of ERFNet.
|
||||
|
||||
This module is a little different from basical ConvModule.
|
||||
The features from Conv and MaxPool layers are
|
||||
concatenated before BatchNorm.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', eps=1e-3),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.conv = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
out_channels - in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1)
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.bn = build_norm_layer(self.norm_cfg, out_channels)[1]
|
||||
self.act = build_activation_layer(self.act_cfg)
|
||||
|
||||
def forward(self, input):
|
||||
conv_out = self.conv(input)
|
||||
pool_out = self.pool(input)
|
||||
pool_out = resize(
|
||||
input=pool_out,
|
||||
size=conv_out.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
output = torch.cat([conv_out, pool_out], 1)
|
||||
output = self.bn(output)
|
||||
output = self.act(output)
|
||||
return output
|
||||
|
||||
|
||||
class NonBottleneck1d(BaseModule):
|
||||
"""Non-bottleneck block of ERFNet.
|
||||
|
||||
Args:
|
||||
channels (int): Number of channels in Non-bottleneck block.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.
|
||||
dilation (int): Dilation rate for last two conv layers.
|
||||
Default 1.
|
||||
num_conv_layer (int): Number of 3x1 and 1x3 convolution layers.
|
||||
Default 2.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
drop_rate=0,
|
||||
dilation=1,
|
||||
num_conv_layer=2,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', eps=1e-3),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.act = build_activation_layer(self.act_cfg)
|
||||
|
||||
self.convs_layers = nn.ModuleList()
|
||||
for conv_layer in range(num_conv_layer):
|
||||
first_conv_padding = (1, 0) if conv_layer == 0 else (dilation, 0)
|
||||
first_conv_dilation = 1 if conv_layer == 0 else (dilation, 1)
|
||||
second_conv_padding = (0, 1) if conv_layer == 0 else (0, dilation)
|
||||
second_conv_dilation = 1 if conv_layer == 0 else (1, dilation)
|
||||
|
||||
self.convs_layers.append(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=(3, 1),
|
||||
stride=1,
|
||||
padding=first_conv_padding,
|
||||
bias=True,
|
||||
dilation=first_conv_dilation))
|
||||
self.convs_layers.append(self.act)
|
||||
self.convs_layers.append(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=(1, 3),
|
||||
stride=1,
|
||||
padding=second_conv_padding,
|
||||
bias=True,
|
||||
dilation=second_conv_dilation))
|
||||
self.convs_layers.append(
|
||||
build_norm_layer(self.norm_cfg, channels)[1])
|
||||
if conv_layer == 0:
|
||||
self.convs_layers.append(self.act)
|
||||
else:
|
||||
self.convs_layers.append(nn.Dropout(p=drop_rate))
|
||||
|
||||
def forward(self, input):
|
||||
output = input
|
||||
for conv in self.convs_layers:
|
||||
output = conv(output)
|
||||
output = self.act(output + input)
|
||||
return output
|
||||
|
||||
|
||||
class UpsamplerBlock(BaseModule):
|
||||
"""Upsampler block of ERFNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', eps=1e-3),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.conv = nn.ConvTranspose2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
bias=True)
|
||||
self.bn = build_norm_layer(self.norm_cfg, out_channels)[1]
|
||||
self.act = build_activation_layer(self.act_cfg)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.conv(input)
|
||||
output = self.bn(output)
|
||||
output = self.act(output)
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ERFNet(BaseModule):
|
||||
"""ERFNet backbone.
|
||||
|
||||
This backbone is the implementation of `ERFNet: Efficient Residual
|
||||
Factorized ConvNet for Real-time SemanticSegmentation
|
||||
<https://ieeexplore.ieee.org/document/8063438>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels of input
|
||||
image. Default: 3.
|
||||
enc_downsample_channels (Tuple[int]): Size of channel
|
||||
numbers of various Downsampler block in encoder.
|
||||
Default: (16, 64, 128).
|
||||
enc_stage_non_bottlenecks (Tuple[int]): Number of stages of
|
||||
Non-bottleneck block in encoder.
|
||||
Default: (5, 8).
|
||||
enc_non_bottleneck_dilations (Tuple[int]): Dilation rate of each
|
||||
stage of Non-bottleneck block of encoder.
|
||||
Default: (2, 4, 8, 16).
|
||||
enc_non_bottleneck_channels (Tuple[int]): Size of channel
|
||||
numbers of various Non-bottleneck block in encoder.
|
||||
Default: (64, 128).
|
||||
dec_upsample_channels (Tuple[int]): Size of channel numbers of
|
||||
various Deconvolution block in decoder.
|
||||
Default: (64, 16).
|
||||
dec_stages_non_bottleneck (Tuple[int]): Number of stages of
|
||||
Non-bottleneck block in decoder.
|
||||
Default: (2, 2).
|
||||
dec_non_bottleneck_channels (Tuple[int]): Size of channel
|
||||
numbers of various Non-bottleneck block in decoder.
|
||||
Default: (64, 16).
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(enc_downsample_channels) \
|
||||
== len(dec_upsample_channels)+1, 'Number of downsample\
|
||||
block of encoder does not \
|
||||
match number of upsample block of decoder!'
|
||||
assert len(enc_downsample_channels) \
|
||||
== len(enc_stage_non_bottlenecks)+1, 'Number of \
|
||||
downsample block of encoder does not match \
|
||||
number of Non-bottleneck block of encoder!'
|
||||
assert len(enc_downsample_channels) \
|
||||
== len(enc_non_bottleneck_channels)+1, 'Number of \
|
||||
downsample block of encoder does not match \
|
||||
number of channels of Non-bottleneck block of encoder!'
|
||||
assert enc_stage_non_bottlenecks[-1] \
|
||||
% len(enc_non_bottleneck_dilations) == 0, 'Number of \
|
||||
Non-bottleneck block of encoder does not match \
|
||||
number of Non-bottleneck block of encoder!'
|
||||
assert len(dec_upsample_channels) \
|
||||
== len(dec_stages_non_bottleneck), 'Number of \
|
||||
upsample block of decoder does not match \
|
||||
number of Non-bottleneck block of decoder!'
|
||||
assert len(dec_stages_non_bottleneck) \
|
||||
== len(dec_non_bottleneck_channels), 'Number of \
|
||||
Non-bottleneck block of decoder does not match \
|
||||
number of channels of Non-bottleneck block of decoder!'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.enc_downsample_channels = enc_downsample_channels
|
||||
self.enc_stage_non_bottlenecks = enc_stage_non_bottlenecks
|
||||
self.enc_non_bottleneck_dilations = enc_non_bottleneck_dilations
|
||||
self.enc_non_bottleneck_channels = enc_non_bottleneck_channels
|
||||
self.dec_upsample_channels = dec_upsample_channels
|
||||
self.dec_stages_non_bottleneck = dec_stages_non_bottleneck
|
||||
self.dec_non_bottleneck_channels = dec_non_bottleneck_channels
|
||||
self.dropout_ratio = dropout_ratio
|
||||
|
||||
self.encoder = nn.ModuleList()
|
||||
self.decoder = nn.ModuleList()
|
||||
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.encoder.append(
|
||||
DownsamplerBlock(self.in_channels, enc_downsample_channels[0]))
|
||||
|
||||
for i in range(len(enc_downsample_channels) - 1):
|
||||
self.encoder.append(
|
||||
DownsamplerBlock(enc_downsample_channels[i],
|
||||
enc_downsample_channels[i + 1]))
|
||||
# Last part of encoder is some dilated NonBottleneck1d blocks.
|
||||
if i == len(enc_downsample_channels) - 2:
|
||||
iteration_times = int(enc_stage_non_bottlenecks[-1] /
|
||||
len(enc_non_bottleneck_dilations))
|
||||
for j in range(iteration_times):
|
||||
for k in range(len(enc_non_bottleneck_dilations)):
|
||||
self.encoder.append(
|
||||
NonBottleneck1d(enc_downsample_channels[-1],
|
||||
self.dropout_ratio,
|
||||
enc_non_bottleneck_dilations[k]))
|
||||
else:
|
||||
for j in range(enc_stage_non_bottlenecks[i]):
|
||||
self.encoder.append(
|
||||
NonBottleneck1d(enc_downsample_channels[i + 1],
|
||||
self.dropout_ratio))
|
||||
|
||||
for i in range(len(dec_upsample_channels)):
|
||||
if i == 0:
|
||||
self.decoder.append(
|
||||
UpsamplerBlock(enc_downsample_channels[-1],
|
||||
dec_non_bottleneck_channels[i]))
|
||||
else:
|
||||
self.decoder.append(
|
||||
UpsamplerBlock(dec_non_bottleneck_channels[i - 1],
|
||||
dec_non_bottleneck_channels[i]))
|
||||
for j in range(dec_stages_non_bottleneck[i]):
|
||||
self.decoder.append(
|
||||
NonBottleneck1d(dec_non_bottleneck_channels[i]))
|
||||
|
||||
def forward(self, x):
|
||||
for enc in self.encoder:
|
||||
x = enc(x)
|
||||
for dec in self.decoder:
|
||||
x = dec(x)
|
||||
return [x]
|
||||
408
finetune/mmseg/models/backbones/fast_scnn.py
Normal file
408
finetune/mmseg/models/backbones/fast_scnn.py
Normal file
@@ -0,0 +1,408 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.models.decode_heads.psp_head import PPM
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import InvertedResidual, resize
|
||||
|
||||
|
||||
class LearningToDownsample(nn.Module):
|
||||
"""Learning to downsample module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
dw_channels (tuple[int]): Number of output channels of the first and
|
||||
the second depthwise conv (dwconv) layers.
|
||||
out_channels (int): Number of output channels of the whole
|
||||
'learning to downsample' module.
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
act_cfg (dict): Config of activation layers. Default:
|
||||
dict(type='ReLU')
|
||||
dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config
|
||||
of depthwise ConvModule. If it is 'default', it will be the same
|
||||
as `act_cfg`. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
dw_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
dw_act_cfg=None):
|
||||
super().__init__()
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.dw_act_cfg = dw_act_cfg
|
||||
dw_channels1 = dw_channels[0]
|
||||
dw_channels2 = dw_channels[1]
|
||||
|
||||
self.conv = ConvModule(
|
||||
in_channels,
|
||||
dw_channels1,
|
||||
3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.dsconv1 = DepthwiseSeparableConvModule(
|
||||
dw_channels1,
|
||||
dw_channels2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dw_act_cfg=self.dw_act_cfg)
|
||||
|
||||
self.dsconv2 = DepthwiseSeparableConvModule(
|
||||
dw_channels2,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dw_act_cfg=self.dw_act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.dsconv1(x)
|
||||
x = self.dsconv2(x)
|
||||
return x
|
||||
|
||||
|
||||
class GlobalFeatureExtractor(nn.Module):
|
||||
"""Global feature extractor module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels of the GFE module.
|
||||
Default: 64
|
||||
block_channels (tuple[int]): Tuple of ints. Each int specifies the
|
||||
number of output channels of each Inverted Residual module.
|
||||
Default: (64, 96, 128)
|
||||
out_channels(int): Number of output channels of the GFE module.
|
||||
Default: 128
|
||||
expand_ratio (int): Adjusts number of channels of the hidden layer
|
||||
in InvertedResidual by this amount.
|
||||
Default: 6
|
||||
num_blocks (tuple[int]): Tuple of ints. Each int specifies the
|
||||
number of times each Inverted Residual module is repeated.
|
||||
The repeated Inverted Residual modules are called a 'group'.
|
||||
Default: (3, 3, 3)
|
||||
strides (tuple[int]): Tuple of ints. Each int specifies
|
||||
the downsampling factor of each 'group'.
|
||||
Default: (2, 2, 1)
|
||||
pool_scales (tuple[int]): Tuple of ints. Each int specifies
|
||||
the parameter required in 'global average pooling' within PPM.
|
||||
Default: (1, 2, 3, 6)
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
act_cfg (dict): Config of activation layers. Default:
|
||||
dict(type='ReLU')
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=64,
|
||||
block_channels=(64, 96, 128),
|
||||
out_channels=128,
|
||||
expand_ratio=6,
|
||||
num_blocks=(3, 3, 3),
|
||||
strides=(2, 2, 1),
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False):
|
||||
super().__init__()
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
assert len(block_channels) == len(num_blocks) == 3
|
||||
self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
|
||||
num_blocks[0], strides[0],
|
||||
expand_ratio)
|
||||
self.bottleneck2 = self._make_layer(block_channels[0],
|
||||
block_channels[1], num_blocks[1],
|
||||
strides[1], expand_ratio)
|
||||
self.bottleneck3 = self._make_layer(block_channels[1],
|
||||
block_channels[2], num_blocks[2],
|
||||
strides[2], expand_ratio)
|
||||
self.ppm = PPM(
|
||||
pool_scales,
|
||||
block_channels[2],
|
||||
block_channels[2] // 4,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=align_corners)
|
||||
|
||||
self.out = ConvModule(
|
||||
block_channels[2] * 2,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def _make_layer(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
blocks,
|
||||
stride=1,
|
||||
expand_ratio=6):
|
||||
layers = [
|
||||
InvertedResidual(
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
expand_ratio,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
]
|
||||
for i in range(1, blocks):
|
||||
layers.append(
|
||||
InvertedResidual(
|
||||
out_channels,
|
||||
out_channels,
|
||||
1,
|
||||
expand_ratio,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.bottleneck1(x)
|
||||
x = self.bottleneck2(x)
|
||||
x = self.bottleneck3(x)
|
||||
x = torch.cat([x, *self.ppm(x)], dim=1)
|
||||
x = self.out(x)
|
||||
return x
|
||||
|
||||
|
||||
class FeatureFusionModule(nn.Module):
|
||||
"""Feature fusion module.
|
||||
|
||||
Args:
|
||||
higher_in_channels (int): Number of input channels of the
|
||||
higher-resolution branch.
|
||||
lower_in_channels (int): Number of input channels of the
|
||||
lower-resolution branch.
|
||||
out_channels (int): Number of output channels.
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
dwconv_act_cfg (dict): Config of activation layers in 3x3 conv.
|
||||
Default: dict(type='ReLU').
|
||||
conv_act_cfg (dict): Config of activation layers in the two 1x1 conv.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
higher_in_channels,
|
||||
lower_in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dwconv_act_cfg=dict(type='ReLU'),
|
||||
conv_act_cfg=None,
|
||||
align_corners=False):
|
||||
super().__init__()
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.dwconv_act_cfg = dwconv_act_cfg
|
||||
self.conv_act_cfg = conv_act_cfg
|
||||
self.align_corners = align_corners
|
||||
self.dwconv = ConvModule(
|
||||
lower_in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
groups=out_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.dwconv_act_cfg)
|
||||
self.conv_lower_res = ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.conv_act_cfg)
|
||||
|
||||
self.conv_higher_res = ConvModule(
|
||||
higher_in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.conv_act_cfg)
|
||||
|
||||
self.relu = nn.ReLU(True)
|
||||
|
||||
def forward(self, higher_res_feature, lower_res_feature):
|
||||
lower_res_feature = resize(
|
||||
lower_res_feature,
|
||||
size=higher_res_feature.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
lower_res_feature = self.dwconv(lower_res_feature)
|
||||
lower_res_feature = self.conv_lower_res(lower_res_feature)
|
||||
|
||||
higher_res_feature = self.conv_higher_res(higher_res_feature)
|
||||
out = higher_res_feature + lower_res_feature
|
||||
return self.relu(out)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FastSCNN(BaseModule):
|
||||
"""Fast-SCNN Backbone.
|
||||
|
||||
This backbone is the implementation of `Fast-SCNN: Fast Semantic
|
||||
Segmentation Network <https://arxiv.org/abs/1902.04502>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
downsample_dw_channels (tuple[int]): Number of output channels after
|
||||
the first conv layer & the second conv layer in
|
||||
Learning-To-Downsample (LTD) module.
|
||||
Default: (32, 48).
|
||||
global_in_channels (int): Number of input channels of
|
||||
Global Feature Extractor(GFE).
|
||||
Equal to number of output channels of LTD.
|
||||
Default: 64.
|
||||
global_block_channels (tuple[int]): Tuple of integers that describe
|
||||
the output channels for each of the MobileNet-v2 bottleneck
|
||||
residual blocks in GFE.
|
||||
Default: (64, 96, 128).
|
||||
global_block_strides (tuple[int]): Tuple of integers
|
||||
that describe the strides (downsampling factors) for each of the
|
||||
MobileNet-v2 bottleneck residual blocks in GFE.
|
||||
Default: (2, 2, 1).
|
||||
global_out_channels (int): Number of output channels of GFE.
|
||||
Default: 128.
|
||||
higher_in_channels (int): Number of input channels of the higher
|
||||
resolution branch in FFM.
|
||||
Equal to global_in_channels.
|
||||
Default: 64.
|
||||
lower_in_channels (int): Number of input channels of the lower
|
||||
resolution branch in FFM.
|
||||
Equal to global_out_channels.
|
||||
Default: 128.
|
||||
fusion_out_channels (int): Number of output channels of FFM.
|
||||
Default: 128.
|
||||
out_indices (tuple): Tuple of indices of list
|
||||
[higher_res_features, lower_res_features, fusion_output].
|
||||
Often set to (0,1,2) to enable aux. heads.
|
||||
Default: (0, 1, 2).
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
act_cfg (dict): Config of activation layers. Default:
|
||||
dict(type='ReLU')
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False
|
||||
dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config
|
||||
of depthwise ConvModule. If it is 'default', it will be the same
|
||||
as `act_cfg`. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
downsample_dw_channels=(32, 48),
|
||||
global_in_channels=64,
|
||||
global_block_channels=(64, 96, 128),
|
||||
global_block_strides=(2, 2, 1),
|
||||
global_out_channels=128,
|
||||
higher_in_channels=64,
|
||||
lower_in_channels=128,
|
||||
fusion_out_channels=128,
|
||||
out_indices=(0, 1, 2),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
dw_act_cfg=None,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg)
|
||||
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
|
||||
if global_in_channels != higher_in_channels:
|
||||
raise AssertionError('Global Input Channels must be the same \
|
||||
with Higher Input Channels!')
|
||||
elif global_out_channels != lower_in_channels:
|
||||
raise AssertionError('Global Output Channels must be the same \
|
||||
with Lower Input Channels!')
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.downsample_dw_channels1 = downsample_dw_channels[0]
|
||||
self.downsample_dw_channels2 = downsample_dw_channels[1]
|
||||
self.global_in_channels = global_in_channels
|
||||
self.global_block_channels = global_block_channels
|
||||
self.global_block_strides = global_block_strides
|
||||
self.global_out_channels = global_out_channels
|
||||
self.higher_in_channels = higher_in_channels
|
||||
self.lower_in_channels = lower_in_channels
|
||||
self.fusion_out_channels = fusion_out_channels
|
||||
self.out_indices = out_indices
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.align_corners = align_corners
|
||||
self.learning_to_downsample = LearningToDownsample(
|
||||
in_channels,
|
||||
downsample_dw_channels,
|
||||
global_in_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
dw_act_cfg=dw_act_cfg)
|
||||
self.global_feature_extractor = GlobalFeatureExtractor(
|
||||
global_in_channels,
|
||||
global_block_channels,
|
||||
global_out_channels,
|
||||
strides=self.global_block_strides,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
self.feature_fusion = FeatureFusionModule(
|
||||
higher_in_channels,
|
||||
lower_in_channels,
|
||||
fusion_out_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dwconv_act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
|
||||
def forward(self, x):
|
||||
higher_res_features = self.learning_to_downsample(x)
|
||||
lower_res_features = self.global_feature_extractor(higher_res_features)
|
||||
fusion_output = self.feature_fusion(higher_res_features,
|
||||
lower_res_features)
|
||||
|
||||
outs = [higher_res_features, lower_res_features, fusion_output]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
||||
642
finetune/mmseg/models/backbones/hrnet.py
Normal file
642
finetune/mmseg/models/backbones/hrnet.py
Normal file
@@ -0,0 +1,642 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample, resize
|
||||
from .resnet import BasicBlock, Bottleneck
|
||||
|
||||
|
||||
class HRModule(BaseModule):
|
||||
"""High-Resolution Module for HRNet.
|
||||
|
||||
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
|
||||
is in this module.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_branches,
|
||||
blocks,
|
||||
num_blocks,
|
||||
in_channels,
|
||||
num_channels,
|
||||
multiscale_output=True,
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
block_init_cfg=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.block_init_cfg = block_init_cfg
|
||||
self._check_branches(num_branches, num_blocks, in_channels,
|
||||
num_channels)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.num_branches = num_branches
|
||||
|
||||
self.multiscale_output = multiscale_output
|
||||
self.norm_cfg = norm_cfg
|
||||
self.conv_cfg = conv_cfg
|
||||
self.with_cp = with_cp
|
||||
self.branches = self._make_branches(num_branches, blocks, num_blocks,
|
||||
num_channels)
|
||||
self.fuse_layers = self._make_fuse_layers()
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
|
||||
def _check_branches(self, num_branches, num_blocks, in_channels,
|
||||
num_channels):
|
||||
"""Check branches configuration."""
|
||||
if num_branches != len(num_blocks):
|
||||
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \
|
||||
f'{len(num_blocks)})'
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if num_branches != len(num_channels):
|
||||
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \
|
||||
f'{len(num_channels)})'
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if num_branches != len(in_channels):
|
||||
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \
|
||||
f'{len(in_channels)})'
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def _make_one_branch(self,
|
||||
branch_index,
|
||||
block,
|
||||
num_blocks,
|
||||
num_channels,
|
||||
stride=1):
|
||||
"""Build one branch."""
|
||||
downsample = None
|
||||
if stride != 1 or \
|
||||
self.in_channels[branch_index] != \
|
||||
num_channels[branch_index] * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
self.in_channels[branch_index],
|
||||
num_channels[branch_index] * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, num_channels[branch_index] *
|
||||
block.expansion)[1])
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
block(
|
||||
self.in_channels[branch_index],
|
||||
num_channels[branch_index],
|
||||
stride,
|
||||
downsample=downsample,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=self.block_init_cfg))
|
||||
self.in_channels[branch_index] = \
|
||||
num_channels[branch_index] * block.expansion
|
||||
for i in range(1, num_blocks[branch_index]):
|
||||
layers.append(
|
||||
block(
|
||||
self.in_channels[branch_index],
|
||||
num_channels[branch_index],
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=self.block_init_cfg))
|
||||
|
||||
return Sequential(*layers)
|
||||
|
||||
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
||||
"""Build multiple branch."""
|
||||
branches = []
|
||||
|
||||
for i in range(num_branches):
|
||||
branches.append(
|
||||
self._make_one_branch(i, block, num_blocks, num_channels))
|
||||
|
||||
return ModuleList(branches)
|
||||
|
||||
def _make_fuse_layers(self):
|
||||
"""Build fuse layer."""
|
||||
if self.num_branches == 1:
|
||||
return None
|
||||
|
||||
num_branches = self.num_branches
|
||||
in_channels = self.in_channels
|
||||
fuse_layers = []
|
||||
num_out_branches = num_branches if self.multiscale_output else 1
|
||||
for i in range(num_out_branches):
|
||||
fuse_layer = []
|
||||
for j in range(num_branches):
|
||||
if j > i:
|
||||
fuse_layer.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels[j],
|
||||
in_channels[i],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, in_channels[i])[1],
|
||||
# we set align_corners=False for HRNet
|
||||
Upsample(
|
||||
scale_factor=2**(j - i),
|
||||
mode='bilinear',
|
||||
align_corners=False)))
|
||||
elif j == i:
|
||||
fuse_layer.append(None)
|
||||
else:
|
||||
conv_downsamples = []
|
||||
for k in range(i - j):
|
||||
if k == i - j - 1:
|
||||
conv_downsamples.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels[j],
|
||||
in_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg,
|
||||
in_channels[i])[1]))
|
||||
else:
|
||||
conv_downsamples.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels[j],
|
||||
in_channels[j],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg,
|
||||
in_channels[j])[1],
|
||||
nn.ReLU(inplace=False)))
|
||||
fuse_layer.append(nn.Sequential(*conv_downsamples))
|
||||
fuse_layers.append(nn.ModuleList(fuse_layer))
|
||||
|
||||
return nn.ModuleList(fuse_layers)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
if self.num_branches == 1:
|
||||
return [self.branches[0](x[0])]
|
||||
|
||||
for i in range(self.num_branches):
|
||||
x[i] = self.branches[i](x[i])
|
||||
|
||||
x_fuse = []
|
||||
for i in range(len(self.fuse_layers)):
|
||||
y = 0
|
||||
for j in range(self.num_branches):
|
||||
if i == j:
|
||||
y += x[j]
|
||||
elif j > i:
|
||||
y = y + resize(
|
||||
self.fuse_layers[i][j](x[j]),
|
||||
size=x[i].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
else:
|
||||
y += self.fuse_layers[i][j](x[j])
|
||||
x_fuse.append(self.relu(y))
|
||||
return x_fuse
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class HRNet(BaseModule):
|
||||
"""HRNet backbone.
|
||||
|
||||
This backbone is the implementation of `High-Resolution Representations
|
||||
for Labeling Pixels and Regions <https://arxiv.org/abs/1904.04514>`_.
|
||||
|
||||
Args:
|
||||
extra (dict): Detailed configuration for each stage of HRNet.
|
||||
There must be 4 stages, the configuration for each stage must have
|
||||
5 keys:
|
||||
|
||||
- num_modules (int): The number of HRModule in this stage.
|
||||
- num_branches (int): The number of branches in the HRModule.
|
||||
- block (str): The type of convolution block.
|
||||
- num_blocks (tuple): The number of blocks in each branch.
|
||||
The length must be equal to num_branches.
|
||||
- num_channels (tuple): The number of channels in each branch.
|
||||
The length must be equal to num_branches.
|
||||
in_channels (int): Number of input image channels. Normally 3.
|
||||
conv_cfg (dict): Dictionary to construct and config conv layer.
|
||||
Default: None.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Use `BN` by default.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Default: -1.
|
||||
zero_init_residual (bool): Whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity. Default: False.
|
||||
multiscale_output (bool): Whether to output multi-level features
|
||||
produced by multiple branches. If False, only the first level
|
||||
feature will be output. Default: True.
|
||||
pretrained (str, optional): Model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Example:
|
||||
>>> from mmseg.models import HRNet
|
||||
>>> import torch
|
||||
>>> extra = dict(
|
||||
>>> stage1=dict(
|
||||
>>> num_modules=1,
|
||||
>>> num_branches=1,
|
||||
>>> block='BOTTLENECK',
|
||||
>>> num_blocks=(4, ),
|
||||
>>> num_channels=(64, )),
|
||||
>>> stage2=dict(
|
||||
>>> num_modules=1,
|
||||
>>> num_branches=2,
|
||||
>>> block='BASIC',
|
||||
>>> num_blocks=(4, 4),
|
||||
>>> num_channels=(32, 64)),
|
||||
>>> stage3=dict(
|
||||
>>> num_modules=4,
|
||||
>>> num_branches=3,
|
||||
>>> block='BASIC',
|
||||
>>> num_blocks=(4, 4, 4),
|
||||
>>> num_channels=(32, 64, 128)),
|
||||
>>> stage4=dict(
|
||||
>>> num_modules=3,
|
||||
>>> num_branches=4,
|
||||
>>> block='BASIC',
|
||||
>>> num_blocks=(4, 4, 4, 4),
|
||||
>>> num_channels=(32, 64, 128, 256)))
|
||||
>>> self = HRNet(extra, in_channels=1)
|
||||
>>> self.eval()
|
||||
>>> inputs = torch.rand(1, 1, 32, 32)
|
||||
>>> level_outputs = self.forward(inputs)
|
||||
>>> for level_out in level_outputs:
|
||||
... print(tuple(level_out.shape))
|
||||
(1, 32, 8, 8)
|
||||
(1, 64, 4, 4)
|
||||
(1, 128, 2, 2)
|
||||
(1, 256, 1, 1)
|
||||
"""
|
||||
|
||||
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
|
||||
|
||||
def __init__(self,
|
||||
extra,
|
||||
in_channels=3,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
frozen_stages=-1,
|
||||
zero_init_residual=False,
|
||||
multiscale_output=True,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
self.zero_init_residual = zero_init_residual
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
# Assert configurations of 4 stages are in extra
|
||||
assert 'stage1' in extra and 'stage2' in extra \
|
||||
and 'stage3' in extra and 'stage4' in extra
|
||||
# Assert whether the length of `num_blocks` and `num_channels` are
|
||||
# equal to `num_branches`
|
||||
for i in range(4):
|
||||
cfg = extra[f'stage{i + 1}']
|
||||
assert len(cfg['num_blocks']) == cfg['num_branches'] and \
|
||||
len(cfg['num_channels']) == cfg['num_branches']
|
||||
|
||||
self.extra = extra
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
# stem net
|
||||
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
64,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.conv2 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
64,
|
||||
64,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
# stage 1
|
||||
self.stage1_cfg = self.extra['stage1']
|
||||
num_channels = self.stage1_cfg['num_channels'][0]
|
||||
block_type = self.stage1_cfg['block']
|
||||
num_blocks = self.stage1_cfg['num_blocks'][0]
|
||||
|
||||
block = self.blocks_dict[block_type]
|
||||
stage1_out_channels = num_channels * block.expansion
|
||||
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
|
||||
|
||||
# stage 2
|
||||
self.stage2_cfg = self.extra['stage2']
|
||||
num_channels = self.stage2_cfg['num_channels']
|
||||
block_type = self.stage2_cfg['block']
|
||||
|
||||
block = self.blocks_dict[block_type]
|
||||
num_channels = [channel * block.expansion for channel in num_channels]
|
||||
self.transition1 = self._make_transition_layer([stage1_out_channels],
|
||||
num_channels)
|
||||
self.stage2, pre_stage_channels = self._make_stage(
|
||||
self.stage2_cfg, num_channels)
|
||||
|
||||
# stage 3
|
||||
self.stage3_cfg = self.extra['stage3']
|
||||
num_channels = self.stage3_cfg['num_channels']
|
||||
block_type = self.stage3_cfg['block']
|
||||
|
||||
block = self.blocks_dict[block_type]
|
||||
num_channels = [channel * block.expansion for channel in num_channels]
|
||||
self.transition2 = self._make_transition_layer(pre_stage_channels,
|
||||
num_channels)
|
||||
self.stage3, pre_stage_channels = self._make_stage(
|
||||
self.stage3_cfg, num_channels)
|
||||
|
||||
# stage 4
|
||||
self.stage4_cfg = self.extra['stage4']
|
||||
num_channels = self.stage4_cfg['num_channels']
|
||||
block_type = self.stage4_cfg['block']
|
||||
|
||||
block = self.blocks_dict[block_type]
|
||||
num_channels = [channel * block.expansion for channel in num_channels]
|
||||
self.transition3 = self._make_transition_layer(pre_stage_channels,
|
||||
num_channels)
|
||||
self.stage4, pre_stage_channels = self._make_stage(
|
||||
self.stage4_cfg, num_channels, multiscale_output=multiscale_output)
|
||||
|
||||
self._freeze_stages()
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: the normalization layer named "norm1" """
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
"""nn.Module: the normalization layer named "norm2" """
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def _make_transition_layer(self, num_channels_pre_layer,
|
||||
num_channels_cur_layer):
|
||||
"""Make transition layer."""
|
||||
num_branches_cur = len(num_channels_cur_layer)
|
||||
num_branches_pre = len(num_channels_pre_layer)
|
||||
|
||||
transition_layers = []
|
||||
for i in range(num_branches_cur):
|
||||
if i < num_branches_pre:
|
||||
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
||||
transition_layers.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
num_channels_pre_layer[i],
|
||||
num_channels_cur_layer[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg,
|
||||
num_channels_cur_layer[i])[1],
|
||||
nn.ReLU(inplace=True)))
|
||||
else:
|
||||
transition_layers.append(None)
|
||||
else:
|
||||
conv_downsamples = []
|
||||
for j in range(i + 1 - num_branches_pre):
|
||||
in_channels = num_channels_pre_layer[-1]
|
||||
out_channels = num_channels_cur_layer[i] \
|
||||
if j == i - num_branches_pre else in_channels
|
||||
conv_downsamples.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, out_channels)[1],
|
||||
nn.ReLU(inplace=True)))
|
||||
transition_layers.append(nn.Sequential(*conv_downsamples))
|
||||
|
||||
return nn.ModuleList(transition_layers)
|
||||
|
||||
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
||||
"""Make each layer."""
|
||||
downsample = None
|
||||
if stride != 1 or inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
|
||||
|
||||
layers = []
|
||||
block_init_cfg = None
|
||||
if self.pretrained is None and not hasattr(
|
||||
self, 'init_cfg') and self.zero_init_residual:
|
||||
if block is BasicBlock:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm2'))
|
||||
elif block is Bottleneck:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm3'))
|
||||
|
||||
layers.append(
|
||||
block(
|
||||
inplanes,
|
||||
planes,
|
||||
stride,
|
||||
downsample=downsample,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=block_init_cfg))
|
||||
inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(
|
||||
block(
|
||||
inplanes,
|
||||
planes,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=block_init_cfg))
|
||||
|
||||
return Sequential(*layers)
|
||||
|
||||
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
|
||||
"""Make each stage."""
|
||||
num_modules = layer_config['num_modules']
|
||||
num_branches = layer_config['num_branches']
|
||||
num_blocks = layer_config['num_blocks']
|
||||
num_channels = layer_config['num_channels']
|
||||
block = self.blocks_dict[layer_config['block']]
|
||||
|
||||
hr_modules = []
|
||||
block_init_cfg = None
|
||||
if self.pretrained is None and not hasattr(
|
||||
self, 'init_cfg') and self.zero_init_residual:
|
||||
if block is BasicBlock:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm2'))
|
||||
elif block is Bottleneck:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm3'))
|
||||
|
||||
for i in range(num_modules):
|
||||
# multi_scale_output is only used for the last module
|
||||
if not multiscale_output and i == num_modules - 1:
|
||||
reset_multiscale_output = False
|
||||
else:
|
||||
reset_multiscale_output = True
|
||||
|
||||
hr_modules.append(
|
||||
HRModule(
|
||||
num_branches,
|
||||
block,
|
||||
num_blocks,
|
||||
in_channels,
|
||||
num_channels,
|
||||
reset_multiscale_output,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
block_init_cfg=block_init_cfg))
|
||||
|
||||
return Sequential(*hr_modules), in_channels
|
||||
|
||||
def _freeze_stages(self):
|
||||
"""Freeze stages param and norm stats."""
|
||||
if self.frozen_stages >= 0:
|
||||
|
||||
self.norm1.eval()
|
||||
self.norm2.eval()
|
||||
for m in [self.conv1, self.norm1, self.conv2, self.norm2]:
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
if i == 1:
|
||||
m = getattr(self, f'layer{i}')
|
||||
t = getattr(self, f'transition{i}')
|
||||
elif i == 4:
|
||||
m = getattr(self, f'stage{i}')
|
||||
else:
|
||||
m = getattr(self, f'stage{i}')
|
||||
t = getattr(self, f'transition{i}')
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
t.eval()
|
||||
for param in t.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
x = self.relu(x)
|
||||
x = self.layer1(x)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage2_cfg['num_branches']):
|
||||
if self.transition1[i] is not None:
|
||||
x_list.append(self.transition1[i](x))
|
||||
else:
|
||||
x_list.append(x)
|
||||
y_list = self.stage2(x_list)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage3_cfg['num_branches']):
|
||||
if self.transition2[i] is not None:
|
||||
x_list.append(self.transition2[i](y_list[-1]))
|
||||
else:
|
||||
x_list.append(y_list[i])
|
||||
y_list = self.stage3(x_list)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage4_cfg['num_branches']):
|
||||
if self.transition3[i] is not None:
|
||||
x_list.append(self.transition3[i](y_list[-1]))
|
||||
else:
|
||||
x_list.append(y_list[i])
|
||||
y_list = self.stage4(x_list)
|
||||
|
||||
return y_list
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode will keeping the normalization
|
||||
layer freezed."""
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
166
finetune/mmseg/models/backbones/icnet.py
Normal file
166
finetune/mmseg/models/backbones/icnet.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..decode_heads.psp_head import PPM
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ICNet(BaseModule):
|
||||
"""ICNet for Real-Time Semantic Segmentation on High-Resolution Images.
|
||||
|
||||
This backbone is the implementation of
|
||||
`ICNet <https://arxiv.org/abs/1704.08545>`_.
|
||||
|
||||
Args:
|
||||
backbone_cfg (dict): Config dict to build backbone. Usually it is
|
||||
ResNet but it can also be other backbones.
|
||||
in_channels (int): The number of input image channels. Default: 3.
|
||||
layer_channels (Sequence[int]): The numbers of feature channels at
|
||||
layer 2 and layer 4 in ResNet. It can also be other backbones.
|
||||
Default: (512, 2048).
|
||||
light_branch_middle_channels (int): The number of channels of the
|
||||
middle layer in light branch. Default: 32.
|
||||
psp_out_channels (int): The number of channels of the output of PSP
|
||||
module. Default: 512.
|
||||
out_channels (Sequence[int]): The numbers of output feature channels
|
||||
at each branches. Default: (64, 256, 256).
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module. Default: (1, 2, 3, 6).
|
||||
conv_cfg (dict): Dictionary to construct and config conv layer.
|
||||
Default: None.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Dictionary to construct and config act layer.
|
||||
Default: dict(type='ReLU').
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_cfg,
|
||||
in_channels=3,
|
||||
layer_channels=(512, 2048),
|
||||
light_branch_middle_channels=32,
|
||||
psp_out_channels=512,
|
||||
out_channels=(64, 256, 256),
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
init_cfg=None):
|
||||
if backbone_cfg is None:
|
||||
raise TypeError('backbone_cfg must be passed from config file!')
|
||||
if init_cfg is None:
|
||||
init_cfg = [
|
||||
dict(type='Kaiming', mode='fan_out', layer='Conv2d'),
|
||||
dict(type='Constant', val=1, layer='_BatchNorm'),
|
||||
dict(type='Normal', mean=0.01, layer='Linear')
|
||||
]
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.align_corners = align_corners
|
||||
self.backbone = MODELS.build(backbone_cfg)
|
||||
|
||||
# Note: Default `ceil_mode` is false in nn.MaxPool2d, set
|
||||
# `ceil_mode=True` to keep information in the corner of feature map.
|
||||
self.backbone.maxpool = nn.MaxPool2d(
|
||||
kernel_size=3, stride=2, padding=1, ceil_mode=True)
|
||||
|
||||
self.psp_modules = PPM(
|
||||
pool_scales=pool_scales,
|
||||
in_channels=layer_channels[1],
|
||||
channels=psp_out_channels,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
align_corners=align_corners)
|
||||
|
||||
self.psp_bottleneck = ConvModule(
|
||||
layer_channels[1] + len(pool_scales) * psp_out_channels,
|
||||
psp_out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.conv_sub1 = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=light_branch_middle_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg),
|
||||
ConvModule(
|
||||
in_channels=light_branch_middle_channels,
|
||||
out_channels=light_branch_middle_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg),
|
||||
ConvModule(
|
||||
in_channels=light_branch_middle_channels,
|
||||
out_channels=out_channels[0],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg))
|
||||
|
||||
self.conv_sub2 = ConvModule(
|
||||
layer_channels[0],
|
||||
out_channels[1],
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg)
|
||||
|
||||
self.conv_sub4 = ConvModule(
|
||||
psp_out_channels,
|
||||
out_channels[2],
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
output = []
|
||||
|
||||
# sub 1
|
||||
output.append(self.conv_sub1(x))
|
||||
|
||||
# sub 2
|
||||
x = resize(
|
||||
x,
|
||||
scale_factor=0.5,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.backbone.stem(x)
|
||||
x = self.backbone.maxpool(x)
|
||||
x = self.backbone.layer1(x)
|
||||
x = self.backbone.layer2(x)
|
||||
output.append(self.conv_sub2(x))
|
||||
|
||||
# sub 4
|
||||
x = resize(
|
||||
x,
|
||||
scale_factor=0.5,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.backbone.layer3(x)
|
||||
x = self.backbone.layer4(x)
|
||||
psp_outs = self.psp_modules(x) + [x]
|
||||
psp_outs = torch.cat(psp_outs, dim=1)
|
||||
x = self.psp_bottleneck(psp_outs)
|
||||
|
||||
output.append(self.conv_sub4(x))
|
||||
|
||||
return output
|
||||
260
finetune/mmseg/models/backbones/mae.py
Normal file
260
finetune/mmseg/models/backbones/mae.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.import math
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmengine.runner.checkpoint import _load_checkpoint
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer
|
||||
|
||||
|
||||
class MAEAttention(BEiTAttention):
|
||||
"""Multi-head self-attention with relative position bias used in MAE.
|
||||
|
||||
This module is different from ``BEiTAttention`` by initializing the
|
||||
relative bias table with zeros.
|
||||
"""
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize relative position bias with zeros."""
|
||||
|
||||
# As MAE initializes relative position bias as zeros and this class
|
||||
# inherited from BEiT which initializes relative position bias
|
||||
# with `trunc_normal`, `init_weights` here does
|
||||
# nothing and just passes directly
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer):
|
||||
"""Implements one encoder layer in Vision Transformer.
|
||||
|
||||
This module is different from ``BEiTTransformerEncoderLayer`` by replacing
|
||||
``BEiTAttention`` with ``MAEAttention``.
|
||||
"""
|
||||
|
||||
def build_attn(self, attn_cfg):
|
||||
self.attn = MAEAttention(**attn_cfg)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MAE(BEiT):
|
||||
"""VisionTransformer with support for patch.
|
||||
|
||||
Args:
|
||||
img_size (int | tuple): Input image size. Default: 224.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): embedding dimension. Default: 768.
|
||||
num_layers (int): depth of transformer. Default: 12.
|
||||
num_heads (int): number of attention heads. Default: 12.
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
out_indices (list | tuple | int): Output from which stages.
|
||||
Default: -1.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
|
||||
Default: False.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Default: False.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_values (float): Initialize the values of Attention and FFN
|
||||
with learnable scaling. Defaults to 0.1.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dims=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
out_indices=-1,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
patch_norm=False,
|
||||
final_norm=False,
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
pretrained=None,
|
||||
init_values=0.1,
|
||||
init_cfg=None):
|
||||
super().__init__(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
num_layers=num_layers,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
out_indices=out_indices,
|
||||
qv_bias=False,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
patch_norm=patch_norm,
|
||||
final_norm=final_norm,
|
||||
num_fcs=num_fcs,
|
||||
norm_eval=norm_eval,
|
||||
pretrained=pretrained,
|
||||
init_values=init_values,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
|
||||
self.num_patches = self.patch_shape[0] * self.patch_shape[1]
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, self.num_patches + 1, embed_dims))
|
||||
|
||||
def _build_layers(self):
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, self.drop_path_rate, self.num_layers)
|
||||
]
|
||||
self.layers = ModuleList()
|
||||
for i in range(self.num_layers):
|
||||
self.layers.append(
|
||||
MAETransformerEncoderLayer(
|
||||
embed_dims=self.embed_dims,
|
||||
num_heads=self.num_heads,
|
||||
feedforward_channels=self.mlp_ratio * self.embed_dims,
|
||||
attn_drop_rate=self.attn_drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=self.num_fcs,
|
||||
bias=True,
|
||||
act_cfg=self.act_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
window_size=self.patch_shape,
|
||||
init_values=self.init_values))
|
||||
|
||||
def fix_init_weight(self):
|
||||
"""Rescale the initialization according to layer id.
|
||||
|
||||
This function is copied from https://github.com/microsoft/unilm/blob/master/beit/modeling_pretrain.py. # noqa: E501
|
||||
Copyright (c) Microsoft Corporation
|
||||
Licensed under the MIT License
|
||||
"""
|
||||
|
||||
def rescale(param, layer_id):
|
||||
param.div_(math.sqrt(2.0 * layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.layers):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.ffn.layers[1].weight.data, layer_id + 1)
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
def _init_weights(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
self.apply(_init_weights)
|
||||
self.fix_init_weight()
|
||||
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
checkpoint = _load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
state_dict = self.resize_rel_pos_embed(checkpoint)
|
||||
state_dict = self.resize_abs_pos_embed(state_dict)
|
||||
self.load_state_dict(state_dict, False)
|
||||
elif self.init_cfg is not None:
|
||||
super().init_weights()
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
# Copyright 2019 Ross Wightman
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def resize_abs_pos_embed(self, state_dict):
|
||||
if 'pos_embed' in state_dict:
|
||||
pos_embed_checkpoint = state_dict['pos_embed']
|
||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||
num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
|
||||
# height (== width) for the checkpoint position embedding
|
||||
orig_size = int(
|
||||
(pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
|
||||
# height (== width) for the new position embedding
|
||||
new_size = int(self.num_patches**0.5)
|
||||
# class_token and dist_token are kept unchanged
|
||||
if orig_size != new_size:
|
||||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||
# only the position tokens are interpolated
|
||||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
|
||||
embedding_size).permute(
|
||||
0, 3, 1, 2)
|
||||
pos_tokens = torch.nn.functional.interpolate(
|
||||
pos_tokens,
|
||||
size=(new_size, new_size),
|
||||
mode='bicubic',
|
||||
align_corners=False)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
state_dict['pos_embed'] = new_pos_embed
|
||||
return state_dict
|
||||
|
||||
def forward(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
|
||||
x, hw_shape = self.patch_embed(inputs)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i == len(self.layers) - 1:
|
||||
if self.final_norm:
|
||||
x = self.norm1(x)
|
||||
if i in self.out_indices:
|
||||
out = x[:, 1:]
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
450
finetune/mmseg/models/backbones/mit.py
Normal file
450
finetune/mmseg/models/backbones/mit.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
from mmengine.model.weight_init import (constant_init, normal_init,
|
||||
trunc_normal_init)
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw
|
||||
|
||||
|
||||
class MixFFN(BaseModule):
|
||||
"""An implementation of MixFFN of Segformer.
|
||||
|
||||
The differences between MixFFN & FFN:
|
||||
1. Use 1X1 Conv to replace Linear layer.
|
||||
2. Introduce 3X3 Conv to encode positional information.
|
||||
Args:
|
||||
embed_dims (int): The feature dimension. Same as
|
||||
`MultiheadAttention`. Defaults: 256.
|
||||
feedforward_channels (int): The hidden dimension of FFNs.
|
||||
Defaults: 1024.
|
||||
act_cfg (dict, optional): The activation config for FFNs.
|
||||
Default: dict(type='ReLU')
|
||||
ffn_drop (float, optional): Probability of an element to be
|
||||
zeroed in FFN. Default 0.0.
|
||||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||||
when adding the shortcut.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
feedforward_channels,
|
||||
act_cfg=dict(type='GELU'),
|
||||
ffn_drop=0.,
|
||||
dropout_layer=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.feedforward_channels = feedforward_channels
|
||||
self.act_cfg = act_cfg
|
||||
self.activate = build_activation_layer(act_cfg)
|
||||
|
||||
in_channels = embed_dims
|
||||
fc1 = Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=feedforward_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=True)
|
||||
# 3x3 depth wise conv to provide positional encode information
|
||||
pe_conv = Conv2d(
|
||||
in_channels=feedforward_channels,
|
||||
out_channels=feedforward_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=(3 - 1) // 2,
|
||||
bias=True,
|
||||
groups=feedforward_channels)
|
||||
fc2 = Conv2d(
|
||||
in_channels=feedforward_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=True)
|
||||
drop = nn.Dropout(ffn_drop)
|
||||
layers = [fc1, pe_conv, self.activate, drop, fc2, drop]
|
||||
self.layers = Sequential(*layers)
|
||||
self.dropout_layer = build_dropout(
|
||||
dropout_layer) if dropout_layer else torch.nn.Identity()
|
||||
|
||||
def forward(self, x, hw_shape, identity=None):
|
||||
out = nlc_to_nchw(x, hw_shape)
|
||||
out = self.layers(out)
|
||||
out = nchw_to_nlc(out)
|
||||
if identity is None:
|
||||
identity = x
|
||||
return identity + self.dropout_layer(out)
|
||||
|
||||
|
||||
class EfficientMultiheadAttention(MultiheadAttention):
|
||||
"""An implementation of Efficient Multi-head Attention of Segformer.
|
||||
|
||||
This module is modified from MultiheadAttention which is a module from
|
||||
mmcv.cnn.bricks.transformer.
|
||||
Args:
|
||||
embed_dims (int): The embedding dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
attn_drop (float): A Dropout layer on attn_output_weights.
|
||||
Default: 0.0.
|
||||
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
||||
Default: 0.0.
|
||||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||||
when adding the shortcut. Default: None.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default: False.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default True.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
||||
Attention of Segformer. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
dropout_layer=None,
|
||||
init_cfg=None,
|
||||
batch_first=True,
|
||||
qkv_bias=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
sr_ratio=1):
|
||||
super().__init__(
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop,
|
||||
proj_drop,
|
||||
dropout_layer=dropout_layer,
|
||||
init_cfg=init_cfg,
|
||||
batch_first=batch_first,
|
||||
bias=qkv_bias)
|
||||
|
||||
self.sr_ratio = sr_ratio
|
||||
if sr_ratio > 1:
|
||||
self.sr = Conv2d(
|
||||
in_channels=embed_dims,
|
||||
out_channels=embed_dims,
|
||||
kernel_size=sr_ratio,
|
||||
stride=sr_ratio)
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
# handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
|
||||
from mmseg import digit_version, mmcv_version
|
||||
if mmcv_version < digit_version('1.3.17'):
|
||||
warnings.warn('The legacy version of forward function in'
|
||||
'EfficientMultiheadAttention is deprecated in'
|
||||
'mmcv>=1.3.17 and will no longer support in the'
|
||||
'future. Please upgrade your mmcv.')
|
||||
self.forward = self.legacy_forward
|
||||
|
||||
def forward(self, x, hw_shape, identity=None):
|
||||
|
||||
x_q = x
|
||||
if self.sr_ratio > 1:
|
||||
x_kv = nlc_to_nchw(x, hw_shape)
|
||||
x_kv = self.sr(x_kv)
|
||||
x_kv = nchw_to_nlc(x_kv)
|
||||
x_kv = self.norm(x_kv)
|
||||
else:
|
||||
x_kv = x
|
||||
|
||||
if identity is None:
|
||||
identity = x_q
|
||||
|
||||
# Because the dataflow('key', 'query', 'value') of
|
||||
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
|
||||
# embed_dims), We should adjust the shape of dataflow from
|
||||
# batch_first (batch, num_query, embed_dims) to num_query_first
|
||||
# (num_query ,batch, embed_dims), and recover ``attn_output``
|
||||
# from num_query_first to batch_first.
|
||||
if self.batch_first:
|
||||
x_q = x_q.transpose(0, 1)
|
||||
x_kv = x_kv.transpose(0, 1)
|
||||
|
||||
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
|
||||
|
||||
if self.batch_first:
|
||||
out = out.transpose(0, 1)
|
||||
|
||||
return identity + self.dropout_layer(self.proj_drop(out))
|
||||
|
||||
def legacy_forward(self, x, hw_shape, identity=None):
|
||||
"""multi head attention forward in mmcv version < 1.3.17."""
|
||||
|
||||
x_q = x
|
||||
if self.sr_ratio > 1:
|
||||
x_kv = nlc_to_nchw(x, hw_shape)
|
||||
x_kv = self.sr(x_kv)
|
||||
x_kv = nchw_to_nlc(x_kv)
|
||||
x_kv = self.norm(x_kv)
|
||||
else:
|
||||
x_kv = x
|
||||
|
||||
if identity is None:
|
||||
identity = x_q
|
||||
|
||||
# `need_weights=True` will let nn.MultiHeadAttention
|
||||
# `return attn_output, attn_output_weights.sum(dim=1) / num_heads`
|
||||
# The `attn_output_weights.sum(dim=1)` may cause cuda error. So, we set
|
||||
# `need_weights=False` to ignore `attn_output_weights.sum(dim=1)`.
|
||||
# This issue - `https://github.com/pytorch/pytorch/issues/37583` report
|
||||
# the error that large scale tensor sum operation may cause cuda error.
|
||||
out = self.attn(query=x_q, key=x_kv, value=x_kv, need_weights=False)[0]
|
||||
|
||||
return identity + self.dropout_layer(self.proj_drop(out))
|
||||
|
||||
|
||||
class TransformerEncoderLayer(BaseModule):
|
||||
"""Implements one encoder layer in Segformer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
after the feed forward layer. Default 0.0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0.
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||||
qkv_bias (bool): enable bias for qkv if True.
|
||||
Default: True.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default: False.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Default:None.
|
||||
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
||||
Attention of Segformer. Default: 1.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
qkv_bias=True,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
batch_first=True,
|
||||
sr_ratio=1,
|
||||
with_cp=False):
|
||||
super().__init__()
|
||||
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
self.attn = EfficientMultiheadAttention(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
batch_first=batch_first,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_cfg=norm_cfg,
|
||||
sr_ratio=sr_ratio)
|
||||
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
self.ffn = MixFFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.with_cp = with_cp
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
|
||||
def _inner_forward(x):
|
||||
x = self.attn(self.norm1(x), hw_shape, identity=x)
|
||||
x = self.ffn(self.norm2(x), hw_shape, identity=x)
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MixVisionTransformer(BaseModule):
|
||||
"""The backbone of Segformer.
|
||||
|
||||
This backbone is the implementation of `SegFormer: Simple and
|
||||
Efficient Design for Semantic Segmentation with
|
||||
Transformers <https://arxiv.org/abs/2105.15203>`_.
|
||||
Args:
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): Embedding dimension. Default: 768.
|
||||
num_stags (int): The num of stages. Default: 4.
|
||||
num_layers (Sequence[int]): The layer number of each transformer encode
|
||||
layer. Default: [3, 4, 6, 3].
|
||||
num_heads (Sequence[int]): The attention heads of each transformer
|
||||
encode layer. Default: [1, 2, 4, 8].
|
||||
patch_sizes (Sequence[int]): The patch_size of each overlapped patch
|
||||
embedding. Default: [7, 3, 3, 3].
|
||||
strides (Sequence[int]): The stride of each overlapped patch embedding.
|
||||
Default: [4, 2, 2, 2].
|
||||
sr_ratios (Sequence[int]): The spatial reduction rate of each
|
||||
transformer encode layer. Default: [8, 4, 2, 1].
|
||||
out_indices (Sequence[int] | int): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: True.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.0
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=64,
|
||||
num_stages=4,
|
||||
num_layers=[3, 4, 6, 3],
|
||||
num_heads=[1, 2, 4, 8],
|
||||
patch_sizes=[7, 3, 3, 3],
|
||||
strides=[4, 2, 2, 2],
|
||||
sr_ratios=[8, 4, 2, 1],
|
||||
out_indices=(0, 1, 2, 3),
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
pretrained=None,
|
||||
init_cfg=None,
|
||||
with_cp=False):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.num_stages = num_stages
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.patch_sizes = patch_sizes
|
||||
self.strides = strides
|
||||
self.sr_ratios = sr_ratios
|
||||
self.with_cp = with_cp
|
||||
assert num_stages == len(num_layers) == len(num_heads) \
|
||||
== len(patch_sizes) == len(strides) == len(sr_ratios)
|
||||
|
||||
self.out_indices = out_indices
|
||||
assert max(out_indices) < self.num_stages
|
||||
|
||||
# transformer encoder
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, drop_path_rate, sum(num_layers))
|
||||
] # stochastic num_layer decay rule
|
||||
|
||||
cur = 0
|
||||
self.layers = ModuleList()
|
||||
for i, num_layer in enumerate(num_layers):
|
||||
embed_dims_i = embed_dims * num_heads[i]
|
||||
patch_embed = PatchEmbed(
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims_i,
|
||||
kernel_size=patch_sizes[i],
|
||||
stride=strides[i],
|
||||
padding=patch_sizes[i] // 2,
|
||||
norm_cfg=norm_cfg)
|
||||
layer = ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
embed_dims=embed_dims_i,
|
||||
num_heads=num_heads[i],
|
||||
feedforward_channels=mlp_ratio * embed_dims_i,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=dpr[cur + idx],
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
|
||||
])
|
||||
in_channels = embed_dims_i
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
norm = build_norm_layer(norm_cfg, embed_dims_i)[1]
|
||||
self.layers.append(ModuleList([patch_embed, layer, norm]))
|
||||
cur += num_layer
|
||||
|
||||
def init_weights(self):
|
||||
if self.init_cfg is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[
|
||||
1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
normal_init(
|
||||
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
||||
else:
|
||||
super().init_weights()
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
x, hw_shape = layer[0](x)
|
||||
for block in layer[1]:
|
||||
x = block(x, hw_shape)
|
||||
x = layer[2](x)
|
||||
x = nlc_to_nchw(x, hw_shape)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
return outs
|
||||
197
finetune/mmseg/models/backbones/mobilenet_v2.py
Normal file
197
finetune/mmseg/models/backbones/mobilenet_v2.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import InvertedResidual, make_divisible
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MobileNetV2(BaseModule):
|
||||
"""MobileNetV2 backbone.
|
||||
|
||||
This backbone is the implementation of
|
||||
`MobileNetV2: Inverted Residuals and Linear Bottlenecks
|
||||
<https://arxiv.org/abs/1801.04381>`_.
|
||||
|
||||
Args:
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Default: 1.0.
|
||||
strides (Sequence[int], optional): Strides of the first block of each
|
||||
layer. If not specified, default config in ``arch_setting`` will
|
||||
be used.
|
||||
dilations (Sequence[int]): Dilation of each layer.
|
||||
out_indices (None or Sequence[int]): Output from which stages.
|
||||
Default: (7, ).
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Default: -1, which means not freezing any parameters.
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU6').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
# Parameters to build layers. 3 parameters are needed to construct a
|
||||
# layer, from left to right: expand_ratio, channel, num_blocks.
|
||||
arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4],
|
||||
[6, 96, 3], [6, 160, 3], [6, 320, 1]]
|
||||
|
||||
def __init__(self,
|
||||
widen_factor=1.,
|
||||
strides=(1, 2, 2, 2, 1, 2, 1),
|
||||
dilations=(1, 1, 1, 1, 1, 1, 1),
|
||||
out_indices=(1, 2, 4, 6),
|
||||
frozen_stages=-1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6'),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.widen_factor = widen_factor
|
||||
self.strides = strides
|
||||
self.dilations = dilations
|
||||
assert len(strides) == len(dilations) == len(self.arch_settings)
|
||||
self.out_indices = out_indices
|
||||
for index in out_indices:
|
||||
if index not in range(0, 7):
|
||||
raise ValueError('the item in out_indices must in '
|
||||
f'range(0, 7). But received {index}')
|
||||
|
||||
if frozen_stages not in range(-1, 7):
|
||||
raise ValueError('frozen_stages must be in range(-1, 7). '
|
||||
f'But received {frozen_stages}')
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
self.in_channels = make_divisible(32 * widen_factor, 8)
|
||||
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=3,
|
||||
out_channels=self.in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.layers = []
|
||||
|
||||
for i, layer_cfg in enumerate(self.arch_settings):
|
||||
expand_ratio, channel, num_blocks = layer_cfg
|
||||
stride = self.strides[i]
|
||||
dilation = self.dilations[i]
|
||||
out_channels = make_divisible(channel * widen_factor, 8)
|
||||
inverted_res_layer = self.make_layer(
|
||||
out_channels=out_channels,
|
||||
num_blocks=num_blocks,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
expand_ratio=expand_ratio)
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, inverted_res_layer)
|
||||
self.layers.append(layer_name)
|
||||
|
||||
def make_layer(self, out_channels, num_blocks, stride, dilation,
|
||||
expand_ratio):
|
||||
"""Stack InvertedResidual blocks to build a layer for MobileNetV2.
|
||||
|
||||
Args:
|
||||
out_channels (int): out_channels of block.
|
||||
num_blocks (int): Number of blocks.
|
||||
stride (int): Stride of the first block.
|
||||
dilation (int): Dilation of the first block.
|
||||
expand_ratio (int): Expand the number of channels of the
|
||||
hidden layer in InvertedResidual by this ratio.
|
||||
"""
|
||||
layers = []
|
||||
for i in range(num_blocks):
|
||||
layers.append(
|
||||
InvertedResidual(
|
||||
self.in_channels,
|
||||
out_channels,
|
||||
stride if i == 0 else 1,
|
||||
expand_ratio=expand_ratio,
|
||||
dilation=dilation if i == 0 else 1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
with_cp=self.with_cp))
|
||||
self.in_channels = out_channels
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer = getattr(self, layer_name)
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
if len(outs) == 1:
|
||||
return outs[0]
|
||||
else:
|
||||
return tuple(outs)
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
for param in self.conv1.parameters():
|
||||
param.requires_grad = False
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
layer = getattr(self, f'layer{i}')
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
267
finetune/mmseg/models/backbones/mobilenet_v3.py
Normal file
267
finetune/mmseg/models/backbones/mobilenet_v3.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.cnn.bricks import Conv2dAdaptivePadding
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils import is_tuple_of
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import InvertedResidualV3 as InvertedResidual
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MobileNetV3(BaseModule):
|
||||
"""MobileNetV3 backbone.
|
||||
|
||||
This backbone is the improved implementation of `Searching for MobileNetV3
|
||||
<https://ieeexplore.ieee.org/document/9008835>`_.
|
||||
|
||||
Args:
|
||||
arch (str): Architecture of mobilnetv3, from {'small', 'large'}.
|
||||
Default: 'small'.
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
out_indices (tuple[int]): Output from which layer.
|
||||
Default: (0, 1, 12).
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Default: -1, which means not freezing any parameters.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
# Parameters to build each block:
|
||||
# [kernel size, mid channels, out channels, with_se, act type, stride]
|
||||
arch_settings = {
|
||||
'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4
|
||||
[3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8
|
||||
[3, 88, 24, False, 'ReLU', 1],
|
||||
[5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16
|
||||
[5, 240, 40, True, 'HSwish', 1],
|
||||
[5, 240, 40, True, 'HSwish', 1],
|
||||
[5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16
|
||||
[5, 144, 48, True, 'HSwish', 1],
|
||||
[5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32
|
||||
[5, 576, 96, True, 'HSwish', 1],
|
||||
[5, 576, 96, True, 'HSwish', 1]],
|
||||
'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2
|
||||
[3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4
|
||||
[3, 72, 24, False, 'ReLU', 1],
|
||||
[5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8
|
||||
[5, 120, 40, True, 'ReLU', 1],
|
||||
[5, 120, 40, True, 'ReLU', 1],
|
||||
[3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16
|
||||
[3, 200, 80, False, 'HSwish', 1],
|
||||
[3, 184, 80, False, 'HSwish', 1],
|
||||
[3, 184, 80, False, 'HSwish', 1],
|
||||
[3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16
|
||||
[3, 672, 112, True, 'HSwish', 1],
|
||||
[5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32
|
||||
[5, 960, 160, True, 'HSwish', 1],
|
||||
[5, 960, 160, True, 'HSwish', 1]]
|
||||
} # yapf: disable
|
||||
|
||||
def __init__(self,
|
||||
arch='small',
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
out_indices=(0, 1, 12),
|
||||
frozen_stages=-1,
|
||||
reduction_factor=1,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
assert arch in self.arch_settings
|
||||
assert isinstance(reduction_factor, int) and reduction_factor > 0
|
||||
assert is_tuple_of(out_indices, int)
|
||||
for index in out_indices:
|
||||
if index not in range(0, len(self.arch_settings[arch]) + 2):
|
||||
raise ValueError(
|
||||
'the item in out_indices must in '
|
||||
f'range(0, {len(self.arch_settings[arch])+2}). '
|
||||
f'But received {index}')
|
||||
|
||||
if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
|
||||
raise ValueError('frozen_stages must be in range(-1, '
|
||||
f'{len(self.arch_settings[arch])+2}). '
|
||||
f'But received {frozen_stages}')
|
||||
self.arch = arch
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.reduction_factor = reduction_factor
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.layers = self._make_layer()
|
||||
|
||||
def _make_layer(self):
|
||||
layers = []
|
||||
|
||||
# build the first layer (layer0)
|
||||
in_channels = 16
|
||||
layer = ConvModule(
|
||||
in_channels=3,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=dict(type='Conv2dAdaptivePadding'),
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type='HSwish'))
|
||||
self.add_module('layer0', layer)
|
||||
layers.append('layer0')
|
||||
|
||||
layer_setting = self.arch_settings[self.arch]
|
||||
for i, params in enumerate(layer_setting):
|
||||
(kernel_size, mid_channels, out_channels, with_se, act,
|
||||
stride) = params
|
||||
|
||||
if self.arch == 'large' and i >= 12 or self.arch == 'small' and \
|
||||
i >= 8:
|
||||
mid_channels = mid_channels // self.reduction_factor
|
||||
out_channels = out_channels // self.reduction_factor
|
||||
|
||||
if with_se:
|
||||
se_cfg = dict(
|
||||
channels=mid_channels,
|
||||
ratio=4,
|
||||
act_cfg=(dict(type='ReLU'),
|
||||
dict(type='HSigmoid', bias=3.0, divisor=6.0)))
|
||||
else:
|
||||
se_cfg = None
|
||||
|
||||
layer = InvertedResidual(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
mid_channels=mid_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
se_cfg=se_cfg,
|
||||
with_expand_conv=(in_channels != mid_channels),
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type=act),
|
||||
with_cp=self.with_cp)
|
||||
in_channels = out_channels
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, layer)
|
||||
layers.append(layer_name)
|
||||
|
||||
# build the last layer
|
||||
# block5 layer12 os=32 for small model
|
||||
# block6 layer16 os=32 for large model
|
||||
layer = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=576 if self.arch == 'small' else 960,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
dilation=4,
|
||||
padding=0,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type='HSwish'))
|
||||
layer_name = f'layer{len(layer_setting) + 1}'
|
||||
self.add_module(layer_name, layer)
|
||||
layers.append(layer_name)
|
||||
|
||||
# next, convert backbone MobileNetV3 to a semantic segmentation version
|
||||
if self.arch == 'small':
|
||||
self.layer4.depthwise_conv.conv.stride = (1, 1)
|
||||
self.layer9.depthwise_conv.conv.stride = (1, 1)
|
||||
for i in range(4, len(layers)):
|
||||
layer = getattr(self, layers[i])
|
||||
if isinstance(layer, InvertedResidual):
|
||||
modified_module = layer.depthwise_conv.conv
|
||||
else:
|
||||
modified_module = layer.conv
|
||||
|
||||
if i < 9:
|
||||
modified_module.dilation = (2, 2)
|
||||
pad = 2
|
||||
else:
|
||||
modified_module.dilation = (4, 4)
|
||||
pad = 4
|
||||
|
||||
if not isinstance(modified_module, Conv2dAdaptivePadding):
|
||||
# Adjust padding
|
||||
pad *= (modified_module.kernel_size[0] - 1) // 2
|
||||
modified_module.padding = (pad, pad)
|
||||
else:
|
||||
self.layer7.depthwise_conv.conv.stride = (1, 1)
|
||||
self.layer13.depthwise_conv.conv.stride = (1, 1)
|
||||
for i in range(7, len(layers)):
|
||||
layer = getattr(self, layers[i])
|
||||
if isinstance(layer, InvertedResidual):
|
||||
modified_module = layer.depthwise_conv.conv
|
||||
else:
|
||||
modified_module = layer.conv
|
||||
|
||||
if i < 13:
|
||||
modified_module.dilation = (2, 2)
|
||||
pad = 2
|
||||
else:
|
||||
modified_module.dilation = (4, 4)
|
||||
pad = 4
|
||||
|
||||
if not isinstance(modified_module, Conv2dAdaptivePadding):
|
||||
# Adjust padding
|
||||
pad *= (modified_module.kernel_size[0] - 1) // 2
|
||||
modified_module.padding = (pad, pad)
|
||||
|
||||
return layers
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer = getattr(self, layer_name)
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
return outs
|
||||
|
||||
def _freeze_stages(self):
|
||||
for i in range(self.frozen_stages + 1):
|
||||
layer = getattr(self, f'layer{i}')
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
467
finetune/mmseg/models/backbones/mscan.py
Normal file
467
finetune/mmseg/models/backbones/mscan.py
Normal file
@@ -0,0 +1,467 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Originally from https://github.com/visual-attention-network/segnext
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks import DropPath
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.model.weight_init import (constant_init, normal_init,
|
||||
trunc_normal_init)
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class Mlp(BaseModule):
|
||||
"""Multi Layer Perceptron (MLP) Module.
|
||||
|
||||
Args:
|
||||
in_features (int): The dimension of input features.
|
||||
hidden_features (int): The dimension of hidden features.
|
||||
Defaults: None.
|
||||
out_features (int): The dimension of output features.
|
||||
Defaults: None.
|
||||
act_cfg (dict): Config dict for activation layer in block.
|
||||
Default: dict(type='GELU').
|
||||
drop (float): The number of dropout rate in MLP block.
|
||||
Defaults: 0.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_cfg=dict(type='GELU'),
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
|
||||
self.dwconv = nn.Conv2d(
|
||||
hidden_features,
|
||||
hidden_features,
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
bias=True,
|
||||
groups=hidden_features)
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.fc1(x)
|
||||
|
||||
x = self.dwconv(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class StemConv(BaseModule):
|
||||
"""Stem Block at the beginning of Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): The dimension of input channels.
|
||||
out_channels (int): The dimension of output channels.
|
||||
act_cfg (dict): Config dict for activation layer in block.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults: dict(type='SyncBN', requires_grad=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||||
super().__init__()
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels // 2,
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=(1, 1)),
|
||||
build_norm_layer(norm_cfg, out_channels // 2)[1],
|
||||
build_activation_layer(act_cfg),
|
||||
nn.Conv2d(
|
||||
out_channels // 2,
|
||||
out_channels,
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=(1, 1)),
|
||||
build_norm_layer(norm_cfg, out_channels)[1],
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.size()
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x, H, W
|
||||
|
||||
|
||||
class MSCAAttention(BaseModule):
|
||||
"""Attention Module in Multi-Scale Convolutional Attention Module (MSCA).
|
||||
|
||||
Args:
|
||||
channels (int): The dimension of channels.
|
||||
kernel_sizes (list): The size of attention
|
||||
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
|
||||
paddings (list): The number of
|
||||
corresponding padding value in attention module.
|
||||
Defaults: [2, [0, 3], [0, 5], [0, 10]].
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
|
||||
paddings=[2, [0, 3], [0, 5], [0, 10]]):
|
||||
super().__init__()
|
||||
self.conv0 = nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=kernel_sizes[0],
|
||||
padding=paddings[0],
|
||||
groups=channels)
|
||||
for i, (kernel_size,
|
||||
padding) in enumerate(zip(kernel_sizes[1:], paddings[1:])):
|
||||
kernel_size_ = [kernel_size, kernel_size[::-1]]
|
||||
padding_ = [padding, padding[::-1]]
|
||||
conv_name = [f'conv{i}_1', f'conv{i}_2']
|
||||
for i_kernel, i_pad, i_conv in zip(kernel_size_, padding_,
|
||||
conv_name):
|
||||
self.add_module(
|
||||
i_conv,
|
||||
nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
tuple(i_kernel),
|
||||
padding=i_pad,
|
||||
groups=channels))
|
||||
self.conv3 = nn.Conv2d(channels, channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
u = x.clone()
|
||||
|
||||
attn = self.conv0(x)
|
||||
|
||||
# Multi-Scale Feature extraction
|
||||
attn_0 = self.conv0_1(attn)
|
||||
attn_0 = self.conv0_2(attn_0)
|
||||
|
||||
attn_1 = self.conv1_1(attn)
|
||||
attn_1 = self.conv1_2(attn_1)
|
||||
|
||||
attn_2 = self.conv2_1(attn)
|
||||
attn_2 = self.conv2_2(attn_2)
|
||||
|
||||
attn = attn + attn_0 + attn_1 + attn_2
|
||||
# Channel Mixing
|
||||
attn = self.conv3(attn)
|
||||
|
||||
# Convolutional Attention
|
||||
x = attn * u
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MSCASpatialAttention(BaseModule):
|
||||
"""Spatial Attention Module in Multi-Scale Convolutional Attention Module
|
||||
(MSCA).
|
||||
|
||||
Args:
|
||||
in_channels (int): The dimension of channels.
|
||||
attention_kernel_sizes (list): The size of attention
|
||||
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
|
||||
attention_kernel_paddings (list): The number of
|
||||
corresponding padding value in attention module.
|
||||
Defaults: [2, [0, 3], [0, 5], [0, 10]].
|
||||
act_cfg (dict): Config dict for activation layer in block.
|
||||
Default: dict(type='GELU').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
|
||||
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
|
||||
act_cfg=dict(type='GELU')):
|
||||
super().__init__()
|
||||
self.proj_1 = nn.Conv2d(in_channels, in_channels, 1)
|
||||
self.activation = build_activation_layer(act_cfg)
|
||||
self.spatial_gating_unit = MSCAAttention(in_channels,
|
||||
attention_kernel_sizes,
|
||||
attention_kernel_paddings)
|
||||
self.proj_2 = nn.Conv2d(in_channels, in_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
shorcut = x.clone()
|
||||
x = self.proj_1(x)
|
||||
x = self.activation(x)
|
||||
x = self.spatial_gating_unit(x)
|
||||
x = self.proj_2(x)
|
||||
x = x + shorcut
|
||||
return x
|
||||
|
||||
|
||||
class MSCABlock(BaseModule):
|
||||
"""Basic Multi-Scale Convolutional Attention Block. It leverage the large-
|
||||
kernel attention (LKA) mechanism to build both channel and spatial
|
||||
attention. In each branch, it uses two depth-wise strip convolutions to
|
||||
approximate standard depth-wise convolutions with large kernels. The kernel
|
||||
size for each branch is set to 7, 11, and 21, respectively.
|
||||
|
||||
Args:
|
||||
channels (int): The dimension of channels.
|
||||
attention_kernel_sizes (list): The size of attention
|
||||
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
|
||||
attention_kernel_paddings (list): The number of
|
||||
corresponding padding value in attention module.
|
||||
Defaults: [2, [0, 3], [0, 5], [0, 10]].
|
||||
mlp_ratio (float): The ratio of multiple input dimension to
|
||||
calculate hidden feature in MLP layer. Defaults: 4.0.
|
||||
drop (float): The number of dropout rate in MLP block.
|
||||
Defaults: 0.0.
|
||||
drop_path (float): The ratio of drop paths.
|
||||
Defaults: 0.0.
|
||||
act_cfg (dict): Config dict for activation layer in block.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults: dict(type='SyncBN', requires_grad=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
|
||||
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
|
||||
mlp_ratio=4.,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||||
super().__init__()
|
||||
self.norm1 = build_norm_layer(norm_cfg, channels)[1]
|
||||
self.attn = MSCASpatialAttention(channels, attention_kernel_sizes,
|
||||
attention_kernel_paddings, act_cfg)
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = build_norm_layer(norm_cfg, channels)[1]
|
||||
mlp_hidden_channels = int(channels * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=channels,
|
||||
hidden_features=mlp_hidden_channels,
|
||||
act_cfg=act_cfg,
|
||||
drop=drop)
|
||||
layer_scale_init_value = 1e-2
|
||||
self.layer_scale_1 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones(channels), requires_grad=True)
|
||||
self.layer_scale_2 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones(channels), requires_grad=True)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
"""Forward function."""
|
||||
|
||||
B, N, C = x.shape
|
||||
x = x.permute(0, 2, 1).view(B, C, H, W)
|
||||
x = x + self.drop_path(
|
||||
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) *
|
||||
self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(
|
||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) *
|
||||
self.mlp(self.norm2(x)))
|
||||
x = x.view(B, C, N).permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
class OverlapPatchEmbed(BaseModule):
|
||||
"""Image to Patch Embedding.
|
||||
|
||||
Args:
|
||||
patch_size (int): The patch size.
|
||||
Defaults: 7.
|
||||
stride (int): Stride of the convolutional layer.
|
||||
Default: 4.
|
||||
in_channels (int): The number of input channels.
|
||||
Defaults: 3.
|
||||
embed_dims (int): The dimensions of embedding.
|
||||
Defaults: 768.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults: dict(type='SyncBN', requires_grad=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
patch_size=7,
|
||||
stride=4,
|
||||
in_channels=3,
|
||||
embed_dim=768,
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||||
super().__init__()
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=stride,
|
||||
padding=patch_size // 2)
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
x = self.norm(x)
|
||||
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
return x, H, W
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MSCAN(BaseModule):
|
||||
"""SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone.
|
||||
|
||||
This backbone is the implementation of `SegNeXt: Rethinking
|
||||
Convolutional Attention Design for Semantic
|
||||
Segmentation <https://arxiv.org/abs/2209.08575>`_.
|
||||
Inspiration from https://github.com/visual-attention-network/segnext.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels. Defaults: 3.
|
||||
embed_dims (list[int]): Embedding dimension.
|
||||
Defaults: [64, 128, 256, 512].
|
||||
mlp_ratios (list[int]): Ratio of mlp hidden dim to embedding dim.
|
||||
Defaults: [4, 4, 4, 4].
|
||||
drop_rate (float): Dropout rate. Defaults: 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults: 0.
|
||||
depths (list[int]): Depths of each Swin Transformer stage.
|
||||
Default: [3, 4, 6, 3].
|
||||
num_stages (int): MSCAN stages. Default: 4.
|
||||
attention_kernel_sizes (list): Size of attention kernel in
|
||||
Attention Module (Figure 2(b) of original paper).
|
||||
Defaults: [5, [1, 7], [1, 11], [1, 21]].
|
||||
attention_kernel_paddings (list): Size of attention paddings
|
||||
in Attention Module (Figure 2(b) of original paper).
|
||||
Defaults: [2, [0, 3], [0, 5], [0, 10]].
|
||||
norm_cfg (dict): Config of norm layers.
|
||||
Defaults: dict(type='SyncBN', requires_grad=True).
|
||||
pretrained (str, optional): model pretrained path.
|
||||
Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=[64, 128, 256, 512],
|
||||
mlp_ratios=[4, 4, 4, 4],
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
depths=[3, 4, 6, 3],
|
||||
num_stages=4,
|
||||
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
|
||||
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.depths = depths
|
||||
self.num_stages = num_stages
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
] # stochastic depth decay rule
|
||||
cur = 0
|
||||
|
||||
for i in range(num_stages):
|
||||
if i == 0:
|
||||
patch_embed = StemConv(3, embed_dims[0], norm_cfg=norm_cfg)
|
||||
else:
|
||||
patch_embed = OverlapPatchEmbed(
|
||||
patch_size=7 if i == 0 else 3,
|
||||
stride=4 if i == 0 else 2,
|
||||
in_channels=in_channels if i == 0 else embed_dims[i - 1],
|
||||
embed_dim=embed_dims[i],
|
||||
norm_cfg=norm_cfg)
|
||||
|
||||
block = nn.ModuleList([
|
||||
MSCABlock(
|
||||
channels=embed_dims[i],
|
||||
attention_kernel_sizes=attention_kernel_sizes,
|
||||
attention_kernel_paddings=attention_kernel_paddings,
|
||||
mlp_ratio=mlp_ratios[i],
|
||||
drop=drop_rate,
|
||||
drop_path=dpr[cur + j],
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg) for j in range(depths[i])
|
||||
])
|
||||
norm = nn.LayerNorm(embed_dims[i])
|
||||
cur += depths[i]
|
||||
|
||||
setattr(self, f'patch_embed{i + 1}', patch_embed)
|
||||
setattr(self, f'block{i + 1}', block)
|
||||
setattr(self, f'norm{i + 1}', norm)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize modules of MSCAN."""
|
||||
|
||||
print('init cfg', self.init_cfg)
|
||||
if self.init_cfg is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[
|
||||
1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
normal_init(
|
||||
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
||||
else:
|
||||
super().init_weights()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
B = x.shape[0]
|
||||
outs = []
|
||||
|
||||
for i in range(self.num_stages):
|
||||
patch_embed = getattr(self, f'patch_embed{i + 1}')
|
||||
block = getattr(self, f'block{i + 1}')
|
||||
norm = getattr(self, f'norm{i + 1}')
|
||||
x, H, W = patch_embed(x)
|
||||
for blk in block:
|
||||
x = blk(x, H, W)
|
||||
x = norm(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
return outs
|
||||
522
finetune/mmseg/models/backbones/pidnet.py
Normal file
522
finetune/mmseg/models/backbones/pidnet.py
Normal file
@@ -0,0 +1,522 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.runner import CheckpointLoader
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType
|
||||
from ..utils import DAPPM, PAPPM, BasicBlock, Bottleneck
|
||||
|
||||
|
||||
class PagFM(BaseModule):
|
||||
"""Pixel-attention-guided fusion module.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
channels (int): The number of channels.
|
||||
after_relu (bool): Whether to use ReLU before attention.
|
||||
Default: False.
|
||||
with_channel (bool): Whether to use channel attention.
|
||||
Default: False.
|
||||
upsample_mode (str): The mode of upsample. Default: 'bilinear'.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(typ='ReLU', inplace=True).
|
||||
init_cfg (dict): Config dict for initialization. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
after_relu: bool = False,
|
||||
with_channel: bool = False,
|
||||
upsample_mode: str = 'bilinear',
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(typ='ReLU', inplace=True),
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
self.after_relu = after_relu
|
||||
self.with_channel = with_channel
|
||||
self.upsample_mode = upsample_mode
|
||||
self.f_i = ConvModule(
|
||||
in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
self.f_p = ConvModule(
|
||||
in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
if with_channel:
|
||||
self.up = ConvModule(
|
||||
channels, in_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
if after_relu:
|
||||
self.relu = MODELS.build(act_cfg)
|
||||
|
||||
def forward(self, x_p: Tensor, x_i: Tensor) -> Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x_p (Tensor): The featrue map from P branch.
|
||||
x_i (Tensor): The featrue map from I branch.
|
||||
|
||||
Returns:
|
||||
Tensor: The feature map with pixel-attention-guided fusion.
|
||||
"""
|
||||
if self.after_relu:
|
||||
x_p = self.relu(x_p)
|
||||
x_i = self.relu(x_i)
|
||||
|
||||
f_i = self.f_i(x_i)
|
||||
f_i = F.interpolate(
|
||||
f_i,
|
||||
size=x_p.shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=False)
|
||||
|
||||
f_p = self.f_p(x_p)
|
||||
|
||||
if self.with_channel:
|
||||
sigma = torch.sigmoid(self.up(f_p * f_i))
|
||||
else:
|
||||
sigma = torch.sigmoid(torch.sum(f_p * f_i, dim=1).unsqueeze(1))
|
||||
|
||||
x_i = F.interpolate(
|
||||
x_i,
|
||||
size=x_p.shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=False)
|
||||
|
||||
out = sigma * x_i + (1 - sigma) * x_p
|
||||
return out
|
||||
|
||||
|
||||
class Bag(BaseModule):
|
||||
"""Boundary-attention-guided fusion module.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
kernel_size (int): The kernel size of the convolution. Default: 3.
|
||||
padding (int): The padding of the convolution. Default: 1.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: dict(order=('norm', 'act', 'conv')).
|
||||
init_cfg (dict): Config dict for initialization. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
padding: int = 1,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
conv_cfg: OptConfigType = dict(order=('norm', 'act', 'conv')),
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.conv = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
padding=padding,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**conv_cfg)
|
||||
|
||||
def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x_p (Tensor): The featrue map from P branch.
|
||||
x_i (Tensor): The featrue map from I branch.
|
||||
x_d (Tensor): The featrue map from D branch.
|
||||
|
||||
Returns:
|
||||
Tensor: The feature map with boundary-attention-guided fusion.
|
||||
"""
|
||||
sigma = torch.sigmoid(x_d)
|
||||
return self.conv(sigma * x_p + (1 - sigma) * x_i)
|
||||
|
||||
|
||||
class LightBag(BaseModule):
|
||||
"""Light Boundary-attention-guided fusion module.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer. Default: None.
|
||||
init_cfg (dict): Config dict for initialization. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = None,
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
self.f_p = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.f_i = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor:
|
||||
"""Forward function.
|
||||
Args:
|
||||
x_p (Tensor): The featrue map from P branch.
|
||||
x_i (Tensor): The featrue map from I branch.
|
||||
x_d (Tensor): The featrue map from D branch.
|
||||
|
||||
Returns:
|
||||
Tensor: The feature map with light boundary-attention-guided
|
||||
fusion.
|
||||
"""
|
||||
sigma = torch.sigmoid(x_d)
|
||||
|
||||
f_p = self.f_p((1 - sigma) * x_i + x_p)
|
||||
f_i = self.f_i(x_i + sigma * x_p)
|
||||
|
||||
return f_p + f_i
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PIDNet(BaseModule):
|
||||
"""PIDNet backbone.
|
||||
|
||||
This backbone is the implementation of `PIDNet: A Real-time Semantic
|
||||
Segmentation Network Inspired from PID Controller
|
||||
<https://arxiv.org/abs/2206.02066>`_.
|
||||
Modified from https://github.com/XuJiacong/PIDNet.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels. Default: 3.
|
||||
channels (int): The number of channels in the stem layer. Default: 64.
|
||||
ppm_channels (int): The number of channels in the PPM layer.
|
||||
Default: 96.
|
||||
num_stem_blocks (int): The number of blocks in the stem layer.
|
||||
Default: 2.
|
||||
num_branch_blocks (int): The number of blocks in the branch layer.
|
||||
Default: 3.
|
||||
align_corners (bool): The align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
init_cfg (dict): Config dict for initialization. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int = 3,
|
||||
channels: int = 64,
|
||||
ppm_channels: int = 96,
|
||||
num_stem_blocks: int = 2,
|
||||
num_branch_blocks: int = 3,
|
||||
align_corners: bool = False,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
init_cfg: OptConfigType = None,
|
||||
**kwargs):
|
||||
super().__init__(init_cfg)
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.align_corners = align_corners
|
||||
|
||||
# stem layer
|
||||
self.stem = self._make_stem_layer(in_channels, channels,
|
||||
num_stem_blocks)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
# I Branch
|
||||
self.i_branch_layers = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.i_branch_layers.append(
|
||||
self._make_layer(
|
||||
block=BasicBlock if i < 2 else Bottleneck,
|
||||
in_channels=channels * 2**(i + 1),
|
||||
channels=channels * 8 if i > 0 else channels * 4,
|
||||
num_blocks=num_branch_blocks if i < 2 else 2,
|
||||
stride=2))
|
||||
|
||||
# P Branch
|
||||
self.p_branch_layers = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.p_branch_layers.append(
|
||||
self._make_layer(
|
||||
block=BasicBlock if i < 2 else Bottleneck,
|
||||
in_channels=channels * 2,
|
||||
channels=channels * 2,
|
||||
num_blocks=num_stem_blocks if i < 2 else 1))
|
||||
self.compression_1 = ConvModule(
|
||||
channels * 4,
|
||||
channels * 2,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
self.compression_2 = ConvModule(
|
||||
channels * 8,
|
||||
channels * 2,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
self.pag_1 = PagFM(channels * 2, channels)
|
||||
self.pag_2 = PagFM(channels * 2, channels)
|
||||
|
||||
# D Branch
|
||||
if num_stem_blocks == 2:
|
||||
self.d_branch_layers = nn.ModuleList([
|
||||
self._make_single_layer(BasicBlock, channels * 2, channels),
|
||||
self._make_layer(Bottleneck, channels, channels, 1)
|
||||
])
|
||||
channel_expand = 1
|
||||
spp_module = PAPPM
|
||||
dfm_module = LightBag
|
||||
act_cfg_dfm = None
|
||||
else:
|
||||
self.d_branch_layers = nn.ModuleList([
|
||||
self._make_single_layer(BasicBlock, channels * 2,
|
||||
channels * 2),
|
||||
self._make_single_layer(BasicBlock, channels * 2, channels * 2)
|
||||
])
|
||||
channel_expand = 2
|
||||
spp_module = DAPPM
|
||||
dfm_module = Bag
|
||||
act_cfg_dfm = act_cfg
|
||||
|
||||
self.diff_1 = ConvModule(
|
||||
channels * 4,
|
||||
channels * channel_expand,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
self.diff_2 = ConvModule(
|
||||
channels * 8,
|
||||
channels * 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
self.spp = spp_module(
|
||||
channels * 16, ppm_channels, channels * 4, num_scales=5)
|
||||
self.dfm = dfm_module(
|
||||
channels * 4, channels * 4, norm_cfg=norm_cfg, act_cfg=act_cfg_dfm)
|
||||
|
||||
self.d_branch_layers.append(
|
||||
self._make_layer(Bottleneck, channels * 2, channels * 2, 1))
|
||||
|
||||
def _make_stem_layer(self, in_channels: int, channels: int,
|
||||
num_blocks: int) -> nn.Sequential:
|
||||
"""Make stem layer.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
num_blocks (int): Number of blocks.
|
||||
|
||||
Returns:
|
||||
nn.Sequential: The stem layer.
|
||||
"""
|
||||
|
||||
layers = [
|
||||
ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
]
|
||||
|
||||
layers.append(
|
||||
self._make_layer(BasicBlock, channels, channels, num_blocks))
|
||||
layers.append(nn.ReLU())
|
||||
layers.append(
|
||||
self._make_layer(
|
||||
BasicBlock, channels, channels * 2, num_blocks, stride=2))
|
||||
layers.append(nn.ReLU())
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_layer(self,
|
||||
block: BasicBlock,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
num_blocks: int,
|
||||
stride: int = 1) -> nn.Sequential:
|
||||
"""Make layer for PIDNet backbone.
|
||||
Args:
|
||||
block (BasicBlock): Basic block.
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
num_blocks (int): Number of blocks.
|
||||
stride (int): Stride of the first block. Default: 1.
|
||||
|
||||
Returns:
|
||||
nn.Sequential: The Branch Layer.
|
||||
"""
|
||||
downsample = None
|
||||
if stride != 1 or in_channels != channels * block.expansion:
|
||||
downsample = ConvModule(
|
||||
in_channels,
|
||||
channels * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
layers = [block(in_channels, channels, stride, downsample)]
|
||||
in_channels = channels * block.expansion
|
||||
for i in range(1, num_blocks):
|
||||
layers.append(
|
||||
block(
|
||||
in_channels,
|
||||
channels,
|
||||
stride=1,
|
||||
act_cfg_out=None if i == num_blocks - 1 else self.act_cfg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_single_layer(self,
|
||||
block: Union[BasicBlock, Bottleneck],
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
stride: int = 1) -> nn.Module:
|
||||
"""Make single layer for PIDNet backbone.
|
||||
Args:
|
||||
block (BasicBlock or Bottleneck): Basic block or Bottleneck.
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
stride (int): Stride of the first block. Default: 1.
|
||||
|
||||
Returns:
|
||||
nn.Module
|
||||
"""
|
||||
|
||||
downsample = None
|
||||
if stride != 1 or in_channels != channels * block.expansion:
|
||||
downsample = ConvModule(
|
||||
in_channels,
|
||||
channels * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
return block(
|
||||
in_channels, channels, stride, downsample, act_cfg_out=None)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize the weights in backbone.
|
||||
|
||||
Since the D branch is not initialized by the pre-trained model, we
|
||||
initialize it with the same method as the ResNet.
|
||||
"""
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
if self.init_cfg is not None:
|
||||
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
||||
f'specify `Pretrained` in ' \
|
||||
f'`init_cfg` in ' \
|
||||
f'{self.__class__.__name__} '
|
||||
ckpt = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], map_location='cpu')
|
||||
self.load_state_dict(ckpt, strict=False)
|
||||
|
||||
def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor with shape (B, C, H, W).
|
||||
|
||||
Returns:
|
||||
Tensor or tuple[Tensor]: If self.training is True, return
|
||||
tuple[Tensor], else return Tensor.
|
||||
"""
|
||||
w_out = x.shape[-1] // 8
|
||||
h_out = x.shape[-2] // 8
|
||||
|
||||
# stage 0-2
|
||||
x = self.stem(x)
|
||||
|
||||
# stage 3
|
||||
x_i = self.relu(self.i_branch_layers[0](x))
|
||||
x_p = self.p_branch_layers[0](x)
|
||||
x_d = self.d_branch_layers[0](x)
|
||||
|
||||
comp_i = self.compression_1(x_i)
|
||||
x_p = self.pag_1(x_p, comp_i)
|
||||
diff_i = self.diff_1(x_i)
|
||||
x_d += F.interpolate(
|
||||
diff_i,
|
||||
size=[h_out, w_out],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.training:
|
||||
temp_p = x_p.clone()
|
||||
|
||||
# stage 4
|
||||
x_i = self.relu(self.i_branch_layers[1](x_i))
|
||||
x_p = self.p_branch_layers[1](self.relu(x_p))
|
||||
x_d = self.d_branch_layers[1](self.relu(x_d))
|
||||
|
||||
comp_i = self.compression_2(x_i)
|
||||
x_p = self.pag_2(x_p, comp_i)
|
||||
diff_i = self.diff_2(x_i)
|
||||
x_d += F.interpolate(
|
||||
diff_i,
|
||||
size=[h_out, w_out],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.training:
|
||||
temp_d = x_d.clone()
|
||||
|
||||
# stage 5
|
||||
x_i = self.i_branch_layers[2](x_i)
|
||||
x_p = self.p_branch_layers[2](self.relu(x_p))
|
||||
x_d = self.d_branch_layers[2](self.relu(x_d))
|
||||
|
||||
x_i = self.spp(x_i)
|
||||
x_i = F.interpolate(
|
||||
x_i,
|
||||
size=[h_out, w_out],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
out = self.dfm(x_p, x_i, x_d)
|
||||
return (temp_p, out, temp_d) if self.training else out
|
||||
318
finetune/mmseg/models/backbones/resnest.py
Normal file
318
finetune/mmseg/models/backbones/resnest.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import ResLayer
|
||||
from .resnet import Bottleneck as _Bottleneck
|
||||
from .resnet import ResNetV1d
|
||||
|
||||
|
||||
class RSoftmax(nn.Module):
|
||||
"""Radix Softmax module in ``SplitAttentionConv2d``.
|
||||
|
||||
Args:
|
||||
radix (int): Radix of input.
|
||||
groups (int): Groups of input.
|
||||
"""
|
||||
|
||||
def __init__(self, radix, groups):
|
||||
super().__init__()
|
||||
self.radix = radix
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
batch = x.size(0)
|
||||
if self.radix > 1:
|
||||
x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
|
||||
x = F.softmax(x, dim=1)
|
||||
x = x.reshape(batch, -1)
|
||||
else:
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class SplitAttentionConv2d(nn.Module):
|
||||
"""Split-Attention Conv2d in ResNeSt.
|
||||
|
||||
Args:
|
||||
in_channels (int): Same as nn.Conv2d.
|
||||
out_channels (int): Same as nn.Conv2d.
|
||||
kernel_size (int | tuple[int]): Same as nn.Conv2d.
|
||||
stride (int | tuple[int]): Same as nn.Conv2d.
|
||||
padding (int | tuple[int]): Same as nn.Conv2d.
|
||||
dilation (int | tuple[int]): Same as nn.Conv2d.
|
||||
groups (int): Same as nn.Conv2d.
|
||||
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||
reduction_factor (int): Reduction factor of inter_channels. Default: 4.
|
||||
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
||||
which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
dcn (dict): Config dict for DCN. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dcn=None):
|
||||
super().__init__()
|
||||
inter_channels = max(in_channels * radix // reduction_factor, 32)
|
||||
self.radix = radix
|
||||
self.groups = groups
|
||||
self.channels = channels
|
||||
self.with_dcn = dcn is not None
|
||||
self.dcn = dcn
|
||||
fallback_on_stride = False
|
||||
if self.with_dcn:
|
||||
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
|
||||
if self.with_dcn and not fallback_on_stride:
|
||||
assert conv_cfg is None, 'conv_cfg must be None for DCN'
|
||||
conv_cfg = dcn
|
||||
self.conv = build_conv_layer(
|
||||
conv_cfg,
|
||||
in_channels,
|
||||
channels * radix,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups * radix,
|
||||
bias=False)
|
||||
self.norm0_name, norm0 = build_norm_layer(
|
||||
norm_cfg, channels * radix, postfix=0)
|
||||
self.add_module(self.norm0_name, norm0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc1 = build_conv_layer(
|
||||
None, channels, inter_channels, 1, groups=self.groups)
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, inter_channels, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.fc2 = build_conv_layer(
|
||||
None, inter_channels, channels * radix, 1, groups=self.groups)
|
||||
self.rsoftmax = RSoftmax(radix, groups)
|
||||
|
||||
@property
|
||||
def norm0(self):
|
||||
"""nn.Module: the normalization layer named "norm0" """
|
||||
return getattr(self, self.norm0_name)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: the normalization layer named "norm1" """
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.norm0(x)
|
||||
x = self.relu(x)
|
||||
|
||||
batch, rchannel = x.shape[:2]
|
||||
batch = x.size(0)
|
||||
if self.radix > 1:
|
||||
splits = x.view(batch, self.radix, -1, *x.shape[2:])
|
||||
gap = splits.sum(dim=1)
|
||||
else:
|
||||
gap = x
|
||||
gap = F.adaptive_avg_pool2d(gap, 1)
|
||||
gap = self.fc1(gap)
|
||||
|
||||
gap = self.norm1(gap)
|
||||
gap = self.relu(gap)
|
||||
|
||||
atten = self.fc2(gap)
|
||||
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
|
||||
|
||||
if self.radix > 1:
|
||||
attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
|
||||
out = torch.sum(attens * splits, dim=1)
|
||||
else:
|
||||
out = atten * x
|
||||
return out.contiguous()
|
||||
|
||||
|
||||
class Bottleneck(_Bottleneck):
|
||||
"""Bottleneck block for ResNeSt.
|
||||
|
||||
Args:
|
||||
inplane (int): Input planes of this block.
|
||||
planes (int): Middle planes of this block.
|
||||
groups (int): Groups of conv2.
|
||||
width_per_group (int): Width per group of conv2. 64x4d indicates
|
||||
``groups=64, width_per_group=4`` and 32x8d indicates
|
||||
``groups=32, width_per_group=8``.
|
||||
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||
reduction_factor (int): Reduction factor of inter_channels in
|
||||
SplitAttentionConv2d. Default: 4.
|
||||
avg_down_stride (bool): Whether to use average pool for stride in
|
||||
Bottleneck. Default: True.
|
||||
kwargs (dict): Key word arguments for base class.
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
groups=1,
|
||||
base_width=4,
|
||||
base_channels=64,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True,
|
||||
**kwargs):
|
||||
"""Bottleneck block for ResNeSt."""
|
||||
super().__init__(inplanes, planes, **kwargs)
|
||||
|
||||
if groups == 1:
|
||||
width = self.planes
|
||||
else:
|
||||
width = math.floor(self.planes *
|
||||
(base_width / base_channels)) * groups
|
||||
|
||||
self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
self.norm_cfg, width, postfix=1)
|
||||
self.norm3_name, norm3 = build_norm_layer(
|
||||
self.norm_cfg, self.planes * self.expansion, postfix=3)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
self.inplanes,
|
||||
width,
|
||||
kernel_size=1,
|
||||
stride=self.conv1_stride,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.with_modulated_dcn = False
|
||||
self.conv2 = SplitAttentionConv2d(
|
||||
width,
|
||||
width,
|
||||
kernel_size=3,
|
||||
stride=1 if self.avg_down_stride else self.conv2_stride,
|
||||
padding=self.dilation,
|
||||
dilation=self.dilation,
|
||||
groups=groups,
|
||||
radix=radix,
|
||||
reduction_factor=reduction_factor,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dcn=self.dcn)
|
||||
delattr(self, self.norm2_name)
|
||||
|
||||
if self.avg_down_stride:
|
||||
self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
|
||||
|
||||
self.conv3 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
width,
|
||||
self.planes * self.expansion,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
self.add_module(self.norm3_name, norm3)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.norm1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv1_plugin_names)
|
||||
|
||||
out = self.conv2(out)
|
||||
|
||||
if self.avg_down_stride:
|
||||
out = self.avd_layer(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv2_plugin_names)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.norm3(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv3_plugin_names)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNeSt(ResNetV1d):
|
||||
"""ResNeSt backbone.
|
||||
|
||||
This backbone is the implementation of `ResNeSt:
|
||||
Split-Attention Networks <https://arxiv.org/abs/2004.08955>`_.
|
||||
|
||||
Args:
|
||||
groups (int): Number of groups of Bottleneck. Default: 1
|
||||
base_width (int): Base width of Bottleneck. Default: 4
|
||||
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||
reduction_factor (int): Reduction factor of inter_channels in
|
||||
SplitAttentionConv2d. Default: 4.
|
||||
avg_down_stride (bool): Whether to use average pool for stride in
|
||||
Bottleneck. Default: True.
|
||||
kwargs (dict): Keyword arguments for ResNet.
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
50: (Bottleneck, (3, 4, 6, 3)),
|
||||
101: (Bottleneck, (3, 4, 23, 3)),
|
||||
152: (Bottleneck, (3, 8, 36, 3)),
|
||||
200: (Bottleneck, (3, 24, 36, 3))
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
groups=1,
|
||||
base_width=4,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True,
|
||||
**kwargs):
|
||||
self.groups = groups
|
||||
self.base_width = base_width
|
||||
self.radix = radix
|
||||
self.reduction_factor = reduction_factor
|
||||
self.avg_down_stride = avg_down_stride
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def make_res_layer(self, **kwargs):
|
||||
"""Pack all blocks in a stage into a ``ResLayer``."""
|
||||
return ResLayer(
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
base_channels=self.base_channels,
|
||||
radix=self.radix,
|
||||
reduction_factor=self.reduction_factor,
|
||||
avg_down_stride=self.avg_down_stride,
|
||||
**kwargs)
|
||||
712
finetune/mmseg/models/backbones/resnet.py
Normal file
712
finetune/mmseg/models/backbones/resnet.py
Normal file
@@ -0,0 +1,712 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import ResLayer
|
||||
|
||||
|
||||
class BasicBlock(BaseModule):
|
||||
"""Basic block for ResNet."""
|
||||
|
||||
expansion = 1
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
downsample=None,
|
||||
style='pytorch',
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dcn=None,
|
||||
plugins=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
conv_cfg,
|
||||
inplanes,
|
||||
planes,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.conv2 = build_conv_layer(
|
||||
conv_cfg, planes, planes, 3, padding=1, bias=False)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.with_cp = with_cp
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: normalization layer after the first convolution layer"""
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
"""nn.Module: normalization layer after the second convolution layer"""
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.norm1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.norm2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(BaseModule):
|
||||
"""Bottleneck block for ResNet.
|
||||
|
||||
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
|
||||
"caffe", the stride-two layer is the first 1x1 conv layer.
|
||||
"""
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
downsample=None,
|
||||
style='pytorch',
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dcn=None,
|
||||
plugins=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
assert style in ['pytorch', 'caffe']
|
||||
assert dcn is None or isinstance(dcn, dict)
|
||||
assert plugins is None or isinstance(plugins, list)
|
||||
if plugins is not None:
|
||||
allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
|
||||
assert all(p['position'] in allowed_position for p in plugins)
|
||||
|
||||
self.inplanes = inplanes
|
||||
self.planes = planes
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.style = style
|
||||
self.with_cp = with_cp
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.dcn = dcn
|
||||
self.with_dcn = dcn is not None
|
||||
self.plugins = plugins
|
||||
self.with_plugins = plugins is not None
|
||||
|
||||
if self.with_plugins:
|
||||
# collect plugins for conv1/conv2/conv3
|
||||
self.after_conv1_plugins = [
|
||||
plugin['cfg'] for plugin in plugins
|
||||
if plugin['position'] == 'after_conv1'
|
||||
]
|
||||
self.after_conv2_plugins = [
|
||||
plugin['cfg'] for plugin in plugins
|
||||
if plugin['position'] == 'after_conv2'
|
||||
]
|
||||
self.after_conv3_plugins = [
|
||||
plugin['cfg'] for plugin in plugins
|
||||
if plugin['position'] == 'after_conv3'
|
||||
]
|
||||
|
||||
if self.style == 'pytorch':
|
||||
self.conv1_stride = 1
|
||||
self.conv2_stride = stride
|
||||
else:
|
||||
self.conv1_stride = stride
|
||||
self.conv2_stride = 1
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
|
||||
self.norm3_name, norm3 = build_norm_layer(
|
||||
norm_cfg, planes * self.expansion, postfix=3)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
conv_cfg,
|
||||
inplanes,
|
||||
planes,
|
||||
kernel_size=1,
|
||||
stride=self.conv1_stride,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
fallback_on_stride = False
|
||||
if self.with_dcn:
|
||||
fallback_on_stride = dcn.pop('fallback_on_stride', False)
|
||||
if not self.with_dcn or fallback_on_stride:
|
||||
self.conv2 = build_conv_layer(
|
||||
conv_cfg,
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=self.conv2_stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
else:
|
||||
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
|
||||
self.conv2 = build_conv_layer(
|
||||
dcn,
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=self.conv2_stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
self.conv3 = build_conv_layer(
|
||||
conv_cfg,
|
||||
planes,
|
||||
planes * self.expansion,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
self.add_module(self.norm3_name, norm3)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
|
||||
if self.with_plugins:
|
||||
self.after_conv1_plugin_names = self.make_block_plugins(
|
||||
planes, self.after_conv1_plugins)
|
||||
self.after_conv2_plugin_names = self.make_block_plugins(
|
||||
planes, self.after_conv2_plugins)
|
||||
self.after_conv3_plugin_names = self.make_block_plugins(
|
||||
planes * self.expansion, self.after_conv3_plugins)
|
||||
|
||||
def make_block_plugins(self, in_channels, plugins):
|
||||
"""make plugins for block.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of plugin.
|
||||
plugins (list[dict]): List of plugins cfg to build.
|
||||
|
||||
Returns:
|
||||
list[str]: List of the names of plugin.
|
||||
"""
|
||||
assert isinstance(plugins, list)
|
||||
plugin_names = []
|
||||
for plugin in plugins:
|
||||
plugin = plugin.copy()
|
||||
name, layer = build_plugin_layer(
|
||||
plugin,
|
||||
in_channels=in_channels,
|
||||
postfix=plugin.pop('postfix', ''))
|
||||
assert not hasattr(self, name), f'duplicate plugin {name}'
|
||||
self.add_module(name, layer)
|
||||
plugin_names.append(name)
|
||||
return plugin_names
|
||||
|
||||
def forward_plugin(self, x, plugin_names):
|
||||
"""Forward function for plugins."""
|
||||
out = x
|
||||
for name in plugin_names:
|
||||
out = getattr(self, name)(x)
|
||||
return out
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: normalization layer after the first convolution layer"""
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
"""nn.Module: normalization layer after the second convolution layer"""
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
@property
|
||||
def norm3(self):
|
||||
"""nn.Module: normalization layer after the third convolution layer"""
|
||||
return getattr(self, self.norm3_name)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.norm1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv1_plugin_names)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.norm2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv2_plugin_names)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.norm3(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv3_plugin_names)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNet(BaseModule):
|
||||
"""ResNet backbone.
|
||||
|
||||
This backbone is the improved implementation of `Deep Residual Learning
|
||||
for Image Recognition <https://arxiv.org/abs/1512.03385>`_.
|
||||
|
||||
Args:
|
||||
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
stem_channels (int): Number of stem channels. Default: 64.
|
||||
base_channels (int): Number of base channels of res layer. Default: 64.
|
||||
num_stages (int): Resnet stages, normally 4. Default: 4.
|
||||
strides (Sequence[int]): Strides of the first block of each stage.
|
||||
Default: (1, 2, 2, 2).
|
||||
dilations (Sequence[int]): Dilation of each stage.
|
||||
Default: (1, 1, 1, 1).
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
||||
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
||||
the first 1x1 conv layer. Default: 'pytorch'.
|
||||
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
|
||||
Default: False.
|
||||
avg_down (bool): Use AvgPool instead of stride conv when
|
||||
downsampling in the bottleneck. Default: False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Default: -1.
|
||||
conv_cfg (dict | None): Dictionary to construct and config conv layer.
|
||||
When conv_cfg is None, cfg will be set to dict(type='Conv2d').
|
||||
Default: None.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
dcn (dict | None): Dictionary to construct and config DCN conv layer.
|
||||
When dcn is not None, conv_cfg must be None. Default: None.
|
||||
stage_with_dcn (Sequence[bool]): Whether to set DCN conv for each
|
||||
stage. The length of stage_with_dcn is equal to num_stages.
|
||||
Default: (False, False, False, False).
|
||||
plugins (list[dict]): List of plugins for stages, each dict contains:
|
||||
|
||||
- cfg (dict, required): Cfg dict to build plugin.
|
||||
|
||||
- position (str, required): Position inside block to insert plugin,
|
||||
options: 'after_conv1', 'after_conv2', 'after_conv3'.
|
||||
|
||||
- stages (tuple[bool], optional): Stages to apply plugin, length
|
||||
should be same as 'num_stages'.
|
||||
Default: None.
|
||||
multi_grid (Sequence[int]|None): Multi grid dilation rates of last
|
||||
stage. Default: None.
|
||||
contract_dilation (bool): Whether contract first dilation of each layer
|
||||
Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
zero_init_residual (bool): Whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity. Default: True.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Example:
|
||||
>>> from mmseg.models import ResNet
|
||||
>>> import torch
|
||||
>>> self = ResNet(depth=18)
|
||||
>>> self.eval()
|
||||
>>> inputs = torch.rand(1, 3, 32, 32)
|
||||
>>> level_outputs = self.forward(inputs)
|
||||
>>> for level_out in level_outputs:
|
||||
... print(tuple(level_out.shape))
|
||||
(1, 64, 8, 8)
|
||||
(1, 128, 4, 4)
|
||||
(1, 256, 2, 2)
|
||||
(1, 512, 1, 1)
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
18: (BasicBlock, (2, 2, 2, 2)),
|
||||
34: (BasicBlock, (3, 4, 6, 3)),
|
||||
50: (Bottleneck, (3, 4, 6, 3)),
|
||||
101: (Bottleneck, (3, 4, 23, 3)),
|
||||
152: (Bottleneck, (3, 8, 36, 3))
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
depth,
|
||||
in_channels=3,
|
||||
stem_channels=64,
|
||||
base_channels=64,
|
||||
num_stages=4,
|
||||
strides=(1, 2, 2, 2),
|
||||
dilations=(1, 1, 1, 1),
|
||||
out_indices=(0, 1, 2, 3),
|
||||
style='pytorch',
|
||||
deep_stem=False,
|
||||
avg_down=False,
|
||||
frozen_stages=-1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=False,
|
||||
dcn=None,
|
||||
stage_with_dcn=(False, False, False, False),
|
||||
plugins=None,
|
||||
multi_grid=None,
|
||||
contract_dilation=False,
|
||||
with_cp=False,
|
||||
zero_init_residual=True,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
if depth not in self.arch_settings:
|
||||
raise KeyError(f'invalid depth {depth} for resnet')
|
||||
|
||||
self.pretrained = pretrained
|
||||
self.zero_init_residual = zero_init_residual
|
||||
block_init_cfg = None
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
block = self.arch_settings[depth][0]
|
||||
if self.zero_init_residual:
|
||||
if block is BasicBlock:
|
||||
block_init_cfg = dict(
|
||||
type='Constant',
|
||||
val=0,
|
||||
override=dict(name='norm2'))
|
||||
elif block is Bottleneck:
|
||||
block_init_cfg = dict(
|
||||
type='Constant',
|
||||
val=0,
|
||||
override=dict(name='norm3'))
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.depth = depth
|
||||
self.stem_channels = stem_channels
|
||||
self.base_channels = base_channels
|
||||
self.num_stages = num_stages
|
||||
assert num_stages >= 1 and num_stages <= 4
|
||||
self.strides = strides
|
||||
self.dilations = dilations
|
||||
assert len(strides) == len(dilations) == num_stages
|
||||
self.out_indices = out_indices
|
||||
assert max(out_indices) < num_stages
|
||||
self.style = style
|
||||
self.deep_stem = deep_stem
|
||||
self.avg_down = avg_down
|
||||
self.frozen_stages = frozen_stages
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.with_cp = with_cp
|
||||
self.norm_eval = norm_eval
|
||||
self.dcn = dcn
|
||||
self.stage_with_dcn = stage_with_dcn
|
||||
if dcn is not None:
|
||||
assert len(stage_with_dcn) == num_stages
|
||||
self.plugins = plugins
|
||||
self.multi_grid = multi_grid
|
||||
self.contract_dilation = contract_dilation
|
||||
self.block, stage_blocks = self.arch_settings[depth]
|
||||
self.stage_blocks = stage_blocks[:num_stages]
|
||||
self.inplanes = stem_channels
|
||||
|
||||
self._make_stem_layer(in_channels, stem_channels)
|
||||
|
||||
self.res_layers = []
|
||||
for i, num_blocks in enumerate(self.stage_blocks):
|
||||
stride = strides[i]
|
||||
dilation = dilations[i]
|
||||
dcn = self.dcn if self.stage_with_dcn[i] else None
|
||||
if plugins is not None:
|
||||
stage_plugins = self.make_stage_plugins(plugins, i)
|
||||
else:
|
||||
stage_plugins = None
|
||||
# multi grid is applied to last layer only
|
||||
stage_multi_grid = multi_grid if i == len(
|
||||
self.stage_blocks) - 1 else None
|
||||
planes = base_channels * 2**i
|
||||
res_layer = self.make_res_layer(
|
||||
block=self.block,
|
||||
inplanes=self.inplanes,
|
||||
planes=planes,
|
||||
num_blocks=num_blocks,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
style=self.style,
|
||||
avg_down=self.avg_down,
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
dcn=dcn,
|
||||
plugins=stage_plugins,
|
||||
multi_grid=stage_multi_grid,
|
||||
contract_dilation=contract_dilation,
|
||||
init_cfg=block_init_cfg)
|
||||
self.inplanes = planes * self.block.expansion
|
||||
layer_name = f'layer{i+1}'
|
||||
self.add_module(layer_name, res_layer)
|
||||
self.res_layers.append(layer_name)
|
||||
|
||||
self._freeze_stages()
|
||||
|
||||
self.feat_dim = self.block.expansion * base_channels * 2**(
|
||||
len(self.stage_blocks) - 1)
|
||||
|
||||
def make_stage_plugins(self, plugins, stage_idx):
|
||||
"""make plugins for ResNet 'stage_idx'th stage .
|
||||
|
||||
Currently we support to insert 'context_block',
|
||||
'empirical_attention_block', 'nonlocal_block' into the backbone like
|
||||
ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
|
||||
Bottleneck.
|
||||
|
||||
An example of plugins format could be :
|
||||
>>> plugins=[
|
||||
... dict(cfg=dict(type='xxx', arg1='xxx'),
|
||||
... stages=(False, True, True, True),
|
||||
... position='after_conv2'),
|
||||
... dict(cfg=dict(type='yyy'),
|
||||
... stages=(True, True, True, True),
|
||||
... position='after_conv3'),
|
||||
... dict(cfg=dict(type='zzz', postfix='1'),
|
||||
... stages=(True, True, True, True),
|
||||
... position='after_conv3'),
|
||||
... dict(cfg=dict(type='zzz', postfix='2'),
|
||||
... stages=(True, True, True, True),
|
||||
... position='after_conv3')
|
||||
... ]
|
||||
>>> self = ResNet(depth=18)
|
||||
>>> stage_plugins = self.make_stage_plugins(plugins, 0)
|
||||
>>> assert len(stage_plugins) == 3
|
||||
|
||||
Suppose 'stage_idx=0', the structure of blocks in the stage would be:
|
||||
conv1-> conv2->conv3->yyy->zzz1->zzz2
|
||||
Suppose 'stage_idx=1', the structure of blocks in the stage would be:
|
||||
conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
|
||||
|
||||
If stages is missing, the plugin would be applied to all stages.
|
||||
|
||||
Args:
|
||||
plugins (list[dict]): List of plugins cfg to build. The postfix is
|
||||
required if multiple same type plugins are inserted.
|
||||
stage_idx (int): Index of stage to build
|
||||
|
||||
Returns:
|
||||
list[dict]: Plugins for current stage
|
||||
"""
|
||||
stage_plugins = []
|
||||
for plugin in plugins:
|
||||
plugin = plugin.copy()
|
||||
stages = plugin.pop('stages', None)
|
||||
assert stages is None or len(stages) == self.num_stages
|
||||
# whether to insert plugin into current stage
|
||||
if stages is None or stages[stage_idx]:
|
||||
stage_plugins.append(plugin)
|
||||
|
||||
return stage_plugins
|
||||
|
||||
def make_res_layer(self, **kwargs):
|
||||
"""Pack all blocks in a stage into a ``ResLayer``."""
|
||||
return ResLayer(**kwargs)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: the normalization layer named "norm1" """
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def _make_stem_layer(self, in_channels, stem_channels):
|
||||
"""Make stem layer for ResNet."""
|
||||
if self.deep_stem:
|
||||
self.stem = nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
stem_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
|
||||
nn.ReLU(inplace=True),
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
stem_channels // 2,
|
||||
stem_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
|
||||
nn.ReLU(inplace=True),
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
stem_channels // 2,
|
||||
stem_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, stem_channels)[1],
|
||||
nn.ReLU(inplace=True))
|
||||
else:
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
stem_channels,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
bias=False)
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
self.norm_cfg, stem_channels, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
def _freeze_stages(self):
|
||||
"""Freeze stages param and norm stats."""
|
||||
if self.frozen_stages >= 0:
|
||||
if self.deep_stem:
|
||||
self.stem.eval()
|
||||
for param in self.stem.parameters():
|
||||
param.requires_grad = False
|
||||
else:
|
||||
self.norm1.eval()
|
||||
for m in [self.conv1, self.norm1]:
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
m = getattr(self, f'layer{i}')
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
if self.deep_stem:
|
||||
x = self.stem(x)
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.res_layers):
|
||||
res_layer = getattr(self, layer_name)
|
||||
x = res_layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep normalization layer
|
||||
freezed."""
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNetV1c(ResNet):
|
||||
"""ResNetV1c variant described in [1]_.
|
||||
|
||||
Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv in
|
||||
the input stem with three 3x3 convs. For more details please refer to `Bag
|
||||
of Tricks for Image Classification with Convolutional Neural Networks
|
||||
<https://arxiv.org/abs/1812.01187>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(deep_stem=True, avg_down=False, **kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNetV1d(ResNet):
|
||||
"""ResNetV1d variant described in [1]_.
|
||||
|
||||
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
|
||||
the input stem with three 3x3 convs. And in the downsampling block, a 2x2
|
||||
avg_pool with stride 2 is added before conv, whose stride is changed to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(deep_stem=True, avg_down=True, **kwargs)
|
||||
150
finetune/mmseg/models/backbones/resnext.py
Normal file
150
finetune/mmseg/models/backbones/resnext.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import ResLayer
|
||||
from .resnet import Bottleneck as _Bottleneck
|
||||
from .resnet import ResNet
|
||||
|
||||
|
||||
class Bottleneck(_Bottleneck):
|
||||
"""Bottleneck block for ResNeXt.
|
||||
|
||||
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
|
||||
"caffe", the stride-two layer is the first 1x1 conv layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
groups=1,
|
||||
base_width=4,
|
||||
base_channels=64,
|
||||
**kwargs):
|
||||
super().__init__(inplanes, planes, **kwargs)
|
||||
|
||||
if groups == 1:
|
||||
width = self.planes
|
||||
else:
|
||||
width = math.floor(self.planes *
|
||||
(base_width / base_channels)) * groups
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
self.norm_cfg, width, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
self.norm_cfg, width, postfix=2)
|
||||
self.norm3_name, norm3 = build_norm_layer(
|
||||
self.norm_cfg, self.planes * self.expansion, postfix=3)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
self.inplanes,
|
||||
width,
|
||||
kernel_size=1,
|
||||
stride=self.conv1_stride,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
fallback_on_stride = False
|
||||
self.with_modulated_dcn = False
|
||||
if self.with_dcn:
|
||||
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
|
||||
if not self.with_dcn or fallback_on_stride:
|
||||
self.conv2 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
width,
|
||||
width,
|
||||
kernel_size=3,
|
||||
stride=self.conv2_stride,
|
||||
padding=self.dilation,
|
||||
dilation=self.dilation,
|
||||
groups=groups,
|
||||
bias=False)
|
||||
else:
|
||||
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
|
||||
self.conv2 = build_conv_layer(
|
||||
self.dcn,
|
||||
width,
|
||||
width,
|
||||
kernel_size=3,
|
||||
stride=self.conv2_stride,
|
||||
padding=self.dilation,
|
||||
dilation=self.dilation,
|
||||
groups=groups,
|
||||
bias=False)
|
||||
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
self.conv3 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
width,
|
||||
self.planes * self.expansion,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
self.add_module(self.norm3_name, norm3)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNeXt(ResNet):
|
||||
"""ResNeXt backbone.
|
||||
|
||||
This backbone is the implementation of `Aggregated
|
||||
Residual Transformations for Deep Neural
|
||||
Networks <https://arxiv.org/abs/1611.05431>`_.
|
||||
|
||||
Args:
|
||||
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
|
||||
in_channels (int): Number of input image channels. Normally 3.
|
||||
num_stages (int): Resnet stages, normally 4.
|
||||
groups (int): Group of resnext.
|
||||
base_width (int): Base width of resnext.
|
||||
strides (Sequence[int]): Strides of the first block of each stage.
|
||||
dilations (Sequence[int]): Dilation of each stage.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
||||
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
||||
the first 1x1 conv layer.
|
||||
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
|
||||
not freezing any parameters.
|
||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed.
|
||||
zero_init_residual (bool): whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity.
|
||||
|
||||
Example:
|
||||
>>> from mmseg.models import ResNeXt
|
||||
>>> import torch
|
||||
>>> self = ResNeXt(depth=50)
|
||||
>>> self.eval()
|
||||
>>> inputs = torch.rand(1, 3, 32, 32)
|
||||
>>> level_outputs = self.forward(inputs)
|
||||
>>> for level_out in level_outputs:
|
||||
... print(tuple(level_out.shape))
|
||||
(1, 256, 8, 8)
|
||||
(1, 512, 4, 4)
|
||||
(1, 1024, 2, 2)
|
||||
(1, 2048, 1, 1)
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
50: (Bottleneck, (3, 4, 6, 3)),
|
||||
101: (Bottleneck, (3, 4, 23, 3)),
|
||||
152: (Bottleneck, (3, 8, 36, 3))
|
||||
}
|
||||
|
||||
def __init__(self, groups=1, base_width=4, **kwargs):
|
||||
self.groups = groups
|
||||
self.base_width = base_width
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def make_res_layer(self, **kwargs):
|
||||
"""Pack all blocks in a stage into a ``ResLayer``"""
|
||||
return ResLayer(
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
base_channels=self.base_channels,
|
||||
**kwargs)
|
||||
422
finetune/mmseg/models/backbones/stdc.py
Normal file
422
finetune/mmseg/models/backbones/stdc.py
Normal file
@@ -0,0 +1,422 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Modified from https://github.com/MichaelFan01/STDC-Seg."""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .bisenetv1 import AttentionRefinementModule
|
||||
|
||||
|
||||
class STDCModule(BaseModule):
|
||||
"""STDCModule.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels before scaling.
|
||||
stride (int): The number of stride for the first conv layer.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
act_cfg (dict): The activation config for conv layers.
|
||||
num_convs (int): Numbers of conv layers.
|
||||
fusion_type (str): Type of fusion operation. Default: 'add'.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
norm_cfg=None,
|
||||
act_cfg=None,
|
||||
num_convs=4,
|
||||
fusion_type='add',
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert num_convs > 1
|
||||
assert fusion_type in ['add', 'cat']
|
||||
self.stride = stride
|
||||
self.with_downsample = True if self.stride == 2 else False
|
||||
self.fusion_type = fusion_type
|
||||
|
||||
self.layers = ModuleList()
|
||||
conv_0 = ConvModule(
|
||||
in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg)
|
||||
|
||||
if self.with_downsample:
|
||||
self.downsample = ConvModule(
|
||||
out_channels // 2,
|
||||
out_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=out_channels // 2,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
if self.fusion_type == 'add':
|
||||
self.layers.append(nn.Sequential(conv_0, self.downsample))
|
||||
self.skip = Sequential(
|
||||
ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None))
|
||||
else:
|
||||
self.layers.append(conv_0)
|
||||
self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
|
||||
else:
|
||||
self.layers.append(conv_0)
|
||||
|
||||
for i in range(1, num_convs):
|
||||
out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i
|
||||
self.layers.append(
|
||||
ConvModule(
|
||||
out_channels // 2**i,
|
||||
out_channels // out_factor,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.fusion_type == 'add':
|
||||
out = self.forward_add(inputs)
|
||||
else:
|
||||
out = self.forward_cat(inputs)
|
||||
return out
|
||||
|
||||
def forward_add(self, inputs):
|
||||
layer_outputs = []
|
||||
x = inputs.clone()
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
layer_outputs.append(x)
|
||||
if self.with_downsample:
|
||||
inputs = self.skip(inputs)
|
||||
|
||||
return torch.cat(layer_outputs, dim=1) + inputs
|
||||
|
||||
def forward_cat(self, inputs):
|
||||
x0 = self.layers[0](inputs)
|
||||
layer_outputs = [x0]
|
||||
for i, layer in enumerate(self.layers[1:]):
|
||||
if i == 0:
|
||||
if self.with_downsample:
|
||||
x = layer(self.downsample(x0))
|
||||
else:
|
||||
x = layer(x0)
|
||||
else:
|
||||
x = layer(x)
|
||||
layer_outputs.append(x)
|
||||
if self.with_downsample:
|
||||
layer_outputs[0] = self.skip(x0)
|
||||
return torch.cat(layer_outputs, dim=1)
|
||||
|
||||
|
||||
class FeatureFusionModule(BaseModule):
|
||||
"""Feature Fusion Module. This module is different from FeatureFusionModule
|
||||
in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter
|
||||
channel number is calculated by given `scale_factor`, while
|
||||
FeatureFusionModule in BiSeNetV1 only uses one ConvModule in
|
||||
`self.conv_atten`.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
scale_factor (int): The number of channel scale factor.
|
||||
Default: 4.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): The activation config for conv layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
scale_factor=4,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
channels = out_channels // scale_factor
|
||||
self.conv0 = ConvModule(
|
||||
in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
self.attention = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
ConvModule(
|
||||
out_channels,
|
||||
channels,
|
||||
1,
|
||||
norm_cfg=None,
|
||||
bias=False,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
channels,
|
||||
out_channels,
|
||||
1,
|
||||
norm_cfg=None,
|
||||
bias=False,
|
||||
act_cfg=None), nn.Sigmoid())
|
||||
|
||||
def forward(self, spatial_inputs, context_inputs):
|
||||
inputs = torch.cat([spatial_inputs, context_inputs], dim=1)
|
||||
x = self.conv0(inputs)
|
||||
attn = self.attention(x)
|
||||
x_attn = x * attn
|
||||
return x_attn + x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class STDCNet(BaseModule):
|
||||
"""This backbone is the implementation of `Rethinking BiSeNet For Real-time
|
||||
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
|
||||
|
||||
Args:
|
||||
stdc_type (int): The type of backbone structure,
|
||||
`STDCNet1` and`STDCNet2` denotes two main backbones in paper,
|
||||
whose FLOPs is 813M and 1446M, respectively.
|
||||
in_channels (int): The num of input_channels.
|
||||
channels (tuple[int]): The output channels for each stage.
|
||||
bottleneck_type (str): The type of STDC Module type, the value must
|
||||
be 'add' or 'cat'.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
act_cfg (dict): The activation config for conv layers.
|
||||
num_convs (int): Numbers of conv layer at each STDC Module.
|
||||
Default: 4.
|
||||
with_final_conv (bool): Whether add a conv layer at the Module output.
|
||||
Default: True.
|
||||
pretrained (str, optional): Model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> stdc_type = 'STDCNet1'
|
||||
>>> in_channels = 3
|
||||
>>> channels = (32, 64, 256, 512, 1024)
|
||||
>>> bottleneck_type = 'cat'
|
||||
>>> inputs = torch.rand(1, 3, 1024, 2048)
|
||||
>>> self = STDCNet(stdc_type, in_channels,
|
||||
... channels, bottleneck_type).eval()
|
||||
>>> outputs = self.forward(inputs)
|
||||
>>> for i in range(len(outputs)):
|
||||
... print(f'outputs[{i}].shape = {outputs[i].shape}')
|
||||
outputs[0].shape = torch.Size([1, 256, 128, 256])
|
||||
outputs[1].shape = torch.Size([1, 512, 64, 128])
|
||||
outputs[2].shape = torch.Size([1, 1024, 32, 64])
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
'STDCNet1': [(2, 1), (2, 1), (2, 1)],
|
||||
'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)]
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
stdc_type,
|
||||
in_channels,
|
||||
channels,
|
||||
bottleneck_type,
|
||||
norm_cfg,
|
||||
act_cfg,
|
||||
num_convs=4,
|
||||
with_final_conv=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert stdc_type in self.arch_settings, \
|
||||
f'invalid structure {stdc_type} for STDCNet.'
|
||||
assert bottleneck_type in ['add', 'cat'],\
|
||||
f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}'
|
||||
|
||||
assert len(channels) == 5,\
|
||||
f'invalid channels length {len(channels)} for STDCNet.'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.stage_strides = self.arch_settings[stdc_type]
|
||||
self.prtrained = pretrained
|
||||
self.num_convs = num_convs
|
||||
self.with_final_conv = with_final_conv
|
||||
|
||||
self.stages = ModuleList([
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels[0],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
self.channels[0],
|
||||
self.channels[1],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
])
|
||||
# `self.num_shallow_features` is the number of shallow modules in
|
||||
# `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper.
|
||||
# They are both not used for following modules like Attention
|
||||
# Refinement Module and Feature Fusion Module.
|
||||
# Thus they would be cut from `outs`. Please refer to Figure 4
|
||||
# of original paper for more details.
|
||||
self.num_shallow_features = len(self.stages)
|
||||
|
||||
for strides in self.stage_strides:
|
||||
idx = len(self.stages) - 1
|
||||
self.stages.append(
|
||||
self._make_stage(self.channels[idx], self.channels[idx + 1],
|
||||
strides, norm_cfg, act_cfg, bottleneck_type))
|
||||
# After appending, `self.stages` is a ModuleList including several
|
||||
# shallow modules and STDCModules.
|
||||
# (len(self.stages) ==
|
||||
# self.num_shallow_features + len(self.stage_strides))
|
||||
if self.with_final_conv:
|
||||
self.final_conv = ConvModule(
|
||||
self.channels[-1],
|
||||
max(1024, self.channels[-1]),
|
||||
1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def _make_stage(self, in_channels, out_channels, strides, norm_cfg,
|
||||
act_cfg, bottleneck_type):
|
||||
layers = []
|
||||
for i, stride in enumerate(strides):
|
||||
layers.append(
|
||||
STDCModule(
|
||||
in_channels if i == 0 else out_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
norm_cfg,
|
||||
act_cfg,
|
||||
num_convs=self.num_convs,
|
||||
fusion_type=bottleneck_type))
|
||||
return Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
for stage in self.stages:
|
||||
x = stage(x)
|
||||
outs.append(x)
|
||||
if self.with_final_conv:
|
||||
outs[-1] = self.final_conv(outs[-1])
|
||||
outs = outs[self.num_shallow_features:]
|
||||
return tuple(outs)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class STDCContextPathNet(BaseModule):
|
||||
"""STDCNet with Context Path. The `outs` below is a list of three feature
|
||||
maps from deep to shallow, whose height and width is from small to big,
|
||||
respectively. The biggest feature map of `outs` is outputted for
|
||||
`STDCHead`, where Detail Loss would be calculated by Detail Ground-truth.
|
||||
The other two feature maps are used for Attention Refinement Module,
|
||||
respectively. Besides, the biggest feature map of `outs` and the last
|
||||
output of Attention Refinement Module are concatenated for Feature Fusion
|
||||
Module. Then, this fusion feature map `feat_fuse` would be outputted for
|
||||
`decode_head`. More details please refer to Figure 4 of original paper.
|
||||
|
||||
Args:
|
||||
backbone_cfg (dict): Config dict for stdc backbone.
|
||||
last_in_channels (tuple(int)), The number of channels of last
|
||||
two feature maps from stdc backbone. Default: (1024, 512).
|
||||
out_channels (int): The channels of output feature maps.
|
||||
Default: 128.
|
||||
ffm_cfg (dict): Config dict for Feature Fusion Module. Default:
|
||||
`dict(in_channels=512, out_channels=256, scale_factor=4)`.
|
||||
upsample_mode (str): Algorithm used for upsampling:
|
||||
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||
``'trilinear'``. Default: ``'nearest'``.
|
||||
align_corners (str): align_corners argument of F.interpolate. It
|
||||
must be `None` if upsample_mode is ``'nearest'``. Default: None.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Return:
|
||||
outputs (tuple): The tuple of list of output feature map for
|
||||
auxiliary heads and decoder head.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_cfg,
|
||||
last_in_channels=(1024, 512),
|
||||
out_channels=128,
|
||||
ffm_cfg=dict(
|
||||
in_channels=512, out_channels=256, scale_factor=4),
|
||||
upsample_mode='nearest',
|
||||
align_corners=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.backbone = MODELS.build(backbone_cfg)
|
||||
self.arms = ModuleList()
|
||||
self.convs = ModuleList()
|
||||
for channels in last_in_channels:
|
||||
self.arms.append(AttentionRefinementModule(channels, out_channels))
|
||||
self.convs.append(
|
||||
ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg))
|
||||
self.conv_avg = ConvModule(
|
||||
last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg)
|
||||
|
||||
self.ffm = FeatureFusionModule(**ffm_cfg)
|
||||
|
||||
self.upsample_mode = upsample_mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
outs = list(self.backbone(x))
|
||||
avg = F.adaptive_avg_pool2d(outs[-1], 1)
|
||||
avg_feat = self.conv_avg(avg)
|
||||
|
||||
feature_up = resize(
|
||||
avg_feat,
|
||||
size=outs[-1].shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=self.align_corners)
|
||||
arms_out = []
|
||||
for i in range(len(self.arms)):
|
||||
x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up
|
||||
feature_up = resize(
|
||||
x_arm,
|
||||
size=outs[len(outs) - 1 - i - 1].shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=self.align_corners)
|
||||
feature_up = self.convs[i](feature_up)
|
||||
arms_out.append(feature_up)
|
||||
|
||||
feat_fuse = self.ffm(outs[0], arms_out[1])
|
||||
|
||||
# The `outputs` has four feature maps.
|
||||
# `outs[0]` is outputted for `STDCHead` auxiliary head.
|
||||
# Two feature maps of `arms_out` are outputted for auxiliary head.
|
||||
# `feat_fuse` is outputted for decoder head.
|
||||
outputs = [outs[0]] + list(arms_out) + [feat_fuse]
|
||||
return tuple(outputs)
|
||||
757
finetune/mmseg/models/backbones/swin.py
Normal file
757
finetune/mmseg/models/backbones/swin.py
Normal file
@@ -0,0 +1,757 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, build_dropout
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, trunc_normal_,
|
||||
trunc_normal_init)
|
||||
from mmengine.runner import CheckpointLoader
|
||||
from mmengine.utils import to_2tuple
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils.embed import PatchEmbed, PatchMerging
|
||||
|
||||
|
||||
class WindowMSA(BaseModule):
|
||||
"""Window based multi-head self-attention (W-MSA) module with relative
|
||||
position bias.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||||
Default: True.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
||||
Default: 0.0
|
||||
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
|
||||
init_cfg (dict | None, optional): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
window_size,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.embed_dims = embed_dims
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_embed_dims = embed_dims // num_heads
|
||||
self.scale = qk_scale or head_embed_dims**-0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
||||
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# About 2x faster than original impl
|
||||
Wh, Ww = self.window_size
|
||||
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
|
||||
rel_position_index = rel_index_coords + rel_index_coords.T
|
||||
rel_position_index = rel_position_index.flip(1).contiguous()
|
||||
self.register_buffer('relative_position_index', rel_position_index)
|
||||
|
||||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop_rate)
|
||||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.proj_drop = nn.Dropout(proj_drop_rate)
|
||||
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
Args:
|
||||
|
||||
x (tensor): input features with shape of (num_windows*B, N, C)
|
||||
mask (tensor | None, Optional): mask with shape of (num_windows,
|
||||
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
# make torchscript happy (cannot use tensor as tuple)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1],
|
||||
-1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B // nW, nW, self.num_heads, N,
|
||||
N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def double_step_seq(step1, len1, step2, len2):
|
||||
seq1 = torch.arange(0, step1 * len1, step1)
|
||||
seq2 = torch.arange(0, step2 * len2, step2)
|
||||
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
|
||||
|
||||
|
||||
class ShiftWindowMSA(BaseModule):
|
||||
"""Shifted Window Multihead Self-Attention Module.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): The height and width of the window.
|
||||
shift_size (int, optional): The shift step of each window towards
|
||||
right-bottom. If zero, act as regular window-msa. Defaults to 0.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||||
Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Defaults: None.
|
||||
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
||||
Defaults: 0.
|
||||
proj_drop_rate (float, optional): Dropout ratio of output.
|
||||
Defaults: 0.
|
||||
dropout_layer (dict, optional): The dropout_layer used before output.
|
||||
Defaults: dict(type='DropPath', drop_prob=0.).
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
window_size,
|
||||
shift_size=0,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop_rate=0,
|
||||
proj_drop_rate=0,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=0.),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
assert 0 <= self.shift_size < self.window_size
|
||||
|
||||
self.w_msa = WindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
window_size=to_2tuple(window_size),
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
proj_drop_rate=proj_drop_rate,
|
||||
init_cfg=None)
|
||||
|
||||
self.drop = build_dropout(dropout_layer)
|
||||
|
||||
def forward(self, query, hw_shape):
|
||||
B, L, C = query.shape
|
||||
H, W = hw_shape
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
query = query.view(B, H, W, C)
|
||||
|
||||
# pad feature maps to multiples of window size
|
||||
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||
query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
|
||||
H_pad, W_pad = query.shape[1], query.shape[2]
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_query = torch.roll(
|
||||
query,
|
||||
shifts=(-self.shift_size, -self.shift_size),
|
||||
dims=(1, 2))
|
||||
|
||||
# calculate attention mask for SW-MSA
|
||||
img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
# nW, window_size, window_size, 1
|
||||
mask_windows = self.window_partition(img_mask)
|
||||
mask_windows = mask_windows.view(
|
||||
-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
||||
float(-100.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0))
|
||||
else:
|
||||
shifted_query = query
|
||||
attn_mask = None
|
||||
|
||||
# nW*B, window_size, window_size, C
|
||||
query_windows = self.window_partition(shifted_query)
|
||||
# nW*B, window_size*window_size, C
|
||||
query_windows = query_windows.view(-1, self.window_size**2, C)
|
||||
|
||||
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
|
||||
attn_windows = self.w_msa(query_windows, mask=attn_mask)
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size,
|
||||
self.window_size, C)
|
||||
|
||||
# B H' W' C
|
||||
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(
|
||||
shifted_x,
|
||||
shifts=(self.shift_size, self.shift_size),
|
||||
dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
if pad_r > 0 or pad_b:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
def window_reverse(self, windows, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
window_size = self.window_size
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size,
|
||||
window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
def window_partition(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
window_size = self.window_size
|
||||
x = x.view(B, H // window_size, window_size, W // window_size,
|
||||
window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
||||
windows = windows.view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
class SwinBlock(BaseModule):
|
||||
""""
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
window_size (int, optional): The local window scale. Default: 7.
|
||||
shift (bool, optional): whether to shift window or not. Default False.
|
||||
qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
drop_rate (float, optional): Dropout rate. Default: 0.
|
||||
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
|
||||
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
|
||||
act_cfg (dict, optional): The config dict of activation function.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict, optional): The config dict of normalization.
|
||||
Default: dict(type='LN').
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
init_cfg (dict | list | None, optional): The init config.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
window_size=7,
|
||||
shift=False,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.with_cp = with_cp
|
||||
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.attn = ShiftWindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=window_size // 2 if shift else 0,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
proj_drop_rate=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
init_cfg=None)
|
||||
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=2,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg,
|
||||
add_identity=True,
|
||||
init_cfg=None)
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
x = self.norm1(x)
|
||||
x = self.attn(x, hw_shape)
|
||||
|
||||
x = x + identity
|
||||
|
||||
identity = x
|
||||
x = self.norm2(x)
|
||||
x = self.ffn(x, identity=identity)
|
||||
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SwinBlockSequence(BaseModule):
|
||||
"""Implements one stage in Swin Transformer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
depth (int): The number of blocks in this stage.
|
||||
window_size (int, optional): The local window scale. Default: 7.
|
||||
qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
drop_rate (float, optional): Dropout rate. Default: 0.
|
||||
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
|
||||
drop_path_rate (float | list[float], optional): Stochastic depth
|
||||
rate. Default: 0.
|
||||
downsample (BaseModule | None, optional): The downsample operation
|
||||
module. Default: None.
|
||||
act_cfg (dict, optional): The config dict of activation function.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict, optional): The config dict of normalization.
|
||||
Default: dict(type='LN').
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
init_cfg (dict | list | None, optional): The init config.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
depth,
|
||||
window_size=7,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
downsample=None,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(drop_path_rate, list):
|
||||
drop_path_rates = drop_path_rate
|
||||
assert len(drop_path_rates) == depth
|
||||
else:
|
||||
drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
|
||||
|
||||
self.blocks = ModuleList()
|
||||
for i in range(depth):
|
||||
block = SwinBlock(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=feedforward_channels,
|
||||
window_size=window_size,
|
||||
shift=False if i % 2 == 0 else True,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rates[i],
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
init_cfg=None)
|
||||
self.blocks.append(block)
|
||||
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
for block in self.blocks:
|
||||
x = block(x, hw_shape)
|
||||
|
||||
if self.downsample:
|
||||
x_down, down_hw_shape = self.downsample(x, hw_shape)
|
||||
return x_down, down_hw_shape, x, hw_shape
|
||||
else:
|
||||
return x, hw_shape, x, hw_shape
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SwinTransformer(BaseModule):
|
||||
"""Swin Transformer backbone.
|
||||
|
||||
This backbone is the implementation of `Swin Transformer:
|
||||
Hierarchical Vision Transformer using Shifted
|
||||
Windows <https://arxiv.org/abs/2103.14030>`_.
|
||||
Inspiration from https://github.com/microsoft/Swin-Transformer.
|
||||
|
||||
Args:
|
||||
pretrain_img_size (int | tuple[int]): The size of input image when
|
||||
pretrain. Defaults: 224.
|
||||
in_channels (int): The num of input channels.
|
||||
Defaults: 3.
|
||||
embed_dims (int): The feature dimension. Default: 96.
|
||||
patch_size (int | tuple[int]): Patch size. Default: 4.
|
||||
window_size (int): Window size. Default: 7.
|
||||
mlp_ratio (int | float): Ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
depths (tuple[int]): Depths of each Swin Transformer stage.
|
||||
Default: (2, 2, 6, 2).
|
||||
num_heads (tuple[int]): Parallel attention heads of each Swin
|
||||
Transformer stage. Default: (3, 6, 12, 24).
|
||||
strides (tuple[int]): The patch merging or patch embedding stride of
|
||||
each Swin Transformer stage. (In swin, we set kernel size equal to
|
||||
stride.) Default: (4, 2, 2, 2).
|
||||
out_indices (tuple[int]): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key,
|
||||
value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
patch_norm (bool): If add a norm layer for patch embed and patch
|
||||
merging. Default: True.
|
||||
drop_rate (float): Dropout rate. Defaults: 0.
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
|
||||
use_abs_pos_embed (bool): If True, add absolute position embedding to
|
||||
the patch embedding. Defaults: False.
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='LN').
|
||||
norm_cfg (dict): Config dict for normalization layer at
|
||||
output of backone. Defaults: dict(type='LN').
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrain_img_size=224,
|
||||
in_channels=3,
|
||||
embed_dims=96,
|
||||
patch_size=4,
|
||||
window_size=7,
|
||||
mlp_ratio=4,
|
||||
depths=(2, 2, 6, 2),
|
||||
num_heads=(3, 6, 12, 24),
|
||||
strides=(4, 2, 2, 2),
|
||||
out_indices=(0, 1, 2, 3),
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
patch_norm=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.1,
|
||||
use_abs_pos_embed=False,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
frozen_stages=-1,
|
||||
init_cfg=None):
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
if isinstance(pretrain_img_size, int):
|
||||
pretrain_img_size = to_2tuple(pretrain_img_size)
|
||||
elif isinstance(pretrain_img_size, tuple):
|
||||
if len(pretrain_img_size) == 1:
|
||||
pretrain_img_size = to_2tuple(pretrain_img_size[0])
|
||||
assert len(pretrain_img_size) == 2, \
|
||||
f'The size of image should have length 1 or 2, ' \
|
||||
f'but got {len(pretrain_img_size)}'
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be specified at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
init_cfg = init_cfg
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
num_layers = len(depths)
|
||||
self.out_indices = out_indices
|
||||
self.use_abs_pos_embed = use_abs_pos_embed
|
||||
|
||||
assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=strides[0],
|
||||
padding='corner',
|
||||
norm_cfg=norm_cfg if patch_norm else None,
|
||||
init_cfg=None)
|
||||
|
||||
if self.use_abs_pos_embed:
|
||||
patch_row = pretrain_img_size[0] // patch_size
|
||||
patch_col = pretrain_img_size[1] // patch_size
|
||||
num_patches = patch_row * patch_col
|
||||
self.absolute_pos_embed = nn.Parameter(
|
||||
torch.zeros((1, num_patches, embed_dims)))
|
||||
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
# set stochastic depth decay rule
|
||||
total_depth = sum(depths)
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
|
||||
]
|
||||
|
||||
self.stages = ModuleList()
|
||||
in_channels = embed_dims
|
||||
for i in range(num_layers):
|
||||
if i < num_layers - 1:
|
||||
downsample = PatchMerging(
|
||||
in_channels=in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
stride=strides[i + 1],
|
||||
norm_cfg=norm_cfg if patch_norm else None,
|
||||
init_cfg=None)
|
||||
else:
|
||||
downsample = None
|
||||
|
||||
stage = SwinBlockSequence(
|
||||
embed_dims=in_channels,
|
||||
num_heads=num_heads[i],
|
||||
feedforward_channels=int(mlp_ratio * in_channels),
|
||||
depth=depths[i],
|
||||
window_size=window_size,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],
|
||||
downsample=downsample,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
init_cfg=None)
|
||||
self.stages.append(stage)
|
||||
if downsample:
|
||||
in_channels = downsample.out_channels
|
||||
|
||||
self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)]
|
||||
# Add a norm layer for each output
|
||||
for i in out_indices:
|
||||
layer = build_norm_layer(norm_cfg, self.num_features[i])[1]
|
||||
layer_name = f'norm{i}'
|
||||
self.add_module(layer_name, layer)
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep layers freezed."""
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
if self.use_abs_pos_embed:
|
||||
self.absolute_pos_embed.requires_grad = False
|
||||
self.drop_after_pos.eval()
|
||||
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
|
||||
if (i - 1) in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i-1}')
|
||||
norm_layer.eval()
|
||||
for param in norm_layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
m = self.stages[i - 1]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def init_weights(self):
|
||||
if self.init_cfg is None:
|
||||
print_log(f'No pre-trained weights for '
|
||||
f'{self.__class__.__name__}, '
|
||||
f'training start from scratch')
|
||||
if self.use_abs_pos_embed:
|
||||
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
else:
|
||||
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
||||
f'specify `Pretrained` in ' \
|
||||
f'`init_cfg` in ' \
|
||||
f'{self.__class__.__name__} '
|
||||
ckpt = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
if 'state_dict' in ckpt:
|
||||
_state_dict = ckpt['state_dict']
|
||||
elif 'model' in ckpt:
|
||||
_state_dict = ckpt['model']
|
||||
else:
|
||||
_state_dict = ckpt
|
||||
|
||||
state_dict = OrderedDict()
|
||||
for k, v in _state_dict.items():
|
||||
if k.startswith('backbone.'):
|
||||
state_dict[k[9:]] = v
|
||||
else:
|
||||
state_dict[k] = v
|
||||
|
||||
# strip prefix of state_dict
|
||||
if list(state_dict.keys())[0].startswith('module.'):
|
||||
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||
|
||||
# reshape absolute position embedding
|
||||
if state_dict.get('absolute_pos_embed') is not None:
|
||||
absolute_pos_embed = state_dict['absolute_pos_embed']
|
||||
N1, L, C1 = absolute_pos_embed.size()
|
||||
N2, C2, H, W = self.absolute_pos_embed.size()
|
||||
if N1 != N2 or C1 != C2 or L != H * W:
|
||||
print_log('Error in loading absolute_pos_embed, pass')
|
||||
else:
|
||||
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
|
||||
N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
# interpolate position bias table if needed
|
||||
relative_position_bias_table_keys = [
|
||||
k for k in state_dict.keys()
|
||||
if 'relative_position_bias_table' in k
|
||||
]
|
||||
for table_key in relative_position_bias_table_keys:
|
||||
table_pretrained = state_dict[table_key]
|
||||
if table_key in self.state_dict():
|
||||
table_current = self.state_dict()[table_key]
|
||||
L1, nH1 = table_pretrained.size()
|
||||
L2, nH2 = table_current.size()
|
||||
if nH1 != nH2:
|
||||
print_log(f'Error in loading {table_key}, pass')
|
||||
elif L1 != L2:
|
||||
S1 = int(L1**0.5)
|
||||
S2 = int(L2**0.5)
|
||||
table_pretrained_resized = F.interpolate(
|
||||
table_pretrained.permute(1, 0).reshape(
|
||||
1, nH1, S1, S1),
|
||||
size=(S2, S2),
|
||||
mode='bicubic')
|
||||
state_dict[table_key] = table_pretrained_resized.view(
|
||||
nH2, L2).permute(1, 0).contiguous()
|
||||
|
||||
# load state_dict
|
||||
self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
def forward(self, x):
|
||||
x, hw_shape = self.patch_embed(x)
|
||||
|
||||
if self.use_abs_pos_embed:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
outs = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
out = norm_layer(out)
|
||||
out = out.view(-1, *out_hw_shape,
|
||||
self.num_features[i]).permute(0, 3, 1,
|
||||
2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return outs
|
||||
63
finetune/mmseg/models/backbones/timm_backbone.py
Normal file
63
finetune/mmseg/models/backbones/timm_backbone.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
try:
|
||||
import timm
|
||||
except ImportError:
|
||||
timm = None
|
||||
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.registry import MODELS as MMENGINE_MODELS
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class TIMMBackbone(BaseModule):
|
||||
"""Wrapper to use backbones from timm library. More details can be found in
|
||||
`timm <https://github.com/rwightman/pytorch-image-models>`_ .
|
||||
|
||||
Args:
|
||||
model_name (str): Name of timm model to instantiate.
|
||||
pretrained (bool): Load pretrained weights if True.
|
||||
checkpoint_path (str): Path of checkpoint to load after
|
||||
model is initialized.
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
init_cfg (dict, optional): Initialization config dict
|
||||
**kwargs: Other timm & model specific arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
features_only=True,
|
||||
pretrained=True,
|
||||
checkpoint_path='',
|
||||
in_channels=3,
|
||||
init_cfg=None,
|
||||
**kwargs,
|
||||
):
|
||||
if timm is None:
|
||||
raise RuntimeError('timm is not installed')
|
||||
super().__init__(init_cfg)
|
||||
if 'norm_layer' in kwargs:
|
||||
kwargs['norm_layer'] = MMENGINE_MODELS.get(kwargs['norm_layer'])
|
||||
self.timm_model = timm.create_model(
|
||||
model_name=model_name,
|
||||
features_only=features_only,
|
||||
pretrained=pretrained,
|
||||
in_chans=in_channels,
|
||||
checkpoint_path=checkpoint_path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Make unused parameters None
|
||||
self.timm_model.global_pool = None
|
||||
self.timm_model.fc = None
|
||||
self.timm_model.classifier = None
|
||||
|
||||
# Hack to use pretrained weights from timm
|
||||
if pretrained or checkpoint_path:
|
||||
self._is_init = True
|
||||
|
||||
def forward(self, x):
|
||||
features = self.timm_model(x)
|
||||
return features
|
||||
588
finetune/mmseg/models/backbones/twins.py
Normal file
588
finetune/mmseg/models/backbones/twins.py
Normal file
@@ -0,0 +1,588 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, normal_init,
|
||||
trunc_normal_init)
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.models.backbones.mit import EfficientMultiheadAttention
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils.embed import PatchEmbed
|
||||
|
||||
|
||||
class GlobalSubsampledAttention(EfficientMultiheadAttention):
|
||||
"""Global Sub-sampled Attention (Spatial Reduction Attention)
|
||||
|
||||
This module is modified from EfficientMultiheadAttention,
|
||||
which is a module from mmseg.models.backbones.mit.py.
|
||||
Specifically, there is no difference between
|
||||
`GlobalSubsampledAttention` and `EfficientMultiheadAttention`,
|
||||
`GlobalSubsampledAttention` is built as a brand new class
|
||||
because it is renamed as `Global sub-sampled attention (GSA)`
|
||||
in paper.
|
||||
|
||||
|
||||
Args:
|
||||
embed_dims (int): The embedding dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
attn_drop (float): A Dropout layer on attn_output_weights.
|
||||
Default: 0.0.
|
||||
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
||||
Default: 0.0.
|
||||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||||
when adding the shortcut. Default: None.
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dims)
|
||||
or (n, batch, embed_dims). Default: False.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default: True.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
sr_ratio (int): The ratio of spatial reduction of GSA of PCPVT.
|
||||
Default: 1.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
dropout_layer=None,
|
||||
batch_first=True,
|
||||
qkv_bias=True,
|
||||
norm_cfg=dict(type='LN'),
|
||||
sr_ratio=1,
|
||||
init_cfg=None):
|
||||
super().__init__(
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
dropout_layer=dropout_layer,
|
||||
batch_first=batch_first,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_cfg=norm_cfg,
|
||||
sr_ratio=sr_ratio,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
|
||||
class GSAEncoderLayer(BaseModule):
|
||||
"""Implements one encoder layer with GSA.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Default: 0.0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default: 0.0.
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: True
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
sr_ratio (float): Kernel_size of conv in Attention modules. Default: 1.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
qkv_bias=True,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
sr_ratio=1.,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
|
||||
self.attn = GlobalSubsampledAttention(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
qkv_bias=qkv_bias,
|
||||
norm_cfg=norm_cfg,
|
||||
sr_ratio=sr_ratio)
|
||||
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1]
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg,
|
||||
add_identity=False)
|
||||
|
||||
self.drop_path = build_dropout(
|
||||
dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), hw_shape, identity=0.))
|
||||
x = x + self.drop_path(self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class LocallyGroupedSelfAttention(BaseModule):
|
||||
"""Locally-grouped Self Attention (LSA) module.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads. Default: 8
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||||
Default: False.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
||||
Default: 0.0
|
||||
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
|
||||
window_size(int): Window size of LSA. Default: 1.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
window_size=1,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
assert embed_dims % num_heads == 0, f'dim {embed_dims} should be ' \
|
||||
f'divided by num_heads ' \
|
||||
f'{num_heads}.'
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
head_dim = embed_dims // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop_rate)
|
||||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.proj_drop = nn.Dropout(proj_drop_rate)
|
||||
self.window_size = window_size
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
b, n, c = x.shape
|
||||
h, w = hw_shape
|
||||
x = x.view(b, h, w, c)
|
||||
|
||||
# pad feature maps to multiples of Local-groups
|
||||
pad_l = pad_t = 0
|
||||
pad_r = (self.window_size - w % self.window_size) % self.window_size
|
||||
pad_b = (self.window_size - h % self.window_size) % self.window_size
|
||||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||
|
||||
# calculate attention mask for LSA
|
||||
Hp, Wp = x.shape[1:-1]
|
||||
_h, _w = Hp // self.window_size, Wp // self.window_size
|
||||
mask = torch.zeros((1, Hp, Wp), device=x.device)
|
||||
mask[:, -pad_b:, :].fill_(1)
|
||||
mask[:, :, -pad_r:].fill_(1)
|
||||
|
||||
# [B, _h, _w, window_size, window_size, C]
|
||||
x = x.reshape(b, _h, self.window_size, _w, self.window_size,
|
||||
c).transpose(2, 3)
|
||||
mask = mask.reshape(1, _h, self.window_size, _w,
|
||||
self.window_size).transpose(2, 3).reshape(
|
||||
1, _h * _w,
|
||||
self.window_size * self.window_size)
|
||||
# [1, _h*_w, window_size*window_size, window_size*window_size]
|
||||
attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
||||
float(-1000.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0))
|
||||
|
||||
# [3, B, _w*_h, nhead, window_size*window_size, dim]
|
||||
qkv = self.qkv(x).reshape(b, _h * _w,
|
||||
self.window_size * self.window_size, 3,
|
||||
self.num_heads, c // self.num_heads).permute(
|
||||
3, 0, 1, 4, 2, 5)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
# [B, _h*_w, n_head, window_size*window_size, window_size*window_size]
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn + attn_mask.unsqueeze(2)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
attn = (attn @ v).transpose(2, 3).reshape(b, _h, _w, self.window_size,
|
||||
self.window_size, c)
|
||||
x = attn.transpose(2, 3).reshape(b, _h * self.window_size,
|
||||
_w * self.window_size, c)
|
||||
if pad_r > 0 or pad_b > 0:
|
||||
x = x[:, :h, :w, :].contiguous()
|
||||
|
||||
x = x.reshape(b, n, c)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class LSAEncoderLayer(BaseModule):
|
||||
"""Implements one encoder layer in Twins-SVT.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Default: 0.0.
|
||||
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
||||
Default: 0.0
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
window_size (int): Window size of LSA. Default: 1.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
window_size=1,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
|
||||
self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads,
|
||||
qkv_bias, qk_scale,
|
||||
attn_drop_rate, drop_rate,
|
||||
window_size)
|
||||
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1]
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg,
|
||||
add_identity=False)
|
||||
|
||||
self.drop_path = build_dropout(
|
||||
dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), hw_shape))
|
||||
x = x + self.drop_path(self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class ConditionalPositionEncoding(BaseModule):
|
||||
"""The Conditional Position Encoding (CPE) module.
|
||||
|
||||
The CPE is the implementation of 'Conditional Positional Encodings
|
||||
for Vision Transformers <https://arxiv.org/abs/2102.10882>'_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
embed_dims (int): The feature dimension. Default: 768.
|
||||
stride (int): Stride of conv layer. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels,
|
||||
embed_dims,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=True,
|
||||
groups=embed_dims)
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
b, n, c = x.shape
|
||||
h, w = hw_shape
|
||||
feat_token = x
|
||||
cnn_feat = feat_token.transpose(1, 2).view(b, c, h, w)
|
||||
if self.stride == 1:
|
||||
x = self.proj(cnn_feat) + cnn_feat
|
||||
else:
|
||||
x = self.proj(cnn_feat)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PCPVT(BaseModule):
|
||||
"""The backbone of Twins-PCPVT.
|
||||
|
||||
This backbone is the implementation of `Twins: Revisiting the Design
|
||||
of Spatial Attention in Vision Transformers
|
||||
<https://arxiv.org/abs/1512.03385>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512].
|
||||
patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2].
|
||||
strides (list): The strides. Default: [4, 2, 2, 2].
|
||||
num_heads (int): Number of attention heads. Default: [1, 2, 4, 8].
|
||||
mlp_ratios (int): Ratio of mlp hidden dim to embedding dim.
|
||||
Default: [4, 4, 4, 4].
|
||||
out_indices (tuple[int]): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: False.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
depths (list): Depths of each stage. Default [3, 4, 6, 3]
|
||||
sr_ratios (list): Kernel_size of conv in each Attn module in
|
||||
Transformer encoder layer. Default: [8, 4, 2, 1].
|
||||
norm_after_stage(bool): Add extra norm. Default False.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=[64, 128, 256, 512],
|
||||
patch_sizes=[4, 2, 2, 2],
|
||||
strides=[4, 2, 2, 2],
|
||||
num_heads=[1, 2, 4, 8],
|
||||
mlp_ratios=[4, 4, 4, 4],
|
||||
out_indices=(0, 1, 2, 3),
|
||||
qkv_bias=False,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN'),
|
||||
depths=[3, 4, 6, 3],
|
||||
sr_ratios=[8, 4, 2, 1],
|
||||
norm_after_stage=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
self.depths = depths
|
||||
|
||||
# patch_embed
|
||||
self.patch_embeds = ModuleList()
|
||||
self.position_encoding_drops = ModuleList()
|
||||
self.layers = ModuleList()
|
||||
|
||||
for i in range(len(depths)):
|
||||
self.patch_embeds.append(
|
||||
PatchEmbed(
|
||||
in_channels=in_channels if i == 0 else embed_dims[i - 1],
|
||||
embed_dims=embed_dims[i],
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_sizes[i],
|
||||
stride=strides[i],
|
||||
padding='corner',
|
||||
norm_cfg=norm_cfg))
|
||||
|
||||
self.position_encoding_drops.append(nn.Dropout(p=drop_rate))
|
||||
|
||||
self.position_encodings = ModuleList([
|
||||
ConditionalPositionEncoding(embed_dim, embed_dim)
|
||||
for embed_dim in embed_dims
|
||||
])
|
||||
|
||||
# transformer encoder
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
] # stochastic depth decay rule
|
||||
cur = 0
|
||||
|
||||
for k in range(len(depths)):
|
||||
_block = ModuleList([
|
||||
GSAEncoderLayer(
|
||||
embed_dims=embed_dims[k],
|
||||
num_heads=num_heads[k],
|
||||
feedforward_channels=mlp_ratios[k] * embed_dims[k],
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[cur + i],
|
||||
num_fcs=2,
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
sr_ratio=sr_ratios[k]) for i in range(depths[k])
|
||||
])
|
||||
self.layers.append(_block)
|
||||
cur += depths[k]
|
||||
|
||||
self.norm_name, norm = build_norm_layer(
|
||||
norm_cfg, embed_dims[-1], postfix=1)
|
||||
|
||||
self.out_indices = out_indices
|
||||
self.norm_after_stage = norm_after_stage
|
||||
if self.norm_after_stage:
|
||||
self.norm_list = ModuleList()
|
||||
for dim in embed_dims:
|
||||
self.norm_list.append(build_norm_layer(norm_cfg, dim)[1])
|
||||
|
||||
def init_weights(self):
|
||||
if self.init_cfg is not None:
|
||||
super().init_weights()
|
||||
else:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[
|
||||
1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
normal_init(
|
||||
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
||||
|
||||
def forward(self, x):
|
||||
outputs = list()
|
||||
|
||||
b = x.shape[0]
|
||||
|
||||
for i in range(len(self.depths)):
|
||||
x, hw_shape = self.patch_embeds[i](x)
|
||||
h, w = hw_shape
|
||||
x = self.position_encoding_drops[i](x)
|
||||
for j, blk in enumerate(self.layers[i]):
|
||||
x = blk(x, hw_shape)
|
||||
if j == 0:
|
||||
x = self.position_encodings[i](x, hw_shape)
|
||||
if self.norm_after_stage:
|
||||
x = self.norm_list[i](x)
|
||||
x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
if i in self.out_indices:
|
||||
outputs.append(x)
|
||||
|
||||
return tuple(outputs)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SVT(PCPVT):
|
||||
"""The backbone of Twins-SVT.
|
||||
|
||||
This backbone is the implementation of `Twins: Revisiting the Design
|
||||
of Spatial Attention in Vision Transformers
|
||||
<https://arxiv.org/abs/1512.03385>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512].
|
||||
patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2].
|
||||
strides (list): The strides. Default: [4, 2, 2, 2].
|
||||
num_heads (int): Number of attention heads. Default: [1, 2, 4].
|
||||
mlp_ratios (int): Ratio of mlp hidden dim to embedding dim.
|
||||
Default: [4, 4, 4].
|
||||
out_indices (tuple[int]): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: False.
|
||||
drop_rate (float): Dropout rate. Default 0.
|
||||
attn_drop_rate (float): Dropout ratio of attention weight.
|
||||
Default 0.0
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.2.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
depths (list): Depths of each stage. Default [4, 4, 4].
|
||||
sr_ratios (list): Kernel_size of conv in each Attn module in
|
||||
Transformer encoder layer. Default: [4, 2, 1].
|
||||
windiow_sizes (list): Window size of LSA. Default: [7, 7, 7],
|
||||
input_features_slice(bool): Input features need slice. Default: False.
|
||||
norm_after_stage(bool): Add extra norm. Default False.
|
||||
strides (list): Strides in patch-Embedding modules. Default: (2, 2, 2)
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=[64, 128, 256],
|
||||
patch_sizes=[4, 2, 2, 2],
|
||||
strides=[4, 2, 2, 2],
|
||||
num_heads=[1, 2, 4],
|
||||
mlp_ratios=[4, 4, 4],
|
||||
out_indices=(0, 1, 2, 3),
|
||||
qkv_bias=False,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.2,
|
||||
norm_cfg=dict(type='LN'),
|
||||
depths=[4, 4, 4],
|
||||
sr_ratios=[4, 2, 1],
|
||||
windiow_sizes=[7, 7, 7],
|
||||
norm_after_stage=True,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(in_channels, embed_dims, patch_sizes, strides,
|
||||
num_heads, mlp_ratios, out_indices, qkv_bias,
|
||||
drop_rate, attn_drop_rate, drop_path_rate, norm_cfg,
|
||||
depths, sr_ratios, norm_after_stage, pretrained,
|
||||
init_cfg)
|
||||
# transformer encoder
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
] # stochastic depth decay rule
|
||||
|
||||
for k in range(len(depths)):
|
||||
for i in range(depths[k]):
|
||||
if i % 2 == 0:
|
||||
self.layers[k][i] = \
|
||||
LSAEncoderLayer(
|
||||
embed_dims=embed_dims[k],
|
||||
num_heads=num_heads[k],
|
||||
feedforward_channels=mlp_ratios[k] * embed_dims[k],
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=dpr[sum(depths[:k])+i],
|
||||
qkv_bias=qkv_bias,
|
||||
window_size=windiow_sizes[k])
|
||||
436
finetune/mmseg/models/backbones/unet.py
Normal file
436
finetune/mmseg/models/backbones/unet.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import UpConvBlock, Upsample
|
||||
|
||||
|
||||
class BasicConvBlock(nn.Module):
|
||||
"""Basic convolutional block for UNet.
|
||||
|
||||
This module consists of several plain convolutional layers.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
num_convs (int): Number of convolutional layers. Default: 2.
|
||||
stride (int): Whether use stride convolution to downsample
|
||||
the input feature map. If stride=2, it only uses stride convolution
|
||||
in the first convolutional layer to downsample the input feature
|
||||
map. Options are 1 or 2. Default: 1.
|
||||
dilation (int): Whether use dilated convolution to expand the
|
||||
receptive field. Set dilation rate of each convolutional layer and
|
||||
the dilation rate of the first convolutional layer is always 1.
|
||||
Default: 1.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
dcn (bool): Use deformable convolution in convolutional layer or not.
|
||||
Default: None.
|
||||
plugins (dict): plugins for convolutional layers. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_convs=2,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
dcn=None,
|
||||
plugins=None):
|
||||
super().__init__()
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
|
||||
self.with_cp = with_cp
|
||||
convs = []
|
||||
for i in range(num_convs):
|
||||
convs.append(
|
||||
ConvModule(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride if i == 0 else 1,
|
||||
dilation=1 if i == 0 else dilation,
|
||||
padding=1 if i == 0 else dilation,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
self.convs = nn.Sequential(*convs)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(self.convs, x)
|
||||
else:
|
||||
out = self.convs(x)
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DeconvModule(nn.Module):
|
||||
"""Deconvolution upsample module in decoder for UNet (2X upsample).
|
||||
|
||||
This module uses deconvolution to upsample feature map in the decoder
|
||||
of UNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
kernel_size (int): Kernel size of the convolutional layer. Default: 4.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
with_cp=False,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
*,
|
||||
kernel_size=4,
|
||||
scale_factor=2):
|
||||
super().__init__()
|
||||
|
||||
assert (kernel_size - scale_factor >= 0) and\
|
||||
(kernel_size - scale_factor) % 2 == 0,\
|
||||
f'kernel_size should be greater than or equal to scale_factor '\
|
||||
f'and (kernel_size - scale_factor) should be even numbers, '\
|
||||
f'while the kernel size is {kernel_size} and scale_factor is '\
|
||||
f'{scale_factor}.'
|
||||
|
||||
stride = scale_factor
|
||||
padding = (kernel_size - scale_factor) // 2
|
||||
self.with_cp = with_cp
|
||||
deconv = nn.ConvTranspose2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding)
|
||||
|
||||
norm_name, norm = build_norm_layer(norm_cfg, out_channels)
|
||||
activate = build_activation_layer(act_cfg)
|
||||
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(self.deconv_upsamping, x)
|
||||
else:
|
||||
out = self.deconv_upsamping(x)
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class InterpConv(nn.Module):
|
||||
"""Interpolation upsample module in decoder for UNet.
|
||||
|
||||
This module uses interpolation to upsample feature map in the decoder
|
||||
of UNet. It consists of one interpolation upsample layer and one
|
||||
convolutional layer. It can be one interpolation upsample layer followed
|
||||
by one convolutional layer (conv_first=False) or one convolutional layer
|
||||
followed by one interpolation upsample layer (conv_first=True).
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
conv_first (bool): Whether convolutional layer or interpolation
|
||||
upsample layer first. Default: False. It means interpolation
|
||||
upsample layer followed by one convolutional layer.
|
||||
kernel_size (int): Kernel size of the convolutional layer. Default: 1.
|
||||
stride (int): Stride of the convolutional layer. Default: 1.
|
||||
padding (int): Padding of the convolutional layer. Default: 1.
|
||||
upsample_cfg (dict): Interpolation config of the upsample layer.
|
||||
Default: dict(
|
||||
scale_factor=2, mode='bilinear', align_corners=False).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
with_cp=False,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
*,
|
||||
conv_cfg=None,
|
||||
conv_first=False,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
upsample_cfg=dict(
|
||||
scale_factor=2, mode='bilinear', align_corners=False)):
|
||||
super().__init__()
|
||||
|
||||
self.with_cp = with_cp
|
||||
conv = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
upsample = Upsample(**upsample_cfg)
|
||||
if conv_first:
|
||||
self.interp_upsample = nn.Sequential(conv, upsample)
|
||||
else:
|
||||
self.interp_upsample = nn.Sequential(upsample, conv)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(self.interp_upsample, x)
|
||||
else:
|
||||
out = self.interp_upsample(x)
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class UNet(BaseModule):
|
||||
"""UNet backbone.
|
||||
|
||||
This backbone is the implementation of `U-Net: Convolutional Networks
|
||||
for Biomedical Image Segmentation <https://arxiv.org/abs/1505.04597>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Default" 3.
|
||||
base_channels (int): Number of base channels of each stage.
|
||||
The output channels of the first stage. Default: 64.
|
||||
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
|
||||
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
|
||||
len(strides) is equal to num_stages. Normally the stride of the
|
||||
first stage in encoder is 1. If strides[i]=2, it uses stride
|
||||
convolution to downsample in the correspondence encoder stage.
|
||||
Default: (1, 1, 1, 1, 1).
|
||||
enc_num_convs (Sequence[int]): Number of convolutional layers in the
|
||||
convolution block of the correspondence encoder stage.
|
||||
Default: (2, 2, 2, 2, 2).
|
||||
dec_num_convs (Sequence[int]): Number of convolutional layers in the
|
||||
convolution block of the correspondence decoder stage.
|
||||
Default: (2, 2, 2, 2).
|
||||
downsamples (Sequence[int]): Whether use MaxPool to downsample the
|
||||
feature map after the first stage of encoder
|
||||
(stages: [1, num_stages)). If the correspondence encoder stage use
|
||||
stride convolution (strides[i]=2), it will never use MaxPool to
|
||||
downsample, even downsamples[i-1]=True.
|
||||
Default: (True, True, True, True).
|
||||
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
|
||||
Default: (1, 1, 1, 1, 1).
|
||||
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
|
||||
Default: (1, 1, 1, 1).
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
upsample_cfg (dict): The upsample config of the upsample module in
|
||||
decoder. Default: dict(type='InterpConv').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
dcn (bool): Use deformable convolution in convolutional layer or not.
|
||||
Default: None.
|
||||
plugins (dict): plugins for convolutional layers. Default: None.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
|
||||
Notice:
|
||||
The input image size should be divisible by the whole downsample rate
|
||||
of the encoder. More detail of the whole downsample rate can be found
|
||||
in UNet._check_input_divisible.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1),
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
upsample_cfg=dict(type='InterpConv'),
|
||||
norm_eval=False,
|
||||
dcn=None,
|
||||
plugins=None,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
assert len(strides) == num_stages, \
|
||||
'The length of strides should be equal to num_stages, '\
|
||||
f'while the strides is {strides}, the length of '\
|
||||
f'strides is {len(strides)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(enc_num_convs) == num_stages, \
|
||||
'The length of enc_num_convs should be equal to num_stages, '\
|
||||
f'while the enc_num_convs is {enc_num_convs}, the length of '\
|
||||
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(dec_num_convs) == (num_stages-1), \
|
||||
'The length of dec_num_convs should be equal to (num_stages-1), '\
|
||||
f'while the dec_num_convs is {dec_num_convs}, the length of '\
|
||||
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(downsamples) == (num_stages-1), \
|
||||
'The length of downsamples should be equal to (num_stages-1), '\
|
||||
f'while the downsamples is {downsamples}, the length of '\
|
||||
f'downsamples is {len(downsamples)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(enc_dilations) == num_stages, \
|
||||
'The length of enc_dilations should be equal to num_stages, '\
|
||||
f'while the enc_dilations is {enc_dilations}, the length of '\
|
||||
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(dec_dilations) == (num_stages-1), \
|
||||
'The length of dec_dilations should be equal to (num_stages-1), '\
|
||||
f'while the dec_dilations is {dec_dilations}, the length of '\
|
||||
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
self.num_stages = num_stages
|
||||
self.strides = strides
|
||||
self.downsamples = downsamples
|
||||
self.norm_eval = norm_eval
|
||||
self.base_channels = base_channels
|
||||
|
||||
self.encoder = nn.ModuleList()
|
||||
self.decoder = nn.ModuleList()
|
||||
|
||||
for i in range(num_stages):
|
||||
enc_conv_block = []
|
||||
if i != 0:
|
||||
if strides[i] == 1 and downsamples[i - 1]:
|
||||
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
|
||||
upsample = (strides[i] != 1 or downsamples[i - 1])
|
||||
self.decoder.append(
|
||||
UpConvBlock(
|
||||
conv_block=BasicConvBlock,
|
||||
in_channels=base_channels * 2**i,
|
||||
skip_channels=base_channels * 2**(i - 1),
|
||||
out_channels=base_channels * 2**(i - 1),
|
||||
num_convs=dec_num_convs[i - 1],
|
||||
stride=1,
|
||||
dilation=dec_dilations[i - 1],
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
upsample_cfg=upsample_cfg if upsample else None,
|
||||
dcn=None,
|
||||
plugins=None))
|
||||
|
||||
enc_conv_block.append(
|
||||
BasicConvBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=base_channels * 2**i,
|
||||
num_convs=enc_num_convs[i],
|
||||
stride=strides[i],
|
||||
dilation=enc_dilations[i],
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
dcn=None,
|
||||
plugins=None))
|
||||
self.encoder.append(nn.Sequential(*enc_conv_block))
|
||||
in_channels = base_channels * 2**i
|
||||
|
||||
def forward(self, x):
|
||||
self._check_input_divisible(x)
|
||||
enc_outs = []
|
||||
for enc in self.encoder:
|
||||
x = enc(x)
|
||||
enc_outs.append(x)
|
||||
dec_outs = [x]
|
||||
for i in reversed(range(len(self.decoder))):
|
||||
x = self.decoder[i](enc_outs[i], x)
|
||||
dec_outs.append(x)
|
||||
|
||||
return dec_outs
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep normalization layer
|
||||
freezed."""
|
||||
super().train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
def _check_input_divisible(self, x):
|
||||
h, w = x.shape[-2:]
|
||||
whole_downsample_rate = 1
|
||||
for i in range(1, self.num_stages):
|
||||
if self.strides[i] == 2 or self.downsamples[i - 1]:
|
||||
whole_downsample_rate *= 2
|
||||
assert (h % whole_downsample_rate == 0) \
|
||||
and (w % whole_downsample_rate == 0),\
|
||||
f'The input image size {(h, w)} should be divisible by the whole '\
|
||||
f'downsample rate {whole_downsample_rate}, when num_stages is '\
|
||||
f'{self.num_stages}, strides is {self.strides}, and downsamples '\
|
||||
f'is {self.downsamples}.'
|
||||
501
finetune/mmseg/models/backbones/vit.py
Normal file
501
finetune/mmseg/models/backbones/vit.py
Normal file
@@ -0,0 +1,501 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import PatchEmbed, resize
|
||||
|
||||
|
||||
class TransformerEncoderLayer(BaseModule):
|
||||
"""Implements one encoder layer in Vision Transformer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Default: 0.0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default: 0.0.
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default: True
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default: True.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
qkv_bias=True,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
batch_first=True,
|
||||
attn_cfg=dict(),
|
||||
ffn_cfg=dict(),
|
||||
with_cp=False):
|
||||
super().__init__()
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
attn_cfg.update(
|
||||
dict(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
batch_first=batch_first,
|
||||
bias=qkv_bias))
|
||||
|
||||
self.build_attn(attn_cfg)
|
||||
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=2)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
|
||||
ffn_cfg.update(
|
||||
dict(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
if drop_path_rate > 0 else None,
|
||||
act_cfg=act_cfg))
|
||||
self.build_ffn(ffn_cfg)
|
||||
self.with_cp = with_cp
|
||||
|
||||
def build_attn(self, attn_cfg):
|
||||
self.attn = MultiheadAttention(**attn_cfg)
|
||||
|
||||
def build_ffn(self, ffn_cfg):
|
||||
self.ffn = FFN(**ffn_cfg)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
x = self.attn(self.norm1(x), identity=x)
|
||||
x = self.ffn(self.norm2(x), identity=x)
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VisionTransformer(BaseModule):
|
||||
"""Vision Transformer.
|
||||
|
||||
This backbone is the implementation of `An Image is Worth 16x16 Words:
|
||||
Transformers for Image Recognition at
|
||||
Scale <https://arxiv.org/abs/2010.11929>`_.
|
||||
|
||||
Args:
|
||||
img_size (int | tuple): Input image size. Default: 224.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
patch_pad (str | int | None): The padding method in patch embedding.
|
||||
Default: 'corner'.
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): embedding dimension. Default: 768.
|
||||
num_layers (int): depth of transformer. Default: 12.
|
||||
num_heads (int): number of attention heads. Default: 12.
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
out_origin (bool): Whether to output the original input embedding.
|
||||
Default: False
|
||||
out_indices (list | tuple | int): Output from which stages.
|
||||
Default: -1.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default: True.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.0
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Default: True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
`with_cls_token` must be True. Default: False.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
patch_bias (dict): Whether use bias in convolution of PatchEmbed Block.
|
||||
Default: True.
|
||||
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
|
||||
Default: False.
|
||||
pre_norm (bool): Whether to add a norm before Transformer Layers.
|
||||
Default: False.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Default: False.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Default: bicubic.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
frozen_exclude (List): List of parameters that are not to be frozen.
|
||||
Default: ["all"], "all" means there are no frozen parameters.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
patch_pad='corner',
|
||||
in_channels=3,
|
||||
embed_dims=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
out_origin=False,
|
||||
out_indices=-1,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
with_cls_token=True,
|
||||
output_cls_token=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
patch_norm=False,
|
||||
patch_bias=False,
|
||||
pre_norm=False,
|
||||
final_norm=False,
|
||||
interpolate_mode='bicubic',
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
frozen_exclude=['all'],
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(img_size, int):
|
||||
img_size = to_2tuple(img_size)
|
||||
elif isinstance(img_size, tuple):
|
||||
if len(img_size) == 1:
|
||||
img_size = to_2tuple(img_size[0])
|
||||
assert len(img_size) == 2, \
|
||||
f'The size of image should have length 1 or 2, ' \
|
||||
f'but got {len(img_size)}'
|
||||
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.interpolate_mode = interpolate_mode
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.pretrained = pretrained
|
||||
self.out_origin = out_origin
|
||||
self.frozen_exclude = frozen_exclude
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
padding=patch_pad,
|
||||
bias=patch_bias,
|
||||
norm_cfg=norm_cfg if patch_norm else None,
|
||||
init_cfg=None,
|
||||
)
|
||||
|
||||
num_patches = (img_size[0] // patch_size) * \
|
||||
(img_size[1] // patch_size)
|
||||
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + 1, embed_dims))
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
self.pre_norm = pre_norm
|
||||
|
||||
if self.pre_norm:
|
||||
self.pre_ln_name, pre_ln = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix='_pre')
|
||||
self.add_module(self.pre_ln_name, pre_ln)
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
if out_indices == -1:
|
||||
out_indices = num_layers - 1
|
||||
self.out_indices = [out_indices]
|
||||
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
|
||||
self.out_indices = out_indices
|
||||
else:
|
||||
raise TypeError('out_indices must be type of int, list or tuple')
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
self.layers = ModuleList()
|
||||
for i in range(num_layers):
|
||||
self.layers.append(
|
||||
TransformerEncoderLayer(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=mlp_ratio * embed_dims,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
batch_first=True))
|
||||
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
self._freeze()
|
||||
|
||||
@property
|
||||
def pre_ln(self):
|
||||
return getattr(self, self.pre_ln_name)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def init_weights(self):
|
||||
if isinstance(self.init_cfg, dict) and \
|
||||
self.init_cfg.get('type') in ['Pretrained', 'Pretrained_Part']:
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
|
||||
if self.init_cfg.get('type') == 'Pretrained':
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
elif self.init_cfg.get('type') == 'Pretrained_Part':
|
||||
state_dict = checkpoint.copy()
|
||||
para_prefix = 'image_encoder'
|
||||
prefix_len = len(para_prefix) + 1
|
||||
for k, v in checkpoint.items():
|
||||
state_dict.pop(k)
|
||||
if para_prefix in k:
|
||||
state_dict[k[prefix_len:]] = v
|
||||
|
||||
if 'pos_embed' in state_dict.keys():
|
||||
if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
||||
print_log(msg=f'Resize the pos_embed shape from '
|
||||
f'{state_dict["pos_embed"].shape} to '
|
||||
f'{self.pos_embed.shape}')
|
||||
h, w = self.img_size
|
||||
pos_size = int(
|
||||
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
|
||||
state_dict['pos_embed'] = self.resize_pos_embed(
|
||||
state_dict['pos_embed'],
|
||||
(h // self.patch_size, w // self.patch_size),
|
||||
(pos_size, pos_size), self.interpolate_mode)
|
||||
|
||||
load_state_dict(self, state_dict, strict=False, logger=None)
|
||||
elif self.init_cfg is not None:
|
||||
super().init_weights()
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def _freeze(self):
|
||||
if 'all' in self.frozen_exclude:
|
||||
return
|
||||
for name, param in self.named_parameters():
|
||||
if not any([exclude in name for exclude in self.frozen_exclude]):
|
||||
param.requires_grad = False
|
||||
|
||||
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
|
||||
"""Positioning embeding method.
|
||||
|
||||
Resize the pos_embed, if the input image size doesn't match
|
||||
the training size.
|
||||
Args:
|
||||
patched_img (torch.Tensor): The patched image, it should be
|
||||
shape of [B, L1, C].
|
||||
hw_shape (tuple): The downsampled image resolution.
|
||||
pos_embed (torch.Tensor): The pos_embed weighs, it should be
|
||||
shape of [B, L2, c].
|
||||
Return:
|
||||
torch.Tensor: The pos encoded image feature.
|
||||
"""
|
||||
assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
|
||||
'the shapes of patched_img and pos_embed must be [B, L, C]'
|
||||
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
|
||||
if x_len != pos_len:
|
||||
if pos_len == (self.img_size[0] // self.patch_size) * (
|
||||
self.img_size[1] // self.patch_size) + 1:
|
||||
pos_h = self.img_size[0] // self.patch_size
|
||||
pos_w = self.img_size[1] // self.patch_size
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unexpected shape of pos_embed, got {}.'.format(
|
||||
pos_embed.shape))
|
||||
pos_embed = self.resize_pos_embed(pos_embed, hw_shape,
|
||||
(pos_h, pos_w),
|
||||
self.interpolate_mode)
|
||||
return self.drop_after_pos(patched_img + pos_embed)
|
||||
|
||||
@staticmethod
|
||||
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
|
||||
"""Resize pos_embed weights.
|
||||
|
||||
Resize pos_embed using bicubic interpolate method.
|
||||
Args:
|
||||
pos_embed (torch.Tensor): Position embedding weights.
|
||||
input_shpae (tuple): Tuple for (downsampled input image height,
|
||||
downsampled input image width).
|
||||
pos_shape (tuple): The resolution of downsampled origin training
|
||||
image.
|
||||
mode (str): Algorithm used for upsampling:
|
||||
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||
``'trilinear'``. Default: ``'nearest'``
|
||||
Return:
|
||||
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
|
||||
"""
|
||||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
||||
pos_h, pos_w = pos_shape
|
||||
cls_token_weight = pos_embed[:, 0]
|
||||
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
||||
pos_embed_weight = pos_embed_weight.reshape(
|
||||
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
||||
pos_embed_weight = resize(
|
||||
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
|
||||
cls_token_weight = cls_token_weight.unsqueeze(1)
|
||||
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
||||
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
||||
return pos_embed
|
||||
|
||||
def forward(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
|
||||
x, hw_shape = self.patch_embed(inputs)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = self._pos_embeding(x, hw_shape, self.pos_embed)
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
if self.pre_norm:
|
||||
x = self.pre_ln(x)
|
||||
|
||||
outs = []
|
||||
if self.out_origin:
|
||||
if self.with_cls_token:
|
||||
# Remove class token and reshape token for decoder head
|
||||
out = x[:, 1:]
|
||||
else:
|
||||
out = x
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
if self.output_cls_token:
|
||||
out = [out, x[:, 0]]
|
||||
outs.append(out)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i == len(self.layers) - 1:
|
||||
if self.final_norm:
|
||||
x = self.norm1(x)
|
||||
if i in self.out_indices:
|
||||
if self.with_cls_token:
|
||||
# Remove class token and reshape token for decoder head
|
||||
out = x[:, 1:]
|
||||
else:
|
||||
out = x
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
if self.output_cls_token:
|
||||
out = [out, x[:, 0]]
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.LayerNorm):
|
||||
m.eval()
|
||||
395
finetune/mmseg/models/backbones/vpd.py
Normal file
395
finetune/mmseg/models/backbones/vpd.py
Normal file
@@ -0,0 +1,395 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# ------------------------------------------------------------------------------
|
||||
# Adapted from https://github.com/wl-zhao/VPD/blob/main/vpd/models.py
|
||||
# Original licence: MIT License
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.runner import CheckpointLoader, load_checkpoint
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import ConfigType, OptConfigType
|
||||
|
||||
try:
|
||||
from ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from ldm.util import instantiate_from_config
|
||||
has_ldm = True
|
||||
except ImportError:
|
||||
has_ldm = False
|
||||
|
||||
|
||||
def register_attention_control(model, controller):
|
||||
"""Registers a control function to manage attention within a model.
|
||||
|
||||
Args:
|
||||
model: The model to which attention is to be registered.
|
||||
controller: The control function responsible for managing attention.
|
||||
"""
|
||||
|
||||
def ca_forward(self, place_in_unet):
|
||||
"""Custom forward method for attention.
|
||||
|
||||
Args:
|
||||
self: Reference to the current object.
|
||||
place_in_unet: The location in UNet (down/mid/up).
|
||||
|
||||
Returns:
|
||||
The modified forward method.
|
||||
"""
|
||||
|
||||
def forward(x, context=None, mask=None):
|
||||
h = self.heads
|
||||
is_cross = context is not None
|
||||
context = context or x # if context is None, use x
|
||||
|
||||
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
|
||||
q, k, v = (
|
||||
tensor.view(tensor.shape[0] * h, tensor.shape[1],
|
||||
tensor.shape[2] // h) for tensor in [q, k, v])
|
||||
|
||||
sim = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.flatten(1).unsqueeze(1).repeat(h, 1, 1)
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
attn = sim.softmax(dim=-1)
|
||||
attn_mean = attn.view(h, attn.shape[0] // h,
|
||||
*attn.shape[1:]).mean(0)
|
||||
controller(attn_mean, is_cross, place_in_unet)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = out.view(out.shape[0] // h, out.shape[1], out.shape[2] * h)
|
||||
return self.to_out(out)
|
||||
|
||||
return forward
|
||||
|
||||
def register_recr(net_, count, place_in_unet):
|
||||
"""Recursive function to register the custom forward method to all
|
||||
CrossAttention layers.
|
||||
|
||||
Args:
|
||||
net_: The network layer currently being processed.
|
||||
count: The current count of layers processed.
|
||||
place_in_unet: The location in UNet (down/mid/up).
|
||||
|
||||
Returns:
|
||||
The updated count of layers processed.
|
||||
"""
|
||||
if net_.__class__.__name__ == 'CrossAttention':
|
||||
net_.forward = ca_forward(net_, place_in_unet)
|
||||
return count + 1
|
||||
if hasattr(net_, 'children'):
|
||||
return sum(
|
||||
register_recr(child, 0, place_in_unet)
|
||||
for child in net_.children())
|
||||
return count
|
||||
|
||||
cross_att_count = sum(
|
||||
register_recr(net[1], 0, place) for net, place in [
|
||||
(child, 'down') if 'input_blocks' in name else (
|
||||
child, 'up') if 'output_blocks' in name else
|
||||
(child,
|
||||
'mid') if 'middle_block' in name else (None, None) # Default case
|
||||
for name, child in model.diffusion_model.named_children()
|
||||
] if net is not None)
|
||||
|
||||
controller.num_att_layers = cross_att_count
|
||||
|
||||
|
||||
class AttentionStore:
|
||||
"""A class for storing attention information in the UNet model.
|
||||
|
||||
Attributes:
|
||||
base_size (int): Base size for storing attention information.
|
||||
max_size (int): Maximum size for storing attention information.
|
||||
"""
|
||||
|
||||
def __init__(self, base_size=64, max_size=None):
|
||||
"""Initialize AttentionStore with default or custom sizes."""
|
||||
self.reset()
|
||||
self.base_size = base_size
|
||||
self.max_size = max_size or (base_size // 2)
|
||||
self.num_att_layers = -1
|
||||
|
||||
@staticmethod
|
||||
def get_empty_store():
|
||||
"""Returns an empty store for holding attention values."""
|
||||
return {
|
||||
key: []
|
||||
for key in [
|
||||
'down_cross', 'mid_cross', 'up_cross', 'down_self', 'mid_self',
|
||||
'up_self'
|
||||
]
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Resets the step and attention stores to their initial states."""
|
||||
self.cur_step = 0
|
||||
self.cur_att_layer = 0
|
||||
self.step_store = self.get_empty_store()
|
||||
self.attention_store = {}
|
||||
|
||||
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
||||
"""Processes a single forward step, storing the attention.
|
||||
|
||||
Args:
|
||||
attn: The attention tensor.
|
||||
is_cross (bool): Whether it's cross attention.
|
||||
place_in_unet (str): The location in UNet (down/mid/up).
|
||||
|
||||
Returns:
|
||||
The unmodified attention tensor.
|
||||
"""
|
||||
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
||||
if attn.shape[1] <= (self.max_size)**2:
|
||||
self.step_store[key].append(attn)
|
||||
return attn
|
||||
|
||||
def between_steps(self):
|
||||
"""Processes and stores attention information between steps."""
|
||||
if not self.attention_store:
|
||||
self.attention_store = self.step_store
|
||||
else:
|
||||
for key in self.attention_store:
|
||||
self.attention_store[key] = [
|
||||
stored + step for stored, step in zip(
|
||||
self.attention_store[key], self.step_store[key])
|
||||
]
|
||||
self.step_store = self.get_empty_store()
|
||||
|
||||
def get_average_attention(self):
|
||||
"""Calculates and returns the average attention across all steps."""
|
||||
return {
|
||||
key: [item for item in self.step_store[key]]
|
||||
for key in self.step_store
|
||||
}
|
||||
|
||||
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
||||
"""Allows the class instance to be callable."""
|
||||
return self.forward(attn, is_cross, place_in_unet)
|
||||
|
||||
@property
|
||||
def num_uncond_att_layers(self):
|
||||
"""Returns the number of unconditional attention layers (default is
|
||||
0)."""
|
||||
return 0
|
||||
|
||||
def step_callback(self, x_t):
|
||||
"""A placeholder for a step callback.
|
||||
|
||||
Returns the input unchanged.
|
||||
"""
|
||||
return x_t
|
||||
|
||||
|
||||
class UNetWrapper(nn.Module):
|
||||
"""A wrapper for UNet with optional attention mechanisms.
|
||||
|
||||
Args:
|
||||
unet (nn.Module): The UNet model to wrap
|
||||
use_attn (bool): Whether to use attention. Defaults to True
|
||||
base_size (int): Base size for the attention store. Defaults to 512
|
||||
max_attn_size (int, optional): Maximum size for the attention store.
|
||||
Defaults to None
|
||||
attn_selector (str): The types of attention to use.
|
||||
Defaults to 'up_cross+down_cross'
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
unet,
|
||||
use_attn=True,
|
||||
base_size=512,
|
||||
max_attn_size=None,
|
||||
attn_selector='up_cross+down_cross'):
|
||||
super().__init__()
|
||||
|
||||
assert has_ldm, 'To use UNetWrapper, please install required ' \
|
||||
'packages via `pip install -r requirements/optional.txt`.'
|
||||
|
||||
self.unet = unet
|
||||
self.attention_store = AttentionStore(
|
||||
base_size=base_size // 8, max_size=max_attn_size)
|
||||
self.attn_selector = attn_selector.split('+')
|
||||
self.use_attn = use_attn
|
||||
self.init_sizes(base_size)
|
||||
if self.use_attn:
|
||||
register_attention_control(unet, self.attention_store)
|
||||
|
||||
def init_sizes(self, base_size):
|
||||
"""Initialize sizes based on the base size."""
|
||||
self.size16 = base_size // 32
|
||||
self.size32 = base_size // 16
|
||||
self.size64 = base_size // 8
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
"""Forward pass through the model."""
|
||||
diffusion_model = self.unet.diffusion_model
|
||||
if self.use_attn:
|
||||
self.attention_store.reset()
|
||||
hs, emb, out_list = self._unet_forward(x, timesteps, context, y,
|
||||
diffusion_model)
|
||||
if self.use_attn:
|
||||
self._append_attn_to_output(out_list)
|
||||
return out_list[::-1]
|
||||
|
||||
def _unet_forward(self, x, timesteps, context, y, diffusion_model):
|
||||
hs = []
|
||||
t_emb = timestep_embedding(
|
||||
timesteps, diffusion_model.model_channels, repeat_only=False)
|
||||
emb = diffusion_model.time_embed(t_emb)
|
||||
h = x.type(diffusion_model.dtype)
|
||||
for module in diffusion_model.input_blocks:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = diffusion_model.middle_block(h, emb, context)
|
||||
out_list = []
|
||||
for i_out, module in enumerate(diffusion_model.output_blocks):
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context)
|
||||
if i_out in [1, 4, 7]:
|
||||
out_list.append(h)
|
||||
h = h.type(x.dtype)
|
||||
out_list.append(h)
|
||||
return hs, emb, out_list
|
||||
|
||||
def _append_attn_to_output(self, out_list):
|
||||
avg_attn = self.attention_store.get_average_attention()
|
||||
attns = {self.size16: [], self.size32: [], self.size64: []}
|
||||
for k in self.attn_selector:
|
||||
for up_attn in avg_attn[k]:
|
||||
size = int(math.sqrt(up_attn.shape[1]))
|
||||
up_attn = up_attn.transpose(-1, -2).reshape(
|
||||
*up_attn.shape[:2], size, -1)
|
||||
attns[size].append(up_attn)
|
||||
attn16 = torch.stack(attns[self.size16]).mean(0)
|
||||
attn32 = torch.stack(attns[self.size32]).mean(0)
|
||||
attn64 = torch.stack(attns[self.size64]).mean(0) if len(
|
||||
attns[self.size64]) > 0 else None
|
||||
out_list[1] = torch.cat([out_list[1], attn16], dim=1)
|
||||
out_list[2] = torch.cat([out_list[2], attn32], dim=1)
|
||||
if attn64 is not None:
|
||||
out_list[3] = torch.cat([out_list[3], attn64], dim=1)
|
||||
|
||||
|
||||
class TextAdapter(nn.Module):
|
||||
"""A PyTorch Module that serves as a text adapter.
|
||||
|
||||
This module takes text embeddings and adjusts them based on a scaling
|
||||
factor gamma.
|
||||
"""
|
||||
|
||||
def __init__(self, text_dim=768):
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(text_dim, text_dim), nn.GELU(),
|
||||
nn.Linear(text_dim, text_dim))
|
||||
|
||||
def forward(self, texts, gamma):
|
||||
texts_after = self.fc(texts)
|
||||
texts = texts + gamma * texts_after
|
||||
return texts
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VPD(BaseModule):
|
||||
"""VPD (Visual Perception Diffusion) model.
|
||||
|
||||
.. _`VPD`: https://arxiv.org/abs/2303.02153
|
||||
|
||||
Args:
|
||||
diffusion_cfg (dict): Configuration for diffusion model.
|
||||
class_embed_path (str): Path for class embeddings.
|
||||
unet_cfg (dict, optional): Configuration for U-Net.
|
||||
gamma (float, optional): Gamma for text adaptation. Defaults to 1e-4.
|
||||
class_embed_select (bool, optional): If True, enables class embedding
|
||||
selection. Defaults to False.
|
||||
pad_shape (Optional[Union[int, List[int]]], optional): Padding shape.
|
||||
Defaults to None.
|
||||
pad_val (Union[int, List[int]], optional): Padding value.
|
||||
Defaults to 0.
|
||||
init_cfg (dict, optional): Configuration for network initialization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
diffusion_cfg: ConfigType,
|
||||
class_embed_path: str,
|
||||
unet_cfg: OptConfigType = dict(),
|
||||
gamma: float = 1e-4,
|
||||
class_embed_select=False,
|
||||
pad_shape: Optional[Union[int, List[int]]] = None,
|
||||
pad_val: Union[int, List[int]] = 0,
|
||||
init_cfg: OptConfigType = None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
assert has_ldm, 'To use VPD model, please install required packages' \
|
||||
' via `pip install -r requirements/optional.txt`.'
|
||||
|
||||
if pad_shape is not None:
|
||||
if not isinstance(pad_shape, (list, tuple)):
|
||||
pad_shape = (pad_shape, pad_shape)
|
||||
|
||||
self.pad_shape = pad_shape
|
||||
self.pad_val = pad_val
|
||||
|
||||
# diffusion model
|
||||
diffusion_checkpoint = diffusion_cfg.pop('checkpoint', None)
|
||||
sd_model = instantiate_from_config(diffusion_cfg)
|
||||
if diffusion_checkpoint is not None:
|
||||
load_checkpoint(sd_model, diffusion_checkpoint, strict=False)
|
||||
|
||||
self.encoder_vq = sd_model.first_stage_model
|
||||
self.unet = UNetWrapper(sd_model.model, **unet_cfg)
|
||||
|
||||
# class embeddings & text adapter
|
||||
class_embeddings = CheckpointLoader.load_checkpoint(class_embed_path)
|
||||
text_dim = class_embeddings.size(-1)
|
||||
self.text_adapter = TextAdapter(text_dim=text_dim)
|
||||
self.class_embed_select = class_embed_select
|
||||
if class_embed_select:
|
||||
class_embeddings = torch.cat(
|
||||
(class_embeddings, class_embeddings.mean(dim=0,
|
||||
keepdims=True)),
|
||||
dim=0)
|
||||
self.register_buffer('class_embeddings', class_embeddings)
|
||||
self.gamma = nn.Parameter(torch.ones(text_dim) * gamma)
|
||||
|
||||
def forward(self, x):
|
||||
"""Extract features from images."""
|
||||
|
||||
# calculate cross-attn map
|
||||
if self.class_embed_select:
|
||||
if isinstance(x, (tuple, list)):
|
||||
x, class_ids = x[:2]
|
||||
class_ids = class_ids.tolist()
|
||||
else:
|
||||
class_ids = [-1] * x.size(0)
|
||||
class_embeddings = self.class_embeddings[class_ids]
|
||||
c_crossattn = self.text_adapter(class_embeddings, self.gamma)
|
||||
c_crossattn = c_crossattn.unsqueeze(1)
|
||||
else:
|
||||
class_embeddings = self.class_embeddings
|
||||
c_crossattn = self.text_adapter(class_embeddings, self.gamma)
|
||||
c_crossattn = c_crossattn.unsqueeze(0).repeat(x.size(0), 1, 1)
|
||||
|
||||
# pad to required input shape for pretrained diffusion model
|
||||
if self.pad_shape is not None:
|
||||
pad_width = max(0, self.pad_shape[1] - x.shape[-1])
|
||||
pad_height = max(0, self.pad_shape[0] - x.shape[-2])
|
||||
x = F.pad(x, (0, pad_width, 0, pad_height), value=self.pad_val)
|
||||
|
||||
# forward the denoising model
|
||||
with torch.no_grad():
|
||||
latents = self.encoder_vq.encode(x).mode().detach()
|
||||
t = torch.ones((x.shape[0], ), device=x.device).long()
|
||||
outs = self.unet(latents, t, context=c_crossattn)
|
||||
|
||||
return outs
|
||||
52
finetune/mmseg/models/builder.py
Normal file
52
finetune/mmseg/models/builder.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
BACKBONES = MODELS
|
||||
NECKS = MODELS
|
||||
HEADS = MODELS
|
||||
LOSSES = MODELS
|
||||
SEGMENTORS = MODELS
|
||||
|
||||
|
||||
def build_backbone(cfg):
|
||||
"""Build backbone."""
|
||||
warnings.warn('``build_backbone`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return BACKBONES.build(cfg)
|
||||
|
||||
|
||||
def build_neck(cfg):
|
||||
"""Build neck."""
|
||||
warnings.warn('``build_neck`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return NECKS.build(cfg)
|
||||
|
||||
|
||||
def build_head(cfg):
|
||||
"""Build head."""
|
||||
warnings.warn('``build_head`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return HEADS.build(cfg)
|
||||
|
||||
|
||||
def build_loss(cfg):
|
||||
"""Build loss."""
|
||||
warnings.warn('``build_loss`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return LOSSES.build(cfg)
|
||||
|
||||
|
||||
def build_segmentor(cfg, train_cfg=None, test_cfg=None):
|
||||
"""Build segmentor."""
|
||||
if train_cfg is not None or test_cfg is not None:
|
||||
warnings.warn(
|
||||
'train_cfg and test_cfg is deprecated, '
|
||||
'please specify them in model', UserWarning)
|
||||
assert cfg.get('train_cfg') is None or train_cfg is None, \
|
||||
'train_cfg specified in both outer field and model field '
|
||||
assert cfg.get('test_cfg') is None or test_cfg is None, \
|
||||
'test_cfg specified in both outer field and model field '
|
||||
return SEGMENTORS.build(
|
||||
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
|
||||
151
finetune/mmseg/models/data_preprocessor.py
Normal file
151
finetune/mmseg/models/data_preprocessor.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseDataPreprocessor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import stack_batch
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SegDataPreProcessor(BaseDataPreprocessor):
|
||||
"""Image pre-processor for segmentation tasks.
|
||||
|
||||
Comparing with the :class:`mmengine.ImgDataPreprocessor`,
|
||||
|
||||
1. It won't do normalization if ``mean`` is not specified.
|
||||
2. It does normalization and color space conversion after stacking batch.
|
||||
3. It supports batch augmentations like mixup and cutmix.
|
||||
|
||||
|
||||
It provides the data pre-processing as follows
|
||||
|
||||
- Collate and move data to the target device.
|
||||
- Pad inputs to the input size with defined ``pad_val``, and pad seg map
|
||||
with defined ``seg_pad_val``.
|
||||
- Stack inputs to batch_inputs.
|
||||
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
|
||||
- Normalize image with defined std and mean.
|
||||
- Do batch augmentations like Mixup and Cutmix during training.
|
||||
|
||||
Args:
|
||||
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
|
||||
Defaults to None.
|
||||
std (Sequence[Number], optional): The pixel standard deviation of
|
||||
R, G, B channels. Defaults to None.
|
||||
size (tuple, optional): Fixed padding size.
|
||||
size_divisor (int, optional): The divisor of padded size.
|
||||
pad_val (float, optional): Padding value. Default: 0.
|
||||
seg_pad_val (float, optional): Padding value of segmentation map.
|
||||
Default: 255.
|
||||
padding_mode (str): Type of padding. Default: constant.
|
||||
- constant: pads with a constant value, this value is specified
|
||||
with pad_val.
|
||||
bgr_to_rgb (bool): whether to convert image from BGR to RGB.
|
||||
Defaults to False.
|
||||
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
|
||||
Defaults to False.
|
||||
batch_augments (list[dict], optional): Batch-level augmentations
|
||||
test_cfg (dict, optional): The padding size config in testing, if not
|
||||
specify, will use `size` and `size_divisor` params as default.
|
||||
Defaults to None, only supports keys `size` or `size_divisor`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mean: Sequence[Number] = None,
|
||||
std: Sequence[Number] = None,
|
||||
size: Optional[tuple] = None,
|
||||
size_divisor: Optional[int] = None,
|
||||
pad_val: Number = 0,
|
||||
seg_pad_val: Number = 255,
|
||||
bgr_to_rgb: bool = False,
|
||||
rgb_to_bgr: bool = False,
|
||||
batch_augments: Optional[List[dict]] = None,
|
||||
test_cfg: dict = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.size_divisor = size_divisor
|
||||
self.pad_val = pad_val
|
||||
self.seg_pad_val = seg_pad_val
|
||||
|
||||
assert not (bgr_to_rgb and rgb_to_bgr), (
|
||||
'`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time')
|
||||
self.channel_conversion = rgb_to_bgr or bgr_to_rgb
|
||||
|
||||
if mean is not None:
|
||||
assert std is not None, 'To enable the normalization in ' \
|
||||
'preprocessing, please specify both ' \
|
||||
'`mean` and `std`.'
|
||||
# Enable the normalization in preprocessing.
|
||||
self._enable_normalize = True
|
||||
self.register_buffer('mean',
|
||||
torch.tensor(mean).view(-1, 1, 1), False)
|
||||
self.register_buffer('std',
|
||||
torch.tensor(std).view(-1, 1, 1), False)
|
||||
else:
|
||||
self._enable_normalize = False
|
||||
|
||||
# TODO: support batch augmentations.
|
||||
self.batch_augments = batch_augments
|
||||
|
||||
# Support different padding methods in testing
|
||||
self.test_cfg = test_cfg
|
||||
|
||||
def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
|
||||
"""Perform normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (dict): data sampled from dataloader.
|
||||
training (bool): Whether to enable training time augmentation.
|
||||
|
||||
Returns:
|
||||
Dict: Data in the same format as the model input.
|
||||
"""
|
||||
data = self.cast_data(data) # type: ignore
|
||||
inputs = data['inputs']
|
||||
data_samples = data.get('data_samples', None)
|
||||
# TODO: whether normalize should be after stack_batch
|
||||
if self.channel_conversion and inputs[0].size(0) == 3:
|
||||
inputs = [_input[[2, 1, 0], ...] for _input in inputs]
|
||||
|
||||
inputs = [_input.float() for _input in inputs]
|
||||
if self._enable_normalize:
|
||||
inputs = [(_input - self.mean) / self.std for _input in inputs]
|
||||
|
||||
if training:
|
||||
assert data_samples is not None, ('During training, ',
|
||||
'`data_samples` must be define.')
|
||||
inputs, data_samples = stack_batch(
|
||||
inputs=inputs,
|
||||
data_samples=data_samples,
|
||||
size=self.size,
|
||||
size_divisor=self.size_divisor,
|
||||
pad_val=self.pad_val,
|
||||
seg_pad_val=self.seg_pad_val)
|
||||
|
||||
if self.batch_augments is not None:
|
||||
inputs, data_samples = self.batch_augments(
|
||||
inputs, data_samples)
|
||||
else:
|
||||
img_size = inputs[0].shape[1:]
|
||||
assert all(input_.shape[1:] == img_size for input_ in inputs), \
|
||||
'The image size in a batch should be the same.'
|
||||
# pad images when testing
|
||||
if self.test_cfg:
|
||||
inputs, padded_samples = stack_batch(
|
||||
inputs=inputs,
|
||||
size=self.test_cfg.get('size', None),
|
||||
size_divisor=self.test_cfg.get('size_divisor', None),
|
||||
pad_val=self.pad_val,
|
||||
seg_pad_val=self.seg_pad_val)
|
||||
for data_sample, pad_info in zip(data_samples, padded_samples):
|
||||
data_sample.set_metainfo({**pad_info})
|
||||
else:
|
||||
inputs = torch.stack(inputs, dim=0)
|
||||
|
||||
return dict(inputs=inputs, data_samples=data_samples)
|
||||
48
finetune/mmseg/models/decode_heads/__init__.py
Normal file
48
finetune/mmseg/models/decode_heads/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .ann_head import ANNHead
|
||||
from .apc_head import APCHead
|
||||
from .aspp_head import ASPPHead
|
||||
from .cc_head import CCHead
|
||||
from .da_head import DAHead
|
||||
from .ddr_head import DDRHead
|
||||
from .dm_head import DMHead
|
||||
from .dnl_head import DNLHead
|
||||
from .dpt_head import DPTHead
|
||||
from .ema_head import EMAHead
|
||||
from .enc_head import EncHead
|
||||
from .fcn_head import FCNHead
|
||||
from .fpn_head import FPNHead
|
||||
from .gc_head import GCHead
|
||||
from .ham_head import LightHamHead
|
||||
from .isa_head import ISAHead
|
||||
from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator
|
||||
from .lraspp_head import LRASPPHead
|
||||
from .mask2former_head import Mask2FormerHead
|
||||
from .maskformer_head import MaskFormerHead
|
||||
from .nl_head import NLHead
|
||||
from .ocr_head import OCRHead
|
||||
from .pid_head import PIDHead
|
||||
from .point_head import PointHead
|
||||
from .psa_head import PSAHead
|
||||
from .psp_head import PSPHead
|
||||
from .san_head import SideAdapterCLIPHead
|
||||
from .segformer_head import SegformerHead
|
||||
from .segmenter_mask_head import SegmenterMaskTransformerHead
|
||||
from .sep_aspp_head import DepthwiseSeparableASPPHead
|
||||
from .sep_fcn_head import DepthwiseSeparableFCNHead
|
||||
from .setr_mla_head import SETRMLAHead
|
||||
from .setr_up_head import SETRUPHead
|
||||
from .stdc_head import STDCHead
|
||||
from .uper_head import UPerHead
|
||||
from .vpd_depth_head import VPDDepthHead
|
||||
|
||||
__all__ = [
|
||||
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
|
||||
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
||||
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
|
||||
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
|
||||
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
|
||||
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
|
||||
'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead',
|
||||
'LightHamHead', 'PIDHead', 'DDRHead', 'VPDDepthHead', 'SideAdapterCLIPHead'
|
||||
]
|
||||
245
finetune/mmseg/models/decode_heads/ann_head.py
Normal file
245
finetune/mmseg/models/decode_heads/ann_head.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class PPMConcat(nn.ModuleList):
|
||||
"""Pyramid Pooling Module that only concat the features of each layer.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 3, 6, 8)):
|
||||
super().__init__(
|
||||
[nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales])
|
||||
|
||||
def forward(self, feats):
|
||||
"""Forward function."""
|
||||
ppm_outs = []
|
||||
for ppm in self:
|
||||
ppm_out = ppm(feats)
|
||||
ppm_outs.append(ppm_out.view(*feats.shape[:2], -1))
|
||||
concat_outs = torch.cat(ppm_outs, dim=2)
|
||||
return concat_outs
|
||||
|
||||
|
||||
class SelfAttentionBlock(_SelfAttentionBlock):
|
||||
"""Make a ANN used SelfAttentionBlock.
|
||||
|
||||
Args:
|
||||
low_in_channels (int): Input channels of lower level feature,
|
||||
which is the key feature for self-attention.
|
||||
high_in_channels (int): Input channels of higher level feature,
|
||||
which is the query feature for self-attention.
|
||||
channels (int): Output channels of key/query transform.
|
||||
out_channels (int): Output channels.
|
||||
share_key_query (bool): Whether share projection weight between key
|
||||
and query projection.
|
||||
query_scale (int): The scale of query feature map.
|
||||
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module of key feature.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict|None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, low_in_channels, high_in_channels, channels,
|
||||
out_channels, share_key_query, query_scale, key_pool_scales,
|
||||
conv_cfg, norm_cfg, act_cfg):
|
||||
key_psp = PPMConcat(key_pool_scales)
|
||||
if query_scale > 1:
|
||||
query_downsample = nn.MaxPool2d(kernel_size=query_scale)
|
||||
else:
|
||||
query_downsample = None
|
||||
super().__init__(
|
||||
key_in_channels=low_in_channels,
|
||||
query_in_channels=high_in_channels,
|
||||
channels=channels,
|
||||
out_channels=out_channels,
|
||||
share_key_query=share_key_query,
|
||||
query_downsample=query_downsample,
|
||||
key_downsample=key_psp,
|
||||
key_query_num_convs=1,
|
||||
key_query_norm=True,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=False,
|
||||
matmul_norm=True,
|
||||
with_out=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
|
||||
class AFNB(nn.Module):
|
||||
"""Asymmetric Fusion Non-local Block(AFNB)
|
||||
|
||||
Args:
|
||||
low_in_channels (int): Input channels of lower level feature,
|
||||
which is the key feature for self-attention.
|
||||
high_in_channels (int): Input channels of higher level feature,
|
||||
which is the query feature for self-attention.
|
||||
channels (int): Output channels of key/query transform.
|
||||
out_channels (int): Output channels.
|
||||
and query projection.
|
||||
query_scales (tuple[int]): The scales of query feature map.
|
||||
Default: (1,)
|
||||
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module of key feature.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict|None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, low_in_channels, high_in_channels, channels,
|
||||
out_channels, query_scales, key_pool_scales, conv_cfg,
|
||||
norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.stages = nn.ModuleList()
|
||||
for query_scale in query_scales:
|
||||
self.stages.append(
|
||||
SelfAttentionBlock(
|
||||
low_in_channels=low_in_channels,
|
||||
high_in_channels=high_in_channels,
|
||||
channels=channels,
|
||||
out_channels=out_channels,
|
||||
share_key_query=False,
|
||||
query_scale=query_scale,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.bottleneck = ConvModule(
|
||||
out_channels + high_in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
def forward(self, low_feats, high_feats):
|
||||
"""Forward function."""
|
||||
priors = [stage(high_feats, low_feats) for stage in self.stages]
|
||||
context = torch.stack(priors, dim=0).sum(dim=0)
|
||||
output = self.bottleneck(torch.cat([context, high_feats], 1))
|
||||
return output
|
||||
|
||||
|
||||
class APNB(nn.Module):
|
||||
"""Asymmetric Pyramid Non-local Block (APNB)
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of key/query feature,
|
||||
which is the key feature for self-attention.
|
||||
channels (int): Output channels of key/query transform.
|
||||
out_channels (int): Output channels.
|
||||
query_scales (tuple[int]): The scales of query feature map.
|
||||
Default: (1,)
|
||||
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module of key feature.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict|None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, channels, out_channels, query_scales,
|
||||
key_pool_scales, conv_cfg, norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.stages = nn.ModuleList()
|
||||
for query_scale in query_scales:
|
||||
self.stages.append(
|
||||
SelfAttentionBlock(
|
||||
low_in_channels=in_channels,
|
||||
high_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=out_channels,
|
||||
share_key_query=True,
|
||||
query_scale=query_scale,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.bottleneck = ConvModule(
|
||||
2 * in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, feats):
|
||||
"""Forward function."""
|
||||
priors = [stage(feats, feats) for stage in self.stages]
|
||||
context = torch.stack(priors, dim=0).sum(dim=0)
|
||||
output = self.bottleneck(torch.cat([context, feats], 1))
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ANNHead(BaseDecodeHead):
|
||||
"""Asymmetric Non-local Neural Networks for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `ANNNet
|
||||
<https://arxiv.org/abs/1908.07678>`_.
|
||||
|
||||
Args:
|
||||
project_channels (int): Projection channels for Nonlocal.
|
||||
query_scales (tuple[int]): The scales of query feature map.
|
||||
Default: (1,)
|
||||
key_pool_scales (tuple[int]): The pooling scales of key feature map.
|
||||
Default: (1, 3, 6, 8).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
project_channels,
|
||||
query_scales=(1, ),
|
||||
key_pool_scales=(1, 3, 6, 8),
|
||||
**kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
assert len(self.in_channels) == 2
|
||||
low_in_channels, high_in_channels = self.in_channels
|
||||
self.project_channels = project_channels
|
||||
self.fusion = AFNB(
|
||||
low_in_channels=low_in_channels,
|
||||
high_in_channels=high_in_channels,
|
||||
out_channels=high_in_channels,
|
||||
channels=project_channels,
|
||||
query_scales=query_scales,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.bottleneck = ConvModule(
|
||||
high_in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.context = APNB(
|
||||
in_channels=self.channels,
|
||||
out_channels=self.channels,
|
||||
channels=project_channels,
|
||||
query_scales=query_scales,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
low_feats, high_feats = self._transform_inputs(inputs)
|
||||
output = self.fusion(low_feats, high_feats)
|
||||
output = self.dropout(output)
|
||||
output = self.bottleneck(output)
|
||||
output = self.context(output)
|
||||
output = self.cls_seg(output)
|
||||
|
||||
return output
|
||||
159
finetune/mmseg/models/decode_heads/apc_head.py
Normal file
159
finetune/mmseg/models/decode_heads/apc_head.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class ACM(nn.Module):
|
||||
"""Adaptive Context Module used in APCNet.
|
||||
|
||||
Args:
|
||||
pool_scale (int): Pooling scale used in Adaptive Context
|
||||
Module to extract region features.
|
||||
fusion (bool): Add one conv to fuse residual feature.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg,
|
||||
norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.pool_scale = pool_scale
|
||||
self.fusion = fusion
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.pooled_redu_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.input_redu_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.global_info = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)
|
||||
|
||||
self.residual_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
if self.fusion:
|
||||
self.fusion_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale)
|
||||
# [batch_size, channels, h, w]
|
||||
x = self.input_redu_conv(x)
|
||||
# [batch_size, channels, pool_scale, pool_scale]
|
||||
pooled_x = self.pooled_redu_conv(pooled_x)
|
||||
batch_size = x.size(0)
|
||||
# [batch_size, pool_scale * pool_scale, channels]
|
||||
pooled_x = pooled_x.view(batch_size, self.channels,
|
||||
-1).permute(0, 2, 1).contiguous()
|
||||
# [batch_size, h * w, pool_scale * pool_scale]
|
||||
affinity_matrix = self.gla(x + resize(
|
||||
self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:])
|
||||
).permute(0, 2, 3, 1).reshape(
|
||||
batch_size, -1, self.pool_scale**2)
|
||||
affinity_matrix = F.sigmoid(affinity_matrix)
|
||||
# [batch_size, h * w, channels]
|
||||
z_out = torch.matmul(affinity_matrix, pooled_x)
|
||||
# [batch_size, channels, h * w]
|
||||
z_out = z_out.permute(0, 2, 1).contiguous()
|
||||
# [batch_size, channels, h, w]
|
||||
z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3))
|
||||
z_out = self.residual_conv(z_out)
|
||||
z_out = F.relu(z_out + x)
|
||||
if self.fusion:
|
||||
z_out = self.fusion_conv(z_out)
|
||||
|
||||
return z_out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class APCHead(BaseDecodeHead):
|
||||
"""Adaptive Pyramid Context Network for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of
|
||||
`APCNet <https://openaccess.thecvf.com/content_CVPR_2019/papers/\
|
||||
He_Adaptive_Pyramid_Context_Network_for_Semantic_Segmentation_\
|
||||
CVPR_2019_paper.pdf>`_.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Adaptive Context
|
||||
Module. Default: (1, 2, 3, 6).
|
||||
fusion (bool): Add one conv to fuse residual feature.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(pool_scales, (list, tuple))
|
||||
self.pool_scales = pool_scales
|
||||
self.fusion = fusion
|
||||
acm_modules = []
|
||||
for pool_scale in self.pool_scales:
|
||||
acm_modules.append(
|
||||
ACM(pool_scale,
|
||||
self.fusion,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.acm_modules = nn.ModuleList(acm_modules)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels + len(pool_scales) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
acm_outs = [x]
|
||||
for acm_module in self.acm_modules:
|
||||
acm_outs.append(acm_module(x))
|
||||
acm_outs = torch.cat(acm_outs, dim=1)
|
||||
output = self.bottleneck(acm_outs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
122
finetune/mmseg/models/decode_heads/aspp_head.py
Normal file
122
finetune/mmseg/models/decode_heads/aspp_head.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class ASPPModule(nn.ModuleList):
|
||||
"""Atrous Spatial Pyramid Pooling (ASPP) Module.
|
||||
|
||||
Args:
|
||||
dilations (tuple[int]): Dilation rate of each layer.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,
|
||||
act_cfg):
|
||||
super().__init__()
|
||||
self.dilations = dilations
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
for dilation in dilations:
|
||||
self.append(
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1 if dilation == 1 else 3,
|
||||
dilation=dilation,
|
||||
padding=0 if dilation == 1 else dilation,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
aspp_outs = []
|
||||
for aspp_module in self:
|
||||
aspp_outs.append(aspp_module(x))
|
||||
|
||||
return aspp_outs
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ASPPHead(BaseDecodeHead):
|
||||
"""Rethinking Atrous Convolution for Semantic Image Segmentation.
|
||||
|
||||
This head is the implementation of `DeepLabV3
|
||||
<https://arxiv.org/abs/1706.05587>`_.
|
||||
|
||||
Args:
|
||||
dilations (tuple[int]): Dilation rates for ASPP module.
|
||||
Default: (1, 6, 12, 18).
|
||||
"""
|
||||
|
||||
def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(dilations, (list, tuple))
|
||||
self.dilations = dilations
|
||||
self.image_pool = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.aspp_modules = ASPPModule(
|
||||
dilations,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.bottleneck = ConvModule(
|
||||
(len(dilations) + 1) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def _forward_feature(self, inputs):
|
||||
"""Forward function for feature maps before classifying each pixel with
|
||||
``self.cls_seg`` fc.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||
H, W) which is feature map for last layer of decoder head.
|
||||
"""
|
||||
x = self._transform_inputs(inputs)
|
||||
aspp_outs = [
|
||||
resize(
|
||||
self.image_pool(x),
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
]
|
||||
aspp_outs.extend(self.aspp_modules(x))
|
||||
aspp_outs = torch.cat(aspp_outs, dim=1)
|
||||
feats = self.bottleneck(aspp_outs)
|
||||
return feats
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
output = self._forward_feature(inputs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
62
finetune/mmseg/models/decode_heads/cascade_decode_head.py
Normal file
62
finetune/mmseg/models/decode_heads/cascade_decode_head.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.utils import ConfigType
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
|
||||
"""Base class for cascade decode head used in
|
||||
:class:`CascadeEncoderDecoder."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, inputs, prev_output):
|
||||
"""Placeholder of forward function."""
|
||||
pass
|
||||
|
||||
def loss(self, inputs: List[Tensor], prev_output: Tensor,
|
||||
batch_data_samples: List[dict], train_cfg: ConfigType) -> Tensor:
|
||||
"""Forward function for training.
|
||||
|
||||
Args:
|
||||
inputs (List[Tensor]): List of multi-level img features.
|
||||
prev_output (Tensor): The output of previous decode head.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_sem_seg`.
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
seg_logits = self.forward(inputs, prev_output)
|
||||
losses = self.loss_by_feat(seg_logits, batch_data_samples)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, inputs: List[Tensor], prev_output: Tensor,
|
||||
batch_img_metas: List[dict], tese_cfg: ConfigType):
|
||||
"""Forward function for testing.
|
||||
|
||||
Args:
|
||||
inputs (List[Tensor]): List of multi-level img features.
|
||||
prev_output (Tensor): The output of previous decode head.
|
||||
batch_img_metas (dict): List Image info where each dict may also
|
||||
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Output segmentation map.
|
||||
"""
|
||||
seg_logits = self.forward(inputs, prev_output)
|
||||
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas)
|
||||
43
finetune/mmseg/models/decode_heads/cc_head.py
Normal file
43
finetune/mmseg/models/decode_heads/cc_head.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
try:
|
||||
from mmcv.ops import CrissCrossAttention
|
||||
except ModuleNotFoundError:
|
||||
CrissCrossAttention = None
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CCHead(FCNHead):
|
||||
"""CCNet: Criss-Cross Attention for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `CCNet
|
||||
<https://arxiv.org/abs/1811.11721>`_.
|
||||
|
||||
Args:
|
||||
recurrence (int): Number of recurrence of Criss Cross Attention
|
||||
module. Default: 2.
|
||||
"""
|
||||
|
||||
def __init__(self, recurrence=2, **kwargs):
|
||||
if CrissCrossAttention is None:
|
||||
raise RuntimeError('Please install mmcv-full for '
|
||||
'CrissCrossAttention ops')
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.recurrence = recurrence
|
||||
self.cca = CrissCrossAttention(self.channels)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
for _ in range(self.recurrence):
|
||||
output = self.cca(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
184
finetune/mmseg/models/decode_heads/da_head.py
Normal file
184
finetune/mmseg/models/decode_heads/da_head.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, Scale
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList, add_prefix
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class PAM(_SelfAttentionBlock):
|
||||
"""Position Attention Module (PAM)
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of key/query feature.
|
||||
channels (int): Output channels of key/query transform.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, channels):
|
||||
super().__init__(
|
||||
key_in_channels=in_channels,
|
||||
query_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=in_channels,
|
||||
share_key_query=False,
|
||||
query_downsample=None,
|
||||
key_downsample=None,
|
||||
key_query_num_convs=1,
|
||||
key_query_norm=False,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=False,
|
||||
matmul_norm=False,
|
||||
with_out=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=None)
|
||||
|
||||
self.gamma = Scale(0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
out = super().forward(x, x)
|
||||
|
||||
out = self.gamma(out) + x
|
||||
return out
|
||||
|
||||
|
||||
class CAM(nn.Module):
|
||||
"""Channel Attention Module (CAM)"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gamma = Scale(0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
batch_size, channels, height, width = x.size()
|
||||
proj_query = x.view(batch_size, channels, -1)
|
||||
proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1)
|
||||
energy = torch.bmm(proj_query, proj_key)
|
||||
energy_new = torch.max(
|
||||
energy, -1, keepdim=True)[0].expand_as(energy) - energy
|
||||
attention = F.softmax(energy_new, dim=-1)
|
||||
proj_value = x.view(batch_size, channels, -1)
|
||||
|
||||
out = torch.bmm(attention, proj_value)
|
||||
out = out.view(batch_size, channels, height, width)
|
||||
|
||||
out = self.gamma(out) + x
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DAHead(BaseDecodeHead):
|
||||
"""Dual Attention Network for Scene Segmentation.
|
||||
|
||||
This head is the implementation of `DANet
|
||||
<https://arxiv.org/abs/1809.02983>`_.
|
||||
|
||||
Args:
|
||||
pam_channels (int): The channels of Position Attention Module(PAM).
|
||||
"""
|
||||
|
||||
def __init__(self, pam_channels, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.pam_channels = pam_channels
|
||||
self.pam_in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.pam = PAM(self.channels, pam_channels)
|
||||
self.pam_out_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.pam_conv_seg = nn.Conv2d(
|
||||
self.channels, self.num_classes, kernel_size=1)
|
||||
|
||||
self.cam_in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.cam = CAM()
|
||||
self.cam_out_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.cam_conv_seg = nn.Conv2d(
|
||||
self.channels, self.num_classes, kernel_size=1)
|
||||
|
||||
def pam_cls_seg(self, feat):
|
||||
"""PAM feature classification."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.pam_conv_seg(feat)
|
||||
return output
|
||||
|
||||
def cam_cls_seg(self, feat):
|
||||
"""CAM feature classification."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.cam_conv_seg(feat)
|
||||
return output
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
pam_feat = self.pam_in_conv(x)
|
||||
pam_feat = self.pam(pam_feat)
|
||||
pam_feat = self.pam_out_conv(pam_feat)
|
||||
pam_out = self.pam_cls_seg(pam_feat)
|
||||
|
||||
cam_feat = self.cam_in_conv(x)
|
||||
cam_feat = self.cam(cam_feat)
|
||||
cam_feat = self.cam_out_conv(cam_feat)
|
||||
cam_out = self.cam_cls_seg(cam_feat)
|
||||
|
||||
feat_sum = pam_feat + cam_feat
|
||||
pam_cam_out = self.cls_seg(feat_sum)
|
||||
|
||||
return pam_cam_out, pam_out, cam_out
|
||||
|
||||
def predict(self, inputs, batch_img_metas: List[dict], test_cfg,
|
||||
**kwargs) -> List[Tensor]:
|
||||
"""Forward function for testing, only ``pam_cam`` is used."""
|
||||
seg_logits = self.forward(inputs)[0]
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
|
||||
|
||||
def loss_by_feat(self, seg_logit: Tuple[Tensor],
|
||||
batch_data_samples: SampleList, **kwargs) -> dict:
|
||||
"""Compute ``pam_cam``, ``pam``, ``cam`` loss."""
|
||||
pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit
|
||||
loss = dict()
|
||||
loss.update(
|
||||
add_prefix(
|
||||
super().loss_by_feat(pam_cam_seg_logit, batch_data_samples),
|
||||
'pam_cam'))
|
||||
loss.update(
|
||||
add_prefix(super().loss_by_feat(pam_seg_logit, batch_data_samples),
|
||||
'pam'))
|
||||
loss.update(
|
||||
add_prefix(super().loss_by_feat(cam_seg_logit, batch_data_samples),
|
||||
'cam'))
|
||||
return loss
|
||||
116
finetune/mmseg/models/decode_heads/ddr_head.py
Normal file
116
finetune/mmseg/models/decode_heads/ddr_head.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.models.losses import accuracy
|
||||
from mmseg.models.utils import resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType, SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DDRHead(BaseDecodeHead):
|
||||
"""Decode head for DDRNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
num_classes (int): Number of classes.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict, optional): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
num_classes: int,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
channels,
|
||||
num_classes=num_classes,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**kwargs)
|
||||
|
||||
self.head = self._make_base_head(self.in_channels, self.channels)
|
||||
self.aux_head = self._make_base_head(self.in_channels // 2,
|
||||
self.channels)
|
||||
self.aux_cls_seg = nn.Conv2d(
|
||||
self.channels, self.out_channels, kernel_size=1)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: Union[Tensor,
|
||||
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
|
||||
if self.training:
|
||||
c3_feat, c5_feat = inputs
|
||||
x_c = self.head(c5_feat)
|
||||
x_c = self.cls_seg(x_c)
|
||||
x_s = self.aux_head(c3_feat)
|
||||
x_s = self.aux_cls_seg(x_s)
|
||||
|
||||
return x_c, x_s
|
||||
else:
|
||||
x_c = self.head(inputs)
|
||||
x_c = self.cls_seg(x_c)
|
||||
return x_c
|
||||
|
||||
def _make_base_head(self, in_channels: int,
|
||||
channels: int) -> nn.Sequential:
|
||||
layers = [
|
||||
ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
order=('norm', 'act', 'conv')),
|
||||
build_norm_layer(self.norm_cfg, channels)[1],
|
||||
build_activation_layer(self.act_cfg),
|
||||
]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def loss_by_feat(self, seg_logits: Tuple[Tensor],
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
loss = dict()
|
||||
context_logit, spatial_logit = seg_logits
|
||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||
|
||||
context_logit = resize(
|
||||
context_logit,
|
||||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
spatial_logit = resize(
|
||||
spatial_logit,
|
||||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
seg_label = seg_label.squeeze(1)
|
||||
|
||||
loss['loss_context'] = self.loss_decode[0](context_logit, seg_label)
|
||||
loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label)
|
||||
loss['acc_seg'] = accuracy(
|
||||
context_logit, seg_label, ignore_index=self.ignore_index)
|
||||
|
||||
return loss
|
||||
366
finetune/mmseg/models/decode_heads/decode_head.py
Normal file
366
finetune/mmseg/models/decode_heads/decode_head.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import BaseModule
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import build_pixel_sampler
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
from ..losses import accuracy
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
"""Base class for BaseDecodeHead.
|
||||
|
||||
1. The ``init_weights`` method is used to initialize decode_head's
|
||||
model parameters. After segmentor initialization, ``init_weights``
|
||||
is triggered when ``segmentor.init_weights()`` is called externally.
|
||||
|
||||
2. The ``loss`` method is used to calculate the loss of decode_head,
|
||||
which includes two steps: (1) the decode_head model performs forward
|
||||
propagation to obtain the feature maps (2) The ``loss_by_feat`` method
|
||||
is called based on the feature maps to calculate the loss.
|
||||
|
||||
.. code:: text
|
||||
|
||||
loss(): forward() -> loss_by_feat()
|
||||
|
||||
3. The ``predict`` method is used to predict segmentation results,
|
||||
which includes two steps: (1) the decode_head model performs forward
|
||||
propagation to obtain the feature maps (2) The ``predict_by_feat`` method
|
||||
is called based on the feature maps to predict segmentation results
|
||||
including post-processing.
|
||||
|
||||
.. code:: text
|
||||
|
||||
predict(): forward() -> predict_by_feat()
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
num_classes (int): Number of classes.
|
||||
out_channels (int): Output channels of conv_seg. Default: None.
|
||||
threshold (float): Threshold for binary segmentation in the case of
|
||||
`num_classes==1`. Default: None.
|
||||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
||||
conv_cfg (dict|None): Config of conv layers. Default: None.
|
||||
norm_cfg (dict|None): Config of norm layers. Default: None.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU')
|
||||
in_index (int|Sequence[int]): Input feature index. Default: -1
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
Default: None.
|
||||
loss_decode (dict | Sequence[dict]): Config of decode loss.
|
||||
The `loss_name` is property of corresponding loss function which
|
||||
could be shown in training log. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_ce'.
|
||||
e.g. dict(type='CrossEntropyLoss'),
|
||||
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
|
||||
dict(type='DiceLoss', loss_name='loss_dice')]
|
||||
Default: dict(type='CrossEntropyLoss').
|
||||
ignore_index (int | None): The label index to be ignored. When using
|
||||
masked BCE loss, ignore_index should be set to None. Default: 255.
|
||||
sampler (dict|None): The config of segmentation map sampler.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channels,
|
||||
*,
|
||||
num_classes,
|
||||
out_channels=None,
|
||||
threshold=None,
|
||||
dropout_ratio=0.1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
in_index=-1,
|
||||
input_transform=None,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0),
|
||||
ignore_index=255,
|
||||
sampler=None,
|
||||
align_corners=False,
|
||||
init_cfg=dict(
|
||||
type='Normal', std=0.01, override=dict(name='conv_seg'))):
|
||||
super().__init__(init_cfg)
|
||||
self._init_inputs(in_channels, in_index, input_transform)
|
||||
self.channels = channels
|
||||
self.dropout_ratio = dropout_ratio
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.in_index = in_index
|
||||
|
||||
self.ignore_index = ignore_index
|
||||
self.align_corners = align_corners
|
||||
|
||||
if out_channels is None:
|
||||
if num_classes == 2:
|
||||
warnings.warn('For binary segmentation, we suggest using'
|
||||
'`out_channels = 1` to define the output'
|
||||
'channels of segmentor, and use `threshold`'
|
||||
'to convert `seg_logits` into a prediction'
|
||||
'applying a threshold')
|
||||
out_channels = num_classes
|
||||
|
||||
if out_channels != num_classes and out_channels != 1:
|
||||
raise ValueError(
|
||||
'out_channels should be equal to num_classes,'
|
||||
'except binary segmentation set out_channels == 1 and'
|
||||
f'num_classes == 2, but got out_channels={out_channels}'
|
||||
f'and num_classes={num_classes}')
|
||||
|
||||
if out_channels == 1 and threshold is None:
|
||||
threshold = 0.3
|
||||
warnings.warn('threshold is not defined for binary, and defaults'
|
||||
'to 0.3')
|
||||
self.num_classes = num_classes
|
||||
self.out_channels = out_channels
|
||||
self.threshold = threshold
|
||||
|
||||
if isinstance(loss_decode, dict):
|
||||
self.loss_decode = MODELS.build(loss_decode)
|
||||
elif isinstance(loss_decode, (list, tuple)):
|
||||
self.loss_decode = nn.ModuleList()
|
||||
for loss in loss_decode:
|
||||
self.loss_decode.append(MODELS.build(loss))
|
||||
else:
|
||||
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
|
||||
but got {type(loss_decode)}')
|
||||
|
||||
if sampler is not None:
|
||||
self.sampler = build_pixel_sampler(sampler, context=self)
|
||||
else:
|
||||
self.sampler = None
|
||||
|
||||
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
|
||||
if dropout_ratio > 0:
|
||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
def extra_repr(self):
|
||||
"""Extra repr."""
|
||||
s = f'input_transform={self.input_transform}, ' \
|
||||
f'ignore_index={self.ignore_index}, ' \
|
||||
f'align_corners={self.align_corners}'
|
||||
return s
|
||||
|
||||
def _init_inputs(self, in_channels, in_index, input_transform):
|
||||
"""Check and initialize input transforms.
|
||||
|
||||
The in_channels, in_index and input_transform must match.
|
||||
Specifically, when input_transform is None, only single feature map
|
||||
will be selected. So in_channels and in_index must be of type int.
|
||||
When input_transform
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
in_index (int|Sequence[int]): Input feature index.
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
"""
|
||||
|
||||
if input_transform is not None:
|
||||
assert input_transform in ['resize_concat', 'multiple_select']
|
||||
self.input_transform = input_transform
|
||||
self.in_index = in_index
|
||||
if input_transform is not None:
|
||||
assert isinstance(in_channels, (list, tuple))
|
||||
assert isinstance(in_index, (list, tuple))
|
||||
assert len(in_channels) == len(in_index)
|
||||
if input_transform == 'resize_concat':
|
||||
self.in_channels = sum(in_channels)
|
||||
else:
|
||||
self.in_channels = in_channels
|
||||
else:
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(in_index, int)
|
||||
self.in_channels = in_channels
|
||||
|
||||
def _transform_inputs(self, inputs):
|
||||
"""Transform inputs for decoder.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
Tensor: The transformed inputs
|
||||
"""
|
||||
|
||||
if self.input_transform == 'resize_concat':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
upsampled_inputs = [
|
||||
resize(
|
||||
input=x,
|
||||
size=inputs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners) for x in inputs
|
||||
]
|
||||
inputs = torch.cat(upsampled_inputs, dim=1)
|
||||
elif self.input_transform == 'multiple_select':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
else:
|
||||
inputs = inputs[self.in_index]
|
||||
|
||||
return inputs
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, inputs):
|
||||
"""Placeholder of forward function."""
|
||||
pass
|
||||
|
||||
def cls_seg(self, feat):
|
||||
"""Classify each pixel."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.conv_seg(feat)
|
||||
return output
|
||||
|
||||
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList,
|
||||
train_cfg: ConfigType) -> dict:
|
||||
"""Forward function for training.
|
||||
|
||||
Args:
|
||||
inputs (Tuple[Tensor]): List of multi-level img features.
|
||||
batch_data_samples (list[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `img_metas` or `gt_semantic_seg`.
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
seg_logits = self.forward(inputs)
|
||||
losses = self.loss_by_feat(seg_logits, batch_data_samples)
|
||||
return losses
|
||||
|
||||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType) -> Tensor:
|
||||
"""Forward function for prediction.
|
||||
|
||||
Args:
|
||||
inputs (Tuple[Tensor]): List of multi-level img features.
|
||||
batch_img_metas (dict): List Image info where each dict may also
|
||||
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Outputs segmentation logits map.
|
||||
"""
|
||||
seg_logits = self.forward(inputs)
|
||||
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas)
|
||||
|
||||
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
|
||||
gt_semantic_segs = [
|
||||
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
|
||||
]
|
||||
return torch.stack(gt_semantic_segs, dim=0)
|
||||
|
||||
def loss_by_feat(self, seg_logits: Tensor,
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
"""Compute segmentation loss.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): The output from decode head forward function.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
|
||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||
loss = dict()
|
||||
seg_logits = resize(
|
||||
input=seg_logits,
|
||||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.sampler is not None:
|
||||
seg_weight = self.sampler.sample(seg_logits, seg_label)
|
||||
else:
|
||||
seg_weight = None
|
||||
seg_label = seg_label.squeeze(1)
|
||||
|
||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.loss_decode]
|
||||
else:
|
||||
losses_decode = self.loss_decode
|
||||
for loss_decode in losses_decode:
|
||||
if loss_decode.loss_name not in loss:
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
seg_logits,
|
||||
seg_label,
|
||||
weight=seg_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
else:
|
||||
loss[loss_decode.loss_name] += loss_decode(
|
||||
seg_logits,
|
||||
seg_label,
|
||||
weight=seg_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
|
||||
loss['acc_seg'] = accuracy(
|
||||
seg_logits, seg_label, ignore_index=self.ignore_index)
|
||||
return loss
|
||||
|
||||
def predict_by_feat(self, seg_logits: Tensor,
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Transform a batch of output seg_logits to the input shape.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): The output from decode head forward function.
|
||||
batch_img_metas (list[dict]): Meta information of each image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
|
||||
Returns:
|
||||
Tensor: Outputs segmentation logits map.
|
||||
"""
|
||||
|
||||
if isinstance(batch_img_metas[0]['img_shape'], torch.Size):
|
||||
# slide inference
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
elif 'pad_shape' in batch_img_metas[0]:
|
||||
size = batch_img_metas[0]['pad_shape'][:2]
|
||||
else:
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
|
||||
seg_logits = resize(
|
||||
input=seg_logits,
|
||||
size=size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
return seg_logits
|
||||
141
finetune/mmseg/models/decode_heads/dm_head.py
Normal file
141
finetune/mmseg/models/decode_heads/dm_head.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class DCM(nn.Module):
|
||||
"""Dynamic Convolutional Module used in DMNet.
|
||||
|
||||
Args:
|
||||
filter_size (int): The filter size of generated convolution kernel
|
||||
used in Dynamic Convolutional Module.
|
||||
fusion (bool): Add one conv to fuse DCM output feature.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg,
|
||||
norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.filter_size = filter_size
|
||||
self.fusion = fusion
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1,
|
||||
0)
|
||||
|
||||
self.input_redu_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
if self.norm_cfg is not None:
|
||||
self.norm = build_norm_layer(self.norm_cfg, self.channels)[1]
|
||||
else:
|
||||
self.norm = None
|
||||
self.activate = build_activation_layer(self.act_cfg)
|
||||
|
||||
if self.fusion:
|
||||
self.fusion_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
generated_filter = self.filter_gen_conv(
|
||||
F.adaptive_avg_pool2d(x, self.filter_size))
|
||||
x = self.input_redu_conv(x)
|
||||
b, c, h, w = x.shape
|
||||
# [1, b * c, h, w], c = self.channels
|
||||
x = x.view(1, b * c, h, w)
|
||||
# [b * c, 1, filter_size, filter_size]
|
||||
generated_filter = generated_filter.view(b * c, 1, self.filter_size,
|
||||
self.filter_size)
|
||||
pad = (self.filter_size - 1) // 2
|
||||
if (self.filter_size - 1) % 2 == 0:
|
||||
p2d = (pad, pad, pad, pad)
|
||||
else:
|
||||
p2d = (pad + 1, pad, pad + 1, pad)
|
||||
x = F.pad(input=x, pad=p2d, mode='constant', value=0)
|
||||
# [1, b * c, h, w]
|
||||
output = F.conv2d(input=x, weight=generated_filter, groups=b * c)
|
||||
# [b, c, h, w]
|
||||
output = output.view(b, c, h, w)
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
output = self.activate(output)
|
||||
|
||||
if self.fusion:
|
||||
output = self.fusion_conv(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DMHead(BaseDecodeHead):
|
||||
"""Dynamic Multi-scale Filters for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of
|
||||
`DMNet <https://openaccess.thecvf.com/content_ICCV_2019/papers/\
|
||||
He_Dynamic_Multi-Scale_Filters_for_Semantic_Segmentation_\
|
||||
ICCV_2019_paper.pdf>`_.
|
||||
|
||||
Args:
|
||||
filter_sizes (tuple[int]): The size of generated convolutional filters
|
||||
used in Dynamic Convolutional Module. Default: (1, 3, 5, 7).
|
||||
fusion (bool): Add one conv to fuse DCM output feature.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(filter_sizes, (list, tuple))
|
||||
self.filter_sizes = filter_sizes
|
||||
self.fusion = fusion
|
||||
dcm_modules = []
|
||||
for filter_size in self.filter_sizes:
|
||||
dcm_modules.append(
|
||||
DCM(filter_size,
|
||||
self.fusion,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.dcm_modules = nn.ModuleList(dcm_modules)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels + len(filter_sizes) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
dcm_outs = [x]
|
||||
for dcm_module in self.dcm_modules:
|
||||
dcm_outs.append(dcm_module(x))
|
||||
dcm_outs = torch.cat(dcm_outs, dim=1)
|
||||
output = self.bottleneck(dcm_outs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
137
finetune/mmseg/models/decode_heads/dnl_head.py
Normal file
137
finetune/mmseg/models/decode_heads/dnl_head.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.cnn import NonLocal2d
|
||||
from torch import nn
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
class DisentangledNonLocal2d(NonLocal2d):
|
||||
"""Disentangled Non-Local Blocks.
|
||||
|
||||
Args:
|
||||
temperature (float): Temperature to adjust attention. Default: 0.05
|
||||
"""
|
||||
|
||||
def __init__(self, *arg, temperature, **kwargs):
|
||||
super().__init__(*arg, **kwargs)
|
||||
self.temperature = temperature
|
||||
self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)
|
||||
|
||||
def embedded_gaussian(self, theta_x, phi_x):
|
||||
"""Embedded gaussian with temperature."""
|
||||
|
||||
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
||||
pairwise_weight = torch.matmul(theta_x, phi_x)
|
||||
if self.use_scale:
|
||||
# theta_x.shape[-1] is `self.inter_channels`
|
||||
pairwise_weight /= torch.tensor(
|
||||
theta_x.shape[-1],
|
||||
dtype=torch.float,
|
||||
device=pairwise_weight.device)**torch.tensor(
|
||||
0.5, device=pairwise_weight.device)
|
||||
pairwise_weight /= torch.tensor(
|
||||
self.temperature, device=pairwise_weight.device)
|
||||
pairwise_weight = pairwise_weight.softmax(dim=-1)
|
||||
return pairwise_weight
|
||||
|
||||
def forward(self, x):
|
||||
# x: [N, C, H, W]
|
||||
n = x.size(0)
|
||||
|
||||
# g_x: [N, HxW, C]
|
||||
g_x = self.g(x).view(n, self.inter_channels, -1)
|
||||
g_x = g_x.permute(0, 2, 1)
|
||||
|
||||
# theta_x: [N, HxW, C], phi_x: [N, C, HxW]
|
||||
if self.mode == 'gaussian':
|
||||
theta_x = x.view(n, self.in_channels, -1)
|
||||
theta_x = theta_x.permute(0, 2, 1)
|
||||
if self.sub_sample:
|
||||
phi_x = self.phi(x).view(n, self.in_channels, -1)
|
||||
else:
|
||||
phi_x = x.view(n, self.in_channels, -1)
|
||||
elif self.mode == 'concatenation':
|
||||
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
|
||||
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
|
||||
else:
|
||||
theta_x = self.theta(x).view(n, self.inter_channels, -1)
|
||||
theta_x = theta_x.permute(0, 2, 1)
|
||||
phi_x = self.phi(x).view(n, self.inter_channels, -1)
|
||||
|
||||
# subtract mean
|
||||
theta_x -= theta_x.mean(dim=-2, keepdim=True)
|
||||
phi_x -= phi_x.mean(dim=-1, keepdim=True)
|
||||
|
||||
pairwise_func = getattr(self, self.mode)
|
||||
# pairwise_weight: [N, HxW, HxW]
|
||||
pairwise_weight = pairwise_func(theta_x, phi_x)
|
||||
|
||||
# y: [N, HxW, C]
|
||||
y = torch.matmul(pairwise_weight, g_x)
|
||||
# y: [N, C, H, W]
|
||||
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
|
||||
*x.size()[2:])
|
||||
|
||||
# unary_mask: [N, 1, HxW]
|
||||
unary_mask = self.conv_mask(x)
|
||||
unary_mask = unary_mask.view(n, 1, -1)
|
||||
unary_mask = unary_mask.softmax(dim=-1)
|
||||
# unary_x: [N, 1, C]
|
||||
unary_x = torch.matmul(unary_mask, g_x)
|
||||
# unary_x: [N, C, 1, 1]
|
||||
unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(
|
||||
n, self.inter_channels, 1, 1)
|
||||
|
||||
output = x + self.conv_out(y + unary_x)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DNLHead(FCNHead):
|
||||
"""Disentangled Non-Local Neural Networks.
|
||||
|
||||
This head is the implementation of `DNLNet
|
||||
<https://arxiv.org/abs/2006.06668>`_.
|
||||
|
||||
Args:
|
||||
reduction (int): Reduction factor of projection transform. Default: 2.
|
||||
use_scale (bool): Whether to scale pairwise_weight by
|
||||
sqrt(1/inter_channels). Default: False.
|
||||
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
|
||||
'dot_product'. Default: 'embedded_gaussian.'.
|
||||
temperature (float): Temperature to adjust attention. Default: 0.05
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reduction=2,
|
||||
use_scale=True,
|
||||
mode='embedded_gaussian',
|
||||
temperature=0.05,
|
||||
**kwargs):
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.reduction = reduction
|
||||
self.use_scale = use_scale
|
||||
self.mode = mode
|
||||
self.temperature = temperature
|
||||
self.dnl_block = DisentangledNonLocal2d(
|
||||
in_channels=self.channels,
|
||||
reduction=self.reduction,
|
||||
use_scale=self.use_scale,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
mode=self.mode,
|
||||
temperature=self.temperature)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
output = self.dnl_block(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
294
finetune/mmseg/models/decode_heads/dpt_head.py
Normal file
294
finetune/mmseg/models/decode_heads/dpt_head.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, Linear, build_activation_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class ReassembleBlocks(BaseModule):
|
||||
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
|
||||
rearrange the feature vector to feature map.
|
||||
|
||||
Args:
|
||||
in_channels (int): ViT feature channels. Default: 768.
|
||||
out_channels (List): output channels of each stage.
|
||||
Default: [96, 192, 384, 768].
|
||||
readout_type (str): Type of readout operation. Default: 'ignore'.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=768,
|
||||
out_channels=[96, 192, 384, 768],
|
||||
readout_type='ignore',
|
||||
patch_size=16,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
assert readout_type in ['ignore', 'add', 'project']
|
||||
self.readout_type = readout_type
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.projects = nn.ModuleList([
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
act_cfg=None,
|
||||
) for out_channel in out_channels
|
||||
])
|
||||
|
||||
self.resize_layers = nn.ModuleList([
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[0],
|
||||
out_channels=out_channels[0],
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[1],
|
||||
out_channels=out_channels[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0),
|
||||
nn.Identity(),
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels[3],
|
||||
out_channels=out_channels[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1)
|
||||
])
|
||||
if self.readout_type == 'project':
|
||||
self.readout_projects = nn.ModuleList()
|
||||
for _ in range(len(self.projects)):
|
||||
self.readout_projects.append(
|
||||
nn.Sequential(
|
||||
Linear(2 * in_channels, in_channels),
|
||||
build_activation_layer(dict(type='GELU'))))
|
||||
|
||||
def forward(self, inputs):
|
||||
assert isinstance(inputs, list)
|
||||
out = []
|
||||
for i, x in enumerate(inputs):
|
||||
assert len(x) == 2
|
||||
x, cls_token = x[0], x[1]
|
||||
feature_shape = x.shape
|
||||
if self.readout_type == 'project':
|
||||
x = x.flatten(2).permute((0, 2, 1))
|
||||
readout = cls_token.unsqueeze(1).expand_as(x)
|
||||
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
||||
x = x.permute(0, 2, 1).reshape(feature_shape)
|
||||
elif self.readout_type == 'add':
|
||||
x = x.flatten(2) + cls_token.unsqueeze(-1)
|
||||
x = x.reshape(feature_shape)
|
||||
else:
|
||||
pass
|
||||
x = self.projects[i](x)
|
||||
x = self.resize_layers[i](x)
|
||||
out.append(x)
|
||||
return out
|
||||
|
||||
|
||||
class PreActResidualConvUnit(BaseModule):
|
||||
"""ResidualConvUnit, pre-activate residual unit.
|
||||
|
||||
Args:
|
||||
in_channels (int): number of channels in the input feature map.
|
||||
act_cfg (dict): dictionary to construct and config activation layer.
|
||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||
stride (int): stride of the first block. Default: 1
|
||||
dilation (int): dilation rate for convs layers. Default: 1.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
act_cfg,
|
||||
norm_cfg,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.conv1 = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
bias=False,
|
||||
order=('act', 'conv', 'norm'))
|
||||
|
||||
self.conv2 = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
bias=False,
|
||||
order=('act', 'conv', 'norm'))
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs_ = inputs.clone()
|
||||
x = self.conv1(inputs)
|
||||
x = self.conv2(x)
|
||||
return x + inputs_
|
||||
|
||||
|
||||
class FeatureFusionBlock(BaseModule):
|
||||
"""FeatureFusionBlock, merge feature map from different stages.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
act_cfg (dict): The activation config for ResidualConvUnit.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
expand (bool): Whether expand the channels in post process block.
|
||||
Default: False.
|
||||
align_corners (bool): align_corner setting for bilinear upsample.
|
||||
Default: True.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
act_cfg,
|
||||
norm_cfg,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.expand = expand
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.out_channels = in_channels
|
||||
if self.expand:
|
||||
self.out_channels = in_channels // 2
|
||||
|
||||
self.project = ConvModule(
|
||||
self.in_channels,
|
||||
self.out_channels,
|
||||
kernel_size=1,
|
||||
act_cfg=None,
|
||||
bias=True)
|
||||
|
||||
self.res_conv_unit1 = PreActResidualConvUnit(
|
||||
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
||||
self.res_conv_unit2 = PreActResidualConvUnit(
|
||||
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
||||
|
||||
def forward(self, *inputs):
|
||||
x = inputs[0]
|
||||
if len(inputs) == 2:
|
||||
if x.shape != inputs[1].shape:
|
||||
res = resize(
|
||||
inputs[1],
|
||||
size=(x.shape[2], x.shape[3]),
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
else:
|
||||
res = inputs[1]
|
||||
x = x + self.res_conv_unit1(res)
|
||||
x = self.res_conv_unit2(x)
|
||||
x = resize(
|
||||
x,
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.project(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DPTHead(BaseDecodeHead):
|
||||
"""Vision Transformers for Dense Prediction.
|
||||
|
||||
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The embed dimension of the ViT backbone.
|
||||
Default: 768.
|
||||
post_process_channels (List): Out channels of post process conv
|
||||
layers. Default: [96, 192, 384, 768].
|
||||
readout_type (str): Type of readout operation. Default: 'ignore'.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
expand_channels (bool): Whether expand the channels in post process
|
||||
block. Default: False.
|
||||
act_cfg (dict): The activation config for residual conv unit.
|
||||
Default dict(type='ReLU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims=768,
|
||||
post_process_channels=[96, 192, 384, 768],
|
||||
readout_type='ignore',
|
||||
patch_size=16,
|
||||
expand_channels=False,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
norm_cfg=dict(type='BN'),
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.in_channels = self.in_channels
|
||||
self.expand_channels = expand_channels
|
||||
self.reassemble_blocks = ReassembleBlocks(embed_dims,
|
||||
post_process_channels,
|
||||
readout_type, patch_size)
|
||||
|
||||
self.post_process_channels = [
|
||||
channel * math.pow(2, i) if expand_channels else channel
|
||||
for i, channel in enumerate(post_process_channels)
|
||||
]
|
||||
self.convs = nn.ModuleList()
|
||||
for channel in self.post_process_channels:
|
||||
self.convs.append(
|
||||
ConvModule(
|
||||
channel,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
act_cfg=None,
|
||||
bias=False))
|
||||
self.fusion_blocks = nn.ModuleList()
|
||||
for _ in range(len(self.convs)):
|
||||
self.fusion_blocks.append(
|
||||
FeatureFusionBlock(self.channels, act_cfg, norm_cfg))
|
||||
self.fusion_blocks[0].res_conv_unit1 = None
|
||||
self.project = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg)
|
||||
self.num_fusion_blocks = len(self.fusion_blocks)
|
||||
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
|
||||
self.num_post_process_channels = len(self.post_process_channels)
|
||||
assert self.num_fusion_blocks == self.num_reassemble_blocks
|
||||
assert self.num_reassemble_blocks == self.num_post_process_channels
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == self.num_reassemble_blocks
|
||||
x = self._transform_inputs(inputs)
|
||||
x = self.reassemble_blocks(x)
|
||||
x = [self.convs[i](feature) for i, feature in enumerate(x)]
|
||||
out = self.fusion_blocks[0](x[-1])
|
||||
for i in range(1, len(self.fusion_blocks)):
|
||||
out = self.fusion_blocks[i](out, x[-(i + 1)])
|
||||
out = self.project(out)
|
||||
out = self.cls_seg(out)
|
||||
return out
|
||||
169
finetune/mmseg/models/decode_heads/ema_head.py
Normal file
169
finetune/mmseg/models/decode_heads/ema_head.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
def reduce_mean(tensor):
|
||||
"""Reduce mean when distributed training."""
|
||||
if not (dist.is_available() and dist.is_initialized()):
|
||||
return tensor
|
||||
tensor = tensor.clone()
|
||||
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
|
||||
return tensor
|
||||
|
||||
|
||||
class EMAModule(nn.Module):
|
||||
"""Expectation Maximization Attention Module used in EMANet.
|
||||
|
||||
Args:
|
||||
channels (int): Channels of the whole module.
|
||||
num_bases (int): Number of bases.
|
||||
num_stages (int): Number of the EM iterations.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, num_bases, num_stages, momentum):
|
||||
super().__init__()
|
||||
assert num_stages >= 1, 'num_stages must be at least 1!'
|
||||
self.num_bases = num_bases
|
||||
self.num_stages = num_stages
|
||||
self.momentum = momentum
|
||||
|
||||
bases = torch.zeros(1, channels, self.num_bases)
|
||||
bases.normal_(0, math.sqrt(2. / self.num_bases))
|
||||
# [1, channels, num_bases]
|
||||
bases = F.normalize(bases, dim=1, p=2)
|
||||
self.register_buffer('bases', bases)
|
||||
|
||||
def forward(self, feats):
|
||||
"""Forward function."""
|
||||
batch_size, channels, height, width = feats.size()
|
||||
# [batch_size, channels, height*width]
|
||||
feats = feats.view(batch_size, channels, height * width)
|
||||
# [batch_size, channels, num_bases]
|
||||
bases = self.bases.repeat(batch_size, 1, 1)
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(self.num_stages):
|
||||
# [batch_size, height*width, num_bases]
|
||||
attention = torch.einsum('bcn,bck->bnk', feats, bases)
|
||||
attention = F.softmax(attention, dim=2)
|
||||
# l1 norm
|
||||
attention_normed = F.normalize(attention, dim=1, p=1)
|
||||
# [batch_size, channels, num_bases]
|
||||
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
|
||||
# l2 norm
|
||||
bases = F.normalize(bases, dim=1, p=2)
|
||||
|
||||
feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
|
||||
feats_recon = feats_recon.view(batch_size, channels, height, width)
|
||||
|
||||
if self.training:
|
||||
bases = bases.mean(dim=0, keepdim=True)
|
||||
bases = reduce_mean(bases)
|
||||
# l2 norm
|
||||
bases = F.normalize(bases, dim=1, p=2)
|
||||
self.bases = (1 -
|
||||
self.momentum) * self.bases + self.momentum * bases
|
||||
|
||||
return feats_recon
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class EMAHead(BaseDecodeHead):
|
||||
"""Expectation Maximization Attention Networks for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `EMANet
|
||||
<https://arxiv.org/abs/1907.13426>`_.
|
||||
|
||||
Args:
|
||||
ema_channels (int): EMA module channels
|
||||
num_bases (int): Number of bases.
|
||||
num_stages (int): Number of the EM iterations.
|
||||
concat_input (bool): Whether concat the input and output of convs
|
||||
before classification layer. Default: True
|
||||
momentum (float): Momentum to update the base. Default: 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ema_channels,
|
||||
num_bases,
|
||||
num_stages,
|
||||
concat_input=True,
|
||||
momentum=0.1,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.ema_channels = ema_channels
|
||||
self.num_bases = num_bases
|
||||
self.num_stages = num_stages
|
||||
self.concat_input = concat_input
|
||||
self.momentum = momentum
|
||||
self.ema_module = EMAModule(self.ema_channels, self.num_bases,
|
||||
self.num_stages, self.momentum)
|
||||
|
||||
self.ema_in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.ema_channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
# project (0, inf) -> (-inf, inf)
|
||||
self.ema_mid_conv = ConvModule(
|
||||
self.ema_channels,
|
||||
self.ema_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=None,
|
||||
act_cfg=None)
|
||||
for param in self.ema_mid_conv.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.ema_out_conv = ConvModule(
|
||||
self.ema_channels,
|
||||
self.ema_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
self.bottleneck = ConvModule(
|
||||
self.ema_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
if self.concat_input:
|
||||
self.conv_cat = ConvModule(
|
||||
self.in_channels + self.channels,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
feats = self.ema_in_conv(x)
|
||||
identity = feats
|
||||
feats = self.ema_mid_conv(feats)
|
||||
recon = self.ema_module(feats)
|
||||
recon = F.relu(recon, inplace=True)
|
||||
recon = self.ema_out_conv(recon)
|
||||
output = F.relu(identity + recon, inplace=True)
|
||||
output = self.bottleneck(output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
196
finetune/mmseg/models/decode_heads/enc_head.py
Normal file
196
finetune/mmseg/models/decode_heads/enc_head.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
from ..utils import Encoding, resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class EncModule(nn.Module):
|
||||
"""Encoding Module used in EncNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
num_codes (int): Number of code words.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.encoding_project = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
# TODO: resolve this hack
|
||||
# change to 1d
|
||||
if norm_cfg is not None:
|
||||
encoding_norm_cfg = norm_cfg.copy()
|
||||
if encoding_norm_cfg['type'] in ['BN', 'IN']:
|
||||
encoding_norm_cfg['type'] += '1d'
|
||||
else:
|
||||
encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(
|
||||
'2d', '1d')
|
||||
else:
|
||||
# fallback to BN1d
|
||||
encoding_norm_cfg = dict(type='BN1d')
|
||||
self.encoding = nn.Sequential(
|
||||
Encoding(channels=in_channels, num_codes=num_codes),
|
||||
build_norm_layer(encoding_norm_cfg, num_codes)[1],
|
||||
nn.ReLU(inplace=True))
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_channels, in_channels), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
encoding_projection = self.encoding_project(x)
|
||||
encoding_feat = self.encoding(encoding_projection).mean(dim=1)
|
||||
batch_size, channels, _, _ = x.size()
|
||||
gamma = self.fc(encoding_feat)
|
||||
y = gamma.view(batch_size, channels, 1, 1)
|
||||
output = F.relu_(x + x * y)
|
||||
return encoding_feat, output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class EncHead(BaseDecodeHead):
|
||||
"""Context Encoding for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `EncNet
|
||||
<https://arxiv.org/abs/1803.08904>`_.
|
||||
|
||||
Args:
|
||||
num_codes (int): Number of code words. Default: 32.
|
||||
use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to
|
||||
regularize the training. Default: True.
|
||||
add_lateral (bool): Whether use lateral connection to fuse features.
|
||||
Default: False.
|
||||
loss_se_decode (dict): Config of decode loss.
|
||||
Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_codes=32,
|
||||
use_se_loss=True,
|
||||
add_lateral=False,
|
||||
loss_se_decode=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
loss_weight=0.2),
|
||||
**kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
self.use_se_loss = use_se_loss
|
||||
self.add_lateral = add_lateral
|
||||
self.num_codes = num_codes
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels[-1],
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
if add_lateral:
|
||||
self.lateral_convs = nn.ModuleList()
|
||||
for in_channels in self.in_channels[:-1]: # skip the last one
|
||||
self.lateral_convs.append(
|
||||
ConvModule(
|
||||
in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.fusion = ConvModule(
|
||||
len(self.in_channels) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.enc_module = EncModule(
|
||||
self.channels,
|
||||
num_codes=num_codes,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
if self.use_se_loss:
|
||||
self.loss_se_decode = MODELS.build(loss_se_decode)
|
||||
self.se_layer = nn.Linear(self.channels, self.num_classes)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
inputs = self._transform_inputs(inputs)
|
||||
feat = self.bottleneck(inputs[-1])
|
||||
if self.add_lateral:
|
||||
laterals = [
|
||||
resize(
|
||||
lateral_conv(inputs[i]),
|
||||
size=feat.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||
]
|
||||
feat = self.fusion(torch.cat([feat, *laterals], 1))
|
||||
encode_feat, output = self.enc_module(feat)
|
||||
output = self.cls_seg(output)
|
||||
if self.use_se_loss:
|
||||
se_output = self.se_layer(encode_feat)
|
||||
return output, se_output
|
||||
else:
|
||||
return output
|
||||
|
||||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType):
|
||||
"""Forward function for testing, ignore se_loss."""
|
||||
if self.use_se_loss:
|
||||
seg_logits = self.forward(inputs)[0]
|
||||
else:
|
||||
seg_logits = self.forward(inputs)
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas)
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_onehot_labels(seg_label, num_classes):
|
||||
"""Convert segmentation label to onehot.
|
||||
|
||||
Args:
|
||||
seg_label (Tensor): Segmentation label of shape (N, H, W).
|
||||
num_classes (int): Number of classes.
|
||||
|
||||
Returns:
|
||||
Tensor: Onehot labels of shape (N, num_classes).
|
||||
"""
|
||||
|
||||
batch_size = seg_label.size(0)
|
||||
onehot_labels = seg_label.new_zeros((batch_size, num_classes))
|
||||
for i in range(batch_size):
|
||||
hist = seg_label[i].float().histc(
|
||||
bins=num_classes, min=0, max=num_classes - 1)
|
||||
onehot_labels[i] = hist > 0
|
||||
return onehot_labels
|
||||
|
||||
def loss_by_feat(self, seg_logit: Tuple[Tensor],
|
||||
batch_data_samples: SampleList, **kwargs) -> dict:
|
||||
"""Compute segmentation and semantic encoding loss."""
|
||||
seg_logit, se_seg_logit = seg_logit
|
||||
loss = dict()
|
||||
loss.update(super().loss_by_feat(seg_logit, batch_data_samples))
|
||||
|
||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||
se_loss = self.loss_se_decode(
|
||||
se_seg_logit,
|
||||
self._convert_to_onehot_labels(seg_label, self.num_classes))
|
||||
loss['loss_se'] = se_loss
|
||||
return loss
|
||||
96
finetune/mmseg/models/decode_heads/fcn_head.py
Normal file
96
finetune/mmseg/models/decode_heads/fcn_head.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FCNHead(BaseDecodeHead):
|
||||
"""Fully Convolution Networks for Semantic Segmentation.
|
||||
|
||||
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
|
||||
|
||||
Args:
|
||||
num_convs (int): Number of convs in the head. Default: 2.
|
||||
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
||||
concat_input (bool): Whether concat the input and output of convs
|
||||
before classification layer.
|
||||
dilation (int): The dilation rate for convs in the head. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_convs=2,
|
||||
kernel_size=3,
|
||||
concat_input=True,
|
||||
dilation=1,
|
||||
**kwargs):
|
||||
assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int)
|
||||
self.num_convs = num_convs
|
||||
self.concat_input = concat_input
|
||||
self.kernel_size = kernel_size
|
||||
super().__init__(**kwargs)
|
||||
if num_convs == 0:
|
||||
assert self.in_channels == self.channels
|
||||
|
||||
conv_padding = (kernel_size // 2) * dilation
|
||||
convs = []
|
||||
convs.append(
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=conv_padding,
|
||||
dilation=dilation,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
for i in range(num_convs - 1):
|
||||
convs.append(
|
||||
ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=conv_padding,
|
||||
dilation=dilation,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
if num_convs == 0:
|
||||
self.convs = nn.Identity()
|
||||
else:
|
||||
self.convs = nn.Sequential(*convs)
|
||||
if self.concat_input:
|
||||
self.conv_cat = ConvModule(
|
||||
self.in_channels + self.channels,
|
||||
self.channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def _forward_feature(self, inputs):
|
||||
"""Forward function for feature maps before classifying each pixel with
|
||||
``self.cls_seg`` fc.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||
H, W) which is feature map for last layer of decoder head.
|
||||
"""
|
||||
x = self._transform_inputs(inputs)
|
||||
feats = self.convs(x)
|
||||
if self.concat_input:
|
||||
feats = self.conv_cat(torch.cat([x, feats], dim=1))
|
||||
return feats
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
output = self._forward_feature(inputs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
68
finetune/mmseg/models/decode_heads/fpn_head.py
Normal file
68
finetune/mmseg/models/decode_heads/fpn_head.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample, resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FPNHead(BaseDecodeHead):
|
||||
"""Panoptic Feature Pyramid Networks.
|
||||
|
||||
This head is the implementation of `Semantic FPN
|
||||
<https://arxiv.org/abs/1901.02446>`_.
|
||||
|
||||
Args:
|
||||
feature_strides (tuple[int]): The strides for input feature maps.
|
||||
stack_lateral. All strides suppose to be power of 2. The first
|
||||
one is of largest resolution.
|
||||
"""
|
||||
|
||||
def __init__(self, feature_strides, **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
assert len(feature_strides) == len(self.in_channels)
|
||||
assert min(feature_strides) == feature_strides[0]
|
||||
self.feature_strides = feature_strides
|
||||
|
||||
self.scale_heads = nn.ModuleList()
|
||||
for i in range(len(feature_strides)):
|
||||
head_length = max(
|
||||
1,
|
||||
int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
|
||||
scale_head = []
|
||||
for k in range(head_length):
|
||||
scale_head.append(
|
||||
ConvModule(
|
||||
self.in_channels[i] if k == 0 else self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
if feature_strides[i] != feature_strides[0]:
|
||||
scale_head.append(
|
||||
Upsample(
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners))
|
||||
self.scale_heads.append(nn.Sequential(*scale_head))
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
x = self._transform_inputs(inputs)
|
||||
|
||||
output = self.scale_heads[0](x[0])
|
||||
for i in range(1, len(self.feature_strides)):
|
||||
# non inplace
|
||||
output = output + resize(
|
||||
self.scale_heads[i](x[i]),
|
||||
size=output.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
48
finetune/mmseg/models/decode_heads/gc_head.py
Normal file
48
finetune/mmseg/models/decode_heads/gc_head.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.cnn import ContextBlock
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class GCHead(FCNHead):
|
||||
"""GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.
|
||||
|
||||
This head is the implementation of `GCNet
|
||||
<https://arxiv.org/abs/1904.11492>`_.
|
||||
|
||||
Args:
|
||||
ratio (float): Multiplier of channels ratio. Default: 1/4.
|
||||
pooling_type (str): The pooling type of context aggregation.
|
||||
Options are 'att', 'avg'. Default: 'avg'.
|
||||
fusion_types (tuple[str]): The fusion type for feature fusion.
|
||||
Options are 'channel_add', 'channel_mul'. Default: ('channel_add',)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ratio=1 / 4.,
|
||||
pooling_type='att',
|
||||
fusion_types=('channel_add', ),
|
||||
**kwargs):
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.ratio = ratio
|
||||
self.pooling_type = pooling_type
|
||||
self.fusion_types = fusion_types
|
||||
self.gc_block = ContextBlock(
|
||||
in_channels=self.channels,
|
||||
ratio=self.ratio,
|
||||
pooling_type=self.pooling_type,
|
||||
fusion_types=self.fusion_types)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
output = self.gc_block(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
255
finetune/mmseg/models/decode_heads/ham_head.py
Normal file
255
finetune/mmseg/models/decode_heads/ham_head.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Originally from https://github.com/visual-attention-network/segnext
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.device import get_device
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class Matrix_Decomposition_2D_Base(nn.Module):
|
||||
"""Base class of 2D Matrix Decomposition.
|
||||
|
||||
Args:
|
||||
MD_S (int): The number of spatial coefficient in
|
||||
Matrix Decomposition, it may be used for calculation
|
||||
of the number of latent dimension D in Matrix
|
||||
Decomposition. Defaults: 1.
|
||||
MD_R (int): The number of latent dimension R in
|
||||
Matrix Decomposition. Defaults: 64.
|
||||
train_steps (int): The number of iteration steps in
|
||||
Multiplicative Update (MU) rule to solve Non-negative
|
||||
Matrix Factorization (NMF) in training. Defaults: 6.
|
||||
eval_steps (int): The number of iteration steps in
|
||||
Multiplicative Update (MU) rule to solve Non-negative
|
||||
Matrix Factorization (NMF) in evaluation. Defaults: 7.
|
||||
inv_t (int): Inverted multiple number to make coefficient
|
||||
smaller in softmax. Defaults: 100.
|
||||
rand_init (bool): Whether to initialize randomly.
|
||||
Defaults: True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
MD_S=1,
|
||||
MD_R=64,
|
||||
train_steps=6,
|
||||
eval_steps=7,
|
||||
inv_t=100,
|
||||
rand_init=True):
|
||||
super().__init__()
|
||||
|
||||
self.S = MD_S
|
||||
self.R = MD_R
|
||||
|
||||
self.train_steps = train_steps
|
||||
self.eval_steps = eval_steps
|
||||
|
||||
self.inv_t = inv_t
|
||||
|
||||
self.rand_init = rand_init
|
||||
|
||||
def _build_bases(self, B, S, D, R, device=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def local_step(self, x, bases, coef):
|
||||
raise NotImplementedError
|
||||
|
||||
def local_inference(self, x, bases):
|
||||
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
coef = torch.bmm(x.transpose(1, 2), bases)
|
||||
coef = F.softmax(self.inv_t * coef, dim=-1)
|
||||
|
||||
steps = self.train_steps if self.training else self.eval_steps
|
||||
for _ in range(steps):
|
||||
bases, coef = self.local_step(x, bases, coef)
|
||||
|
||||
return bases, coef
|
||||
|
||||
def compute_coef(self, x, bases, coef):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x, return_bases=False):
|
||||
"""Forward Function."""
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# (B, C, H, W) -> (B * S, D, N)
|
||||
D = C // self.S
|
||||
N = H * W
|
||||
x = x.view(B * self.S, D, N)
|
||||
if not self.rand_init and not hasattr(self, 'bases'):
|
||||
bases = self._build_bases(1, self.S, D, self.R, device=x.device)
|
||||
self.register_buffer('bases', bases)
|
||||
|
||||
# (S, D, R) -> (B * S, D, R)
|
||||
if self.rand_init:
|
||||
bases = self._build_bases(B, self.S, D, self.R, device=x.device)
|
||||
else:
|
||||
bases = self.bases.repeat(B, 1, 1)
|
||||
|
||||
bases, coef = self.local_inference(x, bases)
|
||||
|
||||
# (B * S, N, R)
|
||||
coef = self.compute_coef(x, bases, coef)
|
||||
|
||||
# (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
|
||||
x = torch.bmm(bases, coef.transpose(1, 2))
|
||||
|
||||
# (B * S, D, N) -> (B, C, H, W)
|
||||
x = x.view(B, C, H, W)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class NMF2D(Matrix_Decomposition_2D_Base):
|
||||
"""Non-negative Matrix Factorization (NMF) module.
|
||||
|
||||
It is inherited from ``Matrix_Decomposition_2D_Base`` module.
|
||||
"""
|
||||
|
||||
def __init__(self, args=dict()):
|
||||
super().__init__(**args)
|
||||
|
||||
self.inv_t = 1
|
||||
|
||||
def _build_bases(self, B, S, D, R, device=None):
|
||||
"""Build bases in initialization."""
|
||||
if device is None:
|
||||
device = get_device()
|
||||
bases = torch.rand((B * S, D, R)).to(device)
|
||||
bases = F.normalize(bases, dim=1)
|
||||
|
||||
return bases
|
||||
|
||||
def local_step(self, x, bases, coef):
|
||||
"""Local step in iteration to renew bases and coefficient."""
|
||||
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
numerator = torch.bmm(x.transpose(1, 2), bases)
|
||||
# (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
|
||||
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
|
||||
# Multiplicative Update
|
||||
coef = coef * numerator / (denominator + 1e-6)
|
||||
|
||||
# (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
|
||||
numerator = torch.bmm(x, coef)
|
||||
# (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
|
||||
denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
|
||||
# Multiplicative Update
|
||||
bases = bases * numerator / (denominator + 1e-6)
|
||||
|
||||
return bases, coef
|
||||
|
||||
def compute_coef(self, x, bases, coef):
|
||||
"""Compute coefficient."""
|
||||
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
numerator = torch.bmm(x.transpose(1, 2), bases)
|
||||
# (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
|
||||
# multiplication update
|
||||
coef = coef * numerator / (denominator + 1e-6)
|
||||
|
||||
return coef
|
||||
|
||||
|
||||
class Hamburger(nn.Module):
|
||||
"""Hamburger Module. It consists of one slice of "ham" (matrix
|
||||
decomposition) and two slices of "bread" (linear transformation).
|
||||
|
||||
Args:
|
||||
ham_channels (int): Input and output channels of feature.
|
||||
ham_kwargs (dict): Config of matrix decomposition module.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ham_channels=512,
|
||||
ham_kwargs=dict(),
|
||||
norm_cfg=None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.ham_in = ConvModule(
|
||||
ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None)
|
||||
|
||||
self.ham = NMF2D(ham_kwargs)
|
||||
|
||||
self.ham_out = ConvModule(
|
||||
ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
|
||||
def forward(self, x):
|
||||
enjoy = self.ham_in(x)
|
||||
enjoy = F.relu(enjoy, inplace=True)
|
||||
enjoy = self.ham(enjoy)
|
||||
enjoy = self.ham_out(enjoy)
|
||||
ham = F.relu(x + enjoy, inplace=True)
|
||||
|
||||
return ham
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class LightHamHead(BaseDecodeHead):
|
||||
"""SegNeXt decode head.
|
||||
|
||||
This decode head is the implementation of `SegNeXt: Rethinking
|
||||
Convolutional Attention Design for Semantic
|
||||
Segmentation <https://arxiv.org/abs/2209.08575>`_.
|
||||
Inspiration from https://github.com/visual-attention-network/segnext.
|
||||
|
||||
Specifically, LightHamHead is inspired by HamNet from
|
||||
`Is Attention Better Than Matrix Decomposition?
|
||||
<https://arxiv.org/abs/2109.04553>`.
|
||||
|
||||
Args:
|
||||
ham_channels (int): input channels for Hamburger.
|
||||
Defaults: 512.
|
||||
ham_kwargs (int): kwagrs for Ham. Defaults: dict().
|
||||
"""
|
||||
|
||||
def __init__(self, ham_channels=512, ham_kwargs=dict(), **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
self.ham_channels = ham_channels
|
||||
|
||||
self.squeeze = ConvModule(
|
||||
sum(self.in_channels),
|
||||
self.ham_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.hamburger = Hamburger(ham_channels, ham_kwargs, **kwargs)
|
||||
|
||||
self.align = ConvModule(
|
||||
self.ham_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
inputs = self._transform_inputs(inputs)
|
||||
|
||||
inputs = [
|
||||
resize(
|
||||
level,
|
||||
size=inputs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners) for level in inputs
|
||||
]
|
||||
|
||||
inputs = torch.cat(inputs, dim=1)
|
||||
# apply a conv block to squeeze feature map
|
||||
x = self.squeeze(inputs)
|
||||
# apply hamburger module
|
||||
x = self.hamburger(x)
|
||||
|
||||
# apply a conv block to align feature map
|
||||
output = self.align(x)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
143
finetune/mmseg/models/decode_heads/isa_head.py
Normal file
143
finetune/mmseg/models/decode_heads/isa_head.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class SelfAttentionBlock(_SelfAttentionBlock):
|
||||
"""Self-Attention Module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of key/query feature.
|
||||
channels (int): Output channels of key/query transform.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
act_cfg (dict | None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, channels, conv_cfg, norm_cfg, act_cfg):
|
||||
super().__init__(
|
||||
key_in_channels=in_channels,
|
||||
query_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=in_channels,
|
||||
share_key_query=False,
|
||||
query_downsample=None,
|
||||
key_downsample=None,
|
||||
key_query_num_convs=2,
|
||||
key_query_norm=True,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=False,
|
||||
matmul_norm=True,
|
||||
with_out=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.output_project = self.build_project(
|
||||
in_channels,
|
||||
in_channels,
|
||||
num_convs=1,
|
||||
use_conv_module=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
context = super().forward(x, x)
|
||||
return self.output_project(context)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ISAHead(BaseDecodeHead):
|
||||
"""Interlaced Sparse Self-Attention for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `ISA
|
||||
<https://arxiv.org/abs/1907.12273>`_.
|
||||
|
||||
Args:
|
||||
isa_channels (int): The channels of ISA Module.
|
||||
down_factor (tuple[int]): The local group size of ISA.
|
||||
"""
|
||||
|
||||
def __init__(self, isa_channels, down_factor=(8, 8), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.down_factor = down_factor
|
||||
|
||||
self.in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.global_relation = SelfAttentionBlock(
|
||||
self.channels,
|
||||
isa_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.local_relation = SelfAttentionBlock(
|
||||
self.channels,
|
||||
isa_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.out_conv = ConvModule(
|
||||
self.channels * 2,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x_ = self._transform_inputs(inputs)
|
||||
x = self.in_conv(x_)
|
||||
residual = x
|
||||
|
||||
n, c, h, w = x.size()
|
||||
loc_h, loc_w = self.down_factor # size of local group in H- and W-axes
|
||||
glb_h, glb_w = math.ceil(h / loc_h), math.ceil(w / loc_w)
|
||||
pad_h, pad_w = glb_h * loc_h - h, glb_w * loc_w - w
|
||||
if pad_h > 0 or pad_w > 0: # pad if the size is not divisible
|
||||
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
|
||||
pad_h - pad_h // 2)
|
||||
x = F.pad(x, padding)
|
||||
|
||||
# global relation
|
||||
x = x.view(n, c, glb_h, loc_h, glb_w, loc_w)
|
||||
# do permutation to gather global group
|
||||
x = x.permute(0, 3, 5, 1, 2, 4) # (n, loc_h, loc_w, c, glb_h, glb_w)
|
||||
x = x.reshape(-1, c, glb_h, glb_w)
|
||||
# apply attention within each global group
|
||||
x = self.global_relation(x) # (n * loc_h * loc_w, c, glb_h, glb_w)
|
||||
|
||||
# local relation
|
||||
x = x.view(n, loc_h, loc_w, c, glb_h, glb_w)
|
||||
# do permutation to gather local group
|
||||
x = x.permute(0, 4, 5, 3, 1, 2) # (n, glb_h, glb_w, c, loc_h, loc_w)
|
||||
x = x.reshape(-1, c, loc_h, loc_w)
|
||||
# apply attention within each local group
|
||||
x = self.local_relation(x) # (n * glb_h * glb_w, c, loc_h, loc_w)
|
||||
|
||||
# permute each pixel back to its original position
|
||||
x = x.view(n, glb_h, glb_w, c, loc_h, loc_w)
|
||||
x = x.permute(0, 3, 1, 4, 2, 5) # (n, c, glb_h, loc_h, glb_w, loc_w)
|
||||
x = x.reshape(n, c, glb_h * loc_h, glb_w * loc_w)
|
||||
if pad_h > 0 or pad_w > 0: # remove padding
|
||||
x = x[:, :, pad_h // 2:pad_h // 2 + h, pad_w // 2:pad_w // 2 + w]
|
||||
|
||||
x = self.out_conv(torch.cat([x, residual], dim=1))
|
||||
out = self.cls_seg(x)
|
||||
|
||||
return out
|
||||
461
finetune/mmseg/models/decode_heads/knet_head.py
Normal file
461
finetune/mmseg/models/decode_heads/knet_head.py
Normal file
@@ -0,0 +1,461 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import (FFN, MultiheadAttention,
|
||||
build_transformer_layer)
|
||||
from mmengine.logging import print_log
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class KernelUpdator(nn.Module):
|
||||
"""Dynamic Kernel Updator in Kernel Update Head.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels of input feature map.
|
||||
Default: 256.
|
||||
feat_channels (int): The number of middle-stage channels in
|
||||
the kernel updator. Default: 64.
|
||||
out_channels (int): The number of output channels.
|
||||
gate_sigmoid (bool): Whether use sigmoid function in gate
|
||||
mechanism. Default: True.
|
||||
gate_norm_act (bool): Whether add normalization and activation
|
||||
layer in gate mechanism. Default: False.
|
||||
activate_out: Whether add activation after gate mechanism.
|
||||
Default: False.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='LN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=256,
|
||||
feat_channels=64,
|
||||
out_channels=None,
|
||||
gate_sigmoid=True,
|
||||
gate_norm_act=False,
|
||||
activate_out=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.feat_channels = feat_channels
|
||||
self.out_channels_raw = out_channels
|
||||
self.gate_sigmoid = gate_sigmoid
|
||||
self.gate_norm_act = gate_norm_act
|
||||
self.activate_out = activate_out
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.out_channels = out_channels if out_channels else in_channels
|
||||
|
||||
self.num_params_in = self.feat_channels
|
||||
self.num_params_out = self.feat_channels
|
||||
self.dynamic_layer = nn.Linear(
|
||||
self.in_channels, self.num_params_in + self.num_params_out)
|
||||
self.input_layer = nn.Linear(self.in_channels,
|
||||
self.num_params_in + self.num_params_out,
|
||||
1)
|
||||
self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
|
||||
self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
|
||||
if self.gate_norm_act:
|
||||
self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
|
||||
self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
|
||||
self.activation = build_activation_layer(act_cfg)
|
||||
|
||||
self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1)
|
||||
self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
|
||||
|
||||
def forward(self, update_feature, input_feature):
|
||||
"""Forward function of KernelUpdator.
|
||||
|
||||
Args:
|
||||
update_feature (torch.Tensor): Feature map assembled from
|
||||
each group. It would be reshaped with last dimension
|
||||
shape: `self.in_channels`.
|
||||
input_feature (torch.Tensor): Intermediate feature
|
||||
with shape: (N, num_classes, conv_kernel_size**2, channels).
|
||||
Returns:
|
||||
Tensor: The output tensor of shape (N*C1/C2, K*K, C2), where N is
|
||||
the number of classes, C1 and C2 are the feature map channels of
|
||||
KernelUpdateHead and KernelUpdator, respectively.
|
||||
"""
|
||||
|
||||
update_feature = update_feature.reshape(-1, self.in_channels)
|
||||
num_proposals = update_feature.size(0)
|
||||
# dynamic_layer works for
|
||||
# phi_1 and psi_3 in Eq.(4) and (5) of K-Net paper
|
||||
parameters = self.dynamic_layer(update_feature)
|
||||
param_in = parameters[:, :self.num_params_in].view(
|
||||
-1, self.feat_channels)
|
||||
param_out = parameters[:, -self.num_params_out:].view(
|
||||
-1, self.feat_channels)
|
||||
|
||||
# input_layer works for
|
||||
# phi_2 and psi_4 in Eq.(4) and (5) of K-Net paper
|
||||
input_feats = self.input_layer(
|
||||
input_feature.reshape(num_proposals, -1, self.feat_channels))
|
||||
input_in = input_feats[..., :self.num_params_in]
|
||||
input_out = input_feats[..., -self.num_params_out:]
|
||||
|
||||
# `gate_feats` is F^G in K-Net paper
|
||||
gate_feats = input_in * param_in.unsqueeze(-2)
|
||||
if self.gate_norm_act:
|
||||
gate_feats = self.activation(self.gate_norm(gate_feats))
|
||||
|
||||
input_gate = self.input_norm_in(self.input_gate(gate_feats))
|
||||
update_gate = self.norm_in(self.update_gate(gate_feats))
|
||||
if self.gate_sigmoid:
|
||||
input_gate = input_gate.sigmoid()
|
||||
update_gate = update_gate.sigmoid()
|
||||
param_out = self.norm_out(param_out)
|
||||
input_out = self.input_norm_out(input_out)
|
||||
|
||||
if self.activate_out:
|
||||
param_out = self.activation(param_out)
|
||||
input_out = self.activation(input_out)
|
||||
|
||||
# Gate mechanism. Eq.(5) in original paper.
|
||||
# param_out has shape (batch_size, feat_channels, out_channels)
|
||||
features = update_gate * param_out.unsqueeze(
|
||||
-2) + input_gate * input_out
|
||||
|
||||
features = self.fc_layer(features)
|
||||
features = self.fc_norm(features)
|
||||
features = self.activation(features)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class KernelUpdateHead(nn.Module):
|
||||
"""Kernel Update Head in K-Net.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes. Default: 150.
|
||||
num_ffn_fcs (int): The number of fully-connected layers in
|
||||
FFNs. Default: 2.
|
||||
num_heads (int): The number of parallel attention heads.
|
||||
Default: 8.
|
||||
num_mask_fcs (int): The number of fully connected layers for
|
||||
mask prediction. Default: 3.
|
||||
feedforward_channels (int): The hidden dimension of FFNs.
|
||||
Defaults: 2048.
|
||||
in_channels (int): The number of channels of input feature map.
|
||||
Default: 256.
|
||||
out_channels (int): The number of output channels.
|
||||
Default: 256.
|
||||
dropout (float): The Probability of an element to be
|
||||
zeroed in MultiheadAttention and FFN. Default 0.0.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
ffn_act_cfg (dict): Config of activation layers in FFN.
|
||||
Default: dict(type='ReLU').
|
||||
conv_kernel_size (int): The kernel size of convolution in
|
||||
Kernel Update Head for dynamic kernel updation.
|
||||
Default: 1.
|
||||
feat_transform_cfg (dict | None): Config of feature transform.
|
||||
Default: None.
|
||||
kernel_init (bool): Whether initiate mask kernel in mask head.
|
||||
Default: False.
|
||||
with_ffn (bool): Whether add FFN in kernel update head.
|
||||
Default: True.
|
||||
feat_gather_stride (int): Stride of convolution in feature transform.
|
||||
Default: 1.
|
||||
mask_transform_stride (int): Stride of mask transform.
|
||||
Default: 1.
|
||||
kernel_updator_cfg (dict): Config of kernel updator.
|
||||
Default: dict(
|
||||
type='DynamicConv',
|
||||
in_channels=256,
|
||||
feat_channels=64,
|
||||
out_channels=256,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
norm_cfg=dict(type='LN')).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes=150,
|
||||
num_ffn_fcs=2,
|
||||
num_heads=8,
|
||||
num_mask_fcs=3,
|
||||
feedforward_channels=2048,
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
dropout=0.0,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
||||
conv_kernel_size=1,
|
||||
feat_transform_cfg=None,
|
||||
kernel_init=False,
|
||||
with_ffn=True,
|
||||
feat_gather_stride=1,
|
||||
mask_transform_stride=1,
|
||||
kernel_updator_cfg=dict(
|
||||
type='DynamicConv',
|
||||
in_channels=256,
|
||||
feat_channels=64,
|
||||
out_channels=256,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
norm_cfg=dict(type='LN'))):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.fp16_enabled = False
|
||||
self.dropout = dropout
|
||||
self.num_heads = num_heads
|
||||
self.kernel_init = kernel_init
|
||||
self.with_ffn = with_ffn
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.feat_gather_stride = feat_gather_stride
|
||||
self.mask_transform_stride = mask_transform_stride
|
||||
|
||||
self.attention = MultiheadAttention(in_channels * conv_kernel_size**2,
|
||||
num_heads, dropout)
|
||||
self.attention_norm = build_norm_layer(
|
||||
dict(type='LN'), in_channels * conv_kernel_size**2)[1]
|
||||
self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg)
|
||||
|
||||
if feat_transform_cfg is not None:
|
||||
kernel_size = feat_transform_cfg.pop('kernel_size', 1)
|
||||
transform_channels = in_channels
|
||||
self.feat_transform = ConvModule(
|
||||
transform_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
stride=feat_gather_stride,
|
||||
padding=int(feat_gather_stride // 2),
|
||||
**feat_transform_cfg)
|
||||
else:
|
||||
self.feat_transform = None
|
||||
|
||||
if self.with_ffn:
|
||||
self.ffn = FFN(
|
||||
in_channels,
|
||||
feedforward_channels,
|
||||
num_ffn_fcs,
|
||||
act_cfg=ffn_act_cfg,
|
||||
dropout=dropout)
|
||||
self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
|
||||
|
||||
self.mask_fcs = nn.ModuleList()
|
||||
for _ in range(num_mask_fcs):
|
||||
self.mask_fcs.append(
|
||||
nn.Linear(in_channels, in_channels, bias=False))
|
||||
self.mask_fcs.append(
|
||||
build_norm_layer(dict(type='LN'), in_channels)[1])
|
||||
self.mask_fcs.append(build_activation_layer(act_cfg))
|
||||
|
||||
self.fc_mask = nn.Linear(in_channels, out_channels)
|
||||
|
||||
def init_weights(self):
|
||||
"""Use xavier initialization for all weight parameter and set
|
||||
classification head bias as a specific value when use focal loss."""
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
else:
|
||||
# adopt the default initialization for
|
||||
# the weight and bias of the layer norm
|
||||
pass
|
||||
if self.kernel_init:
|
||||
print_log(
|
||||
'mask kernel in mask head is normal initialized by std 0.01')
|
||||
nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01)
|
||||
|
||||
def forward(self, x, proposal_feat, mask_preds, mask_shape=None):
|
||||
"""Forward function of Dynamic Instance Interactive Head.
|
||||
|
||||
Args:
|
||||
x (Tensor): Feature map from FPN with shape
|
||||
(batch_size, feature_dimensions, H , W).
|
||||
proposal_feat (Tensor): Intermediate feature get from
|
||||
diihead in last stage, has shape
|
||||
(batch_size, num_proposals, feature_dimensions)
|
||||
mask_preds (Tensor): mask prediction from the former stage in shape
|
||||
(batch_size, num_proposals, H, W).
|
||||
|
||||
Returns:
|
||||
Tuple: The first tensor is predicted mask with shape
|
||||
(N, num_classes, H, W), the second tensor is dynamic kernel
|
||||
with shape (N, num_classes, channels, K, K).
|
||||
"""
|
||||
N, num_proposals = proposal_feat.shape[:2]
|
||||
if self.feat_transform is not None:
|
||||
x = self.feat_transform(x)
|
||||
|
||||
C, H, W = x.shape[-3:]
|
||||
|
||||
mask_h, mask_w = mask_preds.shape[-2:]
|
||||
if mask_h != H or mask_w != W:
|
||||
gather_mask = F.interpolate(
|
||||
mask_preds, (H, W), align_corners=False, mode='bilinear')
|
||||
else:
|
||||
gather_mask = mask_preds
|
||||
|
||||
sigmoid_masks = gather_mask.softmax(dim=1)
|
||||
|
||||
# Group Feature Assembling. Eq.(3) in original paper.
|
||||
# einsum is faster than bmm by 30%
|
||||
x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x)
|
||||
|
||||
# obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C]
|
||||
proposal_feat = proposal_feat.reshape(N, num_proposals,
|
||||
self.in_channels,
|
||||
-1).permute(0, 1, 3, 2)
|
||||
obj_feat = self.kernel_update_conv(x_feat, proposal_feat)
|
||||
|
||||
# [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C]
|
||||
obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2)
|
||||
obj_feat = self.attention_norm(self.attention(obj_feat))
|
||||
# [N, B, K*K*C] -> [B, N, K*K*C]
|
||||
obj_feat = obj_feat.permute(1, 0, 2)
|
||||
|
||||
# obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]
|
||||
obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels)
|
||||
|
||||
# FFN
|
||||
if self.with_ffn:
|
||||
obj_feat = self.ffn_norm(self.ffn(obj_feat))
|
||||
|
||||
mask_feat = obj_feat
|
||||
|
||||
for reg_layer in self.mask_fcs:
|
||||
mask_feat = reg_layer(mask_feat)
|
||||
|
||||
# [B, N, K*K, C] -> [B, N, C, K*K]
|
||||
mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2)
|
||||
|
||||
if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1):
|
||||
mask_x = F.interpolate(
|
||||
x, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||
H, W = mask_x.shape[-2:]
|
||||
else:
|
||||
mask_x = x
|
||||
# group conv is 5x faster than unfold and uses about 1/5 memory
|
||||
# Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms
|
||||
# Group conv vs. unfold vs. concat batch, 278 : 1420 : 369
|
||||
# but in real training group conv is slower than concat batch
|
||||
# so we keep using concat batch.
|
||||
# fold_x = F.unfold(
|
||||
# mask_x,
|
||||
# self.conv_kernel_size,
|
||||
# padding=int(self.conv_kernel_size // 2))
|
||||
# mask_feat = mask_feat.reshape(N, num_proposals, -1)
|
||||
# new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x)
|
||||
# [B, N, C, K*K] -> [B*N, C, K, K]
|
||||
mask_feat = mask_feat.reshape(N, num_proposals, C,
|
||||
self.conv_kernel_size,
|
||||
self.conv_kernel_size)
|
||||
# [B, C, H, W] -> [1, B*C, H, W]
|
||||
new_mask_preds = []
|
||||
for i in range(N):
|
||||
new_mask_preds.append(
|
||||
F.conv2d(
|
||||
mask_x[i:i + 1],
|
||||
mask_feat[i],
|
||||
padding=int(self.conv_kernel_size // 2)))
|
||||
|
||||
new_mask_preds = torch.cat(new_mask_preds, dim=0)
|
||||
new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W)
|
||||
if self.mask_transform_stride == 2:
|
||||
new_mask_preds = F.interpolate(
|
||||
new_mask_preds,
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
if mask_shape is not None and mask_shape[0] != H:
|
||||
new_mask_preds = F.interpolate(
|
||||
new_mask_preds,
|
||||
mask_shape,
|
||||
align_corners=False,
|
||||
mode='bilinear')
|
||||
|
||||
return new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape(
|
||||
N, num_proposals, self.in_channels, self.conv_kernel_size,
|
||||
self.conv_kernel_size)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class IterativeDecodeHead(BaseDecodeHead):
|
||||
"""K-Net: Towards Unified Image Segmentation.
|
||||
|
||||
This head is the implementation of
|
||||
`K-Net: <https://arxiv.org/abs/2106.14855>`_.
|
||||
|
||||
Args:
|
||||
num_stages (int): The number of stages (kernel update heads)
|
||||
in IterativeDecodeHead. Default: 3.
|
||||
kernel_generate_head:(dict): Config of kernel generate head which
|
||||
generate mask predictions, dynamic kernels and class predictions
|
||||
for next kernel update heads.
|
||||
kernel_update_head (dict): Config of kernel update head which refine
|
||||
dynamic kernels and class predictions iteratively.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_stages, kernel_generate_head, kernel_update_head,
|
||||
**kwargs):
|
||||
# ``IterativeDecodeHead`` would skip initialization of
|
||||
# ``BaseDecodeHead`` which would be called when building
|
||||
# ``self.kernel_generate_head``.
|
||||
super(BaseDecodeHead, self).__init__(**kwargs)
|
||||
assert num_stages == len(kernel_update_head)
|
||||
self.num_stages = num_stages
|
||||
self.kernel_generate_head = MODELS.build(kernel_generate_head)
|
||||
self.kernel_update_head = nn.ModuleList()
|
||||
self.align_corners = self.kernel_generate_head.align_corners
|
||||
self.num_classes = self.kernel_generate_head.num_classes
|
||||
self.input_transform = self.kernel_generate_head.input_transform
|
||||
self.ignore_index = self.kernel_generate_head.ignore_index
|
||||
self.out_channels = self.num_classes
|
||||
|
||||
for head_cfg in kernel_update_head:
|
||||
self.kernel_update_head.append(MODELS.build(head_cfg))
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
feats = self.kernel_generate_head._forward_feature(inputs)
|
||||
sem_seg = self.kernel_generate_head.cls_seg(feats)
|
||||
seg_kernels = self.kernel_generate_head.conv_seg.weight.clone()
|
||||
seg_kernels = seg_kernels[None].expand(
|
||||
feats.size(0), *seg_kernels.size())
|
||||
|
||||
stage_segs = [sem_seg]
|
||||
for i in range(self.num_stages):
|
||||
sem_seg, seg_kernels = self.kernel_update_head[i](feats,
|
||||
seg_kernels,
|
||||
sem_seg)
|
||||
stage_segs.append(sem_seg)
|
||||
if self.training:
|
||||
return stage_segs
|
||||
# only return the prediction of the last stage during testing
|
||||
return stage_segs[-1]
|
||||
|
||||
def loss_by_feat(self, seg_logits: List[Tensor],
|
||||
batch_data_samples: SampleList, **kwargs) -> dict:
|
||||
losses = dict()
|
||||
for i, logit in enumerate(seg_logits):
|
||||
loss = self.kernel_generate_head.loss_by_feat(
|
||||
logit, batch_data_samples)
|
||||
for k, v in loss.items():
|
||||
losses[f'{k}.s{i}'] = v
|
||||
|
||||
return losses
|
||||
91
finetune/mmseg/models/decode_heads/lraspp_head.py
Normal file
91
finetune/mmseg/models/decode_heads/lraspp_head.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.utils import is_tuple_of
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class LRASPPHead(BaseDecodeHead):
|
||||
"""Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.
|
||||
|
||||
This head is the improved implementation of `Searching for MobileNetV3
|
||||
<https://ieeexplore.ieee.org/document/9008835>`_.
|
||||
|
||||
Args:
|
||||
branch_channels (tuple[int]): The number of output channels in every
|
||||
each branch. Default: (32, 64).
|
||||
"""
|
||||
|
||||
def __init__(self, branch_channels=(32, 64), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if self.input_transform != 'multiple_select':
|
||||
raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
|
||||
f'must be \'multiple_select\'. But received '
|
||||
f'\'{self.input_transform}\'')
|
||||
assert is_tuple_of(branch_channels, int)
|
||||
assert len(branch_channels) == len(self.in_channels) - 1
|
||||
self.branch_channels = branch_channels
|
||||
|
||||
self.convs = nn.Sequential()
|
||||
self.conv_ups = nn.Sequential()
|
||||
for i in range(len(branch_channels)):
|
||||
self.convs.add_module(
|
||||
f'conv{i}',
|
||||
nn.Conv2d(
|
||||
self.in_channels[i], branch_channels[i], 1, bias=False))
|
||||
self.conv_ups.add_module(
|
||||
f'conv_up{i}',
|
||||
ConvModule(
|
||||
self.channels + branch_channels[i],
|
||||
self.channels,
|
||||
1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
bias=False))
|
||||
|
||||
self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)
|
||||
|
||||
self.aspp_conv = ConvModule(
|
||||
self.in_channels[-1],
|
||||
self.channels,
|
||||
1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
bias=False)
|
||||
self.image_pool = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=49, stride=(16, 20)),
|
||||
ConvModule(
|
||||
self.in_channels[2],
|
||||
self.channels,
|
||||
1,
|
||||
act_cfg=dict(type='Sigmoid'),
|
||||
bias=False))
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
inputs = self._transform_inputs(inputs)
|
||||
|
||||
x = inputs[-1]
|
||||
|
||||
x = self.aspp_conv(x) * resize(
|
||||
self.image_pool(x),
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.conv_up_input(x)
|
||||
|
||||
for i in range(len(self.branch_channels) - 1, -1, -1):
|
||||
x = resize(
|
||||
x,
|
||||
size=inputs[i].size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = torch.cat([x, self.convs[i](inputs[i])], 1)
|
||||
x = self.conv_ups[i](x)
|
||||
|
||||
return self.cls_seg(x)
|
||||
163
finetune/mmseg/models/decode_heads/mask2former_head.py
Normal file
163
finetune/mmseg/models/decode_heads/mask2former_head.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
try:
|
||||
from mmdet.models.dense_heads import \
|
||||
Mask2FormerHead as MMDET_Mask2FormerHead
|
||||
except ModuleNotFoundError:
|
||||
MMDET_Mask2FormerHead = BaseModule
|
||||
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures.seg_data_sample import SegDataSample
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class Mask2FormerHead(MMDET_Mask2FormerHead):
|
||||
"""Implements the Mask2Former head.
|
||||
|
||||
See `Mask2Former: Masked-attention Mask Transformer for Universal Image
|
||||
Segmentation <https://arxiv.org/abs/2112.01527>`_ for details.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes. Default: 150.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
ignore_index (int): The label index to be ignored. Default: 255.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
align_corners=False,
|
||||
ignore_index=255,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.align_corners = align_corners
|
||||
self.out_channels = num_classes
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
feat_channels = kwargs['feat_channels']
|
||||
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
|
||||
|
||||
def _seg_data_to_instance_data(self, batch_data_samples: SampleList):
|
||||
"""Perform forward propagation to convert paradigm from MMSegmentation
|
||||
to MMDetection to ensure ``MMDET_Mask2FormerHead`` could be called
|
||||
normally. Specifically, ``batch_gt_instances`` would be added.
|
||||
|
||||
Args:
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: A tuple contains two lists.
|
||||
|
||||
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
||||
gt_instance. It usually includes ``labels``, each is
|
||||
unique ground truth label id of images, with
|
||||
shape (num_gt, ) and ``masks``, each is ground truth
|
||||
masks of each instances of a image, shape (num_gt, h, w).
|
||||
- batch_img_metas (list[dict]): List of image meta information.
|
||||
"""
|
||||
batch_img_metas = []
|
||||
batch_gt_instances = []
|
||||
|
||||
for data_sample in batch_data_samples:
|
||||
batch_img_metas.append(data_sample.metainfo)
|
||||
gt_sem_seg = data_sample.gt_sem_seg.data
|
||||
classes = torch.unique(
|
||||
gt_sem_seg,
|
||||
sorted=False,
|
||||
return_inverse=False,
|
||||
return_counts=False)
|
||||
|
||||
# remove ignored region
|
||||
gt_labels = classes[classes != self.ignore_index]
|
||||
|
||||
masks = []
|
||||
for class_id in gt_labels:
|
||||
masks.append(gt_sem_seg == class_id)
|
||||
|
||||
if len(masks) == 0:
|
||||
gt_masks = torch.zeros(
|
||||
(0, gt_sem_seg.shape[-2],
|
||||
gt_sem_seg.shape[-1])).to(gt_sem_seg).long()
|
||||
else:
|
||||
gt_masks = torch.stack(masks).squeeze(1).long()
|
||||
|
||||
instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
|
||||
batch_gt_instances.append(instance_data)
|
||||
return batch_gt_instances, batch_img_metas
|
||||
|
||||
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
|
||||
train_cfg: ConfigType) -> dict:
|
||||
"""Perform forward propagation and loss calculation of the decoder head
|
||||
on the features of the upstream network.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the upstream
|
||||
network, each is a 4D-tensor.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
train_cfg (ConfigType): Training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components.
|
||||
"""
|
||||
# batch SegDataSample to InstanceDataSample
|
||||
batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data(
|
||||
batch_data_samples)
|
||||
|
||||
# forward
|
||||
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
|
||||
|
||||
# loss
|
||||
losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
|
||||
batch_gt_instances, batch_img_metas)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType) -> Tuple[Tensor]:
|
||||
"""Test without augmentaton.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the
|
||||
upstream network, each is a 4D-tensor.
|
||||
batch_img_metas (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
test_cfg (ConfigType): Test config.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor of segmentation mask.
|
||||
"""
|
||||
batch_data_samples = [
|
||||
SegDataSample(metainfo=metainfo) for metainfo in batch_img_metas
|
||||
]
|
||||
|
||||
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
|
||||
mask_cls_results = all_cls_scores[-1]
|
||||
mask_pred_results = all_mask_preds[-1]
|
||||
if 'pad_shape' in batch_img_metas[0]:
|
||||
size = batch_img_metas[0]['pad_shape']
|
||||
else:
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
# upsample mask
|
||||
mask_pred_results = F.interpolate(
|
||||
mask_pred_results, size=size, mode='bilinear', align_corners=False)
|
||||
cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1]
|
||||
mask_pred = mask_pred_results.sigmoid()
|
||||
seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred)
|
||||
return seg_logits
|
||||
174
finetune/mmseg/models/decode_heads/maskformer_head.py
Normal file
174
finetune/mmseg/models/decode_heads/maskformer_head.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
try:
|
||||
from mmdet.models.dense_heads import MaskFormerHead as MMDET_MaskFormerHead
|
||||
except ModuleNotFoundError:
|
||||
MMDET_MaskFormerHead = BaseModule
|
||||
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures.seg_data_sample import SegDataSample
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MaskFormerHead(MMDET_MaskFormerHead):
|
||||
"""Implements the MaskFormer head.
|
||||
|
||||
See `Per-Pixel Classification is Not All You Need for Semantic Segmentation
|
||||
<https://arxiv.org/pdf/2107.06278>`_ for details.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes. Default: 150.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
ignore_index (int): The label index to be ignored. Default: 255.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes: int = 150,
|
||||
align_corners: bool = False,
|
||||
ignore_index: int = 255,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.out_channels = kwargs['out_channels']
|
||||
self.align_corners = True
|
||||
self.num_classes = num_classes
|
||||
self.align_corners = align_corners
|
||||
self.out_channels = num_classes
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
feat_channels = kwargs['feat_channels']
|
||||
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
|
||||
|
||||
def _seg_data_to_instance_data(self, batch_data_samples: SampleList):
|
||||
"""Perform forward propagation to convert paradigm from MMSegmentation
|
||||
to MMDetection to ensure ``MMDET_MaskFormerHead`` could be called
|
||||
normally. Specifically, ``batch_gt_instances`` would be added.
|
||||
|
||||
Args:
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: A tuple contains two lists.
|
||||
|
||||
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
||||
gt_instance. It usually includes ``labels``, each is
|
||||
unique ground truth label id of images, with
|
||||
shape (num_gt, ) and ``masks``, each is ground truth
|
||||
masks of each instances of a image, shape (num_gt, h, w).
|
||||
- batch_img_metas (list[dict]): List of image meta information.
|
||||
"""
|
||||
batch_img_metas = []
|
||||
batch_gt_instances = []
|
||||
for data_sample in batch_data_samples:
|
||||
# Add `batch_input_shape` in metainfo of data_sample, which would
|
||||
# be used in MaskFormerHead of MMDetection.
|
||||
metainfo = data_sample.metainfo
|
||||
metainfo['batch_input_shape'] = metainfo['img_shape']
|
||||
data_sample.set_metainfo(metainfo)
|
||||
batch_img_metas.append(data_sample.metainfo)
|
||||
gt_sem_seg = data_sample.gt_sem_seg.data
|
||||
classes = torch.unique(
|
||||
gt_sem_seg,
|
||||
sorted=False,
|
||||
return_inverse=False,
|
||||
return_counts=False)
|
||||
|
||||
# remove ignored region
|
||||
gt_labels = classes[classes != self.ignore_index]
|
||||
|
||||
masks = []
|
||||
for class_id in gt_labels:
|
||||
masks.append(gt_sem_seg == class_id)
|
||||
|
||||
if len(masks) == 0:
|
||||
gt_masks = torch.zeros((0, gt_sem_seg.shape[-2],
|
||||
gt_sem_seg.shape[-1])).to(gt_sem_seg)
|
||||
else:
|
||||
gt_masks = torch.stack(masks).squeeze(1)
|
||||
|
||||
instance_data = InstanceData(
|
||||
labels=gt_labels, masks=gt_masks.long())
|
||||
batch_gt_instances.append(instance_data)
|
||||
return batch_gt_instances, batch_img_metas
|
||||
|
||||
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
|
||||
train_cfg: ConfigType) -> dict:
|
||||
"""Perform forward propagation and loss calculation of the decoder head
|
||||
on the features of the upstream network.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the upstream
|
||||
network, each is a 4D-tensor.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
train_cfg (ConfigType): Training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components.
|
||||
"""
|
||||
# batch SegDataSample to InstanceDataSample
|
||||
batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data(
|
||||
batch_data_samples)
|
||||
|
||||
# forward
|
||||
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
|
||||
|
||||
# loss
|
||||
losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
|
||||
batch_gt_instances, batch_img_metas)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType) -> Tuple[Tensor]:
|
||||
"""Test without augmentaton.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the
|
||||
upstream network, each is a 4D-tensor.
|
||||
batch_img_metas (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
test_cfg (ConfigType): Test config.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor of segmentation mask.
|
||||
"""
|
||||
|
||||
batch_data_samples = []
|
||||
for metainfo in batch_img_metas:
|
||||
metainfo['batch_input_shape'] = metainfo['img_shape']
|
||||
batch_data_samples.append(SegDataSample(metainfo=metainfo))
|
||||
# Forward function of MaskFormerHead from MMDetection needs
|
||||
# 'batch_data_samples' as inputs, which is image shape actually.
|
||||
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
|
||||
mask_cls_results = all_cls_scores[-1]
|
||||
mask_pred_results = all_mask_preds[-1]
|
||||
|
||||
# upsample masks
|
||||
img_shape = batch_img_metas[0]['batch_input_shape']
|
||||
mask_pred_results = F.interpolate(
|
||||
mask_pred_results,
|
||||
size=img_shape,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
# semantic inference
|
||||
cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1]
|
||||
mask_pred = mask_pred_results.sigmoid()
|
||||
seg_logits = torch.einsum('bqc,bqhw->bchw', cls_score, mask_pred)
|
||||
return seg_logits
|
||||
50
finetune/mmseg/models/decode_heads/nl_head.py
Normal file
50
finetune/mmseg/models/decode_heads/nl_head.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.cnn import NonLocal2d
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class NLHead(FCNHead):
|
||||
"""Non-local Neural Networks.
|
||||
|
||||
This head is the implementation of `NLNet
|
||||
<https://arxiv.org/abs/1711.07971>`_.
|
||||
|
||||
Args:
|
||||
reduction (int): Reduction factor of projection transform. Default: 2.
|
||||
use_scale (bool): Whether to scale pairwise_weight by
|
||||
sqrt(1/inter_channels). Default: True.
|
||||
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
|
||||
'dot_product'. Default: 'embedded_gaussian.'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reduction=2,
|
||||
use_scale=True,
|
||||
mode='embedded_gaussian',
|
||||
**kwargs):
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.reduction = reduction
|
||||
self.use_scale = use_scale
|
||||
self.mode = mode
|
||||
self.nl_block = NonLocal2d(
|
||||
in_channels=self.channels,
|
||||
reduction=self.reduction,
|
||||
use_scale=self.use_scale,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
mode=self.mode)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
output = self.nl_block(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
127
finetune/mmseg/models/decode_heads/ocr_head.py
Normal file
127
finetune/mmseg/models/decode_heads/ocr_head.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from ..utils import resize
|
||||
from .cascade_decode_head import BaseCascadeDecodeHead
|
||||
|
||||
|
||||
class SpatialGatherModule(nn.Module):
|
||||
"""Aggregate the context features according to the initial predicted
|
||||
probability distribution.
|
||||
|
||||
Employ the soft-weighted method to aggregate the context.
|
||||
"""
|
||||
|
||||
def __init__(self, scale):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, feats, probs):
|
||||
"""Forward function."""
|
||||
batch_size, num_classes, height, width = probs.size()
|
||||
channels = feats.size(1)
|
||||
probs = probs.view(batch_size, num_classes, -1)
|
||||
feats = feats.view(batch_size, channels, -1)
|
||||
# [batch_size, height*width, num_classes]
|
||||
feats = feats.permute(0, 2, 1)
|
||||
# [batch_size, channels, height*width]
|
||||
probs = F.softmax(self.scale * probs, dim=2)
|
||||
# [batch_size, channels, num_classes]
|
||||
ocr_context = torch.matmul(probs, feats)
|
||||
ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3)
|
||||
return ocr_context
|
||||
|
||||
|
||||
class ObjectAttentionBlock(_SelfAttentionBlock):
|
||||
"""Make a OCR used SelfAttentionBlock."""
|
||||
|
||||
def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg,
|
||||
act_cfg):
|
||||
if scale > 1:
|
||||
query_downsample = nn.MaxPool2d(kernel_size=scale)
|
||||
else:
|
||||
query_downsample = None
|
||||
super().__init__(
|
||||
key_in_channels=in_channels,
|
||||
query_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=in_channels,
|
||||
share_key_query=False,
|
||||
query_downsample=query_downsample,
|
||||
key_downsample=None,
|
||||
key_query_num_convs=2,
|
||||
key_query_norm=True,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=True,
|
||||
matmul_norm=True,
|
||||
with_out=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.bottleneck = ConvModule(
|
||||
in_channels * 2,
|
||||
in_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, query_feats, key_feats):
|
||||
"""Forward function."""
|
||||
context = super().forward(query_feats, key_feats)
|
||||
output = self.bottleneck(torch.cat([context, query_feats], dim=1))
|
||||
if self.query_downsample is not None:
|
||||
output = resize(query_feats)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class OCRHead(BaseCascadeDecodeHead):
|
||||
"""Object-Contextual Representations for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `OCRNet
|
||||
<https://arxiv.org/abs/1909.11065>`_.
|
||||
|
||||
Args:
|
||||
ocr_channels (int): The intermediate channels of OCR block.
|
||||
scale (int): The scale of probability map in SpatialGatherModule in
|
||||
Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, ocr_channels, scale=1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.ocr_channels = ocr_channels
|
||||
self.scale = scale
|
||||
self.object_context_block = ObjectAttentionBlock(
|
||||
self.channels,
|
||||
self.ocr_channels,
|
||||
self.scale,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.spatial_gather_module = SpatialGatherModule(self.scale)
|
||||
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs, prev_output):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
feats = self.bottleneck(x)
|
||||
context = self.spatial_gather_module(feats, prev_output)
|
||||
object_context = self.object_context_block(feats, context)
|
||||
output = self.cls_seg(object_context)
|
||||
|
||||
return output
|
||||
183
finetune/mmseg/models/decode_heads/pid_head.py
Normal file
183
finetune/mmseg/models/decode_heads/pid_head.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.models.losses import accuracy
|
||||
from mmseg.models.utils import resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType, SampleList
|
||||
|
||||
|
||||
class BasePIDHead(BaseModule):
|
||||
"""Base class for PID head.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
init_cfg (dict or list[dict], optional): Init config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
self.conv = ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
order=('norm', 'act', 'conv'))
|
||||
_, self.norm = build_norm_layer(norm_cfg, num_features=channels)
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
|
||||
def forward(self, x: Tensor, cls_seg: Optional[nn.Module]) -> Tensor:
|
||||
"""Forward function.
|
||||
Args:
|
||||
x (Tensor): Input tensor.
|
||||
cls_seg (nn.Module, optional): The classification head.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor.
|
||||
"""
|
||||
x = self.conv(x)
|
||||
x = self.norm(x)
|
||||
x = self.act(x)
|
||||
if cls_seg is not None:
|
||||
x = cls_seg(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PIDHead(BaseDecodeHead):
|
||||
"""Decode head for PIDNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
num_classes (int): Number of classes.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
num_classes: int,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
channels,
|
||||
num_classes=num_classes,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**kwargs)
|
||||
self.i_head = BasePIDHead(in_channels, channels, norm_cfg, act_cfg)
|
||||
self.p_head = BasePIDHead(in_channels // 2, channels, norm_cfg,
|
||||
act_cfg)
|
||||
self.d_head = BasePIDHead(
|
||||
in_channels // 2,
|
||||
in_channels // 4,
|
||||
norm_cfg,
|
||||
)
|
||||
self.p_cls_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
|
||||
self.d_cls_seg = nn.Conv2d(in_channels // 4, 1, kernel_size=1)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: Union[Tensor,
|
||||
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
|
||||
"""Forward function.
|
||||
Args:
|
||||
inputs (Tensor | tuple[Tensor]): Input tensor or tuple of
|
||||
Tensor. When training, the input is a tuple of three tensors,
|
||||
(p_feat, i_feat, d_feat), and the output is a tuple of three
|
||||
tensors, (p_seg_logit, i_seg_logit, d_seg_logit).
|
||||
When inference, only the head of integral branch is used, and
|
||||
input is a tensor of integral feature map, and the output is
|
||||
the segmentation logit.
|
||||
|
||||
Returns:
|
||||
Tensor | tuple[Tensor]: Output tensor or tuple of tensors.
|
||||
"""
|
||||
if self.training:
|
||||
x_p, x_i, x_d = inputs
|
||||
x_p = self.p_head(x_p, self.p_cls_seg)
|
||||
x_i = self.i_head(x_i, self.cls_seg)
|
||||
x_d = self.d_head(x_d, self.d_cls_seg)
|
||||
return x_p, x_i, x_d
|
||||
else:
|
||||
return self.i_head(inputs, self.cls_seg)
|
||||
|
||||
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tuple[Tensor]:
|
||||
gt_semantic_segs = [
|
||||
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
|
||||
]
|
||||
gt_edge_segs = [
|
||||
data_sample.gt_edge_map.data for data_sample in batch_data_samples
|
||||
]
|
||||
gt_sem_segs = torch.stack(gt_semantic_segs, dim=0)
|
||||
gt_edge_segs = torch.stack(gt_edge_segs, dim=0)
|
||||
return gt_sem_segs, gt_edge_segs
|
||||
|
||||
def loss_by_feat(self, seg_logits: Tuple[Tensor],
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
loss = dict()
|
||||
p_logit, i_logit, d_logit = seg_logits
|
||||
sem_label, bd_label = self._stack_batch_gt(batch_data_samples)
|
||||
p_logit = resize(
|
||||
input=p_logit,
|
||||
size=sem_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
i_logit = resize(
|
||||
input=i_logit,
|
||||
size=sem_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
d_logit = resize(
|
||||
input=d_logit,
|
||||
size=bd_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
sem_label = sem_label.squeeze(1)
|
||||
bd_label = bd_label.squeeze(1)
|
||||
loss['loss_sem_p'] = self.loss_decode[0](
|
||||
p_logit, sem_label, ignore_index=self.ignore_index)
|
||||
loss['loss_sem_i'] = self.loss_decode[1](i_logit, sem_label)
|
||||
loss['loss_bd'] = self.loss_decode[2](d_logit, bd_label)
|
||||
filler = torch.ones_like(sem_label) * self.ignore_index
|
||||
sem_bd_label = torch.where(
|
||||
torch.sigmoid(d_logit[:, 0, :, :]) > 0.8, sem_label, filler)
|
||||
loss['loss_sem_bd'] = self.loss_decode[3](i_logit, sem_bd_label)
|
||||
loss['acc_seg'] = accuracy(
|
||||
i_logit, sem_label, ignore_index=self.ignore_index)
|
||||
return loss
|
||||
367
finetune/mmseg/models/decode_heads/point_head.py
Normal file
367
finetune/mmseg/models/decode_heads/point_head.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
try:
|
||||
from mmcv.ops import point_sample
|
||||
except ModuleNotFoundError:
|
||||
point_sample = None
|
||||
|
||||
from typing import List
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList
|
||||
from ..losses import accuracy
|
||||
from ..utils import resize
|
||||
from .cascade_decode_head import BaseCascadeDecodeHead
|
||||
|
||||
|
||||
def calculate_uncertainty(seg_logits):
|
||||
"""Estimate uncertainty based on seg logits.
|
||||
|
||||
For each location of the prediction ``seg_logits`` we estimate
|
||||
uncertainty as the difference between top first and top second
|
||||
predicted logits.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): Semantic segmentation logits,
|
||||
shape (batch_size, num_classes, height, width).
|
||||
|
||||
Returns:
|
||||
scores (Tensor): T uncertainty scores with the most uncertain
|
||||
locations having the highest uncertainty score, shape (
|
||||
batch_size, 1, height, width)
|
||||
"""
|
||||
top2_scores = torch.topk(seg_logits, k=2, dim=1)[0]
|
||||
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PointHead(BaseCascadeDecodeHead):
|
||||
"""A mask point head use in PointRend.
|
||||
|
||||
This head is implemented of `PointRend: Image Segmentation as
|
||||
Rendering <https://arxiv.org/abs/1912.08193>`_.
|
||||
``PointHead`` use shared multi-layer perceptron (equivalent to
|
||||
nn.Conv1d) to predict the logit of input points. The fine-grained feature
|
||||
and coarse feature will be concatenate together for predication.
|
||||
|
||||
Args:
|
||||
num_fcs (int): Number of fc layers in the head. Default: 3.
|
||||
in_channels (int): Number of input channels. Default: 256.
|
||||
fc_channels (int): Number of fc channels. Default: 256.
|
||||
num_classes (int): Number of classes for logits. Default: 80.
|
||||
class_agnostic (bool): Whether use class agnostic classification.
|
||||
If so, the output channels of logits will be 1. Default: False.
|
||||
coarse_pred_each_layer (bool): Whether concatenate coarse feature with
|
||||
the output of each fc layer. Default: True.
|
||||
conv_cfg (dict|None): Dictionary to construct and config conv layer.
|
||||
Default: dict(type='Conv1d'))
|
||||
norm_cfg (dict|None): Dictionary to construct and config norm layer.
|
||||
Default: None.
|
||||
loss_point (dict): Dictionary to construct and config loss layer of
|
||||
point head. Default: dict(type='CrossEntropyLoss', use_mask=True,
|
||||
loss_weight=1.0).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_fcs=3,
|
||||
coarse_pred_each_layer=True,
|
||||
conv_cfg=dict(type='Conv1d'),
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU', inplace=False),
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
input_transform='multiple_select',
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
init_cfg=dict(
|
||||
type='Normal', std=0.01, override=dict(name='fc_seg')),
|
||||
**kwargs)
|
||||
if point_sample is None:
|
||||
raise RuntimeError('Please install mmcv-full for '
|
||||
'point_sample ops')
|
||||
|
||||
self.num_fcs = num_fcs
|
||||
self.coarse_pred_each_layer = coarse_pred_each_layer
|
||||
|
||||
fc_in_channels = sum(self.in_channels) + self.num_classes
|
||||
fc_channels = self.channels
|
||||
self.fcs = nn.ModuleList()
|
||||
for k in range(num_fcs):
|
||||
fc = ConvModule(
|
||||
fc_in_channels,
|
||||
fc_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.fcs.append(fc)
|
||||
fc_in_channels = fc_channels
|
||||
fc_in_channels += self.num_classes if self.coarse_pred_each_layer \
|
||||
else 0
|
||||
self.fc_seg = nn.Conv1d(
|
||||
fc_in_channels,
|
||||
self.num_classes,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
if self.dropout_ratio > 0:
|
||||
self.dropout = nn.Dropout(self.dropout_ratio)
|
||||
delattr(self, 'conv_seg')
|
||||
|
||||
def cls_seg(self, feat):
|
||||
"""Classify each pixel with fc."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.fc_seg(feat)
|
||||
return output
|
||||
|
||||
def forward(self, fine_grained_point_feats, coarse_point_feats):
|
||||
x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1)
|
||||
for fc in self.fcs:
|
||||
x = fc(x)
|
||||
if self.coarse_pred_each_layer:
|
||||
x = torch.cat((x, coarse_point_feats), dim=1)
|
||||
return self.cls_seg(x)
|
||||
|
||||
def _get_fine_grained_point_feats(self, x, points):
|
||||
"""Sample from fine grained features.
|
||||
|
||||
Args:
|
||||
x (list[Tensor]): Feature pyramid from by neck or backbone.
|
||||
points (Tensor): Point coordinates, shape (batch_size,
|
||||
num_points, 2).
|
||||
|
||||
Returns:
|
||||
fine_grained_feats (Tensor): Sampled fine grained feature,
|
||||
shape (batch_size, sum(channels of x), num_points).
|
||||
"""
|
||||
|
||||
fine_grained_feats_list = [
|
||||
point_sample(_, points, align_corners=self.align_corners)
|
||||
for _ in x
|
||||
]
|
||||
if len(fine_grained_feats_list) > 1:
|
||||
fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1)
|
||||
else:
|
||||
fine_grained_feats = fine_grained_feats_list[0]
|
||||
|
||||
return fine_grained_feats
|
||||
|
||||
def _get_coarse_point_feats(self, prev_output, points):
|
||||
"""Sample from fine grained features.
|
||||
|
||||
Args:
|
||||
prev_output (list[Tensor]): Prediction of previous decode head.
|
||||
points (Tensor): Point coordinates, shape (batch_size,
|
||||
num_points, 2).
|
||||
|
||||
Returns:
|
||||
coarse_feats (Tensor): Sampled coarse feature, shape (batch_size,
|
||||
num_classes, num_points).
|
||||
"""
|
||||
|
||||
coarse_feats = point_sample(
|
||||
prev_output, points, align_corners=self.align_corners)
|
||||
|
||||
return coarse_feats
|
||||
|
||||
def loss(self, inputs, prev_output, batch_data_samples: SampleList,
|
||||
train_cfg, **kwargs):
|
||||
"""Forward function for training.
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
prev_output (Tensor): The output of previous decode head.
|
||||
batch_data_samples (list[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `img_metas` or `gt_semantic_seg`.
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
x = self._transform_inputs(inputs)
|
||||
with torch.no_grad():
|
||||
points = self.get_points_train(
|
||||
prev_output, calculate_uncertainty, cfg=train_cfg)
|
||||
fine_grained_point_feats = self._get_fine_grained_point_feats(
|
||||
x, points)
|
||||
coarse_point_feats = self._get_coarse_point_feats(prev_output, points)
|
||||
point_logits = self.forward(fine_grained_point_feats,
|
||||
coarse_point_feats)
|
||||
|
||||
losses = self.loss_by_feat(point_logits, points, batch_data_samples)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, inputs, prev_output, batch_img_metas: List[dict],
|
||||
test_cfg, **kwargs):
|
||||
"""Forward function for testing.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
prev_output (Tensor): The output of previous decode head.
|
||||
img_metas (list[dict]): List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Output segmentation map.
|
||||
"""
|
||||
|
||||
x = self._transform_inputs(inputs)
|
||||
refined_seg_logits = prev_output.clone()
|
||||
for _ in range(test_cfg.subdivision_steps):
|
||||
refined_seg_logits = resize(
|
||||
refined_seg_logits,
|
||||
scale_factor=test_cfg.scale_factor,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
batch_size, channels, height, width = refined_seg_logits.shape
|
||||
point_indices, points = self.get_points_test(
|
||||
refined_seg_logits, calculate_uncertainty, cfg=test_cfg)
|
||||
fine_grained_point_feats = self._get_fine_grained_point_feats(
|
||||
x, points)
|
||||
coarse_point_feats = self._get_coarse_point_feats(
|
||||
prev_output, points)
|
||||
point_logits = self.forward(fine_grained_point_feats,
|
||||
coarse_point_feats)
|
||||
|
||||
point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
|
||||
refined_seg_logits = refined_seg_logits.reshape(
|
||||
batch_size, channels, height * width)
|
||||
refined_seg_logits = refined_seg_logits.scatter_(
|
||||
2, point_indices, point_logits)
|
||||
refined_seg_logits = refined_seg_logits.view(
|
||||
batch_size, channels, height, width)
|
||||
|
||||
return self.predict_by_feat(refined_seg_logits, batch_img_metas,
|
||||
**kwargs)
|
||||
|
||||
def loss_by_feat(self, point_logits, points, batch_data_samples, **kwargs):
|
||||
"""Compute segmentation loss."""
|
||||
gt_semantic_seg = self._stack_batch_gt(batch_data_samples)
|
||||
point_label = point_sample(
|
||||
gt_semantic_seg.float(),
|
||||
points,
|
||||
mode='nearest',
|
||||
align_corners=self.align_corners)
|
||||
point_label = point_label.squeeze(1).long()
|
||||
|
||||
loss = dict()
|
||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.loss_decode]
|
||||
else:
|
||||
losses_decode = self.loss_decode
|
||||
for loss_module in losses_decode:
|
||||
loss['point' + loss_module.loss_name] = loss_module(
|
||||
point_logits, point_label, ignore_index=self.ignore_index)
|
||||
|
||||
loss['acc_point'] = accuracy(
|
||||
point_logits, point_label, ignore_index=self.ignore_index)
|
||||
return loss
|
||||
|
||||
def get_points_train(self, seg_logits, uncertainty_func, cfg):
|
||||
"""Sample points for training.
|
||||
|
||||
Sample points in [0, 1] x [0, 1] coordinate space based on their
|
||||
uncertainty. The uncertainties are calculated for each point using
|
||||
'uncertainty_func' function that takes point's logit prediction as
|
||||
input.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): Semantic segmentation logits, shape (
|
||||
batch_size, num_classes, height, width).
|
||||
uncertainty_func (func): uncertainty calculation function.
|
||||
cfg (dict): Training config of point head.
|
||||
|
||||
Returns:
|
||||
point_coords (Tensor): A tensor of shape (batch_size, num_points,
|
||||
2) that contains the coordinates of ``num_points`` sampled
|
||||
points.
|
||||
"""
|
||||
num_points = cfg.num_points
|
||||
oversample_ratio = cfg.oversample_ratio
|
||||
importance_sample_ratio = cfg.importance_sample_ratio
|
||||
assert oversample_ratio >= 1
|
||||
assert 0 <= importance_sample_ratio <= 1
|
||||
batch_size = seg_logits.shape[0]
|
||||
num_sampled = int(num_points * oversample_ratio)
|
||||
point_coords = torch.rand(
|
||||
batch_size, num_sampled, 2, device=seg_logits.device)
|
||||
point_logits = point_sample(seg_logits, point_coords)
|
||||
# It is crucial to calculate uncertainty based on the sampled
|
||||
# prediction value for the points. Calculating uncertainties of the
|
||||
# coarse predictions first and sampling them for points leads to
|
||||
# incorrect results. To illustrate this: assume uncertainty func(
|
||||
# logits)=-abs(logits), a sampled point between two coarse
|
||||
# predictions with -1 and 1 logits has 0 logits, and therefore 0
|
||||
# uncertainty value. However, if we calculate uncertainties for the
|
||||
# coarse predictions first, both will have -1 uncertainty,
|
||||
# and sampled point will get -1 uncertainty.
|
||||
point_uncertainties = uncertainty_func(point_logits)
|
||||
num_uncertain_points = int(importance_sample_ratio * num_points)
|
||||
num_random_points = num_points - num_uncertain_points
|
||||
idx = torch.topk(
|
||||
point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
|
||||
shift = num_sampled * torch.arange(
|
||||
batch_size, dtype=torch.long, device=seg_logits.device)
|
||||
idx += shift[:, None]
|
||||
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
|
||||
batch_size, num_uncertain_points, 2)
|
||||
if num_random_points > 0:
|
||||
rand_point_coords = torch.rand(
|
||||
batch_size, num_random_points, 2, device=seg_logits.device)
|
||||
point_coords = torch.cat((point_coords, rand_point_coords), dim=1)
|
||||
return point_coords
|
||||
|
||||
def get_points_test(self, seg_logits, uncertainty_func, cfg):
|
||||
"""Sample points for testing.
|
||||
|
||||
Find ``num_points`` most uncertain points from ``uncertainty_map``.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): A tensor of shape (batch_size, num_classes,
|
||||
height, width) for class-specific or class-agnostic prediction.
|
||||
uncertainty_func (func): uncertainty calculation function.
|
||||
cfg (dict): Testing config of point head.
|
||||
|
||||
Returns:
|
||||
point_indices (Tensor): A tensor of shape (batch_size, num_points)
|
||||
that contains indices from [0, height x width) of the most
|
||||
uncertain points.
|
||||
point_coords (Tensor): A tensor of shape (batch_size, num_points,
|
||||
2) that contains [0, 1] x [0, 1] normalized coordinates of the
|
||||
most uncertain points from the ``height x width`` grid .
|
||||
"""
|
||||
|
||||
num_points = cfg.subdivision_num_points
|
||||
uncertainty_map = uncertainty_func(seg_logits)
|
||||
batch_size, _, height, width = uncertainty_map.shape
|
||||
h_step = 1.0 / height
|
||||
w_step = 1.0 / width
|
||||
|
||||
uncertainty_map = uncertainty_map.view(batch_size, height * width)
|
||||
num_points = min(height * width, num_points)
|
||||
point_indices = uncertainty_map.topk(num_points, dim=1)[1]
|
||||
point_coords = torch.zeros(
|
||||
batch_size,
|
||||
num_points,
|
||||
2,
|
||||
dtype=torch.float,
|
||||
device=seg_logits.device)
|
||||
point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
|
||||
width).float() * w_step
|
||||
point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
|
||||
width).float() * h_step
|
||||
return point_indices, point_coords
|
||||
197
finetune/mmseg/models/decode_heads/psa_head.py
Normal file
197
finetune/mmseg/models/decode_heads/psa_head.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
try:
|
||||
from mmcv.ops import PSAMask
|
||||
except ModuleNotFoundError:
|
||||
PSAMask = None
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PSAHead(BaseDecodeHead):
|
||||
"""Point-wise Spatial Attention Network for Scene Parsing.
|
||||
|
||||
This head is the implementation of `PSANet
|
||||
<https://hszhao.github.io/papers/eccv18_psanet.pdf>`_.
|
||||
|
||||
Args:
|
||||
mask_size (tuple[int]): The PSA mask size. It usually equals input
|
||||
size.
|
||||
psa_type (str): The type of psa module. Options are 'collect',
|
||||
'distribute', 'bi-direction'. Default: 'bi-direction'
|
||||
compact (bool): Whether use compact map for 'collect' mode.
|
||||
Default: True.
|
||||
shrink_factor (int): The downsample factors of psa mask. Default: 2.
|
||||
normalization_factor (float): The normalize factor of attention.
|
||||
psa_softmax (bool): Whether use softmax for attention.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mask_size,
|
||||
psa_type='bi-direction',
|
||||
compact=False,
|
||||
shrink_factor=2,
|
||||
normalization_factor=1.0,
|
||||
psa_softmax=True,
|
||||
**kwargs):
|
||||
if PSAMask is None:
|
||||
raise RuntimeError('Please install mmcv-full for PSAMask ops')
|
||||
super().__init__(**kwargs)
|
||||
assert psa_type in ['collect', 'distribute', 'bi-direction']
|
||||
self.psa_type = psa_type
|
||||
self.compact = compact
|
||||
self.shrink_factor = shrink_factor
|
||||
self.mask_size = mask_size
|
||||
mask_h, mask_w = mask_size
|
||||
self.psa_softmax = psa_softmax
|
||||
if normalization_factor is None:
|
||||
normalization_factor = mask_h * mask_w
|
||||
self.normalization_factor = normalization_factor
|
||||
|
||||
self.reduce = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.attention = nn.Sequential(
|
||||
ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
nn.Conv2d(
|
||||
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
|
||||
if psa_type == 'bi-direction':
|
||||
self.reduce_p = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.attention_p = nn.Sequential(
|
||||
ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
nn.Conv2d(
|
||||
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
|
||||
self.psamask_collect = PSAMask('collect', mask_size)
|
||||
self.psamask_distribute = PSAMask('distribute', mask_size)
|
||||
else:
|
||||
self.psamask = PSAMask(psa_type, mask_size)
|
||||
self.proj = ConvModule(
|
||||
self.channels * (2 if psa_type == 'bi-direction' else 1),
|
||||
self.in_channels,
|
||||
kernel_size=1,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels * 2,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
identity = x
|
||||
align_corners = self.align_corners
|
||||
if self.psa_type in ['collect', 'distribute']:
|
||||
out = self.reduce(x)
|
||||
n, c, h, w = out.size()
|
||||
if self.shrink_factor != 1:
|
||||
if h % self.shrink_factor and w % self.shrink_factor:
|
||||
h = (h - 1) // self.shrink_factor + 1
|
||||
w = (w - 1) // self.shrink_factor + 1
|
||||
align_corners = True
|
||||
else:
|
||||
h = h // self.shrink_factor
|
||||
w = w // self.shrink_factor
|
||||
align_corners = False
|
||||
out = resize(
|
||||
out,
|
||||
size=(h, w),
|
||||
mode='bilinear',
|
||||
align_corners=align_corners)
|
||||
y = self.attention(out)
|
||||
if self.compact:
|
||||
if self.psa_type == 'collect':
|
||||
y = y.view(n, h * w,
|
||||
h * w).transpose(1, 2).view(n, h * w, h, w)
|
||||
else:
|
||||
y = self.psamask(y)
|
||||
if self.psa_softmax:
|
||||
y = F.softmax(y, dim=1)
|
||||
out = torch.bmm(
|
||||
out.view(n, c, h * w), y.view(n, h * w, h * w)).view(
|
||||
n, c, h, w) * (1.0 / self.normalization_factor)
|
||||
else:
|
||||
x_col = self.reduce(x)
|
||||
x_dis = self.reduce_p(x)
|
||||
n, c, h, w = x_col.size()
|
||||
if self.shrink_factor != 1:
|
||||
if h % self.shrink_factor and w % self.shrink_factor:
|
||||
h = (h - 1) // self.shrink_factor + 1
|
||||
w = (w - 1) // self.shrink_factor + 1
|
||||
align_corners = True
|
||||
else:
|
||||
h = h // self.shrink_factor
|
||||
w = w // self.shrink_factor
|
||||
align_corners = False
|
||||
x_col = resize(
|
||||
x_col,
|
||||
size=(h, w),
|
||||
mode='bilinear',
|
||||
align_corners=align_corners)
|
||||
x_dis = resize(
|
||||
x_dis,
|
||||
size=(h, w),
|
||||
mode='bilinear',
|
||||
align_corners=align_corners)
|
||||
y_col = self.attention(x_col)
|
||||
y_dis = self.attention_p(x_dis)
|
||||
if self.compact:
|
||||
y_dis = y_dis.view(n, h * w,
|
||||
h * w).transpose(1, 2).view(n, h * w, h, w)
|
||||
else:
|
||||
y_col = self.psamask_collect(y_col)
|
||||
y_dis = self.psamask_distribute(y_dis)
|
||||
if self.psa_softmax:
|
||||
y_col = F.softmax(y_col, dim=1)
|
||||
y_dis = F.softmax(y_dis, dim=1)
|
||||
x_col = torch.bmm(
|
||||
x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(
|
||||
n, c, h, w) * (1.0 / self.normalization_factor)
|
||||
x_dis = torch.bmm(
|
||||
x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(
|
||||
n, c, h, w) * (1.0 / self.normalization_factor)
|
||||
out = torch.cat([x_col, x_dis], 1)
|
||||
out = self.proj(out)
|
||||
out = resize(
|
||||
out,
|
||||
size=identity.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=align_corners)
|
||||
out = self.bottleneck(torch.cat((identity, out), dim=1))
|
||||
out = self.cls_seg(out)
|
||||
return out
|
||||
117
finetune/mmseg/models/decode_heads/psp_head.py
Normal file
117
finetune/mmseg/models/decode_heads/psp_head.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class PPM(nn.ModuleList):
|
||||
"""Pooling Pyramid Module used in PSPNet.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
|
||||
act_cfg, align_corners, **kwargs):
|
||||
super().__init__()
|
||||
self.pool_scales = pool_scales
|
||||
self.align_corners = align_corners
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
for pool_scale in pool_scales:
|
||||
self.append(
|
||||
nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(pool_scale),
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
**kwargs)))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
ppm_outs = []
|
||||
for ppm in self:
|
||||
ppm_out = ppm(x)
|
||||
upsampled_ppm_out = resize(
|
||||
ppm_out,
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
ppm_outs.append(upsampled_ppm_out)
|
||||
return ppm_outs
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PSPHead(BaseDecodeHead):
|
||||
"""Pyramid Scene Parsing Network.
|
||||
|
||||
This head is the implementation of
|
||||
`PSPNet <https://arxiv.org/abs/1612.01105>`_.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module. Default: (1, 2, 3, 6).
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(pool_scales, (list, tuple))
|
||||
self.pool_scales = pool_scales
|
||||
self.psp_modules = PPM(
|
||||
self.pool_scales,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels + len(pool_scales) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def _forward_feature(self, inputs):
|
||||
"""Forward function for feature maps before classifying each pixel with
|
||||
``self.cls_seg`` fc.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||
H, W) which is feature map for last layer of decoder head.
|
||||
"""
|
||||
x = self._transform_inputs(inputs)
|
||||
psp_outs = [x]
|
||||
psp_outs.extend(self.psp_modules(x))
|
||||
psp_outs = torch.cat(psp_outs, dim=1)
|
||||
feats = self.bottleneck(psp_outs)
|
||||
return feats
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
output = self._forward_feature(inputs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
736
finetune/mmseg/models/decode_heads/san_head.py
Normal file
736
finetune/mmseg/models/decode_heads/san_head.py
Normal file
@@ -0,0 +1,736 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from functools import partial
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
||||
from mmcv.ops import point_sample
|
||||
from mmengine.dist import all_reduce
|
||||
from mmengine.model.weight_init import (caffe2_xavier_init, normal_init,
|
||||
trunc_normal_)
|
||||
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmseg.models.backbones.vit import TransformerEncoderLayer
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import (ConfigType, MatchMasks, SampleList,
|
||||
seg_data_to_instance_data)
|
||||
from ..utils import (MLP, LayerNorm2d, PatchEmbed, cross_attn_layer,
|
||||
get_uncertain_point_coords_with_randomness, resize)
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class MLPMaskDecoder(nn.Module):
|
||||
"""Module for decoding query and visual features with MLP layers to
|
||||
generate the attention biases and the mask proposals."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels: int,
|
||||
total_heads: int = 1,
|
||||
total_layers: int = 1,
|
||||
embed_channels: int = 256,
|
||||
mlp_channels: int = 256,
|
||||
mlp_num_layers: int = 3,
|
||||
rescale_attn_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.total_heads = total_heads
|
||||
self.total_layers = total_layers
|
||||
|
||||
dense_affine_func = partial(nn.Conv2d, kernel_size=1)
|
||||
# Query Branch
|
||||
self.query_mlp = MLP(in_channels, mlp_channels, embed_channels,
|
||||
mlp_num_layers)
|
||||
# Pixel Branch
|
||||
self.pix_mlp = MLP(
|
||||
in_channels,
|
||||
mlp_channels,
|
||||
embed_channels,
|
||||
mlp_num_layers,
|
||||
affine_func=dense_affine_func,
|
||||
)
|
||||
# Attention Bias Branch
|
||||
self.attn_mlp = MLP(
|
||||
in_channels,
|
||||
mlp_channels,
|
||||
embed_channels * self.total_heads * self.total_layers,
|
||||
mlp_num_layers,
|
||||
affine_func=dense_affine_func,
|
||||
)
|
||||
if rescale_attn_bias:
|
||||
self.bias_scaling = nn.Linear(1, 1)
|
||||
else:
|
||||
self.bias_scaling = nn.Identity()
|
||||
|
||||
def forward(self, query: torch.Tensor,
|
||||
x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Forward function.
|
||||
Args:
|
||||
query (Tensor): Query Tokens [B,N,C].
|
||||
x (Tensor): Visual features [B,C,H,W]
|
||||
|
||||
Return:
|
||||
mask_preds (Tensor): Mask proposals.
|
||||
attn_bias (List[Tensor]): List of attention bias.
|
||||
"""
|
||||
query = self.query_mlp(query)
|
||||
pix = self.pix_mlp(x)
|
||||
b, c, h, w = pix.shape
|
||||
# preidict mask
|
||||
mask_preds = torch.einsum('bqc,bchw->bqhw', query, pix)
|
||||
# generate attn bias
|
||||
attn = self.attn_mlp(x)
|
||||
attn = attn.reshape(b, self.total_layers, self.total_heads, c, h, w)
|
||||
attn_bias = torch.einsum('bqc,blnchw->blnqhw', query, attn)
|
||||
attn_bias = self.bias_scaling(attn_bias[..., None]).squeeze(-1)
|
||||
attn_bias = attn_bias.chunk(self.total_layers, dim=1)
|
||||
attn_bias = [attn.squeeze(1) for attn in attn_bias]
|
||||
return mask_preds, attn_bias
|
||||
|
||||
|
||||
class SideAdapterNetwork(nn.Module):
|
||||
"""Side Adapter Network for predicting mask proposals and attention bias.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
clip_channels (int): Number of channels of visual features.
|
||||
Default: 768.
|
||||
embed_dims (int): embedding dimension. Default: 240.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
patch_bias (bool): Whether use bias in patch embedding.
|
||||
Default: True.
|
||||
num_queries (int): Number of queries for mask proposals.
|
||||
Default: 100.
|
||||
fusion_index (List[int]): The layer number of the encode
|
||||
transformer to fuse with the CLIP feature.
|
||||
Default: [0, 1, 2, 3].
|
||||
cfg_encoder (ConfigType): Configs for the encode layers.
|
||||
cfg_decoder (ConfigType): Configs for the decode layers.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
clip_channels: int = 768,
|
||||
embed_dims: int = 240,
|
||||
patch_size: int = 16,
|
||||
patch_bias: bool = True,
|
||||
num_queries: int = 100,
|
||||
fusion_index: list = [0, 1, 2, 3],
|
||||
cfg_encoder: ConfigType = ...,
|
||||
cfg_decoder: ConfigType = ...,
|
||||
norm_cfg: dict = dict(type='LN'),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
padding=0,
|
||||
input_size=(640, 640),
|
||||
bias=patch_bias,
|
||||
norm_cfg=None,
|
||||
init_cfg=None,
|
||||
)
|
||||
ori_h, ori_w = self.patch_embed.init_out_size
|
||||
num_patches = ori_h * ori_w
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.randn(1, num_patches, embed_dims) * .02)
|
||||
self.query_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_queries, embed_dims))
|
||||
self.query_embed = nn.Parameter(
|
||||
torch.zeros(1, num_queries, embed_dims))
|
||||
encode_layers = []
|
||||
for i in range(cfg_encoder.num_encode_layer):
|
||||
encode_layers.append(
|
||||
TransformerEncoderLayer(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=cfg_encoder.num_heads,
|
||||
feedforward_channels=cfg_encoder.mlp_ratio * embed_dims,
|
||||
norm_cfg=norm_cfg))
|
||||
self.encode_layers = nn.ModuleList(encode_layers)
|
||||
conv_clips = []
|
||||
for i in range(len(fusion_index)):
|
||||
conv_clips.append(
|
||||
nn.Sequential(
|
||||
LayerNorm2d(clip_channels),
|
||||
ConvModule(
|
||||
clip_channels,
|
||||
embed_dims,
|
||||
kernel_size=1,
|
||||
norm_cfg=None,
|
||||
act_cfg=None)))
|
||||
self.conv_clips = nn.ModuleList(conv_clips)
|
||||
self.fusion_index = fusion_index
|
||||
self.mask_decoder = MLPMaskDecoder(
|
||||
in_channels=embed_dims,
|
||||
total_heads=cfg_decoder.num_heads,
|
||||
total_layers=cfg_decoder.num_layers,
|
||||
embed_channels=cfg_decoder.embed_channels,
|
||||
mlp_channels=cfg_decoder.mlp_channels,
|
||||
mlp_num_layers=cfg_decoder.num_mlp,
|
||||
rescale_attn_bias=cfg_decoder.rescale)
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
nn.init.normal_(self.query_embed, std=0.02)
|
||||
nn.init.normal_(self.query_pos_embed, std=0.02)
|
||||
for i in range(len(self.conv_clips)):
|
||||
caffe2_xavier_init(self.conv_clips[i][1].conv)
|
||||
|
||||
def fuse_clip(self, fused_index: int, x: torch.Tensor,
|
||||
clip_feature: torch.Tensor, hwshape: Tuple[int,
|
||||
int], L: int):
|
||||
"""Fuse CLIP feature and visual tokens."""
|
||||
fused_clip = (resize(
|
||||
self.conv_clips[fused_index](clip_feature.contiguous()),
|
||||
size=hwshape,
|
||||
mode='bilinear',
|
||||
align_corners=False)).permute(0, 2, 3, 1).reshape(x[:, -L:,
|
||||
...].shape)
|
||||
x = torch.cat([x[:, :-L, ...], x[:, -L:, ...] + fused_clip], dim=1)
|
||||
return x
|
||||
|
||||
def encode_feature(self, image: torch.Tensor,
|
||||
clip_features: List[torch.Tensor],
|
||||
deep_supervision_idxs: List[int]) -> List[List]:
|
||||
"""Encode images by a lightweight vision transformer."""
|
||||
assert len(self.fusion_index) == len(clip_features)
|
||||
x, hwshape = self.patch_embed(image)
|
||||
ori_h, ori_w = self.patch_embed.init_out_size
|
||||
pos_embed = self.pos_embed
|
||||
if self.pos_embed.shape[1] != x.shape[1]:
|
||||
# resize the position embedding
|
||||
pos_embed = (
|
||||
resize(
|
||||
self.pos_embed.reshape(1, ori_h, ori_w,
|
||||
-1).permute(0, 3, 1, 2),
|
||||
size=hwshape,
|
||||
mode='bicubic',
|
||||
align_corners=False,
|
||||
).flatten(2).permute(0, 2, 1))
|
||||
pos_embed = torch.cat([
|
||||
self.query_pos_embed.expand(pos_embed.shape[0], -1, -1), pos_embed
|
||||
],
|
||||
dim=1)
|
||||
x = torch.cat([self.query_embed.expand(x.shape[0], -1, -1), x], dim=1)
|
||||
x = x + pos_embed
|
||||
L = hwshape[0] * hwshape[1]
|
||||
fused_index = 0
|
||||
if self.fusion_index[fused_index] == 0:
|
||||
x = self.fuse_clip(fused_index, x, clip_features[0][0], hwshape, L)
|
||||
fused_index += 1
|
||||
outs = []
|
||||
for index, block in enumerate(self.encode_layers, start=1):
|
||||
x = block(x)
|
||||
if index < len(self.fusion_index
|
||||
) and index == self.fusion_index[fused_index]:
|
||||
x = self.fuse_clip(fused_index, x,
|
||||
clip_features[fused_index][0], hwshape, L)
|
||||
fused_index += 1
|
||||
x_query = x[:, :-L, ...]
|
||||
x_feat = x[:, -L:, ...].permute(0, 2, 1)\
|
||||
.reshape(x.shape[0], x.shape[-1], hwshape[0], hwshape[1])
|
||||
|
||||
if index in deep_supervision_idxs or index == len(
|
||||
self.encode_layers):
|
||||
outs.append({'query': x_query, 'x': x_feat})
|
||||
|
||||
if index < len(self.encode_layers):
|
||||
x = x + pos_embed
|
||||
return outs
|
||||
|
||||
def decode_feature(self, features):
|
||||
mask_embeds = []
|
||||
attn_biases = []
|
||||
for feature in features:
|
||||
mask_embed, attn_bias = self.mask_decoder(**feature)
|
||||
mask_embeds.append(mask_embed)
|
||||
attn_biases.append(attn_bias)
|
||||
return mask_embeds, attn_biases
|
||||
|
||||
def forward(
|
||||
self, image: torch.Tensor, clip_features: List[torch.Tensor],
|
||||
deep_supervision_idxs: List[int]
|
||||
) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
|
||||
"""Forward function."""
|
||||
features = self.encode_feature(image, clip_features,
|
||||
deep_supervision_idxs)
|
||||
mask_embeds, attn_biases = self.decode_feature(features)
|
||||
return mask_embeds, attn_biases
|
||||
|
||||
|
||||
class RecWithAttnbias(nn.Module):
|
||||
"""Mask recognition module by applying the attention biases to rest deeper
|
||||
CLIP layers.
|
||||
|
||||
Args:
|
||||
sos_token_format (str): The format of sos token. It should be
|
||||
chosen from ["cls_token", "learnable_token", "pos_embedding"].
|
||||
Default: 'cls_token'.
|
||||
sos_token_num (int): Number of sos token. It should be equal to
|
||||
the number of quries. Default: 100.
|
||||
num_layers (int): Number of rest CLIP layers for mask recognition.
|
||||
Default: 3.
|
||||
cross_attn (bool): Whether use cross attention to update sos token.
|
||||
Default: False.
|
||||
embed_dims (int): The feature dimension of CLIP layers.
|
||||
Default: 768.
|
||||
num_heads (int): Parallel attention heads of CLIP layers.
|
||||
Default: 768.
|
||||
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
qkv_bias (bool): Whether to use bias in multihead-attention.
|
||||
Default: True.
|
||||
out_dims (int): Number of channels of the output mask proposals.
|
||||
It should be equal to the out_dims of text_encoder.
|
||||
Default: 512.
|
||||
final_norm (True): Whether use norm layer for sos token.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
frozen_exclude (List): List of parameters that are not to be frozen.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
sos_token_format: str = 'cls_token',
|
||||
sos_token_num: int = 100,
|
||||
num_layers: int = 3,
|
||||
cross_attn: bool = False,
|
||||
embed_dims: int = 768,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: int = 4,
|
||||
num_fcs: int = 2,
|
||||
qkv_bias: bool = True,
|
||||
out_dims: int = 512,
|
||||
final_norm: bool = True,
|
||||
act_cfg: dict = dict(type='GELU'),
|
||||
norm_cfg: dict = dict(type='LN'),
|
||||
frozen_exclude: List = []):
|
||||
super().__init__()
|
||||
|
||||
assert sos_token_format in [
|
||||
'cls_token', 'learnable_token', 'pos_embedding'
|
||||
]
|
||||
self.sos_token_format = sos_token_format
|
||||
self.sos_token_num = sos_token_num
|
||||
self.frozen_exclude = frozen_exclude
|
||||
self.cross_attn = cross_attn
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
if sos_token_format in ['learnable_token', 'pos_embedding']:
|
||||
self.sos_token = nn.Parameter(
|
||||
torch.randn(sos_token_num, 1, self.proj.shape[0]))
|
||||
self.frozen.append('sos_token')
|
||||
|
||||
layers = []
|
||||
for i in range(num_layers):
|
||||
layers.append(
|
||||
BaseTransformerLayer(
|
||||
attn_cfgs=dict(
|
||||
type='MultiheadAttention',
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
batch_first=False,
|
||||
bias=qkv_bias),
|
||||
ffn_cfgs=dict(
|
||||
type='FFN',
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=mlp_ratio * embed_dims,
|
||||
act_cfg=act_cfg),
|
||||
operation_order=('norm', 'self_attn', 'norm', 'ffn')))
|
||||
self.layers = nn.ModuleList(layers)
|
||||
|
||||
self.ln_post = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.proj = nn.Linear(embed_dims, out_dims, bias=False)
|
||||
|
||||
self.final_norm = final_norm
|
||||
self._freeze()
|
||||
|
||||
def init_weights(self, rec_state_dict):
|
||||
if hasattr(self, 'sos_token'):
|
||||
normal_init(self.sos_token, std=0.02)
|
||||
if rec_state_dict is not None:
|
||||
load_state_dict(self, rec_state_dict, strict=False, logger=None)
|
||||
else:
|
||||
super().init_weights()
|
||||
|
||||
def _freeze(self):
|
||||
if 'all' in self.frozen_exclude:
|
||||
return
|
||||
for name, param in self.named_parameters():
|
||||
if not any([exclude in name for exclude in self.frozen_exclude]):
|
||||
param.requires_grad = False
|
||||
|
||||
def _build_attn_biases(self, attn_biases, target_shape):
|
||||
formatted_attn_biases = []
|
||||
for attn_bias in attn_biases:
|
||||
# convert it to proper format: N*num_head,L,L
|
||||
# attn_bias: [N, num_head/1, num_sos,H,W]
|
||||
n, num_head, num_sos, h, w = attn_bias.shape
|
||||
# reshape and downsample
|
||||
attn_bias = F.adaptive_max_pool2d(
|
||||
attn_bias.reshape(n, num_head * num_sos, h, w),
|
||||
output_size=target_shape)
|
||||
attn_bias = attn_bias.reshape(n, num_head, num_sos, *target_shape)
|
||||
|
||||
true_num_head = self.num_heads
|
||||
assert (num_head == 1 or num_head
|
||||
== true_num_head), f'num_head={num_head} is not supported.'
|
||||
if num_head == 1:
|
||||
attn_bias = attn_bias.repeat(1, true_num_head, 1, 1, 1)
|
||||
attn_bias = attn_bias.reshape(n * true_num_head, num_sos, -1)
|
||||
L = attn_bias.shape[-1]
|
||||
if self.cross_attn:
|
||||
# [n*num_head, num_sos, L]
|
||||
formatted_attn_biases.append(attn_bias)
|
||||
else:
|
||||
# [n*num_head, num_sos+1+L, num_sos+1+L]
|
||||
new_attn_bias = attn_bias.new_zeros(num_sos + 1 + L,
|
||||
num_sos + 1 + L)
|
||||
new_attn_bias[:, :num_sos] = -100
|
||||
new_attn_bias[torch.arange(num_sos), torch.arange(num_sos)] = 0
|
||||
new_attn_bias[:num_sos, num_sos] = -100
|
||||
new_attn_bias = (
|
||||
new_attn_bias[None, ...].expand(n * true_num_head, -1,
|
||||
-1).clone())
|
||||
new_attn_bias[..., :num_sos, -L:] = attn_bias
|
||||
formatted_attn_biases.append(new_attn_bias)
|
||||
|
||||
if len(formatted_attn_biases) == 1:
|
||||
formatted_attn_biases = [
|
||||
formatted_attn_biases[0] for _ in range(self.num_layers)
|
||||
]
|
||||
return formatted_attn_biases
|
||||
|
||||
def forward(self, bias: List[Tensor], feature: List[Tensor]):
|
||||
"""Forward function to recognize the category of masks
|
||||
Args:
|
||||
bias (List[Tensor]): Attention bias for transformer layers
|
||||
feature (List[Tensor]): Output of the image encoder,
|
||||
including cls_token and img_feature.
|
||||
"""
|
||||
cls_token = feature[1].unsqueeze(0)
|
||||
img_feature = feature[0]
|
||||
b, c, h, w = img_feature.shape
|
||||
# construct clip shadow features
|
||||
x = torch.cat(
|
||||
[cls_token,
|
||||
img_feature.reshape(b, c, -1).permute(2, 0, 1)])
|
||||
|
||||
# construct sos token
|
||||
if self.sos_token_format == 'cls_token':
|
||||
sos_token = cls_token.repeat(self.sos_token_num, 1, 1)
|
||||
elif self.sos_token_format == 'learnable_token':
|
||||
sos_token = self.sos_token.expand(-1, b, -1)
|
||||
elif self.sos_token_format == 'pos_embedding':
|
||||
sos_token = self.sos_token.expand(-1, b, -1) + cls_token
|
||||
|
||||
# construct attn bias
|
||||
attn_biases = self._build_attn_biases(bias, target_shape=(h, w))
|
||||
|
||||
if self.cross_attn:
|
||||
for i, block in enumerate(self.layers):
|
||||
if self.cross_attn:
|
||||
sos_token = cross_attn_layer(
|
||||
block,
|
||||
sos_token,
|
||||
x[1:, ],
|
||||
attn_biases[i],
|
||||
)
|
||||
if i < len(self.layers) - 1:
|
||||
x = block(x)
|
||||
else:
|
||||
x = torch.cat([sos_token, x], dim=0)
|
||||
for i, block in enumerate(self.layers):
|
||||
x = block(x, attn_masks=[attn_biases[i]])
|
||||
sos_token = x[:self.sos_token_num]
|
||||
|
||||
sos_token = sos_token.permute(1, 0, 2) # LND -> NLD
|
||||
sos_token = self.ln_post(sos_token)
|
||||
sos_token = self.proj(sos_token)
|
||||
if self.final_norm:
|
||||
sos_token = F.normalize(sos_token, dim=-1)
|
||||
return sos_token
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SideAdapterCLIPHead(BaseDecodeHead):
|
||||
"""Side Adapter Network (SAN) for open-vocabulary semantic segmentation
|
||||
with pre-trained vision-language model.
|
||||
|
||||
This decode head is the implementation of `Side Adapter Network
|
||||
for Open-Vocabulary Semantic Segmentation`
|
||||
<https://arxiv.org/abs/2302.12242>.
|
||||
Modified from https://github.com/MendelXu/SAN/blob/main/san/model/side_adapter/side_adapter.py # noqa:E501
|
||||
Copyright (c) 2023 MendelXu.
|
||||
Licensed under the MIT License
|
||||
|
||||
Args:
|
||||
num_classes (int): the number of classes.
|
||||
san_cfg (ConfigType): Configs for SideAdapterNetwork module
|
||||
maskgen_cfg (ConfigType): Configs for RecWithAttnbias module
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes: int, san_cfg: ConfigType,
|
||||
maskgen_cfg: ConfigType, deep_supervision_idxs: List[int],
|
||||
train_cfg: ConfigType, **kwargs):
|
||||
super().__init__(
|
||||
in_channels=san_cfg.in_channels,
|
||||
channels=san_cfg.embed_dims,
|
||||
num_classes=num_classes,
|
||||
**kwargs)
|
||||
assert san_cfg.num_queries == maskgen_cfg.sos_token_num, \
|
||||
'num_queries in san_cfg should be equal to sos_token_num ' \
|
||||
'in maskgen_cfg'
|
||||
del self.conv_seg
|
||||
self.side_adapter_network = SideAdapterNetwork(**san_cfg)
|
||||
self.rec_with_attnbias = RecWithAttnbias(**maskgen_cfg)
|
||||
self.deep_supervision_idxs = deep_supervision_idxs
|
||||
self.train_cfg = train_cfg
|
||||
if train_cfg:
|
||||
self.match_masks = MatchMasks(
|
||||
num_points=train_cfg.num_points,
|
||||
num_queries=san_cfg.num_queries,
|
||||
num_classes=num_classes,
|
||||
assigner=train_cfg.assigner)
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
rec_state_dict = None
|
||||
if isinstance(self.init_cfg, dict) and \
|
||||
self.init_cfg.get('type') == 'Pretrained_Part':
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
|
||||
rec_state_dict = checkpoint.copy()
|
||||
para_prefix = 'decode_head.rec_with_attnbias'
|
||||
prefix_len = len(para_prefix) + 1
|
||||
for k, v in checkpoint.items():
|
||||
rec_state_dict.pop(k)
|
||||
if para_prefix in k:
|
||||
rec_state_dict[k[prefix_len:]] = v
|
||||
|
||||
self.side_adapter_network.init_weights()
|
||||
self.rec_with_attnbias.init_weights(rec_state_dict)
|
||||
|
||||
def forward(self, inputs: Tuple[Tensor],
|
||||
deep_supervision_idxs) -> Tuple[List]:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
inputs (Tuple[Tensor]): A triplet including images,
|
||||
list of multi-level visual features from image encoder and
|
||||
class embeddings from text_encoder.
|
||||
|
||||
Returns:
|
||||
mask_props (List[Tensor]): Mask proposals predicted by SAN.
|
||||
mask_logits (List[Tensor]): Class logits of mask proposals.
|
||||
"""
|
||||
imgs, clip_feature, class_embeds = inputs
|
||||
# predict mask proposals and attention bias
|
||||
mask_props, attn_biases = self.side_adapter_network(
|
||||
imgs, clip_feature, deep_supervision_idxs)
|
||||
|
||||
# mask recognition with attention bias
|
||||
mask_embeds = [
|
||||
self.rec_with_attnbias(att_bias, clip_feature[-1])
|
||||
for att_bias in attn_biases
|
||||
]
|
||||
# Obtain class prediction of masks by comparing the similarity
|
||||
# between the image token and the text embedding of class names.
|
||||
mask_logits = [
|
||||
torch.einsum('bqc,nc->bqn', mask_embed, class_embeds)
|
||||
for mask_embed in mask_embeds
|
||||
]
|
||||
return mask_props, mask_logits
|
||||
|
||||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType) -> Tensor:
|
||||
"""Forward function for prediction.
|
||||
|
||||
Args:
|
||||
inputs (Tuple[Tensor]): Images, visual features from image encoder
|
||||
and class embedding from text encoder.
|
||||
batch_img_metas (dict): List Image info where each dict may also
|
||||
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Outputs segmentation logits map.
|
||||
"""
|
||||
mask_props, mask_logits = self.forward(inputs, [])
|
||||
|
||||
return self.predict_by_feat([mask_props[-1], mask_logits[-1]],
|
||||
batch_img_metas)
|
||||
|
||||
def predict_by_feat(self, seg_logits: List[Tensor],
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""1. Transform a batch of mask proposals to the input shape.
|
||||
2. Generate segmentation map with mask proposals and class logits.
|
||||
"""
|
||||
mask_pred = seg_logits[0]
|
||||
cls_score = seg_logits[1]
|
||||
if isinstance(batch_img_metas[0]['img_shape'], torch.Size):
|
||||
# slide inference
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
elif 'pad_shape' in batch_img_metas[0]:
|
||||
size = batch_img_metas[0]['pad_shape'][:2]
|
||||
else:
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
# upsample mask
|
||||
mask_pred = F.interpolate(
|
||||
mask_pred, size=size, mode='bilinear', align_corners=False)
|
||||
|
||||
mask_cls = F.softmax(cls_score, dim=-1)[..., :-1]
|
||||
mask_pred = mask_pred.sigmoid()
|
||||
seg_logits = torch.einsum('bqc,bqhw->bchw', mask_cls, mask_pred)
|
||||
return seg_logits
|
||||
|
||||
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
|
||||
train_cfg: ConfigType) -> dict:
|
||||
"""Perform forward propagation and loss calculation of the decoder head
|
||||
on the features of the upstream network.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the upstream
|
||||
network, each is a 4D-tensor.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
train_cfg (ConfigType): Training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components.
|
||||
"""
|
||||
# batch SegDataSample to InstanceDataSample
|
||||
batch_gt_instances = seg_data_to_instance_data(self.ignore_index,
|
||||
batch_data_samples)
|
||||
|
||||
# forward
|
||||
all_mask_props, all_mask_logits = self.forward(
|
||||
x, self.deep_supervision_idxs)
|
||||
|
||||
# loss
|
||||
losses = self.loss_by_feat(all_mask_logits, all_mask_props,
|
||||
batch_gt_instances)
|
||||
|
||||
return losses
|
||||
|
||||
def loss_by_feat(
|
||||
self, all_cls_scores: Tensor, all_mask_preds: Tensor,
|
||||
batch_gt_instances: List[InstanceData]) -> Dict[str, Tensor]:
|
||||
"""Loss function.
|
||||
|
||||
Args:
|
||||
all_cls_scores (Tensor): Classification scores for all decoder
|
||||
layers with shape (num_decoder, batch_size, num_queries,
|
||||
cls_out_channels). Note `cls_out_channels` should includes
|
||||
background.
|
||||
all_mask_preds (Tensor): Mask scores for all decoder layers with
|
||||
shape (num_decoder, batch_size, num_queries, h, w).
|
||||
batch_gt_instances (list[obj:`InstanceData`]): each contains
|
||||
``labels`` and ``masks``.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
num_dec_layers = len(all_cls_scores)
|
||||
batch_gt_instances_list = [
|
||||
batch_gt_instances for _ in range(num_dec_layers)
|
||||
]
|
||||
|
||||
losses = []
|
||||
for i in range(num_dec_layers):
|
||||
cls_scores = all_cls_scores[i]
|
||||
mask_preds = all_mask_preds[i]
|
||||
# matching N mask predictions to K category labels
|
||||
(labels, mask_targets, mask_weights,
|
||||
avg_factor) = self.match_masks.get_targets(
|
||||
cls_scores, mask_preds, batch_gt_instances_list[i])
|
||||
cls_scores = cls_scores.flatten(0, 1)
|
||||
labels = labels.flatten(0, 1)
|
||||
num_total_masks = cls_scores.new_tensor([avg_factor],
|
||||
dtype=torch.float)
|
||||
all_reduce(num_total_masks, op='mean')
|
||||
num_total_masks = max(num_total_masks, 1)
|
||||
|
||||
# extract positive ones
|
||||
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
|
||||
mask_preds = mask_preds[mask_weights > 0]
|
||||
|
||||
if mask_targets.shape[0] != 0:
|
||||
with torch.no_grad():
|
||||
points_coords = get_uncertain_point_coords_with_randomness(
|
||||
mask_preds.unsqueeze(1), None,
|
||||
self.train_cfg.num_points,
|
||||
self.train_cfg.oversample_ratio,
|
||||
self.train_cfg.importance_sample_ratio)
|
||||
# shape (num_total_gts, h, w)
|
||||
# -> (num_total_gts, num_points)
|
||||
mask_point_targets = point_sample(
|
||||
mask_targets.unsqueeze(1).float(),
|
||||
points_coords).squeeze(1)
|
||||
# shape (num_queries, h, w) -> (num_queries, num_points)
|
||||
mask_point_preds = point_sample(
|
||||
mask_preds.unsqueeze(1), points_coords).squeeze(1)
|
||||
|
||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.loss_decode]
|
||||
else:
|
||||
losses_decode = self.loss_decode
|
||||
loss = dict()
|
||||
for loss_decode in losses_decode:
|
||||
if 'loss_cls' in loss_decode.loss_name:
|
||||
if loss_decode.loss_name == 'loss_cls_ce':
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
cls_scores, labels)
|
||||
else:
|
||||
assert False, "Only support 'CrossEntropyLoss' in" \
|
||||
' classification loss'
|
||||
|
||||
elif 'loss_mask' in loss_decode.loss_name:
|
||||
if mask_targets.shape[0] == 0:
|
||||
loss[loss_decode.loss_name] = mask_preds.sum()
|
||||
elif loss_decode.loss_name == 'loss_mask_ce':
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
mask_point_preds,
|
||||
mask_point_targets,
|
||||
avg_factor=num_total_masks *
|
||||
self.train_cfg.num_points)
|
||||
elif loss_decode.loss_name == 'loss_mask_dice':
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
mask_point_preds,
|
||||
mask_point_targets,
|
||||
avg_factor=num_total_masks)
|
||||
else:
|
||||
assert False, "Only support 'CrossEntropyLoss' and" \
|
||||
" 'DiceLoss' in mask loss"
|
||||
else:
|
||||
assert False, "Only support for 'loss_cls' and 'loss_mask'"
|
||||
|
||||
losses.append(loss)
|
||||
|
||||
loss_dict = dict()
|
||||
# loss from the last decoder layer
|
||||
loss_dict.update(losses[-1])
|
||||
# loss from other decoder layers
|
||||
for i, loss in enumerate(losses[:-1]):
|
||||
for k, v in loss.items():
|
||||
loss_dict[f'd{self.deep_supervision_idxs[i]}.{k}'] = v
|
||||
return loss_dict
|
||||
66
finetune/mmseg/models/decode_heads/segformer_head.py
Normal file
66
finetune/mmseg/models/decode_heads/segformer_head.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SegformerHead(BaseDecodeHead):
|
||||
"""The all mlp Head of segformer.
|
||||
|
||||
This head is the implementation of
|
||||
`Segformer <https://arxiv.org/abs/2105.15203>` _.
|
||||
|
||||
Args:
|
||||
interpolate_mode: The interpolate mode of MLP head upsample operation.
|
||||
Default: 'bilinear'.
|
||||
"""
|
||||
|
||||
def __init__(self, interpolate_mode='bilinear', **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
|
||||
self.interpolate_mode = interpolate_mode
|
||||
num_inputs = len(self.in_channels)
|
||||
|
||||
assert num_inputs == len(self.in_index)
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(num_inputs):
|
||||
self.convs.append(
|
||||
ConvModule(
|
||||
in_channels=self.in_channels[i],
|
||||
out_channels=self.channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
|
||||
self.fusion_conv = ConvModule(
|
||||
in_channels=self.channels * num_inputs,
|
||||
out_channels=self.channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=self.norm_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
# Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
|
||||
inputs = self._transform_inputs(inputs)
|
||||
outs = []
|
||||
for idx in range(len(inputs)):
|
||||
x = inputs[idx]
|
||||
conv = self.convs[idx]
|
||||
outs.append(
|
||||
resize(
|
||||
input=conv(x),
|
||||
size=inputs[0].shape[2:],
|
||||
mode=self.interpolate_mode,
|
||||
align_corners=self.align_corners))
|
||||
|
||||
out = self.fusion_conv(torch.cat(outs, dim=1))
|
||||
|
||||
out = self.cls_seg(out)
|
||||
|
||||
return out
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user