1559 lines
61 KiB
Python
1559 lines
61 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from packaging.version import Version
|
|
import numpy as np
|
|
from numpy import random
|
|
import math
|
|
from PIL import Image
|
|
import torch
|
|
import torchvision
|
|
import torchvision.transforms as T
|
|
import torchvision.transforms.functional as F
|
|
import mmcv
|
|
import copy
|
|
from mmcv.utils import deprecated_api_warning
|
|
|
|
|
|
class Compose(object):
|
|
"""Compose multiple transforms sequentially.
|
|
"""
|
|
|
|
def __init__(self, transforms):
|
|
assert isinstance(transforms, (list, tuple))
|
|
self.transforms = transforms
|
|
|
|
def __call__(self, sample):
|
|
"""Call function to apply transforms sequentially.
|
|
|
|
Args:
|
|
sample (Sample): A result dict contains the data to transform.
|
|
|
|
Returns:
|
|
dict: Transformed data.
|
|
"""
|
|
|
|
for t in self.transforms:
|
|
sample = t(sample)
|
|
if sample is None:
|
|
return None
|
|
return sample
|
|
|
|
def __repr__(self):
|
|
format_string = self.__class__.__name__ + '('
|
|
for t in self.transforms:
|
|
format_string += '\n'
|
|
format_string += f' {t}'
|
|
format_string += '\n)'
|
|
return format_string
|
|
|
|
|
|
class SegResize(object):
|
|
"""Resize images & seg.
|
|
|
|
This transform resizes the input image to some scale. If the input dict
|
|
contains the key "scale", then the scale in the input dict is used,
|
|
otherwise the specified scale in the init method is used.
|
|
|
|
``img_scale`` can be None, a tuple (single-scale) or a list of tuple
|
|
(multi-scale). There are 4 multiscale modes:
|
|
|
|
- ``ratio_range is not None``:
|
|
1. When img_scale is None, img_scale is the shape of image in sample
|
|
(img_scale = sample.img.shape[:2]) and the image is resized based
|
|
on the original size. (mode 1)
|
|
2. When img_scale is a tuple (single-scale), randomly sample a ratio from
|
|
the ratio range and multiply it with the image scale. (mode 2)
|
|
|
|
- ``ratio_range is None and multiscale_mode == "range"``: randomly sample a
|
|
scale from the a range. (mode 3)
|
|
|
|
- ``ratio_range is None and multiscale_mode == "value"``: randomly sample a
|
|
scale from multiple scales. (mode 4)
|
|
|
|
Args:
|
|
img_scale (tuple or list[tuple]): Images scales for resizing.
|
|
Default:None.
|
|
multiscale_mode (str): Either "range" or "value".
|
|
Default: 'range'
|
|
ratio_range (tuple[float]): (min_ratio, max_ratio).
|
|
Default: None
|
|
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
|
|
image. Default: True
|
|
"""
|
|
|
|
def __init__(self,
|
|
img_scale=None,
|
|
multiscale_mode='range',
|
|
ratio_range=None,
|
|
keep_ratio=True):
|
|
if img_scale is None:
|
|
self.img_scale = None
|
|
else:
|
|
if isinstance(img_scale, list):
|
|
self.img_scale = img_scale
|
|
else:
|
|
self.img_scale = [img_scale]
|
|
assert mmcv.is_list_of(self.img_scale, tuple)
|
|
|
|
if ratio_range is not None:
|
|
# mode 1: given img_scale=None and a range of image ratio
|
|
# mode 2: given a scale and a range of image ratio
|
|
assert self.img_scale is None or len(self.img_scale) == 1
|
|
else:
|
|
# mode 3 and 4: given multiple scales or a range of scales
|
|
assert multiscale_mode in ['value', 'range']
|
|
|
|
self.multiscale_mode = multiscale_mode
|
|
self.ratio_range = ratio_range
|
|
self.keep_ratio = keep_ratio
|
|
|
|
@staticmethod
|
|
def random_select(img_scales):
|
|
"""Randomly select an img_scale from given candidates.
|
|
|
|
Args:
|
|
img_scales (list[tuple]): Images scales for selection.
|
|
|
|
Returns:
|
|
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
|
|
where ``img_scale`` is the selected image scale and
|
|
``scale_idx`` is the selected index in the given candidates.
|
|
"""
|
|
|
|
assert mmcv.is_list_of(img_scales, tuple)
|
|
scale_idx = np.random.randint(len(img_scales))
|
|
img_scale = img_scales[scale_idx]
|
|
return img_scale, scale_idx
|
|
|
|
@staticmethod
|
|
def random_sample(img_scales):
|
|
"""Randomly sample an img_scale when ``multiscale_mode=='range'``.
|
|
|
|
Args:
|
|
img_scales (list[tuple]): Images scale range for sampling.
|
|
There must be two tuples in img_scales, which specify the lower
|
|
and upper bound of image scales.
|
|
|
|
Returns:
|
|
(tuple, None): Returns a tuple ``(img_scale, None)``, where
|
|
``img_scale`` is sampled scale and None is just a placeholder
|
|
to be consistent with :func:`random_select`.
|
|
"""
|
|
|
|
assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
|
|
img_scale_long = [max(s) for s in img_scales]
|
|
img_scale_short = [min(s) for s in img_scales]
|
|
long_edge = np.random.randint(min(img_scale_long),
|
|
max(img_scale_long) + 1)
|
|
short_edge = np.random.randint(min(img_scale_short),
|
|
max(img_scale_short) + 1)
|
|
img_scale = (long_edge, short_edge)
|
|
return img_scale, None
|
|
|
|
@staticmethod
|
|
def random_sample_ratio(img_scale, ratio_range):
|
|
"""Randomly sample an img_scale when ``ratio_range`` is specified.
|
|
|
|
A ratio will be randomly sampled from the range specified by
|
|
``ratio_range``. Then it would be multiplied with ``img_scale`` to
|
|
generate sampled scale.
|
|
|
|
Args:
|
|
img_scale (tuple): Images scale base to multiply with ratio.
|
|
ratio_range (tuple[float]): The minimum and maximum ratio to scale
|
|
the ``img_scale``.
|
|
|
|
Returns:
|
|
(tuple, None): Returns a tuple ``(scale, None)``, where
|
|
``scale`` is sampled ratio multiplied with ``img_scale`` and
|
|
None is just a placeholder to be consistent with
|
|
:func:`random_select`.
|
|
"""
|
|
|
|
assert isinstance(img_scale, tuple) and len(img_scale) == 2
|
|
min_ratio, max_ratio = ratio_range
|
|
assert min_ratio <= max_ratio
|
|
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
|
|
scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
|
|
return scale, None
|
|
|
|
def _random_scale(self, sample):
|
|
"""Randomly sample an img_scale according to ``ratio_range`` and
|
|
``multiscale_mode``.
|
|
|
|
If ``ratio_range`` is specified, a ratio will be sampled and be
|
|
multiplied with ``img_scale``.
|
|
If multiple scales are specified by ``img_scale``, a scale will be
|
|
sampled according to ``multiscale_mode``.
|
|
Otherwise, single scale will be used.
|
|
|
|
Args:
|
|
sample (Sample): Sample data from :obj:`dataset`.
|
|
|
|
Returns:
|
|
dict: Two new keys 'scale` and 'scale_idx` are added into
|
|
``sample``, which would be used by subsequent pipelines.
|
|
"""
|
|
|
|
if self.ratio_range is not None:
|
|
if self.img_scale is None:
|
|
h, w = sample.img.shape[:2]
|
|
scale, scale_idx = self.random_sample_ratio((w, h),
|
|
self.ratio_range)
|
|
else:
|
|
scale, scale_idx = self.random_sample_ratio(
|
|
self.img_scale[0], self.ratio_range)
|
|
elif len(self.img_scale) == 1:
|
|
scale, scale_idx = self.img_scale[0], 0
|
|
elif self.multiscale_mode == 'range':
|
|
scale, scale_idx = self.random_sample(self.img_scale)
|
|
elif self.multiscale_mode == 'value':
|
|
scale, scale_idx = self.random_select(self.img_scale)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
sample.scale = scale
|
|
sample.scale_idx = scale_idx
|
|
|
|
def _resize_img(self, sample):
|
|
"""Resize images with ``sample['scale']``."""
|
|
if self.keep_ratio:
|
|
img, scale_factor = mmcv.imrescale(sample[sample.img_field],
|
|
sample.scale,
|
|
return_scale=True)
|
|
# the w_scale and h_scale has minor difference
|
|
# a real fix should be done in the mmcv.imrescale in the future
|
|
new_h, new_w = img.shape[:2]
|
|
h, w = sample[sample.img_field].shape[:2]
|
|
w_scale = new_w / w
|
|
h_scale = new_h / h
|
|
else:
|
|
img, w_scale, h_scale = mmcv.imresize(sample[sample.img_field],
|
|
sample.scale,
|
|
return_scale=True)
|
|
scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
|
|
dtype=np.float32)
|
|
sample[sample.img_field] = img
|
|
sample.img_shape = img.shape
|
|
sample.pad_shape = img.shape # in case that there is no padding
|
|
sample.scale_factor = scale_factor
|
|
sample.keep_ratio = self.keep_ratio
|
|
|
|
def _resize_seg(self, sample):
|
|
"""Resize semantic segmentation map with ``sample.scale``."""
|
|
for key in sample.get('ann_fields', []):
|
|
if self.keep_ratio:
|
|
gt_seg = mmcv.imrescale(sample[key],
|
|
sample.scale,
|
|
interpolation='nearest')
|
|
else:
|
|
gt_seg = mmcv.imresize(sample[key],
|
|
sample.scale,
|
|
interpolation='nearest')
|
|
sample[key] = gt_seg
|
|
|
|
def __call__(self, sample):
|
|
"""Call function to resize images, bounding boxes, masks, semantic
|
|
segmentation map.
|
|
|
|
Args:
|
|
sample (Sample): Sample dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Resized sample, 'img_shape', 'pad_shape', 'scale_factor',
|
|
'keep_ratio' keys are added into result dict.
|
|
"""
|
|
|
|
if 'scale' not in sample:
|
|
self._random_scale(sample)
|
|
self._resize_img(sample)
|
|
self._resize_seg(sample)
|
|
return sample
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += (f'(img_scale={self.img_scale}, '
|
|
f'multiscale_mode={self.multiscale_mode}, '
|
|
f'ratio_range={self.ratio_range}, '
|
|
f'keep_ratio={self.keep_ratio})')
|
|
return repr_str
|
|
|
|
|
|
class SegRandomFlip(object):
|
|
"""Flip the image & seg.
|
|
|
|
If the input dict contains the key "flip", then the flag will be used,
|
|
otherwise it will be randomly decided by a ratio specified in the init
|
|
method.
|
|
|
|
Args:
|
|
prob (float, optional): The flipping probability. Default: None.
|
|
direction(str, optional): The flipping direction. Options are
|
|
'horizontal' and 'vertical'. Default: 'horizontal'.
|
|
"""
|
|
|
|
@deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='SegRandomFlip')
|
|
def __init__(self, prob=None, direction='horizontal'):
|
|
self.prob = prob
|
|
self.direction = direction
|
|
if prob is not None:
|
|
assert prob >= 0 and prob <= 1
|
|
assert direction in ['horizontal', 'vertical']
|
|
|
|
def __call__(self, sample):
|
|
"""Call function to flip bounding boxes, masks, semantic segmentation
|
|
maps.
|
|
|
|
Args:
|
|
sample (Sample): Sample data from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Flipped sample, 'flip', 'flip_direction' keys are added into
|
|
result dict.
|
|
"""
|
|
|
|
if 'flip' not in sample:
|
|
flip = True if np.random.rand() < self.prob else False
|
|
sample.flip = flip
|
|
if 'flip_direction' not in sample:
|
|
sample.flip_direction = self.direction
|
|
if sample.flip:
|
|
# flip image
|
|
sample[sample.img_field] = mmcv.imflip(
|
|
sample[sample.img_field], direction=sample.flip_direction)
|
|
|
|
# flip segs
|
|
for key in sample.get('ann_fields', []):
|
|
# use copy() to make numpy stride positive
|
|
sample[key] = mmcv.imflip(
|
|
sample[key], direction=sample.flip_direction).copy()
|
|
return sample
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + f'(prob={self.prob})'
|
|
|
|
|
|
class Normalize(object):
|
|
"""Normalize the image.
|
|
|
|
Added key is "img_norm_cfg".
|
|
|
|
Args:
|
|
mean (sequence): Mean values of 3 channels.
|
|
std (sequence): Std values of 3 channels.
|
|
to_rgb (bool): Whether to convert the image from BGR to RGB,
|
|
default is true.
|
|
"""
|
|
|
|
def __init__(self, mean, std, to_rgb=True):
|
|
self.mean = np.array(mean, dtype=np.float32)
|
|
self.std = np.array(std, dtype=np.float32)
|
|
self.to_rgb = to_rgb
|
|
|
|
def __call__(self, sample):
|
|
"""Call function to normalize images.
|
|
|
|
Args:
|
|
sample (Sample): Sample data from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Normalized sample, 'img_norm_cfg' key is added into
|
|
sample.
|
|
"""
|
|
|
|
sample[sample.img_field] = mmcv.imnormalize(sample[sample.img_field],
|
|
self.mean, self.std,
|
|
self.to_rgb)
|
|
sample.img_norm_cfg = dict(mean=self.mean,
|
|
std=self.std,
|
|
to_rgb=self.to_rgb)
|
|
return sample
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \
|
|
f'{self.to_rgb})'
|
|
return repr_str
|
|
|
|
|
|
class MSNormalize(object):
|
|
|
|
def __init__(self, configs):
|
|
self.configs = configs
|
|
self.keys = configs.keys()
|
|
|
|
def normalize_(self, img, config):
|
|
if isinstance(img, np.ndarray) and img.dtype != np.float32:
|
|
img = img.astype(np.float32)
|
|
if isinstance(img, torch.Tensor):
|
|
img = img.float()
|
|
div_value = config.div_value
|
|
mean = config.mean
|
|
std = config.std
|
|
img /= div_value
|
|
for t, m, s in zip(img, mean, std):
|
|
t -= m
|
|
t /= s
|
|
return img
|
|
|
|
def __call__(self, sample):
|
|
for key in self.keys:
|
|
if isinstance(sample[key], list):
|
|
for i in range(len(sample[key])):
|
|
sample[key][i] = self.normalize_(sample[key][i],
|
|
self.configs[key])
|
|
else:
|
|
sample[key] = self.normalize_(sample[key], self.configs[key])
|
|
return sample
|
|
|
|
|
|
class MSRandomCrop(object):
|
|
"""Random crop the hr_img s2_img targets.
|
|
Args:
|
|
crop_size (tuple): Expected size ratio after cropping, (h, w).
|
|
"""
|
|
|
|
def __init__(self, crop_size, keys):
|
|
assert crop_size[0] > 0 and crop_size[1] > 0
|
|
self.crop_size = crop_size
|
|
self.keys = keys
|
|
|
|
def get_crop_bbox(self):
|
|
"""Randomly get a crop bounding box."""
|
|
margin_h = max(1.0 - self.crop_size[0], 0)
|
|
margin_w = max(1.0 - self.crop_size[1], 0)
|
|
offset_h = np.random.uniform(0, margin_h)
|
|
offset_w = np.random.uniform(0, margin_w)
|
|
crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
|
|
crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
|
|
|
|
return crop_y1, crop_y2, crop_x1, crop_x2
|
|
|
|
def crop(self, img, crop_bbox):
|
|
"""Crop from ``img``"""
|
|
crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
|
|
h, w = img.shape[-2:]
|
|
crop_y1, crop_y2, crop_x1, crop_x2 = int(crop_y1 * h), int(
|
|
crop_y2 * h), int(crop_x1 * w), int(crop_x2 * w)
|
|
img = img[..., crop_y1:crop_y2, crop_x1:crop_x2]
|
|
return img
|
|
|
|
def __call__(self, sample):
|
|
"""Call function to randomly crop images, semantic segmentation maps.
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
Returns:
|
|
dict: Randomly cropped results, 'img_shape' key in result dict is
|
|
updated according to crop size.
|
|
"""
|
|
|
|
crop_bbox = self.get_crop_bbox()
|
|
for key in self.keys:
|
|
if isinstance(sample[key], list):
|
|
for i in range(len(sample[key])):
|
|
sample[key][i] = self.crop(sample[key][i], crop_bbox)
|
|
else:
|
|
sample[key] = self.crop(sample[key], crop_bbox)
|
|
return sample
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + f'(crop_size={self.crop_size})'
|
|
|
|
|
|
class MSRandomRangeCrop(object):
|
|
"""Random crop the hr_img s2_img targets.
|
|
Args:
|
|
crop_size (tuple): Expected size ratio after cropping, (min, max).
|
|
"""
|
|
|
|
def __init__(self, crop_size, keys):
|
|
assert crop_size[0] > 0 and crop_size[1] > 0
|
|
self.crop_size = crop_size
|
|
self.keys = keys
|
|
|
|
def get_crop_bbox(self):
|
|
"""Randomly get a crop bounding box."""
|
|
crop_size_ = np.random.uniform(self.crop_size[0], self.crop_size[1])
|
|
margin_h = max(1.0 - crop_size_, 0)
|
|
margin_w = max(1.0 - crop_size_, 0)
|
|
offset_h = np.random.uniform(0, margin_h)
|
|
offset_w = np.random.uniform(0, margin_w)
|
|
crop_y1, crop_y2 = offset_h, offset_h + crop_size_
|
|
crop_x1, crop_x2 = offset_w, offset_w + crop_size_
|
|
|
|
return crop_y1, crop_y2, crop_x1, crop_x2
|
|
|
|
def crop(self, img, crop_bbox):
|
|
"""Crop from ``img``"""
|
|
crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
|
|
h, w = img.shape[-2:]
|
|
crop_y1, crop_y2, crop_x1, crop_x2 = int(crop_y1 * h), int(
|
|
crop_y2 * h), int(crop_x1 * w), int(crop_x2 * w)
|
|
img = img[..., crop_y1:crop_y2, crop_x1:crop_x2]
|
|
return img
|
|
|
|
def __call__(self, sample):
|
|
"""Call function to randomly crop images, semantic segmentation maps.
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
Returns:
|
|
dict: Randomly cropped results, 'img_shape' key in result dict is
|
|
updated according to crop size.
|
|
"""
|
|
|
|
crop_bbox = self.get_crop_bbox()
|
|
for key in self.keys:
|
|
if isinstance(sample[key], list):
|
|
for i in range(len(sample[key])):
|
|
sample[key][i] = self.crop(sample[key][i], crop_bbox)
|
|
else:
|
|
sample[key] = self.crop(sample[key], crop_bbox)
|
|
return sample
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + f'(crop_size={self.crop_size})'
|
|
|
|
|
|
class MSResize(object):
|
|
|
|
def __init__(self, target_size, keys):
|
|
assert target_size[0] > 0 and target_size[1] > 0
|
|
self.target_size = target_size
|
|
self.keys = keys
|
|
|
|
def __call__(self, sample):
|
|
for key in self.keys:
|
|
if key == 'targets':
|
|
sample[key] = F.resize(
|
|
sample[key],
|
|
self.target_size,
|
|
interpolation=T.InterpolationMode.NEAREST
|
|
if Version(torchvision.__version__) >= Version('0.9.0')
|
|
else Image.NEAREST)
|
|
else:
|
|
sample[key] = F.resize(sample[key], self.target_size)
|
|
return sample
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + f'(target_size={self.target_size})'
|
|
|
|
|
|
class MSSSLRandomResizedCrop(object):
|
|
|
|
def __init__(self, configs, global_crops_number, local_crops_number):
|
|
self.configs = configs
|
|
self.global_crops_number = global_crops_number
|
|
self.local_crops_number = local_crops_number
|
|
|
|
@staticmethod
|
|
def get_params(scale: tuple, ratio: tuple):
|
|
"""Get parameters for ``crop`` for a random sized crop.
|
|
Args:
|
|
scale (tuple): range of size of the origin size cropped
|
|
ratio (tuple): range of aspect ratio of the origin aspect
|
|
ratio cropped
|
|
Returns:
|
|
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
|
|
sized crop.
|
|
"""
|
|
origin_h, origin_w = 1.0, 1.0
|
|
area = 1.0
|
|
|
|
while True:
|
|
target_area = random.uniform(*scale) * area
|
|
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
|
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
|
|
|
w = math.sqrt(target_area * aspect_ratio)
|
|
h = math.sqrt(target_area / aspect_ratio)
|
|
|
|
if w <= origin_w and h <= origin_h:
|
|
i = random.uniform(0, origin_h - h)
|
|
j = random.uniform(0, origin_w - w)
|
|
return i, j, h, w
|
|
|
|
def __call__(self, sample):
|
|
for scope_view in self.configs.keys():
|
|
for index in range(eval(f'self.{scope_view}_crops_number')):
|
|
i, j, h, w = self.get_params(self.configs[scope_view].scale,
|
|
self.configs[scope_view].ratio)
|
|
for source in self.configs[scope_view]['size'].keys():
|
|
img_key = f'{scope_view}_{source}_img'
|
|
img = sample[img_key][index]
|
|
i_img, h_img = int(round(i * img.shape[-2])), int(
|
|
round(h * img.shape[-2]))
|
|
j_img, w_img = int(round(j * img.shape[-1])), int(
|
|
round(w * img.shape[-1]))
|
|
img = F.resized_crop(
|
|
img,
|
|
i_img,
|
|
j_img,
|
|
h_img,
|
|
w_img,
|
|
self.configs[scope_view]['size'][source],
|
|
interpolation=T.InterpolationMode.BICUBIC
|
|
if Version(torchvision.__version__) >= Version('0.9.0')
|
|
else Image.BICUBIC)
|
|
sample[img_key][index] = img
|
|
img_key = f'{scope_view}_{source}_distance'
|
|
img = sample[img_key][index]
|
|
img = F.resized_crop(
|
|
img,
|
|
i_img,
|
|
j_img,
|
|
h_img,
|
|
w_img,
|
|
self.configs[scope_view]['size'][source],
|
|
interpolation=T.InterpolationMode.BICUBIC
|
|
if Version(torchvision.__version__) >= Version('0.9.0')
|
|
else Image.BICUBIC)
|
|
sample[img_key][index] = img
|
|
img_key = f'{scope_view}_lc'
|
|
img = sample[img_key][index]
|
|
i_img, h_img = int(round(i * img.shape[-2])), int(
|
|
round(h * img.shape[-2]))
|
|
j_img, w_img = int(round(j * img.shape[-1])), int(
|
|
round(w * img.shape[-1]))
|
|
img = F.resized_crop(
|
|
img,
|
|
i_img,
|
|
j_img,
|
|
h_img,
|
|
w_img,
|
|
self.configs[scope_view]['size']['s2'],
|
|
interpolation=T.InterpolationMode.NEAREST
|
|
if Version(torchvision.__version__) >= Version('0.9.0')
|
|
else Image.NEAREST)
|
|
sample[img_key][index] = img
|
|
return sample
|
|
|
|
|
|
class MSSSLRandomFlip(object):
|
|
|
|
def __init__(self, configs, global_crops_number, local_crops_number,
|
|
scope_views, sources):
|
|
self.configs = configs
|
|
self.global_crops_number = global_crops_number
|
|
self.local_crops_number = local_crops_number
|
|
self.scope_views = scope_views
|
|
self.sources = sources
|
|
|
|
def __call__(self, sample):
|
|
for scope_view in self.scope_views:
|
|
for index in range(eval(f'self.{scope_view}_crops_number')):
|
|
hflip = False
|
|
vflip = False
|
|
for direction, prob in zip(self.configs['directions'],
|
|
self.configs['probs']):
|
|
p = torch.rand(1)
|
|
if direction == 'horizontal' and p < prob:
|
|
hflip = True
|
|
if direction == 'vertical' and p < prob:
|
|
vflip = True
|
|
for source in self.sources:
|
|
img_key = f'{scope_view}_{source}_img'
|
|
img = sample[img_key][index]
|
|
if hflip:
|
|
img = F.hflip(img)
|
|
if vflip:
|
|
img = F.vflip(img)
|
|
sample[img_key][index] = img
|
|
img_key = f'{scope_view}_{source}_distance'
|
|
img = sample[img_key][index]
|
|
if hflip:
|
|
img = F.hflip(img)
|
|
if vflip:
|
|
img = F.vflip(img)
|
|
sample[img_key][index] = img
|
|
img_key = f'{scope_view}_lc'
|
|
img = sample[img_key][index]
|
|
if hflip:
|
|
img = F.hflip(img)
|
|
if vflip:
|
|
img = F.vflip(img)
|
|
sample[img_key][index] = img
|
|
return sample
|
|
|
|
|
|
class MSSSLRandomRotate(object):
|
|
|
|
def __init__(self, configs, global_crops_number, local_crops_number,
|
|
scope_views, sources):
|
|
self.configs = configs
|
|
self.global_crops_number = global_crops_number
|
|
self.local_crops_number = local_crops_number
|
|
self.scope_views = scope_views
|
|
self.sources = sources
|
|
self.angle_set = [90, 180, 270]
|
|
|
|
def __call__(self, sample):
|
|
for scope_view in self.scope_views:
|
|
for index in range(eval(f'self.{scope_view}_crops_number')):
|
|
p = torch.rand(1)
|
|
if p > self.configs['probs']:
|
|
continue
|
|
angle = self.angle_set[torch.randint(0, 3, (1,)).item()]
|
|
for source in self.sources:
|
|
img_key = f'{scope_view}_{source}_img'
|
|
img = sample[img_key][index]
|
|
img = F.rotate(
|
|
img,
|
|
angle,
|
|
interpolation=T.InterpolationMode.BILINEAR
|
|
if Version(torchvision.__version__) >= Version('0.9.0')
|
|
else Image.BILINEAR)
|
|
sample[img_key][index] = img
|
|
img_key = f'{scope_view}_{source}_distance'
|
|
img = sample[img_key][index]
|
|
img = F.rotate(
|
|
img,
|
|
angle,
|
|
interpolation=T.InterpolationMode.BILINEAR
|
|
if Version(torchvision.__version__) >= Version('0.9.0')
|
|
else Image.BILINEAR)
|
|
sample[img_key][index] = img
|
|
img_key = f'{scope_view}_lc'
|
|
img = sample[img_key][index]
|
|
img = F.rotate(
|
|
img,
|
|
angle,
|
|
interpolation=T.InterpolationMode.NEAREST
|
|
if Version(torchvision.__version__) >= Version('0.9.0')
|
|
else Image.NEAREST)
|
|
sample[img_key][index] = img
|
|
return sample
|
|
|
|
|
|
class MSSSLRandomColorJitter(object):
|
|
|
|
def __init__(self, configs, global_crops_number, local_crops_number,
|
|
scope_views, sources):
|
|
self.configs = configs
|
|
self.global_crops_number = global_crops_number
|
|
self.local_crops_number = local_crops_number
|
|
self.scope_views = scope_views
|
|
self.sources = sources
|
|
self.color_prob = configs['color']['probs']
|
|
self.brightness = configs['color']['brightness']
|
|
self.contrast = configs['color']['contrast']
|
|
self.saturation = configs['color']['saturation']
|
|
self.hue = configs['color']['hue']
|
|
self.gray_prob = configs['gray']['probs']
|
|
|
|
def __call__(self, sample):
|
|
for scope_view in self.scope_views:
|
|
for index in range(eval(f'self.{scope_view}_crops_number')):
|
|
for source in self.sources:
|
|
p = torch.rand(1)
|
|
if p >= self.color_prob:
|
|
continue
|
|
brightness_factor = random.uniform(
|
|
max(0, 1 - self.brightness), 1 + self.brightness)
|
|
contrast_factor = random.uniform(max(0, 1 - self.contrast),
|
|
1 + self.contrast)
|
|
saturation_factor = random.uniform(
|
|
max(0, 1 - self.saturation), 1 + self.saturation)
|
|
hue_factor = random.uniform(-self.hue, self.hue)
|
|
img_key = f'{scope_view}_{source}_img'
|
|
img = sample[img_key][index]
|
|
img = F.adjust_brightness(img, brightness_factor)
|
|
img = F.adjust_contrast(img, contrast_factor)
|
|
img = F.adjust_saturation(img, saturation_factor)
|
|
img = F.adjust_hue(img, hue_factor)
|
|
p = torch.rand(1)
|
|
if p >= self.gray_prob:
|
|
continue
|
|
num_output_channels, _, _ = img.shape
|
|
img = F.rgb_to_grayscale(
|
|
img, num_output_channels=num_output_channels)
|
|
sample[img_key][index] = img
|
|
return sample
|
|
|
|
|
|
class MSSSLRandomGaussianBlur(object):
|
|
|
|
def __init__(self, configs, global_crops_number, local_crops_number,
|
|
scope_views, sources):
|
|
self.configs = configs
|
|
self.global_crops_number = global_crops_number
|
|
self.local_crops_number = local_crops_number
|
|
self.scope_views = scope_views
|
|
self.sources = sources
|
|
self.prob = configs['probs']
|
|
self.sigma = configs['sigma']
|
|
|
|
def __call__(self, sample):
|
|
for scope_view in self.scope_views:
|
|
for index in range(eval(f'self.{scope_view}_crops_number')):
|
|
for source in self.sources:
|
|
p = self.prob[scope_view]
|
|
if scope_view == 'global':
|
|
p = p[index]
|
|
if torch.rand(1) >= p:
|
|
continue
|
|
sigma = random.uniform(self.sigma[0], self.sigma[1])
|
|
kernel_size = max(int(2 * ((sigma - 0.8) / 0.3 + 1) + 1),
|
|
1)
|
|
if kernel_size % 2 == 0:
|
|
kernel_size += 1
|
|
img_key = f'{scope_view}_{source}_img'
|
|
img = sample[img_key][index]
|
|
img = F.gaussian_blur(img, kernel_size, sigma)
|
|
sample[img_key][index] = img
|
|
return sample
|
|
|
|
|
|
class MSSSLRandomSolarize(object):
|
|
|
|
def __init__(self, configs, global_crops_number, scope_views, sources):
|
|
self.configs = configs
|
|
self.global_crops_number = global_crops_number
|
|
self.scope_views = scope_views
|
|
self.sources = sources
|
|
self.prob = configs['probs']
|
|
self.threshold = 130
|
|
|
|
def __call__(self, sample):
|
|
for scope_view in self.scope_views:
|
|
for index in range(eval(f'self.{scope_view}_crops_number')):
|
|
for source in self.sources:
|
|
if index != 1:
|
|
continue
|
|
if torch.rand(1) >= self.prob:
|
|
continue
|
|
img_key = f'{scope_view}_{source}_img'
|
|
img = sample[img_key][index]
|
|
img = F.solarize(img, self.threshold)
|
|
sample[img_key][index] = img
|
|
return sample
|
|
|
|
|
|
class MSSSLRandomChannelOut(object):
|
|
|
|
def __init__(self, configs, global_crops_number, local_crops_number,
|
|
scope_views, sources, mean):
|
|
self.configs = configs
|
|
self.global_crops_number = global_crops_number
|
|
self.local_crops_number = local_crops_number
|
|
self.scope_views = scope_views
|
|
self.sources = sources
|
|
self.mean = mean
|
|
|
|
def __call__(self, sample):
|
|
for scope_view in self.scope_views:
|
|
for index in range(eval(f'self.{scope_view}_crops_number')):
|
|
for source in self.sources:
|
|
out_num = self.configs[scope_view]['out_num']
|
|
out_num = random.randint(out_num[0], out_num[1] + 1)
|
|
out_index = sorted(
|
|
random.choice(len(self.mean), out_num, replace=False))
|
|
img_key = f'{scope_view}_{source}_img'
|
|
img = sample[img_key][index]
|
|
for i in out_index:
|
|
img[i] = int(self.mean[i])
|
|
sample[img_key][index] = img
|
|
return sample
|
|
|
|
|
|
class MaskGenerator:
|
|
|
|
def __init__(self,
|
|
input_size=192,
|
|
mask_patch_size=32,
|
|
model_patch_size=4,
|
|
mask_ratio=0.6):
|
|
self.input_size = input_size
|
|
self.mask_patch_size = mask_patch_size
|
|
self.model_patch_size = model_patch_size
|
|
self.mask_ratio = mask_ratio
|
|
|
|
assert self.input_size % self.mask_patch_size == 0
|
|
assert self.mask_patch_size % self.model_patch_size == 0
|
|
|
|
self.rand_size = self.input_size // self.mask_patch_size
|
|
self.scale = self.mask_patch_size // self.model_patch_size
|
|
|
|
self.token_count = self.rand_size**2
|
|
self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
|
|
|
|
def __call__(self):
|
|
mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
|
|
mask = np.zeros(self.token_count, dtype=int)
|
|
mask[mask_idx] = 1
|
|
|
|
mask = mask.reshape((self.rand_size, self.rand_size))
|
|
mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
|
|
|
|
return mask
|
|
|
|
|
|
class DetResize:
|
|
"""Resize images & bbox & mask.
|
|
|
|
This transform resizes the input image to some scale. Bboxes and masks are
|
|
then resized with the same scale factor. If the input dict contains the key
|
|
"scale", then the scale in the input dict is used, otherwise the specified
|
|
scale in the init method is used. If the input dict contains the key
|
|
"scale_factor" (if MultiScaleFlipAug does not give img_scale but
|
|
scale_factor), the actual scale will be computed by image shape and
|
|
scale_factor.
|
|
|
|
`img_scale` can either be a tuple (single-scale) or a list of tuple
|
|
(multi-scale). There are 3 multiscale modes:
|
|
|
|
- ``ratio_range is not None``: randomly sample a ratio from the ratio \
|
|
range and multiply it with the image scale.
|
|
- ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \
|
|
sample a scale from the multiscale range.
|
|
- ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \
|
|
sample a scale from multiple scales.
|
|
|
|
Args:
|
|
img_scale (tuple or list[tuple]): Images scales for resizing.
|
|
multiscale_mode (str): Either "range" or "value".
|
|
ratio_range (tuple[float]): (min_ratio, max_ratio)
|
|
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
|
|
image.
|
|
bbox_clip_border (bool, optional): Whether to clip the objects outside
|
|
the border of the image. In some dataset like MOT17, the gt bboxes
|
|
are allowed to cross the border of images. Therefore, we don't
|
|
need to clip the gt bboxes in these cases. Defaults to True.
|
|
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
|
|
These two backends generates slightly different sample. Defaults
|
|
to 'cv2'.
|
|
interpolation (str): Interpolation method, accepted values are
|
|
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
|
|
backend, "nearest", "bilinear" for 'pillow' backend.
|
|
override (bool, optional): Whether to override `scale` and
|
|
`scale_factor` so as to call resize twice. Default False. If True,
|
|
after the first resizing, the existed `scale` and `scale_factor`
|
|
will be ignored so the second resizing can be allowed.
|
|
This option is a work-around for multiple times of resize in DETR.
|
|
Defaults to False.
|
|
"""
|
|
|
|
def __init__(self,
|
|
img_scale=None,
|
|
multiscale_mode='range',
|
|
ratio_range=None,
|
|
keep_ratio=True,
|
|
bbox_clip_border=True,
|
|
backend='cv2',
|
|
interpolation='bilinear',
|
|
override=False):
|
|
if img_scale is None:
|
|
self.img_scale = None
|
|
else:
|
|
if isinstance(img_scale, list):
|
|
self.img_scale = img_scale
|
|
else:
|
|
self.img_scale = [img_scale]
|
|
assert mmcv.is_list_of(self.img_scale, tuple)
|
|
|
|
if ratio_range is not None:
|
|
# mode 1: given a scale and a range of image ratio
|
|
assert len(self.img_scale) == 1
|
|
else:
|
|
# mode 2: given multiple scales or a range of scales
|
|
assert multiscale_mode in ['value', 'range']
|
|
|
|
self.backend = backend
|
|
self.multiscale_mode = multiscale_mode
|
|
self.ratio_range = ratio_range
|
|
self.keep_ratio = keep_ratio
|
|
self.interpolation = interpolation
|
|
self.override = override
|
|
self.bbox_clip_border = bbox_clip_border
|
|
|
|
@staticmethod
|
|
def random_select(img_scales):
|
|
"""Randomly select an img_scale from given candidates.
|
|
|
|
Args:
|
|
img_scales (list[tuple]): Images scales for selection.
|
|
|
|
Returns:
|
|
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \
|
|
where ``img_scale`` is the selected image scale and \
|
|
``scale_idx`` is the selected index in the given candidates.
|
|
"""
|
|
|
|
assert mmcv.is_list_of(img_scales, tuple)
|
|
scale_idx = np.random.randint(len(img_scales))
|
|
img_scale = img_scales[scale_idx]
|
|
return img_scale, scale_idx
|
|
|
|
@staticmethod
|
|
def random_sample(img_scales):
|
|
"""Randomly sample an img_scale when ``multiscale_mode=='range'``.
|
|
|
|
Args:
|
|
img_scales (list[tuple]): Images scale range for sampling.
|
|
There must be two tuples in img_scales, which specify the lower
|
|
and upper bound of image scales.
|
|
|
|
Returns:
|
|
(tuple, None): Returns a tuple ``(img_scale, None)``, where \
|
|
``img_scale`` is sampled scale and None is just a placeholder \
|
|
to be consistent with :func:`random_select`.
|
|
"""
|
|
|
|
assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
|
|
img_scale_long = [max(s) for s in img_scales]
|
|
img_scale_short = [min(s) for s in img_scales]
|
|
long_edge = np.random.randint(min(img_scale_long),
|
|
max(img_scale_long) + 1)
|
|
short_edge = np.random.randint(min(img_scale_short),
|
|
max(img_scale_short) + 1)
|
|
img_scale = (long_edge, short_edge)
|
|
return img_scale, None
|
|
|
|
@staticmethod
|
|
def random_sample_ratio(img_scale, ratio_range):
|
|
"""Randomly sample an img_scale when ``ratio_range`` is specified.
|
|
|
|
A ratio will be randomly sampled from the range specified by
|
|
``ratio_range``. Then it would be multiplied with ``img_scale`` to
|
|
generate sampled scale.
|
|
|
|
Args:
|
|
img_scale (tuple): Images scale base to multiply with ratio.
|
|
ratio_range (tuple[float]): The minimum and maximum ratio to scale
|
|
the ``img_scale``.
|
|
|
|
Returns:
|
|
(tuple, None): Returns a tuple ``(scale, None)``, where \
|
|
``scale`` is sampled ratio multiplied with ``img_scale`` and \
|
|
None is just a placeholder to be consistent with \
|
|
:func:`random_select`.
|
|
"""
|
|
|
|
assert isinstance(img_scale, tuple) and len(img_scale) == 2
|
|
min_ratio, max_ratio = ratio_range
|
|
assert min_ratio <= max_ratio
|
|
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
|
|
scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
|
|
return scale, None
|
|
|
|
def _random_scale(self, sample):
|
|
"""Randomly sample an img_scale according to ``ratio_range`` and
|
|
``multiscale_mode``.
|
|
|
|
If ``ratio_range`` is specified, a ratio will be sampled and be
|
|
multiplied with ``img_scale``.
|
|
If multiple scales are specified by ``img_scale``, a scale will be
|
|
sampled according to ``multiscale_mode``.
|
|
Otherwise, single scale will be used.
|
|
|
|
Args:
|
|
sample (Sample): Result from :obj:`dataset`.
|
|
|
|
Returns:
|
|
dict: Two new keys 'scale` and 'scale_idx` are added into \
|
|
``sample``, which would be used by subsequent pipelines.
|
|
"""
|
|
|
|
if self.ratio_range is not None:
|
|
scale, scale_idx = self.random_sample_ratio(
|
|
self.img_scale[0], self.ratio_range)
|
|
elif len(self.img_scale) == 1:
|
|
scale, scale_idx = self.img_scale[0], 0
|
|
elif self.multiscale_mode == 'range':
|
|
scale, scale_idx = self.random_sample(self.img_scale)
|
|
elif self.multiscale_mode == 'value':
|
|
scale, scale_idx = self.random_select(self.img_scale)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
sample.scale = scale
|
|
sample.scale_idx = scale_idx
|
|
|
|
def _resize_img(self, sample):
|
|
"""Resize images with ``sample.scale``."""
|
|
if self.keep_ratio:
|
|
img, scale_factor = mmcv.imrescale(
|
|
sample[sample.img_field],
|
|
sample.scale,
|
|
return_scale=True,
|
|
interpolation=self.interpolation,
|
|
backend=self.backend)
|
|
# the w_scale and h_scale has minor difference
|
|
# a real fix should be done in the mmcv.imrescale in the future
|
|
new_h, new_w = img.shape[:2]
|
|
h, w = sample[sample.img_field].shape[:2]
|
|
w_scale = new_w / w
|
|
h_scale = new_h / h
|
|
else:
|
|
img, w_scale, h_scale = mmcv.imresize(
|
|
sample[sample.img_field],
|
|
sample.scale,
|
|
return_scale=True,
|
|
interpolation=self.interpolation,
|
|
backend=self.backend)
|
|
sample[sample.img_field] = img
|
|
|
|
scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
|
|
dtype=np.float32)
|
|
sample.img_shape = img.shape
|
|
# in case that there is no padding
|
|
sample.pad_shape = img.shape
|
|
sample.scale_factor = scale_factor
|
|
sample.keep_ratio = self.keep_ratio
|
|
|
|
def _resize_bboxes(self, sample):
|
|
"""Resize bounding boxes with ``sample.scale_factor``."""
|
|
for key in sample.get('bbox_fields', []):
|
|
bboxes = sample[key] * sample.scale_factor
|
|
if self.bbox_clip_border:
|
|
img_shape = sample.img_shape
|
|
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
|
|
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
|
|
sample[key] = bboxes
|
|
|
|
def _resize_masks(self, sample):
|
|
"""Resize masks with ``sample.scale``"""
|
|
for key in sample.get('mask_fields', []):
|
|
if sample[key] is None:
|
|
continue
|
|
if self.keep_ratio:
|
|
sample[key] = sample[key].rescale(sample.scale)
|
|
else:
|
|
sample[key] = sample[key].resize(sample.img_shape[:2])
|
|
|
|
def _resize_seg(self, sample):
|
|
"""Resize semantic segmentation map with ``sample.scale``."""
|
|
for key in sample.get('seg_fields', []):
|
|
if self.keep_ratio:
|
|
gt_seg = mmcv.imrescale(sample[key],
|
|
sample.scale,
|
|
interpolation='nearest',
|
|
backend=self.backend)
|
|
else:
|
|
gt_seg = mmcv.imresize(sample[key],
|
|
sample.scale,
|
|
interpolation='nearest',
|
|
backend=self.backend)
|
|
sample[key] = gt_seg
|
|
|
|
def __call__(self, sample):
|
|
"""Call function to resize images, bounding boxes, masks, semantic
|
|
segmentation map.
|
|
|
|
Args:
|
|
sample (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Resized sample, 'img_shape', 'pad_shape', 'scale_factor', \
|
|
'keep_ratio' keys are added into result dict.
|
|
"""
|
|
|
|
sample.scale_idx = None
|
|
if 'scale' not in sample:
|
|
if 'scale_factor' in sample:
|
|
img_shape = sample.hr_img.shape[:2]
|
|
scale_factor = sample.scale_factor
|
|
assert isinstance(scale_factor, float)
|
|
sample.scale = tuple(
|
|
[int(x * scale_factor) for x in img_shape][::-1])
|
|
else:
|
|
self._random_scale(sample)
|
|
else:
|
|
if not self.override:
|
|
assert 'scale_factor' not in sample, (
|
|
'scale and scale_factor cannot be both set.')
|
|
else:
|
|
sample.pop('scale')
|
|
if 'scale_factor' in sample:
|
|
sample.pop('scale_factor')
|
|
self._random_scale(sample)
|
|
|
|
self._resize_img(sample)
|
|
self._resize_bboxes(sample)
|
|
self._resize_masks(sample)
|
|
self._resize_seg(sample)
|
|
return sample
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(img_scale={self.img_scale}, '
|
|
repr_str += f'multiscale_mode={self.multiscale_mode}, '
|
|
repr_str += f'ratio_range={self.ratio_range}, '
|
|
repr_str += f'keep_ratio={self.keep_ratio}, '
|
|
repr_str += f'bbox_clip_border={self.bbox_clip_border})'
|
|
return repr_str
|
|
|
|
|
|
class DetRandomFlip:
|
|
"""Flip the image & bbox & mask.
|
|
|
|
If the input dict contains the key "flip", then the flag will be used,
|
|
otherwise it will be randomly decided by a ratio specified in the init
|
|
method.
|
|
|
|
When random flip is enabled, ``flip_ratio``/``direction`` can either be a
|
|
float/string or tuple of float/string. There are 3 flip modes:
|
|
|
|
- ``flip_ratio`` is float, ``direction`` is string: the image will be
|
|
``direction``ly flipped with probability of ``flip_ratio`` .
|
|
E.g., ``flip_ratio=0.5``, ``direction='horizontal'``,
|
|
then image will be horizontally flipped with probability of 0.5.
|
|
- ``flip_ratio`` is float, ``direction`` is list of string: the image will
|
|
be ``direction[i]``ly flipped with probability of
|
|
``flip_ratio/len(direction)``.
|
|
E.g., ``flip_ratio=0.5``, ``direction=['horizontal', 'vertical']``,
|
|
then image will be horizontally flipped with probability of 0.25,
|
|
vertically with probability of 0.25.
|
|
- ``flip_ratio`` is list of float, ``direction`` is list of string:
|
|
given ``len(flip_ratio) == len(direction)``, the image will
|
|
be ``direction[i]``ly flipped with probability of ``flip_ratio[i]``.
|
|
E.g., ``flip_ratio=[0.3, 0.5]``, ``direction=['horizontal',
|
|
'vertical']``, then image will be horizontally flipped with probability
|
|
of 0.3, vertically with probability of 0.5.
|
|
|
|
Args:
|
|
flip_ratio (float | list[float], optional): The flipping probability.
|
|
Default: None.
|
|
direction(str | list[str], optional): The flipping direction. Options
|
|
are 'horizontal', 'vertical', 'diagonal'. Default: 'horizontal'.
|
|
If input is a list, the length must equal ``flip_ratio``. Each
|
|
element in ``flip_ratio`` indicates the flip probability of
|
|
corresponding direction.
|
|
"""
|
|
|
|
def __init__(self, flip_ratio=None, direction='horizontal'):
|
|
if isinstance(flip_ratio, list):
|
|
assert mmcv.is_list_of(flip_ratio, float)
|
|
assert 0 <= sum(flip_ratio) <= 1
|
|
elif isinstance(flip_ratio, float):
|
|
assert 0 <= flip_ratio <= 1
|
|
elif flip_ratio is None:
|
|
pass
|
|
else:
|
|
raise ValueError('flip_ratios must be None, float, '
|
|
'or list of float')
|
|
self.flip_ratio = flip_ratio
|
|
|
|
valid_directions = ['horizontal', 'vertical', 'diagonal']
|
|
if isinstance(direction, str):
|
|
assert direction in valid_directions
|
|
elif isinstance(direction, list):
|
|
assert mmcv.is_list_of(direction, str)
|
|
assert set(direction).issubset(set(valid_directions))
|
|
else:
|
|
raise ValueError('direction must be either str or list of str')
|
|
self.direction = direction
|
|
|
|
if isinstance(flip_ratio, list):
|
|
assert len(self.flip_ratio) == len(self.direction)
|
|
|
|
def bbox_flip(self, bboxes, img_shape, direction):
|
|
"""Flip bboxes horizontally.
|
|
|
|
Args:
|
|
bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k)
|
|
img_shape (tuple[int]): Image shape (height, width)
|
|
direction (str): Flip direction. Options are 'horizontal',
|
|
'vertical'.
|
|
|
|
Returns:
|
|
numpy.ndarray: Flipped bounding boxes.
|
|
"""
|
|
|
|
assert bboxes.shape[-1] % 4 == 0
|
|
flipped = bboxes.copy()
|
|
if direction == 'horizontal':
|
|
w = img_shape[1]
|
|
flipped[..., 0::4] = w - bboxes[..., 2::4]
|
|
flipped[..., 2::4] = w - bboxes[..., 0::4]
|
|
elif direction == 'vertical':
|
|
h = img_shape[0]
|
|
flipped[..., 1::4] = h - bboxes[..., 3::4]
|
|
flipped[..., 3::4] = h - bboxes[..., 1::4]
|
|
elif direction == 'diagonal':
|
|
w = img_shape[1]
|
|
h = img_shape[0]
|
|
flipped[..., 0::4] = w - bboxes[..., 2::4]
|
|
flipped[..., 1::4] = h - bboxes[..., 3::4]
|
|
flipped[..., 2::4] = w - bboxes[..., 0::4]
|
|
flipped[..., 3::4] = h - bboxes[..., 1::4]
|
|
else:
|
|
raise ValueError(f"Invalid flipping direction '{direction}'")
|
|
return flipped
|
|
|
|
def __call__(self, sample):
|
|
"""Call function to flip bounding boxes, masks, semantic segmentation
|
|
maps.
|
|
|
|
Args:
|
|
sample (Sample): Result from loading pipeline.
|
|
|
|
Returns:
|
|
Sample: Flipped sample, 'flip', 'flip_direction' keys are added \
|
|
into sample.
|
|
"""
|
|
|
|
if 'flip' not in sample:
|
|
if isinstance(self.direction, list):
|
|
# None means non-flip
|
|
direction_list = self.direction + [None]
|
|
else:
|
|
# None means non-flip
|
|
direction_list = [self.direction, None]
|
|
|
|
if isinstance(self.flip_ratio, list):
|
|
non_flip_ratio = 1 - sum(self.flip_ratio)
|
|
flip_ratio_list = self.flip_ratio + [non_flip_ratio]
|
|
else:
|
|
non_flip_ratio = 1 - self.flip_ratio
|
|
# exclude non-flip
|
|
single_ratio = self.flip_ratio / (len(direction_list) - 1)
|
|
flip_ratio_list = [single_ratio] * (len(direction_list) -
|
|
1) + [non_flip_ratio]
|
|
|
|
cur_dir = np.random.choice(direction_list, p=flip_ratio_list)
|
|
|
|
sample.flip = cur_dir is not None
|
|
if 'flip_direction' not in sample:
|
|
sample.flip_direction = cur_dir
|
|
if sample.flip:
|
|
# flip image
|
|
sample[sample.img_field] = mmcv.imflip(
|
|
sample[sample.img_field], direction=sample.flip_direction)
|
|
# flip bboxes
|
|
for key in sample.bbox_fields:
|
|
sample[key] = self.bbox_flip(sample[key], sample.img_shape,
|
|
sample.flip_direction)
|
|
# flip masks
|
|
for key in sample.mask_fields:
|
|
sample[key] = sample[key].flip(sample.flip_direction)
|
|
|
|
# flip segs
|
|
for key in sample.seg_fields:
|
|
sample[key] = mmcv.imflip(sample[key],
|
|
direction=sample.flip_direction)
|
|
return sample
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})'
|
|
|
|
|
|
class DetRandomCrop:
|
|
"""Random crop the image & bboxes & masks.
|
|
|
|
The absolute `crop_size` is sampled based on `crop_type` and `image_size`,
|
|
then the cropped sample are generated.
|
|
|
|
Args:
|
|
crop_size (tuple): The relative ratio or absolute pixels of
|
|
height and width.
|
|
crop_type (str, optional): one of "relative_range", "relative",
|
|
"absolute", "absolute_range". "relative" randomly crops
|
|
(h * crop_size[0], w * crop_size[1]) part from an input of size
|
|
(h, w). "relative_range" uniformly samples relative crop size from
|
|
range [crop_size[0], 1] and [crop_size[1], 1] for height and width
|
|
respectively. "absolute" crops from an input with absolute size
|
|
(crop_size[0], crop_size[1]). "absolute_range" uniformly samples
|
|
crop_h in range [crop_size[0], min(h, crop_size[1])] and crop_w
|
|
in range [crop_size[0], min(w, crop_size[1])]. Default "absolute".
|
|
allow_negative_crop (bool, optional): Whether to allow a crop that does
|
|
not contain any bbox area. Default False.
|
|
recompute_bbox (bool, optional): Whether to re-compute the boxes based
|
|
on cropped instance masks. Default False.
|
|
bbox_clip_border (bool, optional): Whether clip the objects outside
|
|
the border of the image. Defaults to True.
|
|
|
|
Note:
|
|
- If the image is smaller than the absolute crop size, return the
|
|
original image.
|
|
- The keys for bboxes, labels and masks must be aligned. That is,
|
|
`gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and
|
|
`gt_bboxes_ignore` corresponds to `gt_labels_ignore` and
|
|
`gt_masks_ignore`.
|
|
- If the crop does not contain any gt-bbox region and
|
|
`allow_negative_crop` is set to False, skip this image.
|
|
"""
|
|
|
|
def __init__(self,
|
|
crop_size,
|
|
crop_type='absolute',
|
|
allow_negative_crop=False,
|
|
recompute_bbox=False,
|
|
bbox_clip_border=True):
|
|
if crop_type not in [
|
|
'relative_range', 'relative', 'absolute', 'absolute_range'
|
|
]:
|
|
raise ValueError(f'Invalid crop_type {crop_type}.')
|
|
if crop_type in ['absolute', 'absolute_range']:
|
|
assert crop_size[0] > 0 and crop_size[1] > 0
|
|
assert isinstance(crop_size[0], int) and isinstance(
|
|
crop_size[1], int)
|
|
else:
|
|
assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1
|
|
self.crop_size = crop_size
|
|
self.crop_type = crop_type
|
|
self.allow_negative_crop = allow_negative_crop
|
|
self.bbox_clip_border = bbox_clip_border
|
|
self.recompute_bbox = recompute_bbox
|
|
# The key correspondence from bboxes to labels and masks.
|
|
self.bbox2label = {
|
|
'gt_bboxes': 'gt_labels',
|
|
'gt_bboxes_ignore': 'gt_labels_ignore'
|
|
}
|
|
self.bbox2mask = {
|
|
'gt_bboxes': 'gt_masks',
|
|
'gt_bboxes_ignore': 'gt_masks_ignore'
|
|
}
|
|
|
|
def _crop_data(self, sample, crop_size, allow_negative_crop):
|
|
"""Function to randomly crop images, bounding boxes, masks, semantic
|
|
segmentation maps.
|
|
|
|
Args:
|
|
sample (Sample): Result from loading pipeline.
|
|
crop_size (tuple): Expected absolute size after cropping, (h, w).
|
|
allow_negative_crop (bool): Whether to allow a crop that does not
|
|
contain any bbox area. Default to False.
|
|
|
|
Returns:
|
|
dict: Randomly cropped Sample data, 'img_shape' key in sample is
|
|
updated according to crop size.
|
|
"""
|
|
max_try_times = 20
|
|
crop_times = 0
|
|
while True:
|
|
crop_times += 1
|
|
assert crop_size[0] > 0 and crop_size[1] > 0
|
|
img = sample[sample.img_field]
|
|
margin_h = max(img.shape[0] - crop_size[0], 0)
|
|
margin_w = max(img.shape[1] - crop_size[1], 0)
|
|
offset_h = np.random.randint(0, margin_h + 1)
|
|
offset_w = np.random.randint(0, margin_w + 1)
|
|
crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
|
|
crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
|
|
# crop bboxes accordingly and clip to the image boundary
|
|
is_valid = False
|
|
for key in sample.get('bbox_fields', []):
|
|
# e.g. gt_bboxes and gt_bboxes_ignore
|
|
bbox_offset = np.array(
|
|
[offset_w, offset_h, offset_w, offset_h], dtype=np.float32)
|
|
bboxes = sample[key] - bbox_offset
|
|
if self.bbox_clip_border:
|
|
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, crop_size[1])
|
|
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, crop_size[0])
|
|
valid_inds = (bboxes[:, 2] > bboxes[:, 0]) & (bboxes[:, 3] >
|
|
bboxes[:, 1])
|
|
sample[key] = bboxes[valid_inds, :]
|
|
# label fields. e.g. gt_labels and gt_labels_ignore
|
|
label_key = self.bbox2label.get(key)
|
|
if label_key in sample:
|
|
sample[label_key] = sample[label_key][valid_inds]
|
|
# mask fields, e.g. gt_masks and gt_masks_ignore
|
|
mask_key = self.bbox2mask.get(key)
|
|
if mask_key in sample:
|
|
sample[mask_key] = sample[mask_key][
|
|
valid_inds.nonzero()[0]].crop(
|
|
np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
|
|
if self.recompute_bbox:
|
|
sample[key] = sample[mask_key].get_bboxes()
|
|
if valid_inds.any() and key == 'gt_bboxes':
|
|
is_valid = True
|
|
if (crop_times
|
|
== max_try_times) or is_valid or allow_negative_crop:
|
|
# crop the image
|
|
img = sample[sample.img_field]
|
|
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
|
|
img_shape = img.shape
|
|
sample[sample.img_field] = img
|
|
sample.img_shape = img_shape
|
|
# crop semantic seg
|
|
for key in sample.get('seg_fields', []):
|
|
sample[key] = sample[key][crop_y1:crop_y2, crop_x1:crop_x2]
|
|
break
|
|
return sample
|
|
|
|
def _get_crop_size(self, image_size):
|
|
"""Randomly generates the absolute crop size based on `crop_type` and
|
|
`image_size`.
|
|
|
|
Args:
|
|
image_size (tuple): (h, w).
|
|
|
|
Returns:
|
|
crop_size (tuple): (crop_h, crop_w) in absolute pixels.
|
|
"""
|
|
h, w = image_size
|
|
if self.crop_type == 'absolute':
|
|
return (min(self.crop_size[0], h), min(self.crop_size[1], w))
|
|
elif self.crop_type == 'absolute_range':
|
|
assert self.crop_size[0] <= self.crop_size[1]
|
|
crop_h = np.random.randint(min(h, self.crop_size[0]),
|
|
min(h, self.crop_size[1]) + 1)
|
|
crop_w = np.random.randint(min(w, self.crop_size[0]),
|
|
min(w, self.crop_size[1]) + 1)
|
|
return crop_h, crop_w
|
|
elif self.crop_type == 'relative':
|
|
crop_h, crop_w = self.crop_size
|
|
return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
|
|
elif self.crop_type == 'relative_range':
|
|
crop_size = np.asarray(self.crop_size, dtype=np.float32)
|
|
crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size)
|
|
return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
|
|
|
|
def __call__(self, sample):
|
|
"""Call function to randomly crop images, bounding boxes, masks,
|
|
semantic segmentation maps.
|
|
|
|
Args:
|
|
sample (Sample): Result from loading pipeline.
|
|
|
|
Returns:
|
|
Sample: Randomly cropped Sample data, 'img_shape' key in sample is
|
|
updated according to crop size.
|
|
"""
|
|
image_size = sample[sample.img_field].shape[:2]
|
|
crop_size = self._get_crop_size(image_size)
|
|
sample = self._crop_data(sample, crop_size, self.allow_negative_crop)
|
|
sample.bbox_num = sample.gt_bboxes.shape[0]
|
|
return sample
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(crop_size={self.crop_size}, '
|
|
repr_str += f'crop_type={self.crop_type}, '
|
|
repr_str += f'allow_negative_crop={self.allow_negative_crop}, '
|
|
repr_str += f'bbox_clip_border={self.bbox_clip_border})'
|
|
return repr_str
|
|
|
|
|
|
class AutoAugment:
|
|
"""Auto augmentation.
|
|
|
|
This data augmentation is proposed in `Learning Data Augmentation
|
|
Strategies for Object Detection <https://arxiv.org/pdf/1906.11172>`_.
|
|
|
|
TODO: Implement 'Shear', 'Sharpness' and 'Rotate' transforms
|
|
|
|
Args:
|
|
policies (list[list[transformer]]): The policies of auto augmentation. Each
|
|
policy in ``policies`` is a specific augmentation policy, and is
|
|
composed by several augmentations (dict). When AutoAugment is
|
|
called, a random policy in ``policies`` will be selected to
|
|
augment images.
|
|
"""
|
|
|
|
def __init__(self, policies):
|
|
assert isinstance(policies, list) and len(policies) > 0, \
|
|
'Policies must be a non-empty list.'
|
|
for policy in policies:
|
|
assert isinstance(policy, list) and len(policy) > 0, \
|
|
'Each policy in policies must be a non-empty list.'
|
|
|
|
self.policies = copy.deepcopy(policies)
|
|
self.transforms = [Compose(policy) for policy in self.policies]
|
|
|
|
def __call__(self, sample):
|
|
transform = np.random.choice(self.transforms)
|
|
return transform(sample)
|