init
This commit is contained in:
131
finetune/mmseg/models/necks/jpu.py
Normal file
131
finetune/mmseg/models/necks/jpu.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class JPU(BaseModule):
|
||||
"""FastFCN: Rethinking Dilated Convolution in the Backbone
|
||||
for Semantic Segmentation.
|
||||
|
||||
This Joint Pyramid Upsampling (JPU) neck is the implementation of
|
||||
`FastFCN <https://arxiv.org/abs/1903.11816>`_.
|
||||
|
||||
Args:
|
||||
in_channels (Tuple[int], optional): The number of input channels
|
||||
for each convolution operations before upsampling.
|
||||
Default: (512, 1024, 2048).
|
||||
mid_channels (int): The number of output channels of JPU.
|
||||
Default: 512.
|
||||
start_level (int): Index of the start input backbone level used to
|
||||
build the feature pyramid. Default: 0.
|
||||
end_level (int): Index of the end input backbone level (exclusive) to
|
||||
build the feature pyramid. Default: -1, which means the last level.
|
||||
dilations (tuple[int]): Dilation rate of each Depthwise
|
||||
Separable ConvModule. Default: (1, 2, 4, 8).
|
||||
align_corners (bool, optional): The align_corners argument of
|
||||
resize operation. Default: False.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=(512, 1024, 2048),
|
||||
mid_channels=512,
|
||||
start_level=0,
|
||||
end_level=-1,
|
||||
dilations=(1, 2, 4, 8),
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert isinstance(in_channels, tuple)
|
||||
assert isinstance(dilations, tuple)
|
||||
self.in_channels = in_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.start_level = start_level
|
||||
self.num_ins = len(in_channels)
|
||||
if end_level == -1:
|
||||
self.backbone_end_level = self.num_ins
|
||||
else:
|
||||
self.backbone_end_level = end_level
|
||||
assert end_level <= len(in_channels)
|
||||
|
||||
self.dilations = dilations
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.dilation_layers = nn.ModuleList()
|
||||
for i in range(self.start_level, self.backbone_end_level):
|
||||
conv_layer = nn.Sequential(
|
||||
ConvModule(
|
||||
self.in_channels[i],
|
||||
self.mid_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.conv_layers.append(conv_layer)
|
||||
for i in range(len(dilations)):
|
||||
dilation_layer = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=(self.backbone_end_level - self.start_level) *
|
||||
self.mid_channels,
|
||||
out_channels=self.mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=dilations[i],
|
||||
dilation=dilations[i],
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=norm_cfg,
|
||||
pw_act_cfg=act_cfg))
|
||||
self.dilation_layers.append(dilation_layer)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
assert len(inputs) == len(self.in_channels), 'Length of inputs must \
|
||||
be the same with self.in_channels!'
|
||||
|
||||
feats = [
|
||||
self.conv_layers[i - self.start_level](inputs[i])
|
||||
for i in range(self.start_level, self.backbone_end_level)
|
||||
]
|
||||
|
||||
h, w = feats[0].shape[2:]
|
||||
for i in range(1, len(feats)):
|
||||
feats[i] = resize(
|
||||
feats[i],
|
||||
size=(h, w),
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
feat = torch.cat(feats, dim=1)
|
||||
concat_feat = torch.cat([
|
||||
self.dilation_layers[i](feat) for i in range(len(self.dilations))
|
||||
],
|
||||
dim=1)
|
||||
|
||||
outs = []
|
||||
|
||||
# Default: outs[2] is the output of JPU for decoder head, outs[1] is
|
||||
# the feature map from backbone for auxiliary head. Additionally,
|
||||
# outs[0] can also be used for auxiliary head.
|
||||
for i in range(self.start_level, self.backbone_end_level - 1):
|
||||
outs.append(inputs[i])
|
||||
outs.append(concat_feat)
|
||||
return tuple(outs)
|
||||
Reference in New Issue
Block a user