init
This commit is contained in:
90
finetune/mmseg/models/necks/fusion_multilevel_neck.py
Normal file
90
finetune/mmseg/models/necks/fusion_multilevel_neck.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .multilevel_neck import MultiLevelNeck
|
||||
from .fusion_transformer import FusionTransformer
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FusionMultiLevelNeck(nn.Module):
|
||||
def __init__(self,
|
||||
ts_size=10,
|
||||
in_channels_ml=[768, 768, 768, 768],
|
||||
out_channels_ml=768,
|
||||
scales_ml=[0.5, 1, 2, 4],
|
||||
norm_cfg_ml=None,
|
||||
act_cfg_ml=None,
|
||||
input_dims=768,
|
||||
embed_dims=768,
|
||||
num_layers=4,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
with_cls_token=True,
|
||||
output_cls_token=True,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
init_cfg=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super(FusionMultiLevelNeck, self).__init__()
|
||||
self.in_channels = in_channels_ml
|
||||
self.ts_size = ts_size
|
||||
self.multilevel_neck = MultiLevelNeck(
|
||||
in_channels_ml,
|
||||
out_channels_ml,
|
||||
scales_ml,
|
||||
norm_cfg_ml,
|
||||
act_cfg_ml
|
||||
)
|
||||
# self.up_head = UPHead(1024, 2816, 4)
|
||||
|
||||
self.fusion_transformer = FusionTransformer(
|
||||
input_dims,
|
||||
embed_dims,
|
||||
num_layers,
|
||||
num_heads,
|
||||
mlp_ratio,
|
||||
qkv_bias,
|
||||
drop_rate,
|
||||
attn_drop_rate,
|
||||
drop_path_rate,
|
||||
with_cls_token,
|
||||
output_cls_token,
|
||||
norm_cfg,
|
||||
act_cfg,
|
||||
num_fcs,
|
||||
norm_eval,
|
||||
with_cp,
|
||||
init_cfg,
|
||||
)
|
||||
|
||||
def init_weights(self):
|
||||
self.fusion_transformer.init_weights()
|
||||
|
||||
def forward(self, inputs, require_feat: bool = False, require_two: bool = False):
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
|
||||
inputs = self.multilevel_neck(inputs)
|
||||
|
||||
ts = self.ts_size
|
||||
b_total, c, h, w = inputs[-1].shape
|
||||
b = int(b_total / ts)
|
||||
outs = []
|
||||
for idx in range(len(inputs)):
|
||||
|
||||
input_feat = inputs[idx]
|
||||
b_total, c, h, w = inputs[idx].shape
|
||||
input_feat = input_feat.reshape(b, ts, c, h, w).permute(0, 3, 4, 1, 2).reshape(b*h*w, ts, c) # b*ts, c, h, w转换为b*h*w, ts, c
|
||||
feat_fusion = self.fusion_transformer(input_feat, require_feat, require_two)
|
||||
c_fusion = feat_fusion.shape[-1]
|
||||
feat_fusion = feat_fusion.reshape(b, h, w, c_fusion).permute(0, 3, 1, 2) # b*h*w, c -> b, c, h, w
|
||||
outs.append(feat_fusion)
|
||||
|
||||
return tuple(outs)
|
||||
Reference in New Issue
Block a user