Files
SkySensePlusPlus/finetune/mmseg/models/necks/fusion_multilevel_neck.py
esenke 01adcfdf60 init
2025-12-08 22:16:31 +08:00

90 lines
2.9 KiB
Python

import torch
import torch.nn as nn
from .multilevel_neck import MultiLevelNeck
from .fusion_transformer import FusionTransformer
from mmseg.registry import MODELS
@MODELS.register_module()
class FusionMultiLevelNeck(nn.Module):
def __init__(self,
ts_size=10,
in_channels_ml=[768, 768, 768, 768],
out_channels_ml=768,
scales_ml=[0.5, 1, 2, 4],
norm_cfg_ml=None,
act_cfg_ml=None,
input_dims=768,
embed_dims=768,
num_layers=4,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
with_cls_token=True,
output_cls_token=True,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
num_fcs=2,
norm_eval=False,
with_cp=False,
init_cfg=None,
*args,
**kwargs):
super(FusionMultiLevelNeck, self).__init__()
self.in_channels = in_channels_ml
self.ts_size = ts_size
self.multilevel_neck = MultiLevelNeck(
in_channels_ml,
out_channels_ml,
scales_ml,
norm_cfg_ml,
act_cfg_ml
)
# self.up_head = UPHead(1024, 2816, 4)
self.fusion_transformer = FusionTransformer(
input_dims,
embed_dims,
num_layers,
num_heads,
mlp_ratio,
qkv_bias,
drop_rate,
attn_drop_rate,
drop_path_rate,
with_cls_token,
output_cls_token,
norm_cfg,
act_cfg,
num_fcs,
norm_eval,
with_cp,
init_cfg,
)
def init_weights(self):
self.fusion_transformer.init_weights()
def forward(self, inputs, require_feat: bool = False, require_two: bool = False):
assert len(inputs) == len(self.in_channels)
inputs = self.multilevel_neck(inputs)
ts = self.ts_size
b_total, c, h, w = inputs[-1].shape
b = int(b_total / ts)
outs = []
for idx in range(len(inputs)):
input_feat = inputs[idx]
b_total, c, h, w = inputs[idx].shape
input_feat = input_feat.reshape(b, ts, c, h, w).permute(0, 3, 4, 1, 2).reshape(b*h*w, ts, c) # b*ts, c, h, w转换为b*h*w, ts, c
feat_fusion = self.fusion_transformer(input_feat, require_feat, require_two)
c_fusion = feat_fusion.shape[-1]
feat_fusion = feat_fusion.reshape(b, h, w, c_fusion).permute(0, 3, 1, 2) # b*h*w, c -> b, c, h, w
outs.append(feat_fusion)
return tuple(outs)