init
This commit is contained in:
68
finetune/mmseg/models/decode_heads/fpn_head.py
Normal file
68
finetune/mmseg/models/decode_heads/fpn_head.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample, resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FPNHead(BaseDecodeHead):
|
||||
"""Panoptic Feature Pyramid Networks.
|
||||
|
||||
This head is the implementation of `Semantic FPN
|
||||
<https://arxiv.org/abs/1901.02446>`_.
|
||||
|
||||
Args:
|
||||
feature_strides (tuple[int]): The strides for input feature maps.
|
||||
stack_lateral. All strides suppose to be power of 2. The first
|
||||
one is of largest resolution.
|
||||
"""
|
||||
|
||||
def __init__(self, feature_strides, **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
assert len(feature_strides) == len(self.in_channels)
|
||||
assert min(feature_strides) == feature_strides[0]
|
||||
self.feature_strides = feature_strides
|
||||
|
||||
self.scale_heads = nn.ModuleList()
|
||||
for i in range(len(feature_strides)):
|
||||
head_length = max(
|
||||
1,
|
||||
int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
|
||||
scale_head = []
|
||||
for k in range(head_length):
|
||||
scale_head.append(
|
||||
ConvModule(
|
||||
self.in_channels[i] if k == 0 else self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
if feature_strides[i] != feature_strides[0]:
|
||||
scale_head.append(
|
||||
Upsample(
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners))
|
||||
self.scale_heads.append(nn.Sequential(*scale_head))
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
x = self._transform_inputs(inputs)
|
||||
|
||||
output = self.scale_heads[0](x[0])
|
||||
for i in range(1, len(self.feature_strides)):
|
||||
# non inplace
|
||||
output = output + resize(
|
||||
self.scale_heads[i](x[i]),
|
||||
size=output.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
Reference in New Issue
Block a user