494 lines
16 KiB
Python
494 lines
16 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import warnings
|
|
from typing import Dict, Optional, Union
|
|
import io
|
|
|
|
import mmcv
|
|
import mmengine.fileio as fileio
|
|
import numpy as np
|
|
from mmcv.transforms import BaseTransform
|
|
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
|
|
from mmcv.transforms import LoadImageFromFile
|
|
import imageio
|
|
from mmseg.registry import TRANSFORMS
|
|
from mmseg.utils import datafrombytes
|
|
|
|
try:
|
|
from osgeo import gdal
|
|
except ImportError:
|
|
gdal = None
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class LoadAnnotationsNpz(MMCV_LoadAnnotations):
|
|
"""Load annotations for semantic segmentation provided by dataset.
|
|
|
|
The annotation format is as the following:
|
|
|
|
.. code-block:: python
|
|
|
|
{
|
|
# Filename of semantic segmentation ground truth file.
|
|
'seg_map_path': 'a/b/c'
|
|
}
|
|
|
|
After this module, the annotation has been changed to the format below:
|
|
|
|
.. code-block:: python
|
|
|
|
{
|
|
# in str
|
|
'seg_fields': List
|
|
# In uint8 type.
|
|
'gt_seg_map': np.ndarray (H, W)
|
|
}
|
|
|
|
Required Keys:
|
|
|
|
- seg_map_path (str): Path of semantic segmentation ground truth file.
|
|
|
|
Added Keys:
|
|
|
|
- seg_fields (List)
|
|
- gt_seg_map (np.uint8)
|
|
|
|
Args:
|
|
reduce_zero_label (bool, optional): Whether reduce all label value
|
|
by 1. Usually used for datasets where 0 is background label.
|
|
Defaults to None.
|
|
imdecode_backend (str): The image decoding backend type. The backend
|
|
argument for :func:``mmcv.imfrombytes``.
|
|
See :fun:``mmcv.imfrombytes`` for details.
|
|
Defaults to 'pillow'.
|
|
backend_args (dict): Arguments to instantiate a file backend.
|
|
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
|
for details. Defaults to None.
|
|
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
reduce_zero_label=None,
|
|
backend_args=None,
|
|
data_key='arr_0',
|
|
imdecode_backend='pillow',
|
|
) -> None:
|
|
super().__init__(
|
|
with_bbox=False,
|
|
with_label=False,
|
|
with_seg=True,
|
|
with_keypoints=False,
|
|
imdecode_backend=imdecode_backend,
|
|
backend_args=backend_args)
|
|
self.reduce_zero_label = reduce_zero_label
|
|
if self.reduce_zero_label is not None:
|
|
warnings.warn('`reduce_zero_label` will be deprecated, '
|
|
'if you would like to ignore the zero label, please '
|
|
'set `reduce_zero_label=True` when dataset '
|
|
'initialized')
|
|
self.imdecode_backend = imdecode_backend
|
|
self.data_key = data_key
|
|
|
|
def _load_seg_map(self, results: dict) -> None:
|
|
"""Private function to load semantic segmentation annotations.
|
|
|
|
Args:
|
|
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
|
|
|
Returns:
|
|
dict: The dict contains loaded semantic segmentation annotations.
|
|
"""
|
|
|
|
# img_bytes = fileio.get(
|
|
# results['seg_map_path'], backend_args=self.backend_args)
|
|
# gt_semantic_seg = mmcv.imfrombytes(
|
|
# img_bytes, flag='unchanged',
|
|
# backend=self.imdecode_backend).squeeze().astype(np.uint8)
|
|
gt_semantic_seg = np.load(results['seg_map_path'])[self.data_key].squeeze().astype(np.uint8)
|
|
# reduce zero_label
|
|
if self.reduce_zero_label is None:
|
|
self.reduce_zero_label = results['reduce_zero_label']
|
|
assert self.reduce_zero_label == results['reduce_zero_label'], \
|
|
'Initialize dataset with `reduce_zero_label` as ' \
|
|
f'{results["reduce_zero_label"]} but when load annotation ' \
|
|
f'the `reduce_zero_label` is {self.reduce_zero_label}'
|
|
if self.reduce_zero_label:
|
|
# avoid using underflow conversion
|
|
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
|
gt_semantic_seg = gt_semantic_seg - 1
|
|
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
|
# modify if custom classes
|
|
if results.get('label_map', None) is not None:
|
|
# Add deep copy to solve bug of repeatedly
|
|
# replace `gt_semantic_seg`, which is reported in
|
|
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
|
|
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
|
for old_id, new_id in results['label_map'].items():
|
|
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
|
results['gt_seg_map'] = gt_semantic_seg
|
|
results['seg_fields'].append('gt_seg_map')
|
|
|
|
def __repr__(self) -> str:
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(reduce_zero_label={self.reduce_zero_label}, '
|
|
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
|
|
repr_str += f'backend_args={self.backend_args})'
|
|
return repr_str
|
|
|
|
@TRANSFORMS.register_module()
|
|
class LoadImageSingleChannel(BaseTransform):
|
|
"""Load a Remote Sensing mage from file.
|
|
|
|
Required Keys:
|
|
|
|
- img_path
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
- img_shape
|
|
- ori_shape
|
|
|
|
Args:
|
|
to_float32 (bool): Whether to convert the loaded image to a float32
|
|
numpy array. If set to False, the loaded image is a float64 array.
|
|
Defaults to True.
|
|
"""
|
|
|
|
def __init__(self, to_float32: bool = True):
|
|
self.to_float32 = to_float32
|
|
# self.data_key = data_key
|
|
if gdal is None:
|
|
raise RuntimeError('gdal is not installed')
|
|
|
|
def transform(self, results: Dict) -> Dict:
|
|
"""Functions to load image.
|
|
|
|
Args:
|
|
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
|
|
|
Returns:
|
|
dict: The dict contains loaded image and meta information.
|
|
"""
|
|
|
|
filename = results['img_path']
|
|
img = imageio.imread(filename) # h, w, c
|
|
img = img[:, :, 0]
|
|
# if ds is None:
|
|
# raise Exception(f'Unable to open file: {filename}')
|
|
# print(img)s
|
|
# img = np.einsum('ijk->jki', img)
|
|
|
|
if self.to_float32:
|
|
img = img.astype(np.float32)
|
|
|
|
results['img'] = img
|
|
results['img_shape'] = img.shape[:2]
|
|
results['ori_shape'] = img.shape[:2]
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = (f'{self.__class__.__name__}('
|
|
f'to_float32={self.to_float32})')
|
|
return repr_str
|
|
|
|
@TRANSFORMS.register_module()
|
|
class LoadAnnotationsOil(MMCV_LoadAnnotations):
|
|
"""Load annotations for semantic segmentation provided by dataset.
|
|
|
|
The annotation format is as the following:
|
|
|
|
.. code-block:: python
|
|
|
|
{
|
|
# Filename of semantic segmentation ground truth file.
|
|
'seg_map_path': 'a/b/c'
|
|
}
|
|
|
|
After this module, the annotation has been changed to the format below:
|
|
|
|
.. code-block:: python
|
|
|
|
{
|
|
# in str
|
|
'seg_fields': List
|
|
# In uint8 type.
|
|
'gt_seg_map': np.ndarray (H, W)
|
|
}
|
|
|
|
Required Keys:
|
|
|
|
- seg_map_path (str): Path of semantic segmentation ground truth file.
|
|
|
|
Added Keys:
|
|
|
|
- seg_fields (List)
|
|
- gt_seg_map (np.uint8)
|
|
|
|
Args:
|
|
reduce_zero_label (bool, optional): Whether reduce all label value
|
|
by 1. Usually used for datasets where 0 is background label.
|
|
Defaults to None.
|
|
imdecode_backend (str): The image decoding backend type. The backend
|
|
argument for :func:``mmcv.imfrombytes``.
|
|
See :fun:``mmcv.imfrombytes`` for details.
|
|
Defaults to 'pillow'.
|
|
backend_args (dict): Arguments to instantiate a file backend.
|
|
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
|
for details. Defaults to None.
|
|
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
reduce_zero_label=None,
|
|
backend_args=None,
|
|
data_key='arr_0',
|
|
imdecode_backend='pillow',
|
|
) -> None:
|
|
super().__init__(
|
|
with_bbox=False,
|
|
with_label=False,
|
|
with_seg=True,
|
|
with_keypoints=False,
|
|
imdecode_backend=imdecode_backend,
|
|
backend_args=backend_args)
|
|
self.reduce_zero_label = reduce_zero_label
|
|
if self.reduce_zero_label is not None:
|
|
warnings.warn('`reduce_zero_label` will be deprecated, '
|
|
'if you would like to ignore the zero label, please '
|
|
'set `reduce_zero_label=True` when dataset '
|
|
'initialized')
|
|
self.imdecode_backend = imdecode_backend
|
|
self.data_key = data_key
|
|
|
|
def _load_seg_map(self, results: dict) -> None:
|
|
"""Private function to load semantic segmentation annotations.
|
|
|
|
Args:
|
|
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
|
|
|
Returns:
|
|
dict: The dict contains loaded semantic segmentation annotations.
|
|
"""
|
|
|
|
# img_bytes = fileio.get(
|
|
# results['seg_map_path'], backend_args=self.backend_args)
|
|
# gt_semantic_seg = mmcv.imfrombytes(
|
|
# img_bytes, flag='unchanged',
|
|
# backend=self.imdecode_backend).squeeze().astype(np.uint8)
|
|
seg_map = gdal.Open(results['seg_map_path']).ReadAsArray()
|
|
gt_semantic_seg = np.zeros_like(seg_map).astype(np.uint8)
|
|
gt_semantic_seg[seg_map==3.] = 1
|
|
# reduce zero_label
|
|
if self.reduce_zero_label is None:
|
|
self.reduce_zero_label = results['reduce_zero_label']
|
|
assert self.reduce_zero_label == results['reduce_zero_label'], \
|
|
'Initialize dataset with `reduce_zero_label` as ' \
|
|
f'{results["reduce_zero_label"]} but when load annotation ' \
|
|
f'the `reduce_zero_label` is {self.reduce_zero_label}'
|
|
if self.reduce_zero_label:
|
|
# avoid using underflow conversion
|
|
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
|
gt_semantic_seg = gt_semantic_seg - 1
|
|
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
|
# modify if custom classes
|
|
if results.get('label_map', None) is not None:
|
|
# Add deep copy to solve bug of repeatedly
|
|
# replace `gt_semantic_seg`, which is reported in
|
|
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
|
|
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
|
for old_id, new_id in results['label_map'].items():
|
|
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
|
results['gt_seg_map'] = gt_semantic_seg
|
|
results['seg_fields'].append('gt_seg_map')
|
|
|
|
def __repr__(self) -> str:
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(reduce_zero_label={self.reduce_zero_label}, '
|
|
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
|
|
repr_str += f'backend_args={self.backend_args})'
|
|
return repr_str
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class LoadImageFromNpz(BaseTransform):
|
|
"""Load a Remote Sensing mage from file.
|
|
|
|
Required Keys:
|
|
|
|
- img_path
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
- img_shape
|
|
- ori_shape
|
|
|
|
Args:
|
|
to_float32 (bool): Whether to convert the loaded image to a float32
|
|
numpy array. If set to False, the loaded image is a float64 array.
|
|
Defaults to True.
|
|
"""
|
|
|
|
def __init__(self, data_key='arr_0', to_float32: bool = True):
|
|
self.to_float32 = to_float32
|
|
self.data_key = data_key
|
|
if gdal is None:
|
|
raise RuntimeError('gdal is not installed')
|
|
|
|
def transform(self, results: Dict) -> Dict:
|
|
"""Functions to load image.
|
|
|
|
Args:
|
|
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
|
|
|
Returns:
|
|
dict: The dict contains loaded image and meta information.
|
|
"""
|
|
|
|
filename = results['img_path']
|
|
img = np.load(filename)[self.data_key]
|
|
# if ds is None:
|
|
# raise Exception(f'Unable to open file: {filename}')
|
|
# print(img)s
|
|
img = np.einsum('ijk->jki', img)
|
|
|
|
if self.to_float32:
|
|
img = img.astype(np.float32)
|
|
|
|
results['img'] = img
|
|
results['img_shape'] = img.shape[:2]
|
|
results['ori_shape'] = img.shape[:2]
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = (f'{self.__class__.__name__}('
|
|
f'to_float32={self.to_float32})')
|
|
return repr_str
|
|
|
|
@TRANSFORMS.register_module()
|
|
class LoadImageOil(BaseTransform):
|
|
"""Load a Remote Sensing mage from file.
|
|
|
|
Required Keys:
|
|
|
|
- img_path
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
- img_shape
|
|
- ori_shape
|
|
|
|
Args:
|
|
to_float32 (bool): Whether to convert the loaded image to a float32
|
|
numpy array. If set to False, the loaded image is a float64 array.
|
|
Defaults to True.
|
|
"""
|
|
|
|
def __init__(self, data_key='arr_0', to_float32: bool = True):
|
|
self.to_float32 = to_float32
|
|
self.data_key = data_key
|
|
if gdal is None:
|
|
raise RuntimeError('gdal is not installed')
|
|
|
|
def transform(self, results: Dict) -> Dict:
|
|
"""Functions to load image.
|
|
|
|
Args:
|
|
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
|
|
|
Returns:
|
|
dict: The dict contains loaded image and meta information.
|
|
"""
|
|
|
|
filename = results['img_path']
|
|
img = gdal.Open(filename).ReadAsArray()
|
|
img = img[:,:,None]
|
|
# if ds is None:
|
|
# raise Exception(f'Unable to open file: {filename}')
|
|
# print(img)s
|
|
# img = np.einsum('ijk->jki', img)
|
|
|
|
if self.to_float32:
|
|
img = img.astype(np.float32)
|
|
|
|
results['img'] = img
|
|
results['img_shape'] = img.shape[:2]
|
|
results['ori_shape'] = img.shape[:2]
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = (f'{self.__class__.__name__}('
|
|
f'to_float32={self.to_float32})')
|
|
return repr_str
|
|
|
|
@TRANSFORMS.register_module()
|
|
class LoadTsImageFromNpz(BaseTransform):
|
|
"""Load a Remote Sensing mage from file.
|
|
|
|
Required Keys:
|
|
|
|
- img_path
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
- img_shape
|
|
- ori_shape
|
|
|
|
Args:
|
|
to_float32 (bool): Whether to convert the loaded image to a float32
|
|
numpy array. If set to False, the loaded image is a float64 array.
|
|
Defaults to True.
|
|
"""
|
|
|
|
def __init__(self, data_key='arr_0', to_float32: bool = True, ts_size: int=10):
|
|
self.to_float32 = to_float32
|
|
self.data_key = data_key
|
|
if gdal is None:
|
|
raise RuntimeError('gdal is not installed')
|
|
self.ts_size = ts_size
|
|
|
|
def transform(self, results: Dict) -> Dict:
|
|
"""Functions to load image.
|
|
|
|
Args:
|
|
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
|
|
|
Returns:
|
|
dict: The dict contains loaded image and meta information.
|
|
"""
|
|
|
|
filename = results['img_path']
|
|
img = np.load(filename)[self.data_key]
|
|
ts, c, h, w = img.shape
|
|
if ts >= self.ts_size:
|
|
selected_indices = np.random.choice(ts, size=self.ts_size, replace=False)
|
|
else:
|
|
selected_indices = np.random.choice(ts, size=self.ts_size, replace=True)
|
|
selected_indices.sort()
|
|
img = img[selected_indices, :, :, :]
|
|
# print(f'after input shape: {img.shape}')
|
|
img = img.transpose(2, 3, 0, 1).reshape(h, w, self.ts_size*c) # h, w, ts, c -> h, w, ts*c
|
|
# if ds is None:
|
|
# raise Exception(f'Unable to open file: {filename}')
|
|
# print(img)s
|
|
# img = np.einsum('ijk->jki', img)
|
|
|
|
if self.to_float32:
|
|
img = img.astype(np.float32)
|
|
|
|
results['img'] = img
|
|
results['img_shape'] = img.shape[:2]
|
|
results['ori_shape'] = img.shape[:2]
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = (f'{self.__class__.__name__}('
|
|
f'to_float32={self.to_float32})')
|
|
return repr_str
|
|
|
|
|