init
This commit is contained in:
184
finetune/mmseg/models/decode_heads/da_head.py
Normal file
184
finetune/mmseg/models/decode_heads/da_head.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, Scale
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList, add_prefix
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class PAM(_SelfAttentionBlock):
|
||||
"""Position Attention Module (PAM)
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of key/query feature.
|
||||
channels (int): Output channels of key/query transform.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, channels):
|
||||
super().__init__(
|
||||
key_in_channels=in_channels,
|
||||
query_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=in_channels,
|
||||
share_key_query=False,
|
||||
query_downsample=None,
|
||||
key_downsample=None,
|
||||
key_query_num_convs=1,
|
||||
key_query_norm=False,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=False,
|
||||
matmul_norm=False,
|
||||
with_out=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=None)
|
||||
|
||||
self.gamma = Scale(0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
out = super().forward(x, x)
|
||||
|
||||
out = self.gamma(out) + x
|
||||
return out
|
||||
|
||||
|
||||
class CAM(nn.Module):
|
||||
"""Channel Attention Module (CAM)"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gamma = Scale(0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
batch_size, channels, height, width = x.size()
|
||||
proj_query = x.view(batch_size, channels, -1)
|
||||
proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1)
|
||||
energy = torch.bmm(proj_query, proj_key)
|
||||
energy_new = torch.max(
|
||||
energy, -1, keepdim=True)[0].expand_as(energy) - energy
|
||||
attention = F.softmax(energy_new, dim=-1)
|
||||
proj_value = x.view(batch_size, channels, -1)
|
||||
|
||||
out = torch.bmm(attention, proj_value)
|
||||
out = out.view(batch_size, channels, height, width)
|
||||
|
||||
out = self.gamma(out) + x
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DAHead(BaseDecodeHead):
|
||||
"""Dual Attention Network for Scene Segmentation.
|
||||
|
||||
This head is the implementation of `DANet
|
||||
<https://arxiv.org/abs/1809.02983>`_.
|
||||
|
||||
Args:
|
||||
pam_channels (int): The channels of Position Attention Module(PAM).
|
||||
"""
|
||||
|
||||
def __init__(self, pam_channels, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.pam_channels = pam_channels
|
||||
self.pam_in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.pam = PAM(self.channels, pam_channels)
|
||||
self.pam_out_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.pam_conv_seg = nn.Conv2d(
|
||||
self.channels, self.num_classes, kernel_size=1)
|
||||
|
||||
self.cam_in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.cam = CAM()
|
||||
self.cam_out_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.cam_conv_seg = nn.Conv2d(
|
||||
self.channels, self.num_classes, kernel_size=1)
|
||||
|
||||
def pam_cls_seg(self, feat):
|
||||
"""PAM feature classification."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.pam_conv_seg(feat)
|
||||
return output
|
||||
|
||||
def cam_cls_seg(self, feat):
|
||||
"""CAM feature classification."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.cam_conv_seg(feat)
|
||||
return output
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
pam_feat = self.pam_in_conv(x)
|
||||
pam_feat = self.pam(pam_feat)
|
||||
pam_feat = self.pam_out_conv(pam_feat)
|
||||
pam_out = self.pam_cls_seg(pam_feat)
|
||||
|
||||
cam_feat = self.cam_in_conv(x)
|
||||
cam_feat = self.cam(cam_feat)
|
||||
cam_feat = self.cam_out_conv(cam_feat)
|
||||
cam_out = self.cam_cls_seg(cam_feat)
|
||||
|
||||
feat_sum = pam_feat + cam_feat
|
||||
pam_cam_out = self.cls_seg(feat_sum)
|
||||
|
||||
return pam_cam_out, pam_out, cam_out
|
||||
|
||||
def predict(self, inputs, batch_img_metas: List[dict], test_cfg,
|
||||
**kwargs) -> List[Tensor]:
|
||||
"""Forward function for testing, only ``pam_cam`` is used."""
|
||||
seg_logits = self.forward(inputs)[0]
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
|
||||
|
||||
def loss_by_feat(self, seg_logit: Tuple[Tensor],
|
||||
batch_data_samples: SampleList, **kwargs) -> dict:
|
||||
"""Compute ``pam_cam``, ``pam``, ``cam`` loss."""
|
||||
pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit
|
||||
loss = dict()
|
||||
loss.update(
|
||||
add_prefix(
|
||||
super().loss_by_feat(pam_cam_seg_logit, batch_data_samples),
|
||||
'pam_cam'))
|
||||
loss.update(
|
||||
add_prefix(super().loss_by_feat(pam_seg_logit, batch_data_samples),
|
||||
'pam'))
|
||||
loss.update(
|
||||
add_prefix(super().loss_by_feat(cam_seg_logit, batch_data_samples),
|
||||
'cam'))
|
||||
return loss
|
||||
Reference in New Issue
Block a user