init
This commit is contained in:
116
finetune/mmseg/models/decode_heads/ddr_head.py
Normal file
116
finetune/mmseg/models/decode_heads/ddr_head.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.models.losses import accuracy
|
||||
from mmseg.models.utils import resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType, SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DDRHead(BaseDecodeHead):
|
||||
"""Decode head for DDRNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
num_classes (int): Number of classes.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict, optional): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
num_classes: int,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
channels,
|
||||
num_classes=num_classes,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**kwargs)
|
||||
|
||||
self.head = self._make_base_head(self.in_channels, self.channels)
|
||||
self.aux_head = self._make_base_head(self.in_channels // 2,
|
||||
self.channels)
|
||||
self.aux_cls_seg = nn.Conv2d(
|
||||
self.channels, self.out_channels, kernel_size=1)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: Union[Tensor,
|
||||
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
|
||||
if self.training:
|
||||
c3_feat, c5_feat = inputs
|
||||
x_c = self.head(c5_feat)
|
||||
x_c = self.cls_seg(x_c)
|
||||
x_s = self.aux_head(c3_feat)
|
||||
x_s = self.aux_cls_seg(x_s)
|
||||
|
||||
return x_c, x_s
|
||||
else:
|
||||
x_c = self.head(inputs)
|
||||
x_c = self.cls_seg(x_c)
|
||||
return x_c
|
||||
|
||||
def _make_base_head(self, in_channels: int,
|
||||
channels: int) -> nn.Sequential:
|
||||
layers = [
|
||||
ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
order=('norm', 'act', 'conv')),
|
||||
build_norm_layer(self.norm_cfg, channels)[1],
|
||||
build_activation_layer(self.act_cfg),
|
||||
]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def loss_by_feat(self, seg_logits: Tuple[Tensor],
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
loss = dict()
|
||||
context_logit, spatial_logit = seg_logits
|
||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||
|
||||
context_logit = resize(
|
||||
context_logit,
|
||||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
spatial_logit = resize(
|
||||
spatial_logit,
|
||||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
seg_label = seg_label.squeeze(1)
|
||||
|
||||
loss['loss_context'] = self.loss_decode[0](context_logit, seg_label)
|
||||
loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label)
|
||||
loss['acc_seg'] = accuracy(
|
||||
context_logit, seg_label, ignore_index=self.ignore_index)
|
||||
|
||||
return loss
|
||||
Reference in New Issue
Block a user