init
This commit is contained in:
128
finetune/mmseg/utils/misc.py
Normal file
128
finetune/mmseg/utils/misc.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .typing_utils import SampleList
|
||||
|
||||
|
||||
def add_prefix(inputs, prefix):
|
||||
"""Add prefix for dict.
|
||||
|
||||
Args:
|
||||
inputs (dict): The input dict with str keys.
|
||||
prefix (str): The prefix to add.
|
||||
|
||||
Returns:
|
||||
|
||||
dict: The dict with keys updated with ``prefix``.
|
||||
"""
|
||||
|
||||
outputs = dict()
|
||||
for name, value in inputs.items():
|
||||
outputs[f'{prefix}.{name}'] = value
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def stack_batch(inputs: List[torch.Tensor],
|
||||
data_samples: Optional[SampleList] = None,
|
||||
size: Optional[tuple] = None,
|
||||
size_divisor: Optional[int] = None,
|
||||
pad_val: Union[int, float] = 0,
|
||||
seg_pad_val: Union[int, float] = 255) -> torch.Tensor:
|
||||
"""Stack multiple inputs to form a batch and pad the images and gt_sem_segs
|
||||
to the max shape use the right bottom padding mode.
|
||||
|
||||
Args:
|
||||
inputs (List[Tensor]): The input multiple tensors. each is a
|
||||
CHW 3D-tensor.
|
||||
data_samples (list[:obj:`SegDataSample`]): The list of data samples.
|
||||
It usually includes information such as `gt_sem_seg`.
|
||||
size (tuple, optional): Fixed padding size.
|
||||
size_divisor (int, optional): The divisor of padded size.
|
||||
pad_val (int, float): The padding value. Defaults to 0
|
||||
seg_pad_val (int, float): The padding value. Defaults to 255
|
||||
|
||||
Returns:
|
||||
Tensor: The 4D-tensor.
|
||||
List[:obj:`SegDataSample`]: After the padding of the gt_seg_map.
|
||||
"""
|
||||
assert isinstance(inputs, list), \
|
||||
f'Expected input type to be list, but got {type(inputs)}'
|
||||
assert len({tensor.ndim for tensor in inputs}) == 1, \
|
||||
f'Expected the dimensions of all inputs must be the same, ' \
|
||||
f'but got {[tensor.ndim for tensor in inputs]}'
|
||||
assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \
|
||||
f'but got {inputs[0].ndim}'
|
||||
assert len({tensor.shape[0] for tensor in inputs}) == 1, \
|
||||
f'Expected the channels of all inputs must be the same, ' \
|
||||
f'but got {[tensor.shape[0] for tensor in inputs]}'
|
||||
|
||||
# only one of size and size_divisor should be valid
|
||||
assert (size is not None) ^ (size_divisor is not None), \
|
||||
'only one of size and size_divisor should be valid'
|
||||
|
||||
padded_inputs = []
|
||||
padded_samples = []
|
||||
inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs]
|
||||
max_size = np.stack(inputs_sizes).max(0)
|
||||
if size_divisor is not None and size_divisor > 1:
|
||||
# the last two dims are H,W, both subject to divisibility requirement
|
||||
max_size = (max_size +
|
||||
(size_divisor - 1)) // size_divisor * size_divisor
|
||||
|
||||
for i in range(len(inputs)):
|
||||
tensor = inputs[i]
|
||||
if size is not None:
|
||||
width = max(size[-1] - tensor.shape[-1], 0)
|
||||
height = max(size[-2] - tensor.shape[-2], 0)
|
||||
# (padding_left, padding_right, padding_top, padding_bottom)
|
||||
padding_size = (0, width, 0, height)
|
||||
elif size_divisor is not None:
|
||||
width = max(max_size[-1] - tensor.shape[-1], 0)
|
||||
height = max(max_size[-2] - tensor.shape[-2], 0)
|
||||
padding_size = (0, width, 0, height)
|
||||
else:
|
||||
padding_size = [0, 0, 0, 0]
|
||||
|
||||
# pad img
|
||||
pad_img = F.pad(tensor, padding_size, value=pad_val)
|
||||
padded_inputs.append(pad_img)
|
||||
# pad gt_sem_seg
|
||||
if data_samples is not None:
|
||||
data_sample = data_samples[i]
|
||||
pad_shape = None
|
||||
if 'gt_sem_seg' in data_sample:
|
||||
gt_sem_seg = data_sample.gt_sem_seg.data
|
||||
del data_sample.gt_sem_seg.data
|
||||
data_sample.gt_sem_seg.data = F.pad(
|
||||
gt_sem_seg, padding_size, value=seg_pad_val)
|
||||
pad_shape = data_sample.gt_sem_seg.shape
|
||||
if 'gt_edge_map' in data_sample:
|
||||
gt_edge_map = data_sample.gt_edge_map.data
|
||||
del data_sample.gt_edge_map.data
|
||||
data_sample.gt_edge_map.data = F.pad(
|
||||
gt_edge_map, padding_size, value=seg_pad_val)
|
||||
pad_shape = data_sample.gt_edge_map.shape
|
||||
if 'gt_depth_map' in data_sample:
|
||||
gt_depth_map = data_sample.gt_depth_map.data
|
||||
del data_sample.gt_depth_map.data
|
||||
data_sample.gt_depth_map.data = F.pad(
|
||||
gt_depth_map, padding_size, value=seg_pad_val)
|
||||
pad_shape = data_sample.gt_depth_map.shape
|
||||
data_sample.set_metainfo({
|
||||
'img_shape': tensor.shape[-2:],
|
||||
'pad_shape': pad_shape,
|
||||
'padding_size': padding_size
|
||||
})
|
||||
padded_samples.append(data_sample)
|
||||
else:
|
||||
padded_samples.append(
|
||||
dict(
|
||||
img_padding_size=padding_size,
|
||||
pad_shape=pad_img.shape[-2:]))
|
||||
|
||||
return torch.stack(padded_inputs, dim=0), padded_samples
|
||||
Reference in New Issue
Block a user