init
This commit is contained in:
62
finetune/mmseg/models/decode_heads/setr_mla_head.py
Normal file
62
finetune/mmseg/models/decode_heads/setr_mla_head.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SETRMLAHead(BaseDecodeHead):
|
||||
"""Multi level feature aggretation head of SETR.
|
||||
|
||||
MLA head of `SETR <https://arxiv.org/pdf/2012.15840.pdf>`_.
|
||||
|
||||
Args:
|
||||
mlahead_channels (int): Channels of conv-conv-4x of multi-level feature
|
||||
aggregation. Default: 128.
|
||||
up_scale (int): The scale factor of interpolate. Default:4.
|
||||
"""
|
||||
|
||||
def __init__(self, mla_channels=128, up_scale=4, **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
self.mla_channels = mla_channels
|
||||
|
||||
num_inputs = len(self.in_channels)
|
||||
|
||||
# Refer to self.cls_seg settings of BaseDecodeHead
|
||||
assert self.channels == num_inputs * mla_channels
|
||||
|
||||
self.up_convs = nn.ModuleList()
|
||||
for i in range(num_inputs):
|
||||
self.up_convs.append(
|
||||
nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.in_channels[i],
|
||||
out_channels=mla_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
in_channels=mla_channels,
|
||||
out_channels=mla_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
Upsample(
|
||||
scale_factor=up_scale,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)))
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs = self._transform_inputs(inputs)
|
||||
outs = []
|
||||
for x, up_conv in zip(inputs, self.up_convs):
|
||||
outs.append(up_conv(x))
|
||||
out = torch.cat(outs, dim=1)
|
||||
out = self.cls_seg(out)
|
||||
return out
|
||||
Reference in New Issue
Block a user