init
This commit is contained in:
67
finetune/mmseg/models/necks/featurepyramid.py
Normal file
67
finetune/mmseg/models/necks/featurepyramid.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class Feature2Pyramid(nn.Module):
|
||||
"""Feature2Pyramid.
|
||||
|
||||
A neck structure connect ViT backbone and decoder_heads.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Embedding dimension.
|
||||
rescales (list[float]): Different sampling multiples were
|
||||
used to obtain pyramid features. Default: [4, 2, 1, 0.5].
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='SyncBN', requires_grad=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dim,
|
||||
rescales=[4, 2, 1, 0.5],
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||||
super().__init__()
|
||||
self.rescales = rescales
|
||||
self.upsample_4x = None
|
||||
for k in self.rescales:
|
||||
if k == 4:
|
||||
self.upsample_4x = nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
embed_dim, embed_dim, kernel_size=2, stride=2),
|
||||
build_norm_layer(norm_cfg, embed_dim)[1],
|
||||
nn.GELU(),
|
||||
nn.ConvTranspose2d(
|
||||
embed_dim, embed_dim, kernel_size=2, stride=2),
|
||||
)
|
||||
elif k == 2:
|
||||
self.upsample_2x = nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
embed_dim, embed_dim, kernel_size=2, stride=2))
|
||||
elif k == 1:
|
||||
self.identity = nn.Identity()
|
||||
elif k == 0.5:
|
||||
self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
elif k == 0.25:
|
||||
self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4)
|
||||
else:
|
||||
raise KeyError(f'invalid {k} for feature2pyramid')
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == len(self.rescales)
|
||||
outputs = []
|
||||
if self.upsample_4x is not None:
|
||||
ops = [
|
||||
self.upsample_4x, self.upsample_2x, self.identity,
|
||||
self.downsample_2x
|
||||
]
|
||||
else:
|
||||
ops = [
|
||||
self.upsample_2x, self.identity, self.downsample_2x,
|
||||
self.downsample_4x
|
||||
]
|
||||
for i in range(len(inputs)):
|
||||
outputs.append(ops[i](inputs[i]))
|
||||
return tuple(outputs)
|
||||
Reference in New Issue
Block a user