init
This commit is contained in:
102
finetune/mmseg/models/utils/up_conv_block.py
Normal file
102
finetune/mmseg/models/utils/up_conv_block.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_upsample_layer
|
||||
|
||||
|
||||
class UpConvBlock(nn.Module):
|
||||
"""Upsample convolution block in decoder for UNet.
|
||||
|
||||
This upsample convolution block consists of one upsample module
|
||||
followed by one convolution block. The upsample module expands the
|
||||
high-level low-resolution feature map and the convolution block fuses
|
||||
the upsampled high-level low-resolution feature map and the low-level
|
||||
high-resolution feature map from encoder.
|
||||
|
||||
Args:
|
||||
conv_block (nn.Sequential): Sequential of convolutional layers.
|
||||
in_channels (int): Number of input channels of the high-level
|
||||
skip_channels (int): Number of input channels of the low-level
|
||||
high-resolution feature map from encoder.
|
||||
out_channels (int): Number of output channels.
|
||||
num_convs (int): Number of convolutional layers in the conv_block.
|
||||
Default: 2.
|
||||
stride (int): Stride of convolutional layer in conv_block. Default: 1.
|
||||
dilation (int): Dilation rate of convolutional layer in conv_block.
|
||||
Default: 1.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
upsample_cfg (dict): The upsample config of the upsample module in
|
||||
decoder. Default: dict(type='InterpConv'). If the size of
|
||||
high-level feature map is the same as that of skip feature map
|
||||
(low-level feature map from encoder), it does not need upsample the
|
||||
high-level feature map and the upsample_cfg is None.
|
||||
dcn (bool): Use deformable convolution in convolutional layer or not.
|
||||
Default: None.
|
||||
plugins (dict): plugins for convolutional layers. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
conv_block,
|
||||
in_channels,
|
||||
skip_channels,
|
||||
out_channels,
|
||||
num_convs=2,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
upsample_cfg=dict(type='InterpConv'),
|
||||
dcn=None,
|
||||
plugins=None):
|
||||
super().__init__()
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
|
||||
self.conv_block = conv_block(
|
||||
in_channels=2 * skip_channels,
|
||||
out_channels=out_channels,
|
||||
num_convs=num_convs,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
dcn=None,
|
||||
plugins=None)
|
||||
if upsample_cfg is not None:
|
||||
self.upsample = build_upsample_layer(
|
||||
cfg=upsample_cfg,
|
||||
in_channels=in_channels,
|
||||
out_channels=skip_channels,
|
||||
with_cp=with_cp,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
else:
|
||||
self.upsample = ConvModule(
|
||||
in_channels,
|
||||
skip_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, skip, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.upsample(x)
|
||||
out = torch.cat([skip, x], dim=1)
|
||||
out = self.conv_block(out)
|
||||
|
||||
return out
|
||||
Reference in New Issue
Block a user