init
This commit is contained in:
102
finetune/mmseg/models/decode_heads/sep_aspp_head.py
Normal file
102
finetune/mmseg/models/decode_heads/sep_aspp_head.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .aspp_head import ASPPHead, ASPPModule
|
||||
|
||||
|
||||
class DepthwiseSeparableASPPModule(ASPPModule):
|
||||
"""Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable
|
||||
conv."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for i, dilation in enumerate(self.dilations):
|
||||
if dilation > 1:
|
||||
self[i] = DepthwiseSeparableConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
dilation=dilation,
|
||||
padding=dilation,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DepthwiseSeparableASPPHead(ASPPHead):
|
||||
"""Encoder-Decoder with Atrous Separable Convolution for Semantic Image
|
||||
Segmentation.
|
||||
|
||||
This head is the implementation of `DeepLabV3+
|
||||
<https://arxiv.org/abs/1802.02611>`_.
|
||||
|
||||
Args:
|
||||
c1_in_channels (int): The input channels of c1 decoder. If is 0,
|
||||
the no decoder will be used.
|
||||
c1_channels (int): The intermediate channels of c1 decoder.
|
||||
"""
|
||||
|
||||
def __init__(self, c1_in_channels, c1_channels, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert c1_in_channels >= 0
|
||||
self.aspp_modules = DepthwiseSeparableASPPModule(
|
||||
dilations=self.dilations,
|
||||
in_channels=self.in_channels,
|
||||
channels=self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
if c1_in_channels > 0:
|
||||
self.c1_bottleneck = ConvModule(
|
||||
c1_in_channels,
|
||||
c1_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
else:
|
||||
self.c1_bottleneck = None
|
||||
self.sep_bottleneck = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
self.channels + c1_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
DepthwiseSeparableConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
aspp_outs = [
|
||||
resize(
|
||||
self.image_pool(x),
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
]
|
||||
aspp_outs.extend(self.aspp_modules(x))
|
||||
aspp_outs = torch.cat(aspp_outs, dim=1)
|
||||
output = self.bottleneck(aspp_outs)
|
||||
if self.c1_bottleneck is not None:
|
||||
c1_output = self.c1_bottleneck(inputs[0])
|
||||
output = resize(
|
||||
input=output,
|
||||
size=c1_output.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
output = torch.cat([output, c1_output], dim=1)
|
||||
output = self.sep_bottleneck(output)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
Reference in New Issue
Block a user