init
This commit is contained in:
245
finetune/mmseg/models/decode_heads/ann_head.py
Normal file
245
finetune/mmseg/models/decode_heads/ann_head.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class PPMConcat(nn.ModuleList):
|
||||
"""Pyramid Pooling Module that only concat the features of each layer.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 3, 6, 8)):
|
||||
super().__init__(
|
||||
[nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales])
|
||||
|
||||
def forward(self, feats):
|
||||
"""Forward function."""
|
||||
ppm_outs = []
|
||||
for ppm in self:
|
||||
ppm_out = ppm(feats)
|
||||
ppm_outs.append(ppm_out.view(*feats.shape[:2], -1))
|
||||
concat_outs = torch.cat(ppm_outs, dim=2)
|
||||
return concat_outs
|
||||
|
||||
|
||||
class SelfAttentionBlock(_SelfAttentionBlock):
|
||||
"""Make a ANN used SelfAttentionBlock.
|
||||
|
||||
Args:
|
||||
low_in_channels (int): Input channels of lower level feature,
|
||||
which is the key feature for self-attention.
|
||||
high_in_channels (int): Input channels of higher level feature,
|
||||
which is the query feature for self-attention.
|
||||
channels (int): Output channels of key/query transform.
|
||||
out_channels (int): Output channels.
|
||||
share_key_query (bool): Whether share projection weight between key
|
||||
and query projection.
|
||||
query_scale (int): The scale of query feature map.
|
||||
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module of key feature.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict|None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, low_in_channels, high_in_channels, channels,
|
||||
out_channels, share_key_query, query_scale, key_pool_scales,
|
||||
conv_cfg, norm_cfg, act_cfg):
|
||||
key_psp = PPMConcat(key_pool_scales)
|
||||
if query_scale > 1:
|
||||
query_downsample = nn.MaxPool2d(kernel_size=query_scale)
|
||||
else:
|
||||
query_downsample = None
|
||||
super().__init__(
|
||||
key_in_channels=low_in_channels,
|
||||
query_in_channels=high_in_channels,
|
||||
channels=channels,
|
||||
out_channels=out_channels,
|
||||
share_key_query=share_key_query,
|
||||
query_downsample=query_downsample,
|
||||
key_downsample=key_psp,
|
||||
key_query_num_convs=1,
|
||||
key_query_norm=True,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=False,
|
||||
matmul_norm=True,
|
||||
with_out=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
|
||||
class AFNB(nn.Module):
|
||||
"""Asymmetric Fusion Non-local Block(AFNB)
|
||||
|
||||
Args:
|
||||
low_in_channels (int): Input channels of lower level feature,
|
||||
which is the key feature for self-attention.
|
||||
high_in_channels (int): Input channels of higher level feature,
|
||||
which is the query feature for self-attention.
|
||||
channels (int): Output channels of key/query transform.
|
||||
out_channels (int): Output channels.
|
||||
and query projection.
|
||||
query_scales (tuple[int]): The scales of query feature map.
|
||||
Default: (1,)
|
||||
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module of key feature.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict|None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, low_in_channels, high_in_channels, channels,
|
||||
out_channels, query_scales, key_pool_scales, conv_cfg,
|
||||
norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.stages = nn.ModuleList()
|
||||
for query_scale in query_scales:
|
||||
self.stages.append(
|
||||
SelfAttentionBlock(
|
||||
low_in_channels=low_in_channels,
|
||||
high_in_channels=high_in_channels,
|
||||
channels=channels,
|
||||
out_channels=out_channels,
|
||||
share_key_query=False,
|
||||
query_scale=query_scale,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.bottleneck = ConvModule(
|
||||
out_channels + high_in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
def forward(self, low_feats, high_feats):
|
||||
"""Forward function."""
|
||||
priors = [stage(high_feats, low_feats) for stage in self.stages]
|
||||
context = torch.stack(priors, dim=0).sum(dim=0)
|
||||
output = self.bottleneck(torch.cat([context, high_feats], 1))
|
||||
return output
|
||||
|
||||
|
||||
class APNB(nn.Module):
|
||||
"""Asymmetric Pyramid Non-local Block (APNB)
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of key/query feature,
|
||||
which is the key feature for self-attention.
|
||||
channels (int): Output channels of key/query transform.
|
||||
out_channels (int): Output channels.
|
||||
query_scales (tuple[int]): The scales of query feature map.
|
||||
Default: (1,)
|
||||
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module of key feature.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict|None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, channels, out_channels, query_scales,
|
||||
key_pool_scales, conv_cfg, norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.stages = nn.ModuleList()
|
||||
for query_scale in query_scales:
|
||||
self.stages.append(
|
||||
SelfAttentionBlock(
|
||||
low_in_channels=in_channels,
|
||||
high_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=out_channels,
|
||||
share_key_query=True,
|
||||
query_scale=query_scale,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.bottleneck = ConvModule(
|
||||
2 * in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, feats):
|
||||
"""Forward function."""
|
||||
priors = [stage(feats, feats) for stage in self.stages]
|
||||
context = torch.stack(priors, dim=0).sum(dim=0)
|
||||
output = self.bottleneck(torch.cat([context, feats], 1))
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ANNHead(BaseDecodeHead):
|
||||
"""Asymmetric Non-local Neural Networks for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `ANNNet
|
||||
<https://arxiv.org/abs/1908.07678>`_.
|
||||
|
||||
Args:
|
||||
project_channels (int): Projection channels for Nonlocal.
|
||||
query_scales (tuple[int]): The scales of query feature map.
|
||||
Default: (1,)
|
||||
key_pool_scales (tuple[int]): The pooling scales of key feature map.
|
||||
Default: (1, 3, 6, 8).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
project_channels,
|
||||
query_scales=(1, ),
|
||||
key_pool_scales=(1, 3, 6, 8),
|
||||
**kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
assert len(self.in_channels) == 2
|
||||
low_in_channels, high_in_channels = self.in_channels
|
||||
self.project_channels = project_channels
|
||||
self.fusion = AFNB(
|
||||
low_in_channels=low_in_channels,
|
||||
high_in_channels=high_in_channels,
|
||||
out_channels=high_in_channels,
|
||||
channels=project_channels,
|
||||
query_scales=query_scales,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.bottleneck = ConvModule(
|
||||
high_in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.context = APNB(
|
||||
in_channels=self.channels,
|
||||
out_channels=self.channels,
|
||||
channels=project_channels,
|
||||
query_scales=query_scales,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
low_feats, high_feats = self._transform_inputs(inputs)
|
||||
output = self.fusion(low_feats, high_feats)
|
||||
output = self.dropout(output)
|
||||
output = self.bottleneck(output)
|
||||
output = self.context(output)
|
||||
output = self.cls_seg(output)
|
||||
|
||||
return output
|
||||
Reference in New Issue
Block a user