init
This commit is contained in:
117
finetune/mmseg/models/decode_heads/psp_head.py
Normal file
117
finetune/mmseg/models/decode_heads/psp_head.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class PPM(nn.ModuleList):
|
||||
"""Pooling Pyramid Module used in PSPNet.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
|
||||
act_cfg, align_corners, **kwargs):
|
||||
super().__init__()
|
||||
self.pool_scales = pool_scales
|
||||
self.align_corners = align_corners
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
for pool_scale in pool_scales:
|
||||
self.append(
|
||||
nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(pool_scale),
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
**kwargs)))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
ppm_outs = []
|
||||
for ppm in self:
|
||||
ppm_out = ppm(x)
|
||||
upsampled_ppm_out = resize(
|
||||
ppm_out,
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
ppm_outs.append(upsampled_ppm_out)
|
||||
return ppm_outs
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PSPHead(BaseDecodeHead):
|
||||
"""Pyramid Scene Parsing Network.
|
||||
|
||||
This head is the implementation of
|
||||
`PSPNet <https://arxiv.org/abs/1612.01105>`_.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module. Default: (1, 2, 3, 6).
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(pool_scales, (list, tuple))
|
||||
self.pool_scales = pool_scales
|
||||
self.psp_modules = PPM(
|
||||
self.pool_scales,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels + len(pool_scales) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def _forward_feature(self, inputs):
|
||||
"""Forward function for feature maps before classifying each pixel with
|
||||
``self.cls_seg`` fc.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||
H, W) which is feature map for last layer of decoder head.
|
||||
"""
|
||||
x = self._transform_inputs(inputs)
|
||||
psp_outs = [x]
|
||||
psp_outs.extend(self.psp_modules(x))
|
||||
psp_outs = torch.cat(psp_outs, dim=1)
|
||||
feats = self.bottleneck(psp_outs)
|
||||
return feats
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
output = self._forward_feature(inputs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
Reference in New Issue
Block a user