init
This commit is contained in:
197
finetune/mmseg/models/decode_heads/psa_head.py
Normal file
197
finetune/mmseg/models/decode_heads/psa_head.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
try:
|
||||
from mmcv.ops import PSAMask
|
||||
except ModuleNotFoundError:
|
||||
PSAMask = None
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PSAHead(BaseDecodeHead):
|
||||
"""Point-wise Spatial Attention Network for Scene Parsing.
|
||||
|
||||
This head is the implementation of `PSANet
|
||||
<https://hszhao.github.io/papers/eccv18_psanet.pdf>`_.
|
||||
|
||||
Args:
|
||||
mask_size (tuple[int]): The PSA mask size. It usually equals input
|
||||
size.
|
||||
psa_type (str): The type of psa module. Options are 'collect',
|
||||
'distribute', 'bi-direction'. Default: 'bi-direction'
|
||||
compact (bool): Whether use compact map for 'collect' mode.
|
||||
Default: True.
|
||||
shrink_factor (int): The downsample factors of psa mask. Default: 2.
|
||||
normalization_factor (float): The normalize factor of attention.
|
||||
psa_softmax (bool): Whether use softmax for attention.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mask_size,
|
||||
psa_type='bi-direction',
|
||||
compact=False,
|
||||
shrink_factor=2,
|
||||
normalization_factor=1.0,
|
||||
psa_softmax=True,
|
||||
**kwargs):
|
||||
if PSAMask is None:
|
||||
raise RuntimeError('Please install mmcv-full for PSAMask ops')
|
||||
super().__init__(**kwargs)
|
||||
assert psa_type in ['collect', 'distribute', 'bi-direction']
|
||||
self.psa_type = psa_type
|
||||
self.compact = compact
|
||||
self.shrink_factor = shrink_factor
|
||||
self.mask_size = mask_size
|
||||
mask_h, mask_w = mask_size
|
||||
self.psa_softmax = psa_softmax
|
||||
if normalization_factor is None:
|
||||
normalization_factor = mask_h * mask_w
|
||||
self.normalization_factor = normalization_factor
|
||||
|
||||
self.reduce = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.attention = nn.Sequential(
|
||||
ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
nn.Conv2d(
|
||||
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
|
||||
if psa_type == 'bi-direction':
|
||||
self.reduce_p = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.attention_p = nn.Sequential(
|
||||
ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
nn.Conv2d(
|
||||
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
|
||||
self.psamask_collect = PSAMask('collect', mask_size)
|
||||
self.psamask_distribute = PSAMask('distribute', mask_size)
|
||||
else:
|
||||
self.psamask = PSAMask(psa_type, mask_size)
|
||||
self.proj = ConvModule(
|
||||
self.channels * (2 if psa_type == 'bi-direction' else 1),
|
||||
self.in_channels,
|
||||
kernel_size=1,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels * 2,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
identity = x
|
||||
align_corners = self.align_corners
|
||||
if self.psa_type in ['collect', 'distribute']:
|
||||
out = self.reduce(x)
|
||||
n, c, h, w = out.size()
|
||||
if self.shrink_factor != 1:
|
||||
if h % self.shrink_factor and w % self.shrink_factor:
|
||||
h = (h - 1) // self.shrink_factor + 1
|
||||
w = (w - 1) // self.shrink_factor + 1
|
||||
align_corners = True
|
||||
else:
|
||||
h = h // self.shrink_factor
|
||||
w = w // self.shrink_factor
|
||||
align_corners = False
|
||||
out = resize(
|
||||
out,
|
||||
size=(h, w),
|
||||
mode='bilinear',
|
||||
align_corners=align_corners)
|
||||
y = self.attention(out)
|
||||
if self.compact:
|
||||
if self.psa_type == 'collect':
|
||||
y = y.view(n, h * w,
|
||||
h * w).transpose(1, 2).view(n, h * w, h, w)
|
||||
else:
|
||||
y = self.psamask(y)
|
||||
if self.psa_softmax:
|
||||
y = F.softmax(y, dim=1)
|
||||
out = torch.bmm(
|
||||
out.view(n, c, h * w), y.view(n, h * w, h * w)).view(
|
||||
n, c, h, w) * (1.0 / self.normalization_factor)
|
||||
else:
|
||||
x_col = self.reduce(x)
|
||||
x_dis = self.reduce_p(x)
|
||||
n, c, h, w = x_col.size()
|
||||
if self.shrink_factor != 1:
|
||||
if h % self.shrink_factor and w % self.shrink_factor:
|
||||
h = (h - 1) // self.shrink_factor + 1
|
||||
w = (w - 1) // self.shrink_factor + 1
|
||||
align_corners = True
|
||||
else:
|
||||
h = h // self.shrink_factor
|
||||
w = w // self.shrink_factor
|
||||
align_corners = False
|
||||
x_col = resize(
|
||||
x_col,
|
||||
size=(h, w),
|
||||
mode='bilinear',
|
||||
align_corners=align_corners)
|
||||
x_dis = resize(
|
||||
x_dis,
|
||||
size=(h, w),
|
||||
mode='bilinear',
|
||||
align_corners=align_corners)
|
||||
y_col = self.attention(x_col)
|
||||
y_dis = self.attention_p(x_dis)
|
||||
if self.compact:
|
||||
y_dis = y_dis.view(n, h * w,
|
||||
h * w).transpose(1, 2).view(n, h * w, h, w)
|
||||
else:
|
||||
y_col = self.psamask_collect(y_col)
|
||||
y_dis = self.psamask_distribute(y_dis)
|
||||
if self.psa_softmax:
|
||||
y_col = F.softmax(y_col, dim=1)
|
||||
y_dis = F.softmax(y_dis, dim=1)
|
||||
x_col = torch.bmm(
|
||||
x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(
|
||||
n, c, h, w) * (1.0 / self.normalization_factor)
|
||||
x_dis = torch.bmm(
|
||||
x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(
|
||||
n, c, h, w) * (1.0 / self.normalization_factor)
|
||||
out = torch.cat([x_col, x_dis], 1)
|
||||
out = self.proj(out)
|
||||
out = resize(
|
||||
out,
|
||||
size=identity.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=align_corners)
|
||||
out = self.bottleneck(torch.cat((identity, out), dim=1))
|
||||
out = self.cls_seg(out)
|
||||
return out
|
||||
Reference in New Issue
Block a user