init
This commit is contained in:
48
finetune/mmseg/models/decode_heads/gc_head.py
Normal file
48
finetune/mmseg/models/decode_heads/gc_head.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.cnn import ContextBlock
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class GCHead(FCNHead):
|
||||
"""GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.
|
||||
|
||||
This head is the implementation of `GCNet
|
||||
<https://arxiv.org/abs/1904.11492>`_.
|
||||
|
||||
Args:
|
||||
ratio (float): Multiplier of channels ratio. Default: 1/4.
|
||||
pooling_type (str): The pooling type of context aggregation.
|
||||
Options are 'att', 'avg'. Default: 'avg'.
|
||||
fusion_types (tuple[str]): The fusion type for feature fusion.
|
||||
Options are 'channel_add', 'channel_mul'. Default: ('channel_add',)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ratio=1 / 4.,
|
||||
pooling_type='att',
|
||||
fusion_types=('channel_add', ),
|
||||
**kwargs):
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.ratio = ratio
|
||||
self.pooling_type = pooling_type
|
||||
self.fusion_types = fusion_types
|
||||
self.gc_block = ContextBlock(
|
||||
in_channels=self.channels,
|
||||
ratio=self.ratio,
|
||||
pooling_type=self.pooling_type,
|
||||
fusion_types=self.fusion_types)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
output = self.gc_block(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
Reference in New Issue
Block a user