init
This commit is contained in:
118
finetune/mmseg/models/necks/mla_neck.py
Normal file
118
finetune/mmseg/models/necks/mla_neck.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class MLAModule(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels=[1024, 1024, 1024, 1024],
|
||||
out_channels=256,
|
||||
norm_cfg=None,
|
||||
act_cfg=None):
|
||||
super().__init__()
|
||||
self.channel_proj = nn.ModuleList()
|
||||
for i in range(len(in_channels)):
|
||||
self.channel_proj.append(
|
||||
ConvModule(
|
||||
in_channels=in_channels[i],
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.feat_extract = nn.ModuleList()
|
||||
for i in range(len(in_channels)):
|
||||
self.feat_extract.append(
|
||||
ConvModule(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
# feat_list -> [p2, p3, p4, p5]
|
||||
feat_list = []
|
||||
for x, conv in zip(inputs, self.channel_proj):
|
||||
feat_list.append(conv(x))
|
||||
|
||||
# feat_list -> [p5, p4, p3, p2]
|
||||
# mid_list -> [m5, m4, m3, m2]
|
||||
feat_list = feat_list[::-1]
|
||||
mid_list = []
|
||||
for feat in feat_list:
|
||||
if len(mid_list) == 0:
|
||||
mid_list.append(feat)
|
||||
else:
|
||||
mid_list.append(mid_list[-1] + feat)
|
||||
|
||||
# mid_list -> [m5, m4, m3, m2]
|
||||
# out_list -> [o2, o3, o4, o5]
|
||||
out_list = []
|
||||
for mid, conv in zip(mid_list, self.feat_extract):
|
||||
out_list.append(conv(mid))
|
||||
|
||||
return tuple(out_list)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MLANeck(nn.Module):
|
||||
"""Multi-level Feature Aggregation.
|
||||
|
||||
This neck is `The Multi-level Feature Aggregation construction of
|
||||
SETR <https://arxiv.org/abs/2012.15840>`_.
|
||||
|
||||
|
||||
Args:
|
||||
in_channels (List[int]): Number of input channels per scale.
|
||||
out_channels (int): Number of output channels (used at each scale).
|
||||
norm_layer (dict): Config dict for input normalization.
|
||||
Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True).
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
act_cfg (dict): Config dict for activation layer in ConvModule.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
norm_layer=dict(type='LN', eps=1e-6, requires_grad=True),
|
||||
norm_cfg=None,
|
||||
act_cfg=None):
|
||||
super().__init__()
|
||||
assert isinstance(in_channels, list)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
# In order to build general vision transformer backbone, we have to
|
||||
# move MLA to neck.
|
||||
self.norm = nn.ModuleList([
|
||||
build_norm_layer(norm_layer, in_channels[i])[1]
|
||||
for i in range(len(in_channels))
|
||||
])
|
||||
|
||||
self.mla = MLAModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
|
||||
# Convert from nchw to nlc
|
||||
outs = []
|
||||
for i in range(len(inputs)):
|
||||
x = inputs[i]
|
||||
n, c, h, w = x.shape
|
||||
x = x.reshape(n, c, h * w).transpose(2, 1).contiguous()
|
||||
x = self.norm[i](x)
|
||||
x = x.transpose(1, 2).reshape(n, c, h, w).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
outs = self.mla(outs)
|
||||
return tuple(outs)
|
||||
Reference in New Issue
Block a user