init
This commit is contained in:
43
finetune/mmseg/models/decode_heads/cc_head.py
Normal file
43
finetune/mmseg/models/decode_heads/cc_head.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
try:
|
||||
from mmcv.ops import CrissCrossAttention
|
||||
except ModuleNotFoundError:
|
||||
CrissCrossAttention = None
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CCHead(FCNHead):
|
||||
"""CCNet: Criss-Cross Attention for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `CCNet
|
||||
<https://arxiv.org/abs/1811.11721>`_.
|
||||
|
||||
Args:
|
||||
recurrence (int): Number of recurrence of Criss Cross Attention
|
||||
module. Default: 2.
|
||||
"""
|
||||
|
||||
def __init__(self, recurrence=2, **kwargs):
|
||||
if CrissCrossAttention is None:
|
||||
raise RuntimeError('Please install mmcv-full for '
|
||||
'CrissCrossAttention ops')
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.recurrence = recurrence
|
||||
self.cca = CrissCrossAttention(self.channels)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
for _ in range(self.recurrence):
|
||||
output = self.cca(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
Reference in New Issue
Block a user