init
This commit is contained in:
35
finetune/mmseg/datasets/__init__.py
Normal file
35
finetune/mmseg/datasets/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# yapf: disable
|
||||
from .atlantic import AtlanticDataset
|
||||
from .c2sfloods import C2SFloodDataset
|
||||
from .cabuar import CABURADataset
|
||||
from .germany import GermanyCropDataset
|
||||
from .sos import SOSDataset
|
||||
# yapf: disable
|
||||
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
|
||||
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
|
||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||
BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
|
||||
LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||
LoadImageFromNDArray, LoadMultipleRSImageFromFile,
|
||||
LoadSingleRSImageFromFile, PackSegInputs,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomMosaic, RandomRotate, RandomRotFlip, Rerange,
|
||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||
SegRescale)
|
||||
|
||||
# yapf: enable
|
||||
__all__ = [
|
||||
'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip',
|
||||
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
|
||||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
|
||||
'Albu', 'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile',
|
||||
'ConcatCDInput', 'AtlanticDataset', 'C2SFloodDataset',
|
||||
'CABURADataset', 'GermanyCropDataset', 'SOSDataset'
|
||||
]
|
||||
48
finetune/mmseg/datasets/atlantic.py
Normal file
48
finetune/mmseg/datasets/atlantic.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import List
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
@DATASETS.register_module()
|
||||
class AtlanticDataset(BaseSegDataset):
|
||||
METAINFO = dict(
|
||||
classes=('Background', 'Deforestation area'),
|
||||
palette=[[0,0,0], [255,255,255]]
|
||||
)
|
||||
def __init__(self,
|
||||
img_suffix='.tif',
|
||||
seg_map_suffix='.tif',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = json.load(open(self.ann_file))
|
||||
for line in lines:
|
||||
# img_name = line.strip()
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, line['s2_path']))
|
||||
if 'target_path' in line.keys():
|
||||
data_info['seg_map_path'] = osp.join(self.data_root, line['target_path'])
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
552
finetune/mmseg/datasets/basesegdataset.py
Normal file
552
finetune/mmseg/datasets/basesegdataset.py
Normal file
@@ -0,0 +1,552 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import mmengine
|
||||
import mmengine.fileio as fileio
|
||||
import numpy as np
|
||||
from mmengine.dataset import BaseDataset, Compose
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class BaseSegDataset(BaseDataset):
|
||||
"""Custom dataset for semantic segmentation. An example of file structure
|
||||
is as followed.
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
├── data
|
||||
│ ├── my_dataset
|
||||
│ │ ├── img_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{img_suffix}
|
||||
│ │ │ │ ├── yyy{img_suffix}
|
||||
│ │ │ │ ├── zzz{img_suffix}
|
||||
│ │ │ ├── val
|
||||
│ │ ├── ann_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{seg_map_suffix}
|
||||
│ │ │ │ ├── yyy{seg_map_suffix}
|
||||
│ │ │ │ ├── zzz{seg_map_suffix}
|
||||
│ │ │ ├── val
|
||||
|
||||
The img/gt_semantic_seg pair of BaseSegDataset should be of the same
|
||||
except suffix. A valid img/gt_semantic_seg filename pair should be like
|
||||
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
|
||||
in the suffix). If split is given, then ``xxx`` is specified in txt file.
|
||||
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
|
||||
Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.
|
||||
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path. Defaults to ''.
|
||||
metainfo (dict, optional): Meta information for dataset, such as
|
||||
specify classes to load. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (dict, optional): Prefix for training data. Defaults to
|
||||
dict(img_path=None, seg_map_path=None).
|
||||
img_suffix (str): Suffix of images. Default: '.jpg'
|
||||
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
|
||||
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
||||
indices (int or Sequence[int], optional): Support using first few
|
||||
data in annotation file to facilitate training/testing on a smaller
|
||||
dataset. Defaults to None which means using all ``data_infos``.
|
||||
serialize_data (bool, optional): Whether to hold memory using
|
||||
serialized objects, when enabled, data loader workers can use
|
||||
shared RAM from master process instead of making a copy. Defaults
|
||||
to True.
|
||||
pipeline (list, optional): Processing pipeline. Defaults to [].
|
||||
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
||||
Defaults to False.
|
||||
lazy_init (bool, optional): Whether to load annotation during
|
||||
instantiation. In some cases, such as visualization, only the meta
|
||||
information of the dataset is needed, which is not necessary to
|
||||
load annotation file. ``Basedataset`` can skip load annotations to
|
||||
save time by set ``lazy_init=True``. Defaults to False.
|
||||
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
||||
None img. The maximum extra number of cycles to get a valid
|
||||
image. Defaults to 1000.
|
||||
ignore_index (int): The label index to be ignored. Default: 255
|
||||
reduce_zero_label (bool): Whether to mark label zero as ignored.
|
||||
Default to False.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
METAINFO: dict = dict()
|
||||
|
||||
def __init__(self,
|
||||
ann_file: str = '',
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
data_prefix: dict = dict(img_path='', seg_map_path=''),
|
||||
filter_cfg: Optional[dict] = None,
|
||||
indices: Optional[Union[int, Sequence[int]]] = None,
|
||||
serialize_data: bool = True,
|
||||
pipeline: List[Union[dict, Callable]] = [],
|
||||
test_mode: bool = False,
|
||||
lazy_init: bool = False,
|
||||
max_refetch: int = 1000,
|
||||
ignore_index: int = 255,
|
||||
reduce_zero_label: bool = False,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
|
||||
self.img_suffix = img_suffix
|
||||
self.seg_map_suffix = seg_map_suffix
|
||||
self.ignore_index = ignore_index
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
self.data_root = data_root
|
||||
self.data_prefix = copy.copy(data_prefix)
|
||||
self.ann_file = ann_file
|
||||
self.filter_cfg = copy.deepcopy(filter_cfg)
|
||||
self._indices = indices
|
||||
self.serialize_data = serialize_data
|
||||
self.test_mode = test_mode
|
||||
self.max_refetch = max_refetch
|
||||
self.data_list: List[dict] = []
|
||||
self.data_bytes: np.ndarray
|
||||
|
||||
# Set meta information.
|
||||
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
|
||||
|
||||
# Get label map for custom classes
|
||||
new_classes = self._metainfo.get('classes', None)
|
||||
self.label_map = self.get_label_map(new_classes)
|
||||
self._metainfo.update(
|
||||
dict(
|
||||
label_map=self.label_map,
|
||||
reduce_zero_label=self.reduce_zero_label))
|
||||
|
||||
# Update palette based on label map or generate palette
|
||||
# if it is not defined
|
||||
updated_palette = self._update_palette()
|
||||
self._metainfo.update(dict(palette=updated_palette))
|
||||
|
||||
# Join paths.
|
||||
if self.data_root is not None:
|
||||
self._join_prefix()
|
||||
|
||||
# Build pipeline.
|
||||
self.pipeline = Compose(pipeline)
|
||||
# Full initialize the dataset.
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
if test_mode:
|
||||
assert self._metainfo.get('classes') is not None, \
|
||||
'dataset metainfo `classes` should be specified when testing'
|
||||
|
||||
@classmethod
|
||||
def get_label_map(cls,
|
||||
new_classes: Optional[Sequence] = None
|
||||
) -> Union[Dict, None]:
|
||||
"""Require label mapping.
|
||||
|
||||
The ``label_map`` is a dictionary, its keys are the old label ids and
|
||||
its values are the new label ids, and is used for changing pixel
|
||||
labels in load_annotations. If and only if old classes in cls.METAINFO
|
||||
is not equal to new classes in self._metainfo and nether of them is not
|
||||
None, `label_map` is not None.
|
||||
|
||||
Args:
|
||||
new_classes (list, tuple, optional): The new classes name from
|
||||
metainfo. Default to None.
|
||||
|
||||
|
||||
Returns:
|
||||
dict, optional: The mapping from old classes in cls.METAINFO to
|
||||
new classes in self._metainfo
|
||||
"""
|
||||
old_classes = cls.METAINFO.get('classes', None)
|
||||
if (new_classes is not None and old_classes is not None
|
||||
and list(new_classes) != list(old_classes)):
|
||||
|
||||
label_map = {}
|
||||
if not set(new_classes).issubset(cls.METAINFO['classes']):
|
||||
raise ValueError(
|
||||
f'new classes {new_classes} is not a '
|
||||
f'subset of classes {old_classes} in METAINFO.')
|
||||
for i, c in enumerate(old_classes):
|
||||
if c not in new_classes:
|
||||
label_map[i] = 255
|
||||
else:
|
||||
label_map[i] = new_classes.index(c)
|
||||
return label_map
|
||||
else:
|
||||
return None
|
||||
|
||||
def _update_palette(self) -> list:
|
||||
"""Update palette after loading metainfo.
|
||||
|
||||
If length of palette is equal to classes, just return the palette.
|
||||
If palette is not defined, it will randomly generate a palette.
|
||||
If classes is updated by customer, it will return the subset of
|
||||
palette.
|
||||
|
||||
Returns:
|
||||
Sequence: Palette for current dataset.
|
||||
"""
|
||||
palette = self._metainfo.get('palette', [])
|
||||
classes = self._metainfo.get('classes', [])
|
||||
# palette does match classes
|
||||
if len(palette) == len(classes):
|
||||
return palette
|
||||
|
||||
if len(palette) == 0:
|
||||
# Get random state before set seed, and restore
|
||||
# random state later.
|
||||
# It will prevent loss of randomness, as the palette
|
||||
# may be different in each iteration if not specified.
|
||||
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
||||
state = np.random.get_state()
|
||||
np.random.seed(42)
|
||||
# random palette
|
||||
new_palette = np.random.randint(
|
||||
0, 255, size=(len(classes), 3)).tolist()
|
||||
np.random.set_state(state)
|
||||
elif len(palette) >= len(classes) and self.label_map is not None:
|
||||
new_palette = []
|
||||
# return subset of palette
|
||||
for old_id, new_id in sorted(
|
||||
self.label_map.items(), key=lambda x: x[1]):
|
||||
if new_id != 255:
|
||||
new_palette.append(palette[old_id])
|
||||
new_palette = type(palette)(new_palette)
|
||||
else:
|
||||
raise ValueError('palette does not match classes '
|
||||
f'as metainfo is {self._metainfo}.')
|
||||
return new_palette
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
img_dir = self.data_prefix.get('img_path', None)
|
||||
ann_dir = self.data_prefix.get('seg_map_path', None)
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = mmengine.list_from_file(
|
||||
self.ann_file, backend_args=self.backend_args)
|
||||
for line in lines:
|
||||
img_name = line.strip()
|
||||
data_info = dict(
|
||||
img_path=osp.join(img_dir, img_name + self.img_suffix))
|
||||
if ann_dir is not None:
|
||||
seg_map = img_name + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
else:
|
||||
_suffix_len = len(self.img_suffix)
|
||||
for img in fileio.list_dir_or_file(
|
||||
dir_path=img_dir,
|
||||
list_dir=False,
|
||||
suffix=self.img_suffix,
|
||||
recursive=True,
|
||||
backend_args=self.backend_args):
|
||||
data_info = dict(img_path=osp.join(img_dir, img))
|
||||
if ann_dir is not None:
|
||||
seg_map = img[:-_suffix_len] + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class BaseCDDataset(BaseDataset):
|
||||
"""Custom dataset for change detection. An example of file structure is as
|
||||
followed.
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
├── data
|
||||
│ ├── my_dataset
|
||||
│ │ ├── img_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{img_suffix}
|
||||
│ │ │ │ ├── yyy{img_suffix}
|
||||
│ │ │ │ ├── zzz{img_suffix}
|
||||
│ │ │ ├── val
|
||||
│ │ ├── img_dir2
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{img_suffix}
|
||||
│ │ │ │ ├── yyy{img_suffix}
|
||||
│ │ │ │ ├── zzz{img_suffix}
|
||||
│ │ │ ├── val
|
||||
│ │ ├── ann_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{seg_map_suffix}
|
||||
│ │ │ │ ├── yyy{seg_map_suffix}
|
||||
│ │ │ │ ├── zzz{seg_map_suffix}
|
||||
│ │ │ ├── val
|
||||
|
||||
The image names in img_dir and img_dir2 should be consistent.
|
||||
The img/gt_semantic_seg pair of BaseSegDataset should be of the same
|
||||
except suffix. A valid img/gt_semantic_seg filename pair should be like
|
||||
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
|
||||
in the suffix). If split is given, then ``xxx`` is specified in txt file.
|
||||
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
|
||||
Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.
|
||||
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path. Defaults to ''.
|
||||
metainfo (dict, optional): Meta information for dataset, such as
|
||||
specify classes to load. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (dict, optional): Prefix for training data. Defaults to
|
||||
dict(img_path=None, img_path2=None, seg_map_path=None).
|
||||
img_suffix (str): Suffix of images. Default: '.jpg'
|
||||
img_suffix2 (str): Suffix of images. Default: '.jpg'
|
||||
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
|
||||
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
||||
indices (int or Sequence[int], optional): Support using first few
|
||||
data in annotation file to facilitate training/testing on a smaller
|
||||
dataset. Defaults to None which means using all ``data_infos``.
|
||||
serialize_data (bool, optional): Whether to hold memory using
|
||||
serialized objects, when enabled, data loader workers can use
|
||||
shared RAM from master process instead of making a copy. Defaults
|
||||
to True.
|
||||
pipeline (list, optional): Processing pipeline. Defaults to [].
|
||||
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
||||
Defaults to False.
|
||||
lazy_init (bool, optional): Whether to load annotation during
|
||||
instantiation. In some cases, such as visualization, only the meta
|
||||
information of the dataset is needed, which is not necessary to
|
||||
load annotation file. ``Basedataset`` can skip load annotations to
|
||||
save time by set ``lazy_init=True``. Defaults to False.
|
||||
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
||||
None img. The maximum extra number of cycles to get a valid
|
||||
image. Defaults to 1000.
|
||||
ignore_index (int): The label index to be ignored. Default: 255
|
||||
reduce_zero_label (bool): Whether to mark label zero as ignored.
|
||||
Default to False.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
METAINFO: dict = dict()
|
||||
|
||||
def __init__(self,
|
||||
ann_file: str = '',
|
||||
img_suffix='.jpg',
|
||||
img_suffix2='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
data_prefix: dict = dict(
|
||||
img_path='', img_path2='', seg_map_path=''),
|
||||
filter_cfg: Optional[dict] = None,
|
||||
indices: Optional[Union[int, Sequence[int]]] = None,
|
||||
serialize_data: bool = True,
|
||||
pipeline: List[Union[dict, Callable]] = [],
|
||||
test_mode: bool = False,
|
||||
lazy_init: bool = False,
|
||||
max_refetch: int = 1000,
|
||||
ignore_index: int = 255,
|
||||
reduce_zero_label: bool = False,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
|
||||
self.img_suffix = img_suffix
|
||||
self.img_suffix2 = img_suffix2
|
||||
self.seg_map_suffix = seg_map_suffix
|
||||
self.ignore_index = ignore_index
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
self.data_root = data_root
|
||||
self.data_prefix = copy.copy(data_prefix)
|
||||
self.ann_file = ann_file
|
||||
self.filter_cfg = copy.deepcopy(filter_cfg)
|
||||
self._indices = indices
|
||||
self.serialize_data = serialize_data
|
||||
self.test_mode = test_mode
|
||||
self.max_refetch = max_refetch
|
||||
self.data_list: List[dict] = []
|
||||
self.data_bytes: np.ndarray
|
||||
|
||||
# Set meta information.
|
||||
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
|
||||
|
||||
# Get label map for custom classes
|
||||
new_classes = self._metainfo.get('classes', None)
|
||||
self.label_map = self.get_label_map(new_classes)
|
||||
self._metainfo.update(
|
||||
dict(
|
||||
label_map=self.label_map,
|
||||
reduce_zero_label=self.reduce_zero_label))
|
||||
|
||||
# Update palette based on label map or generate palette
|
||||
# if it is not defined
|
||||
updated_palette = self._update_palette()
|
||||
self._metainfo.update(dict(palette=updated_palette))
|
||||
|
||||
# Join paths.
|
||||
if self.data_root is not None:
|
||||
self._join_prefix()
|
||||
|
||||
# Build pipeline.
|
||||
self.pipeline = Compose(pipeline)
|
||||
# Full initialize the dataset.
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
if test_mode:
|
||||
assert self._metainfo.get('classes') is not None, \
|
||||
'dataset metainfo `classes` should be specified when testing'
|
||||
|
||||
@classmethod
|
||||
def get_label_map(cls,
|
||||
new_classes: Optional[Sequence] = None
|
||||
) -> Union[Dict, None]:
|
||||
"""Require label mapping.
|
||||
|
||||
The ``label_map`` is a dictionary, its keys are the old label ids and
|
||||
its values are the new label ids, and is used for changing pixel
|
||||
labels in load_annotations. If and only if old classes in cls.METAINFO
|
||||
is not equal to new classes in self._metainfo and nether of them is not
|
||||
None, `label_map` is not None.
|
||||
|
||||
Args:
|
||||
new_classes (list, tuple, optional): The new classes name from
|
||||
metainfo. Default to None.
|
||||
|
||||
|
||||
Returns:
|
||||
dict, optional: The mapping from old classes in cls.METAINFO to
|
||||
new classes in self._metainfo
|
||||
"""
|
||||
old_classes = cls.METAINFO.get('classes', None)
|
||||
if (new_classes is not None and old_classes is not None
|
||||
and list(new_classes) != list(old_classes)):
|
||||
|
||||
label_map = {}
|
||||
if not set(new_classes).issubset(cls.METAINFO['classes']):
|
||||
raise ValueError(
|
||||
f'new classes {new_classes} is not a '
|
||||
f'subset of classes {old_classes} in METAINFO.')
|
||||
for i, c in enumerate(old_classes):
|
||||
if c not in new_classes:
|
||||
label_map[i] = 255
|
||||
else:
|
||||
label_map[i] = new_classes.index(c)
|
||||
return label_map
|
||||
else:
|
||||
return None
|
||||
|
||||
def _update_palette(self) -> list:
|
||||
"""Update palette after loading metainfo.
|
||||
|
||||
If length of palette is equal to classes, just return the palette.
|
||||
If palette is not defined, it will randomly generate a palette.
|
||||
If classes is updated by customer, it will return the subset of
|
||||
palette.
|
||||
|
||||
Returns:
|
||||
Sequence: Palette for current dataset.
|
||||
"""
|
||||
palette = self._metainfo.get('palette', [])
|
||||
classes = self._metainfo.get('classes', [])
|
||||
# palette does match classes
|
||||
if len(palette) == len(classes):
|
||||
return palette
|
||||
|
||||
if len(palette) == 0:
|
||||
# Get random state before set seed, and restore
|
||||
# random state later.
|
||||
# It will prevent loss of randomness, as the palette
|
||||
# may be different in each iteration if not specified.
|
||||
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
||||
state = np.random.get_state()
|
||||
np.random.seed(42)
|
||||
# random palette
|
||||
new_palette = np.random.randint(
|
||||
0, 255, size=(len(classes), 3)).tolist()
|
||||
np.random.set_state(state)
|
||||
elif len(palette) >= len(classes) and self.label_map is not None:
|
||||
new_palette = []
|
||||
# return subset of palette
|
||||
for old_id, new_id in sorted(
|
||||
self.label_map.items(), key=lambda x: x[1]):
|
||||
if new_id != 255:
|
||||
new_palette.append(palette[old_id])
|
||||
new_palette = type(palette)(new_palette)
|
||||
else:
|
||||
raise ValueError('palette does not match classes '
|
||||
f'as metainfo is {self._metainfo}.')
|
||||
return new_palette
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
img_dir = self.data_prefix.get('img_path', None)
|
||||
img_dir2 = self.data_prefix.get('img_path2', None)
|
||||
ann_dir = self.data_prefix.get('seg_map_path', None)
|
||||
if osp.isfile(self.ann_file):
|
||||
lines = mmengine.list_from_file(
|
||||
self.ann_file, backend_args=self.backend_args)
|
||||
for line in lines:
|
||||
img_name = line.strip()
|
||||
if '.' in osp.basename(img_name):
|
||||
img_name, img_ext = osp.splitext(img_name)
|
||||
self.img_suffix = img_ext
|
||||
self.img_suffix2 = img_ext
|
||||
data_info = dict(
|
||||
img_path=osp.join(img_dir, img_name + self.img_suffix),
|
||||
img_path2=osp.join(img_dir2, img_name + self.img_suffix2))
|
||||
|
||||
if ann_dir is not None:
|
||||
seg_map = img_name + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
else:
|
||||
for img in fileio.list_dir_or_file(
|
||||
dir_path=img_dir,
|
||||
list_dir=False,
|
||||
suffix=self.img_suffix,
|
||||
recursive=True,
|
||||
backend_args=self.backend_args):
|
||||
if '.' in osp.basename(img):
|
||||
img, img_ext = osp.splitext(img)
|
||||
self.img_suffix = img_ext
|
||||
self.img_suffix2 = img_ext
|
||||
data_info = dict(
|
||||
img_path=osp.join(img_dir, img + self.img_suffix),
|
||||
img_path2=osp.join(img_dir2, img + self.img_suffix2))
|
||||
if ann_dir is not None:
|
||||
seg_map = img + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
65
finetune/mmseg/datasets/c2sfloods.py
Normal file
65
finetune/mmseg/datasets/c2sfloods.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
# LEGEND = [
|
||||
# 255 255 255; % Background
|
||||
# 0 0 0; % Roads
|
||||
# 100 100 100; % Buildings
|
||||
# 0 125 0; % Trees
|
||||
# 0 255 0; % Grass
|
||||
# 150 80 0; % Bare Soil
|
||||
# 0 0 150; % Water
|
||||
# 255 255 0; % Railways
|
||||
# 150 150 255]; % Swimming Pools
|
||||
|
||||
@DATASETS.register_module()
|
||||
class C2SFloodDataset(BaseSegDataset):
|
||||
"""Zurich dataset.
|
||||
|
||||
In segmentation map annotation for LoveDA, 0 is the ignore index.
|
||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('Background', 'Water', 'Cloud', 'Cloud shadow'),
|
||||
palette=[[0,0,0], [255,255,255], [255,0,0], [0,255,0]]
|
||||
)
|
||||
def __init__(self,
|
||||
img_suffix='.npz',
|
||||
seg_map_suffix='.npz',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = json.load(open(self.ann_file))
|
||||
for line in lines:
|
||||
# img_name = line.strip()
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, line['s2_path']))
|
||||
if 'target_path' in line.keys():
|
||||
data_info['seg_map_path'] = osp.join(self.data_root, line['target_path'])
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
54
finetune/mmseg/datasets/cabuar.py
Normal file
54
finetune/mmseg/datasets/cabuar.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
@DATASETS.register_module()
|
||||
class CABURADataset(BaseSegDataset):
|
||||
"""Zurich dataset.
|
||||
|
||||
In segmentation map annotation for LoveDA, 0 is the ignore index.
|
||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('Background', 'Burned area'),
|
||||
palette=[[0,0,0], [255,255,255]]
|
||||
)
|
||||
def __init__(self,
|
||||
img_suffix='.npz',
|
||||
seg_map_suffix='.npz',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = json.load(open(self.ann_file))
|
||||
for line in lines:
|
||||
# img_name = line.strip()
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, line['post_fire']))
|
||||
if 'mask' in line.keys():
|
||||
data_info['seg_map_path'] = osp.join(self.data_root, line['mask'])
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
72
finetune/mmseg/datasets/germany.py
Normal file
72
finetune/mmseg/datasets/germany.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
from mmengine.logging import print_log
|
||||
import pandas as pd
|
||||
# LEGEND = [
|
||||
# 255 255 255; % Background
|
||||
# 0 0 0; % Roads
|
||||
# 100 100 100; % Buildings
|
||||
# 0 125 0; % Trees
|
||||
# 0 255 0; % Grass
|
||||
# 150 80 0; % Bare Soil
|
||||
# 0 0 150; % Water
|
||||
# 255 255 0; % Railways
|
||||
# 150 150 255]; % Swimming Pools
|
||||
|
||||
@DATASETS.register_module()
|
||||
class GermanyCropDataset(BaseSegDataset):
|
||||
"""Zurich dataset.
|
||||
|
||||
In segmentation map annotation for LoveDA, 0 is the ignore index.
|
||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
# {0: "unknown", 1: "sugar_beet", 2: "summer_oat", 3: "meadow", 5: "rape", 8: "hop",
|
||||
# 9: "winter_spelt", 12: "winter_triticale", 13: "beans", 15: "peas", 16: "potatoes",
|
||||
# 17: "soybeans", 19: "asparagus", 22: "winter_wheat", 23: "winter_barley", 24: "winter_rye",
|
||||
# 25: "summer_barley", 26: "maize"}
|
||||
METAINFO = dict(
|
||||
classes=('sugar_beet', 'summer_oat', 'meadow', 'rape', 'hop', 'winter_spelt', 'winter_triticale', 'beans', 'peas',\
|
||||
'potatoes', 'soybeans', 'asparagus', 'winter_wheat', 'winter_barley', 'winter_rye', 'summer_barley', 'maize'),
|
||||
palette=[(255, 255, 255), (255, 255, 170), (255, 255, 85), (255, 170, 255), (255, 170, 170), (255, 170, 85), \
|
||||
(255, 85, 255), (255, 85, 170), (255, 85, 85), (170, 255, 255), (170, 255, 170), (170, 255, 85), (170, 170, 255), \
|
||||
(170, 170, 170), (170, 170, 85), (170, 85, 255), (170, 85, 170)])
|
||||
def __init__(self,
|
||||
img_suffix='.pickle',
|
||||
seg_map_suffix='.pickle',
|
||||
reduce_zero_label=True,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = json.load(open(self.ann_file))
|
||||
print_log(f'dataset count: {len(lines)}')
|
||||
for line in lines:
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, line['s2_path']))
|
||||
if 'target_path' in line.keys():
|
||||
data_info['seg_map_path'] = osp.join(self.data_root, line['target_path'])
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
66
finetune/mmseg/datasets/sos.py
Normal file
66
finetune/mmseg/datasets/sos.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
# LEGEND = [
|
||||
# 255 255 255; % Background
|
||||
# 0 0 0; % Roads
|
||||
# 100 100 100; % Buildings
|
||||
# 0 125 0; % Trees
|
||||
# 0 255 0; % Grass
|
||||
# 150 80 0; % Bare Soil
|
||||
# 0 0 150; % Water
|
||||
# 255 255 0; % Railways
|
||||
# 150 150 255]; % Swimming Pools
|
||||
|
||||
@DATASETS.register_module()
|
||||
class SOSDataset(BaseSegDataset):
|
||||
"""Zurich dataset.
|
||||
|
||||
In segmentation map annotation for LoveDA, 0 is the ignore index.
|
||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('Background', 'Oil Spill Area'),
|
||||
palette=[[0,0,0], [255,255,255]]
|
||||
)
|
||||
def __init__(self,
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = json.load(open(self.ann_file))
|
||||
for line in lines:
|
||||
# img_name = line.strip()
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, line['s1_path']))
|
||||
if 'target_path' in line.keys():
|
||||
data_info['seg_map_path'] = osp.join(self.data_root, line['target_path'])
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
# print(data_list)
|
||||
return data_list
|
||||
32
finetune/mmseg/datasets/transforms/__init__.py
Normal file
32
finetune/mmseg/datasets/transforms/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .formatting import PackSegInputs
|
||||
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||
LoadDepthAnnotation, LoadImageFromNDArray,
|
||||
LoadMultipleRSImageFromFile, LoadSingleRSImageFromFile)
|
||||
# yapf: disable
|
||||
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
|
||||
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
|
||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||
BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomDepthMix, RandomFlip, RandomMosaic,
|
||||
RandomRotate, RandomRotFlip, Rerange, Resize,
|
||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||
SegRescale)
|
||||
from .loading_npz import (LoadAnnotationsNpz, LoadImageFromNpz, LoadTsImageFromNpz, LoadAnnotationsOil, LoadImageOil, LoadImageSingleChannel)
|
||||
|
||||
# yapf: enable
|
||||
__all__ = [
|
||||
'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale',
|
||||
'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange',
|
||||
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
|
||||
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||
'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad',
|
||||
'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput',
|
||||
'LoadMultipleRSImageFromFile', 'LoadDepthAnnotation', 'RandomDepthMix',
|
||||
'RandomFlip', 'Resize', 'LoadAnnotationsNpz', 'LoadImageFromNpz', 'LoadTsImageFromNpz',
|
||||
'LoadAnnotationsOil', 'LoadImageOil', 'LoadImageSingleChannel'
|
||||
]
|
||||
112
finetune/mmseg/datasets/transforms/formatting.py
Normal file
112
finetune/mmseg/datasets/transforms/formatting.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from mmcv.transforms import to_tensor
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class PackSegInputs(BaseTransform):
|
||||
"""Pack the inputs data for the semantic segmentation.
|
||||
|
||||
The ``img_meta`` item is always populated. The contents of the
|
||||
``img_meta`` dictionary depends on ``meta_keys``. By default this includes:
|
||||
|
||||
- ``img_path``: filename of the image
|
||||
|
||||
- ``ori_shape``: original shape of the image as a tuple (h, w, c)
|
||||
|
||||
- ``img_shape``: shape of the image input to the network as a tuple \
|
||||
(h, w, c). Note that images may be zero padded on the \
|
||||
bottom/right if the batch tensor is larger than this shape.
|
||||
|
||||
- ``pad_shape``: shape of padded images
|
||||
|
||||
- ``scale_factor``: a float indicating the preprocessing scale
|
||||
|
||||
- ``flip``: a boolean indicating if image flip transform was used
|
||||
|
||||
- ``flip_direction``: the flipping direction
|
||||
|
||||
Args:
|
||||
meta_keys (Sequence[str], optional): Meta keys to be packed from
|
||||
``SegDataSample`` and collected in ``data[img_metas]``.
|
||||
Default: ``('img_path', 'ori_shape',
|
||||
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction')``
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
meta_keys=('img_path', 'seg_map_path', 'ori_shape',
|
||||
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction', 'reduce_zero_label')):
|
||||
self.meta_keys = meta_keys
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Method to pack the input data.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from the data pipeline.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
|
||||
- 'inputs' (obj:`torch.Tensor`): The forward data of models.
|
||||
- 'data_sample' (obj:`SegDataSample`): The annotation info of the
|
||||
sample.
|
||||
"""
|
||||
packed_results = dict()
|
||||
if 'img' in results:
|
||||
img = results['img']
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
if not img.flags.c_contiguous:
|
||||
img = to_tensor(np.ascontiguousarray(img.transpose(2, 0, 1)))
|
||||
else:
|
||||
img = img.transpose(2, 0, 1)
|
||||
img = to_tensor(img).contiguous()
|
||||
packed_results['inputs'] = img
|
||||
|
||||
data_sample = SegDataSample()
|
||||
if 'gt_seg_map' in results:
|
||||
if len(results['gt_seg_map'].shape) == 2:
|
||||
data = to_tensor(results['gt_seg_map'][None,
|
||||
...].astype(np.int64))
|
||||
else:
|
||||
warnings.warn('Please pay attention your ground truth '
|
||||
'segmentation map, usually the segmentation '
|
||||
'map is 2D, but got '
|
||||
f'{results["gt_seg_map"].shape}')
|
||||
data = to_tensor(results['gt_seg_map'].astype(np.int64))
|
||||
gt_sem_seg_data = dict(data=data)
|
||||
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
|
||||
if 'gt_edge_map' in results:
|
||||
gt_edge_data = dict(
|
||||
data=to_tensor(results['gt_edge_map'][None,
|
||||
...].astype(np.int64)))
|
||||
data_sample.set_data(dict(gt_edge_map=PixelData(**gt_edge_data)))
|
||||
|
||||
if 'gt_depth_map' in results:
|
||||
gt_depth_data = dict(
|
||||
data=to_tensor(results['gt_depth_map'][None, ...]))
|
||||
data_sample.set_data(dict(gt_depth_map=PixelData(**gt_depth_data)))
|
||||
|
||||
img_meta = {}
|
||||
for key in self.meta_keys:
|
||||
if key in results:
|
||||
img_meta[key] = results[key]
|
||||
data_sample.set_metainfo(img_meta)
|
||||
packed_results['data_samples'] = data_sample
|
||||
|
||||
return packed_results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(meta_keys={self.meta_keys})'
|
||||
return repr_str
|
||||
771
finetune/mmseg/datasets/transforms/loading.py
Normal file
771
finetune/mmseg/datasets/transforms/loading.py
Normal file
@@ -0,0 +1,771 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine.fileio as fileio
|
||||
import numpy as np
|
||||
from mmcv.transforms import BaseTransform
|
||||
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
|
||||
from mmcv.transforms import LoadImageFromFile
|
||||
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from mmseg.utils import datafrombytes
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
except ImportError:
|
||||
gdal = None
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadAnnotations(MMCV_LoadAnnotations):
|
||||
"""Load annotations for semantic segmentation provided by dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# Filename of semantic segmentation ground truth file.
|
||||
'seg_map_path': 'a/b/c'
|
||||
}
|
||||
|
||||
After this module, the annotation has been changed to the format below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# in str
|
||||
'seg_fields': List
|
||||
# In uint8 type.
|
||||
'gt_seg_map': np.ndarray (H, W)
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_map_path (str): Path of semantic segmentation ground truth file.
|
||||
|
||||
Added Keys:
|
||||
|
||||
- seg_fields (List)
|
||||
- gt_seg_map (np.uint8)
|
||||
|
||||
Args:
|
||||
reduce_zero_label (bool, optional): Whether reduce all label value
|
||||
by 1. Usually used for datasets where 0 is background label.
|
||||
Defaults to None.
|
||||
imdecode_backend (str): The image decoding backend type. The backend
|
||||
argument for :func:``mmcv.imfrombytes``.
|
||||
See :fun:``mmcv.imfrombytes`` for details.
|
||||
Defaults to 'pillow'.
|
||||
backend_args (dict): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduce_zero_label=None,
|
||||
backend_args=None,
|
||||
imdecode_backend='pillow',
|
||||
) -> None:
|
||||
super().__init__(
|
||||
with_bbox=False,
|
||||
with_label=False,
|
||||
with_seg=True,
|
||||
with_keypoints=False,
|
||||
imdecode_backend=imdecode_backend,
|
||||
backend_args=backend_args)
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
if self.reduce_zero_label is not None:
|
||||
warnings.warn('`reduce_zero_label` will be deprecated, '
|
||||
'if you would like to ignore the zero label, please '
|
||||
'set `reduce_zero_label=True` when dataset '
|
||||
'initialized')
|
||||
self.imdecode_backend = imdecode_backend
|
||||
|
||||
def _load_seg_map(self, results: dict) -> None:
|
||||
"""Private function to load semantic segmentation annotations.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded semantic segmentation annotations.
|
||||
"""
|
||||
|
||||
img_bytes = fileio.get(
|
||||
results['seg_map_path'], backend_args=self.backend_args)
|
||||
gt_semantic_seg = mmcv.imfrombytes(
|
||||
img_bytes, flag='unchanged',
|
||||
backend=self.imdecode_backend).squeeze().astype(np.uint8)
|
||||
|
||||
# reduce zero_label
|
||||
if self.reduce_zero_label is None:
|
||||
self.reduce_zero_label = results['reduce_zero_label']
|
||||
assert self.reduce_zero_label == results['reduce_zero_label'], \
|
||||
'Initialize dataset with `reduce_zero_label` as ' \
|
||||
f'{results["reduce_zero_label"]} but when load annotation ' \
|
||||
f'the `reduce_zero_label` is {self.reduce_zero_label}'
|
||||
if self.reduce_zero_label:
|
||||
# avoid using underflow conversion
|
||||
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
||||
gt_semantic_seg = gt_semantic_seg - 1
|
||||
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
||||
# modify if custom classes
|
||||
if results.get('label_map', None) is not None:
|
||||
# Add deep copy to solve bug of repeatedly
|
||||
# replace `gt_semantic_seg`, which is reported in
|
||||
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
|
||||
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
||||
for old_id, new_id in results['label_map'].items():
|
||||
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
||||
results['gt_seg_map'] = gt_semantic_seg
|
||||
results['seg_fields'].append('gt_seg_map')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(reduce_zero_label={self.reduce_zero_label}, '
|
||||
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
|
||||
repr_str += f'backend_args={self.backend_args})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromNDArray(LoadImageFromFile):
|
||||
"""Load an image from ``results['img']``.
|
||||
|
||||
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
|
||||
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
|
||||
from webcam.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_path
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Transform function to add image meta information.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict with Webcam read image in
|
||||
``results['img']``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
img = results['img']
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img_path'] = None
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadBiomedicalImageFromFile(BaseTransform):
|
||||
"""Load an biomedical mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default,
|
||||
N is the number of modalities, and data type is float32
|
||||
if set to_float32 = True, or float64 if decode_backend is 'nifti' and
|
||||
to_float32 is False.
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy'and 'nifti', and there is a convention that when backend is
|
||||
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||
backend is 'nifti'. Defaults to 'nifti'.
|
||||
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||
Defaults to False.
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an float64 array.
|
||||
Defaults to True.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
decode_backend: str = 'nifti',
|
||||
to_xyz: bool = False,
|
||||
to_float32: bool = True,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
self.decode_backend = decode_backend
|
||||
self.to_xyz = to_xyz
|
||||
self.to_float32 = to_float32
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
|
||||
data_bytes = fileio.get(filename, self.backend_args)
|
||||
img = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
if len(img.shape) == 3:
|
||||
img = img[None, ...]
|
||||
|
||||
if self.decode_backend == 'nifti':
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
if self.to_xyz:
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[1:]
|
||||
results['ori_shape'] = img.shape[1:]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_xyz={self.to_xyz}, '
|
||||
f'to_float32={self.to_float32}, '
|
||||
f'backend_args={self.backend_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadBiomedicalAnnotation(BaseTransform):
|
||||
"""Load ``seg_map`` annotation provided by biomedical dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'gt_seg_map': np.ndarray (X, Y, Z) or (Z, Y, X)
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_map_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- gt_seg_map (np.ndarray): Biomedical seg map with shape (Z, Y, X) by
|
||||
default, and data type is float32 if set to_float32 = True, or
|
||||
float64 if decode_backend is 'nifti' and to_float32 is False.
|
||||
|
||||
Args:
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy'and 'nifti', and there is a convention that when backend is
|
||||
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||
backend is 'nifti'. Defaults to 'nifti'.
|
||||
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||
Defaults to False.
|
||||
to_float32 (bool): Whether to convert the loaded seg map to a float32
|
||||
numpy array. If set to False, the loaded image is an float64 array.
|
||||
Defaults to True.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See :class:`mmengine.fileio` for details.
|
||||
Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
decode_backend: str = 'nifti',
|
||||
to_xyz: bool = False,
|
||||
to_float32: bool = True,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
super().__init__()
|
||||
self.decode_backend = decode_backend
|
||||
self.to_xyz = to_xyz
|
||||
self.to_float32 = to_float32
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
data_bytes = fileio.get(results['seg_map_path'], self.backend_args)
|
||||
gt_seg_map = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
|
||||
if self.to_float32:
|
||||
gt_seg_map = gt_seg_map.astype(np.float32)
|
||||
|
||||
if self.decode_backend == 'nifti':
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
|
||||
if self.to_xyz:
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
|
||||
results['gt_seg_map'] = gt_seg_map
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_xyz={self.to_xyz}, '
|
||||
f'to_float32={self.to_float32}, '
|
||||
f'backend_args={self.backend_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadBiomedicalData(BaseTransform):
|
||||
"""Load an biomedical image and annotation from file.
|
||||
|
||||
The loading data format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'img': np.ndarray data[:-1, X, Y, Z]
|
||||
'seg_map': np.ndarray data[-1, X, Y, Z]
|
||||
}
|
||||
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default,
|
||||
N is the number of modalities.
|
||||
- gt_seg_map (np.ndarray, optional): Biomedical seg map with shape
|
||||
(Z, Y, X) by default.
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
with_seg (bool): Whether to parse and load the semantic segmentation
|
||||
annotation. Defaults to False.
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy'and 'nifti', and there is a convention that when backend is
|
||||
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||
backend is 'nifti'. Defaults to 'nifti'.
|
||||
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||
Defaults to False.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
with_seg=False,
|
||||
decode_backend: str = 'numpy',
|
||||
to_xyz: bool = False,
|
||||
backend_args: Optional[dict] = None) -> None: # noqa
|
||||
self.with_seg = with_seg
|
||||
self.decode_backend = decode_backend
|
||||
self.to_xyz = to_xyz
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
data_bytes = fileio.get(results['img_path'], self.backend_args)
|
||||
data = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
# img is 4D data (N, X, Y, Z), N is the number of protocol
|
||||
img = data[:-1, :]
|
||||
|
||||
if self.decode_backend == 'nifti':
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
if self.to_xyz:
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[1:]
|
||||
results['ori_shape'] = img.shape[1:]
|
||||
|
||||
if self.with_seg:
|
||||
gt_seg_map = data[-1, :]
|
||||
if self.decode_backend == 'nifti':
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
|
||||
if self.to_xyz:
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
results['gt_seg_map'] = gt_seg_map
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'with_seg={self.with_seg}, '
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_xyz={self.to_xyz}, '
|
||||
f'backend_args={self.backend_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class InferencerLoader(BaseTransform):
|
||||
"""Load an image from ``results['img']``.
|
||||
|
||||
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
|
||||
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
|
||||
from webcam.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_path
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.from_file = TRANSFORMS.build(
|
||||
dict(type='LoadImageFromFile', **kwargs))
|
||||
self.from_ndarray = TRANSFORMS.build(
|
||||
dict(type='LoadImageFromNDArray', **kwargs))
|
||||
|
||||
def transform(self, single_input: Union[str, np.ndarray, dict]) -> dict:
|
||||
"""Transform function to add image meta information.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict with Webcam read image in
|
||||
``results['img']``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
if isinstance(single_input, str):
|
||||
inputs = dict(img_path=single_input)
|
||||
elif isinstance(single_input, np.ndarray):
|
||||
inputs = dict(img=single_input)
|
||||
elif isinstance(single_input, dict):
|
||||
inputs = single_input
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if 'img' in inputs:
|
||||
return self.from_ndarray(inputs)
|
||||
return self.from_file(inputs)
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadSingleRSImageFromFile(BaseTransform):
|
||||
"""Load a Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, to_float32: bool = True):
|
||||
self.to_float32 = to_float32
|
||||
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
ds = gdal.Open(filename)
|
||||
if ds is None:
|
||||
raise Exception(f'Unable to open file: {filename}')
|
||||
img = np.einsum('ijk->jki', ds.ReadAsArray())
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadMultipleRSImageFromFile(BaseTransform):
|
||||
"""Load two Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
- img_path2
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img2
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, to_float32: bool = True):
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
self.to_float32 = to_float32
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
filename2 = results['img_path2']
|
||||
|
||||
ds = gdal.Open(filename)
|
||||
ds2 = gdal.Open(filename2)
|
||||
|
||||
if ds is None:
|
||||
raise Exception(f'Unable to open file: {filename}')
|
||||
if ds2 is None:
|
||||
raise Exception(f'Unable to open file: {filename2}')
|
||||
|
||||
img = np.einsum('ijk->jki', ds.ReadAsArray())
|
||||
img2 = np.einsum('ijk->jki', ds2.ReadAsArray())
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
img2 = img2.astype(np.float32)
|
||||
|
||||
if img.shape != img2.shape:
|
||||
raise Exception(f'Image shapes do not match:'
|
||||
f' {img.shape} vs {img2.shape}')
|
||||
|
||||
results['img'] = img
|
||||
results['img2'] = img2
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadDepthAnnotation(BaseTransform):
|
||||
"""Load ``depth_map`` annotation provided by depth estimation dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'gt_depth_map': np.ndarray [Y, X]
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_depth_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- gt_depth_map (np.ndarray): Depth map with shape (Y, X) by
|
||||
default, and data type is float32 if set to_float32 = True.
|
||||
- depth_rescale_factor (float): The rescale factor of depth map, which
|
||||
can be used to recover the original value of depth map.
|
||||
|
||||
Args:
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy', 'nifti', and 'cv2'. Defaults to 'cv2'.
|
||||
to_float32 (bool): Whether to convert the loaded depth map to a float32
|
||||
numpy array. If set to False, the loaded image is an uint16 array.
|
||||
Defaults to True.
|
||||
depth_rescale_factor (float): Factor to rescale the depth value to
|
||||
limit the range. Defaults to 1.0.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See :class:`mmengine.fileio` for details.
|
||||
Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
decode_backend: str = 'cv2',
|
||||
to_float32: bool = True,
|
||||
depth_rescale_factor: float = 1.0,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
super().__init__()
|
||||
self.decode_backend = decode_backend
|
||||
self.to_float32 = to_float32
|
||||
self.depth_rescale_factor = depth_rescale_factor
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load depth map.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded depth map.
|
||||
"""
|
||||
data_bytes = fileio.get(results['depth_map_path'], self.backend_args)
|
||||
gt_depth_map = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
|
||||
if self.to_float32:
|
||||
gt_depth_map = gt_depth_map.astype(np.float32)
|
||||
|
||||
gt_depth_map *= self.depth_rescale_factor
|
||||
results['gt_depth_map'] = gt_depth_map
|
||||
results['seg_fields'].append('gt_depth_map')
|
||||
results['depth_rescale_factor'] = self.depth_rescale_factor
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_float32={self.to_float32}, '
|
||||
f'backend_args={self.backend_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromNpyFile(LoadImageFromFile):
|
||||
"""Load an image from ``results['img_path']``.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def transform(self, results: dict) -> Optional[dict]:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from
|
||||
:class:`mmengine.dataset.BaseDataset`.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
|
||||
try:
|
||||
if Path(filename).suffix in ['.npy', '.npz']:
|
||||
img = np.load(filename)
|
||||
else:
|
||||
if self.file_client_args is not None:
|
||||
file_client = fileio.FileClient.infer_client(
|
||||
self.file_client_args, filename)
|
||||
img_bytes = file_client.get(filename)
|
||||
else:
|
||||
img_bytes = fileio.get(
|
||||
filename, backend_args=self.backend_args)
|
||||
img = mmcv.imfrombytes(
|
||||
img_bytes,
|
||||
flag=self.color_type,
|
||||
backend=self.imdecode_backend)
|
||||
except Exception as e:
|
||||
if self.ignore_empty:
|
||||
return None
|
||||
else:
|
||||
raise e
|
||||
|
||||
# in some cases, images are not read successfully, the img would be
|
||||
# `None`, refer to https://github.com/open-mmlab/mmpretrain/issues/1427
|
||||
assert img is not None, f'failed to load image: {filename}'
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
493
finetune/mmseg/datasets/transforms/loading_npz.py
Normal file
493
finetune/mmseg/datasets/transforms/loading_npz.py
Normal file
@@ -0,0 +1,493 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Dict, Optional, Union
|
||||
import io
|
||||
|
||||
import mmcv
|
||||
import mmengine.fileio as fileio
|
||||
import numpy as np
|
||||
from mmcv.transforms import BaseTransform
|
||||
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
|
||||
from mmcv.transforms import LoadImageFromFile
|
||||
import imageio
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from mmseg.utils import datafrombytes
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
except ImportError:
|
||||
gdal = None
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadAnnotationsNpz(MMCV_LoadAnnotations):
|
||||
"""Load annotations for semantic segmentation provided by dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# Filename of semantic segmentation ground truth file.
|
||||
'seg_map_path': 'a/b/c'
|
||||
}
|
||||
|
||||
After this module, the annotation has been changed to the format below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# in str
|
||||
'seg_fields': List
|
||||
# In uint8 type.
|
||||
'gt_seg_map': np.ndarray (H, W)
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_map_path (str): Path of semantic segmentation ground truth file.
|
||||
|
||||
Added Keys:
|
||||
|
||||
- seg_fields (List)
|
||||
- gt_seg_map (np.uint8)
|
||||
|
||||
Args:
|
||||
reduce_zero_label (bool, optional): Whether reduce all label value
|
||||
by 1. Usually used for datasets where 0 is background label.
|
||||
Defaults to None.
|
||||
imdecode_backend (str): The image decoding backend type. The backend
|
||||
argument for :func:``mmcv.imfrombytes``.
|
||||
See :fun:``mmcv.imfrombytes`` for details.
|
||||
Defaults to 'pillow'.
|
||||
backend_args (dict): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduce_zero_label=None,
|
||||
backend_args=None,
|
||||
data_key='arr_0',
|
||||
imdecode_backend='pillow',
|
||||
) -> None:
|
||||
super().__init__(
|
||||
with_bbox=False,
|
||||
with_label=False,
|
||||
with_seg=True,
|
||||
with_keypoints=False,
|
||||
imdecode_backend=imdecode_backend,
|
||||
backend_args=backend_args)
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
if self.reduce_zero_label is not None:
|
||||
warnings.warn('`reduce_zero_label` will be deprecated, '
|
||||
'if you would like to ignore the zero label, please '
|
||||
'set `reduce_zero_label=True` when dataset '
|
||||
'initialized')
|
||||
self.imdecode_backend = imdecode_backend
|
||||
self.data_key = data_key
|
||||
|
||||
def _load_seg_map(self, results: dict) -> None:
|
||||
"""Private function to load semantic segmentation annotations.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded semantic segmentation annotations.
|
||||
"""
|
||||
|
||||
# img_bytes = fileio.get(
|
||||
# results['seg_map_path'], backend_args=self.backend_args)
|
||||
# gt_semantic_seg = mmcv.imfrombytes(
|
||||
# img_bytes, flag='unchanged',
|
||||
# backend=self.imdecode_backend).squeeze().astype(np.uint8)
|
||||
gt_semantic_seg = np.load(results['seg_map_path'])[self.data_key].squeeze().astype(np.uint8)
|
||||
# reduce zero_label
|
||||
if self.reduce_zero_label is None:
|
||||
self.reduce_zero_label = results['reduce_zero_label']
|
||||
assert self.reduce_zero_label == results['reduce_zero_label'], \
|
||||
'Initialize dataset with `reduce_zero_label` as ' \
|
||||
f'{results["reduce_zero_label"]} but when load annotation ' \
|
||||
f'the `reduce_zero_label` is {self.reduce_zero_label}'
|
||||
if self.reduce_zero_label:
|
||||
# avoid using underflow conversion
|
||||
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
||||
gt_semantic_seg = gt_semantic_seg - 1
|
||||
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
||||
# modify if custom classes
|
||||
if results.get('label_map', None) is not None:
|
||||
# Add deep copy to solve bug of repeatedly
|
||||
# replace `gt_semantic_seg`, which is reported in
|
||||
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
|
||||
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
||||
for old_id, new_id in results['label_map'].items():
|
||||
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
||||
results['gt_seg_map'] = gt_semantic_seg
|
||||
results['seg_fields'].append('gt_seg_map')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(reduce_zero_label={self.reduce_zero_label}, '
|
||||
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
|
||||
repr_str += f'backend_args={self.backend_args})'
|
||||
return repr_str
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageSingleChannel(BaseTransform):
|
||||
"""Load a Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, to_float32: bool = True):
|
||||
self.to_float32 = to_float32
|
||||
# self.data_key = data_key
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
img = imageio.imread(filename) # h, w, c
|
||||
img = img[:, :, 0]
|
||||
# if ds is None:
|
||||
# raise Exception(f'Unable to open file: {filename}')
|
||||
# print(img)s
|
||||
# img = np.einsum('ijk->jki', img)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadAnnotationsOil(MMCV_LoadAnnotations):
|
||||
"""Load annotations for semantic segmentation provided by dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# Filename of semantic segmentation ground truth file.
|
||||
'seg_map_path': 'a/b/c'
|
||||
}
|
||||
|
||||
After this module, the annotation has been changed to the format below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# in str
|
||||
'seg_fields': List
|
||||
# In uint8 type.
|
||||
'gt_seg_map': np.ndarray (H, W)
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_map_path (str): Path of semantic segmentation ground truth file.
|
||||
|
||||
Added Keys:
|
||||
|
||||
- seg_fields (List)
|
||||
- gt_seg_map (np.uint8)
|
||||
|
||||
Args:
|
||||
reduce_zero_label (bool, optional): Whether reduce all label value
|
||||
by 1. Usually used for datasets where 0 is background label.
|
||||
Defaults to None.
|
||||
imdecode_backend (str): The image decoding backend type. The backend
|
||||
argument for :func:``mmcv.imfrombytes``.
|
||||
See :fun:``mmcv.imfrombytes`` for details.
|
||||
Defaults to 'pillow'.
|
||||
backend_args (dict): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduce_zero_label=None,
|
||||
backend_args=None,
|
||||
data_key='arr_0',
|
||||
imdecode_backend='pillow',
|
||||
) -> None:
|
||||
super().__init__(
|
||||
with_bbox=False,
|
||||
with_label=False,
|
||||
with_seg=True,
|
||||
with_keypoints=False,
|
||||
imdecode_backend=imdecode_backend,
|
||||
backend_args=backend_args)
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
if self.reduce_zero_label is not None:
|
||||
warnings.warn('`reduce_zero_label` will be deprecated, '
|
||||
'if you would like to ignore the zero label, please '
|
||||
'set `reduce_zero_label=True` when dataset '
|
||||
'initialized')
|
||||
self.imdecode_backend = imdecode_backend
|
||||
self.data_key = data_key
|
||||
|
||||
def _load_seg_map(self, results: dict) -> None:
|
||||
"""Private function to load semantic segmentation annotations.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded semantic segmentation annotations.
|
||||
"""
|
||||
|
||||
# img_bytes = fileio.get(
|
||||
# results['seg_map_path'], backend_args=self.backend_args)
|
||||
# gt_semantic_seg = mmcv.imfrombytes(
|
||||
# img_bytes, flag='unchanged',
|
||||
# backend=self.imdecode_backend).squeeze().astype(np.uint8)
|
||||
seg_map = gdal.Open(results['seg_map_path']).ReadAsArray()
|
||||
gt_semantic_seg = np.zeros_like(seg_map).astype(np.uint8)
|
||||
gt_semantic_seg[seg_map==3.] = 1
|
||||
# reduce zero_label
|
||||
if self.reduce_zero_label is None:
|
||||
self.reduce_zero_label = results['reduce_zero_label']
|
||||
assert self.reduce_zero_label == results['reduce_zero_label'], \
|
||||
'Initialize dataset with `reduce_zero_label` as ' \
|
||||
f'{results["reduce_zero_label"]} but when load annotation ' \
|
||||
f'the `reduce_zero_label` is {self.reduce_zero_label}'
|
||||
if self.reduce_zero_label:
|
||||
# avoid using underflow conversion
|
||||
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
||||
gt_semantic_seg = gt_semantic_seg - 1
|
||||
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
||||
# modify if custom classes
|
||||
if results.get('label_map', None) is not None:
|
||||
# Add deep copy to solve bug of repeatedly
|
||||
# replace `gt_semantic_seg`, which is reported in
|
||||
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
|
||||
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
||||
for old_id, new_id in results['label_map'].items():
|
||||
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
||||
results['gt_seg_map'] = gt_semantic_seg
|
||||
results['seg_fields'].append('gt_seg_map')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(reduce_zero_label={self.reduce_zero_label}, '
|
||||
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
|
||||
repr_str += f'backend_args={self.backend_args})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromNpz(BaseTransform):
|
||||
"""Load a Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, data_key='arr_0', to_float32: bool = True):
|
||||
self.to_float32 = to_float32
|
||||
self.data_key = data_key
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
img = np.load(filename)[self.data_key]
|
||||
# if ds is None:
|
||||
# raise Exception(f'Unable to open file: {filename}')
|
||||
# print(img)s
|
||||
img = np.einsum('ijk->jki', img)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageOil(BaseTransform):
|
||||
"""Load a Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, data_key='arr_0', to_float32: bool = True):
|
||||
self.to_float32 = to_float32
|
||||
self.data_key = data_key
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
img = gdal.Open(filename).ReadAsArray()
|
||||
img = img[:,:,None]
|
||||
# if ds is None:
|
||||
# raise Exception(f'Unable to open file: {filename}')
|
||||
# print(img)s
|
||||
# img = np.einsum('ijk->jki', img)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadTsImageFromNpz(BaseTransform):
|
||||
"""Load a Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, data_key='arr_0', to_float32: bool = True, ts_size: int=10):
|
||||
self.to_float32 = to_float32
|
||||
self.data_key = data_key
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
self.ts_size = ts_size
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
img = np.load(filename)[self.data_key]
|
||||
ts, c, h, w = img.shape
|
||||
if ts >= self.ts_size:
|
||||
selected_indices = np.random.choice(ts, size=self.ts_size, replace=False)
|
||||
else:
|
||||
selected_indices = np.random.choice(ts, size=self.ts_size, replace=True)
|
||||
selected_indices.sort()
|
||||
img = img[selected_indices, :, :, :]
|
||||
# print(f'after input shape: {img.shape}')
|
||||
img = img.transpose(2, 3, 0, 1).reshape(h, w, self.ts_size*c) # h, w, ts, c -> h, w, ts*c
|
||||
# if ds is None:
|
||||
# raise Exception(f'Unable to open file: {filename}')
|
||||
# print(img)s
|
||||
# img = np.einsum('ijk->jki', img)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
|
||||
2537
finetune/mmseg/datasets/transforms/transforms.py
Normal file
2537
finetune/mmseg/datasets/transforms/transforms.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user