init
This commit is contained in:
253
finetune/mmseg/models/decode_heads/vpd_depth_head.py
Normal file
253
finetune/mmseg/models/decode_heads/vpd_depth_head.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer
|
||||
from mmengine.model import BaseModule
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class VPDDepthDecoder(BaseModule):
|
||||
"""VPD Depth Decoder class.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
num_deconv_layers (int): Number of deconvolution layers.
|
||||
num_deconv_filters (List[int]): List of output channels for
|
||||
deconvolution layers.
|
||||
init_cfg (Optional[Union[Dict, List[Dict]]], optional): Configuration
|
||||
for weight initialization. Defaults to Normal for Conv2d and
|
||||
ConvTranspose2d layers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_deconv_layers: int,
|
||||
num_deconv_filters: List[int],
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
|
||||
type='Normal',
|
||||
std=0.001,
|
||||
layer=['Conv2d', 'ConvTranspose2d'])):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.deconv_layers = self._make_deconv_layer(
|
||||
num_deconv_layers,
|
||||
num_deconv_filters,
|
||||
)
|
||||
|
||||
conv_layers = []
|
||||
conv_layers.append(
|
||||
build_conv_layer(
|
||||
dict(type='Conv2d'),
|
||||
in_channels=num_deconv_filters[-1],
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1))
|
||||
conv_layers.append(build_norm_layer(dict(type='BN'), out_channels)[1])
|
||||
conv_layers.append(nn.ReLU(inplace=True))
|
||||
self.conv_layers = nn.Sequential(*conv_layers)
|
||||
|
||||
self.up_sample = nn.Upsample(
|
||||
scale_factor=2, mode='bilinear', align_corners=False)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass through the decoder network."""
|
||||
out = self.deconv_layers(x)
|
||||
out = self.conv_layers(out)
|
||||
|
||||
out = self.up_sample(out)
|
||||
out = self.up_sample(out)
|
||||
|
||||
return out
|
||||
|
||||
def _make_deconv_layer(self, num_layers, num_deconv_filters):
|
||||
"""Make deconv layers."""
|
||||
|
||||
layers = []
|
||||
in_channels = self.in_channels
|
||||
for i in range(num_layers):
|
||||
|
||||
num_channels = num_deconv_filters[i]
|
||||
layers.append(
|
||||
build_upsample_layer(
|
||||
dict(type='deconv'),
|
||||
in_channels=in_channels,
|
||||
out_channels=num_channels,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
output_padding=0,
|
||||
bias=False))
|
||||
layers.append(nn.BatchNorm2d(num_channels))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
in_channels = num_channels
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VPDDepthHead(BaseDecodeHead):
|
||||
"""Depth Prediction Head for VPD.
|
||||
|
||||
.. _`VPD`: https://arxiv.org/abs/2303.02153
|
||||
|
||||
Args:
|
||||
max_depth (float): Maximum depth value. Defaults to 10.0.
|
||||
in_channels (Sequence[int]): Number of input channels for each
|
||||
convolutional layer.
|
||||
embed_dim (int): Dimension of embedding. Defaults to 192.
|
||||
feature_dim (int): Dimension of aggregated feature. Defaults to 1536.
|
||||
num_deconv_layers (int): Number of deconvolution layers in the
|
||||
decoder. Defaults to 3.
|
||||
num_deconv_filters (Sequence[int]): Number of filters for each deconv
|
||||
layer. Defaults to (32, 32, 32).
|
||||
fmap_border (Union[int, Sequence[int]]): Feature map border for
|
||||
cropping. Defaults to 0.
|
||||
align_corners (bool): Flag for align_corners in interpolation.
|
||||
Defaults to False.
|
||||
loss_decode (dict): Configurations for the loss function. Defaults to
|
||||
dict(type='SiLogLoss').
|
||||
init_cfg (dict): Initialization configurations. Defaults to
|
||||
dict(type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']).
|
||||
"""
|
||||
|
||||
num_classes = 1
|
||||
out_channels = 1
|
||||
input_transform = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_depth: float = 10.0,
|
||||
in_channels: Sequence[int] = [320, 640, 1280, 1280],
|
||||
embed_dim: int = 192,
|
||||
feature_dim: int = 1536,
|
||||
num_deconv_layers: int = 3,
|
||||
num_deconv_filters: Sequence[int] = (32, 32, 32),
|
||||
fmap_border: Union[int, Sequence[int]] = 0,
|
||||
align_corners: bool = False,
|
||||
loss_decode: dict = dict(type='SiLogLoss'),
|
||||
init_cfg=dict(
|
||||
type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']),
|
||||
):
|
||||
|
||||
super(BaseDecodeHead, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
# initialize parameters
|
||||
self.in_channels = in_channels
|
||||
self.max_depth = max_depth
|
||||
self.align_corners = align_corners
|
||||
|
||||
# feature map border
|
||||
if isinstance(fmap_border, int):
|
||||
fmap_border = (fmap_border, fmap_border)
|
||||
self.fmap_border = fmap_border
|
||||
|
||||
# define network layers
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1),
|
||||
nn.GroupNorm(16, in_channels[0]),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1),
|
||||
)
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_channels[1], in_channels[1], 3, stride=2, padding=1)
|
||||
|
||||
self.conv_aggregation = nn.Sequential(
|
||||
nn.Conv2d(sum(in_channels), feature_dim, 1),
|
||||
nn.GroupNorm(16, feature_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.decoder = VPDDepthDecoder(
|
||||
in_channels=embed_dim * 8,
|
||||
out_channels=embed_dim,
|
||||
num_deconv_layers=num_deconv_layers,
|
||||
num_deconv_filters=num_deconv_filters)
|
||||
|
||||
self.depth_pred_layer = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
embed_dim, embed_dim, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(embed_dim, 1, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# build loss
|
||||
if isinstance(loss_decode, dict):
|
||||
self.loss_decode = MODELS.build(loss_decode)
|
||||
elif isinstance(loss_decode, (list, tuple)):
|
||||
self.loss_decode = nn.ModuleList()
|
||||
for loss in loss_decode:
|
||||
self.loss_decode.append(MODELS.build(loss))
|
||||
else:
|
||||
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
|
||||
but got {type(loss_decode)}')
|
||||
|
||||
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
|
||||
gt_depth_maps = [
|
||||
data_sample.gt_depth_map.data for data_sample in batch_data_samples
|
||||
]
|
||||
return torch.stack(gt_depth_maps, dim=0)
|
||||
|
||||
def forward(self, x):
|
||||
x = [
|
||||
x[0], x[1],
|
||||
torch.cat([x[2], F.interpolate(x[3], scale_factor=2)], dim=1)
|
||||
]
|
||||
x = torch.cat([self.conv1(x[0]), self.conv2(x[1]), x[2]], dim=1)
|
||||
x = self.conv_aggregation(x)
|
||||
|
||||
x = x[:, :, :x.size(2) - self.fmap_border[0], :x.size(3) -
|
||||
self.fmap_border[1]].contiguous()
|
||||
x = self.decoder(x)
|
||||
out = self.depth_pred_layer(x)
|
||||
|
||||
depth = torch.sigmoid(out) * self.max_depth
|
||||
|
||||
return depth
|
||||
|
||||
def loss_by_feat(self, pred_depth_map: Tensor,
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
"""Compute depth estimation loss.
|
||||
|
||||
Args:
|
||||
pred_depth_map (Tensor): The output from decode head forward
|
||||
function.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_dpeth_map`.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
|
||||
gt_depth_map = self._stack_batch_gt(batch_data_samples)
|
||||
loss = dict()
|
||||
pred_depth_map = resize(
|
||||
input=pred_depth_map,
|
||||
size=gt_depth_map.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.loss_decode]
|
||||
else:
|
||||
losses_decode = self.loss_decode
|
||||
for loss_decode in losses_decode:
|
||||
if loss_decode.loss_name not in loss:
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
pred_depth_map, gt_depth_map)
|
||||
else:
|
||||
loss[loss_decode.loss_name] += loss_decode(
|
||||
pred_depth_map, gt_depth_map)
|
||||
|
||||
return loss
|
||||
Reference in New Issue
Block a user