init
This commit is contained in:
9
finetune/mmseg/apis/__init__.py
Normal file
9
finetune/mmseg/apis/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .inference import inference_model, init_model, show_result_pyplot
|
||||
from .mmseg_inferencer import MMSegInferencer
|
||||
from .remote_sense_inferencer import RSImage, RSInferencer
|
||||
|
||||
__all__ = [
|
||||
'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer',
|
||||
'RSInferencer', 'RSImage'
|
||||
]
|
||||
189
finetune/mmseg/apis/inference.py
Normal file
189
finetune/mmseg/apis/inference.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner import load_checkpoint
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
from mmseg.models import BaseSegmentor
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
from .utils import ImageType, _preprare_data
|
||||
|
||||
|
||||
def init_model(config: Union[str, Path, Config],
|
||||
checkpoint: Optional[str] = None,
|
||||
device: str = 'cuda:0',
|
||||
cfg_options: Optional[dict] = None):
|
||||
"""Initialize a segmentor from config file.
|
||||
|
||||
Args:
|
||||
config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
|
||||
:obj:`Path`, or the config object.
|
||||
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
||||
will not load any weights.
|
||||
device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
|
||||
Use 'cpu' for loading model on CPU.
|
||||
cfg_options (dict, optional): Options to override some settings in
|
||||
the used config.
|
||||
Returns:
|
||||
nn.Module: The constructed segmentor.
|
||||
"""
|
||||
if isinstance(config, (str, Path)):
|
||||
config = Config.fromfile(config)
|
||||
elif not isinstance(config, Config):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
'but got {}'.format(type(config)))
|
||||
if cfg_options is not None:
|
||||
config.merge_from_dict(cfg_options)
|
||||
if config.model.type == 'EncoderDecoder':
|
||||
if 'init_cfg' in config.model.backbone:
|
||||
config.model.backbone.init_cfg = None
|
||||
elif config.model.type == 'MultimodalEncoderDecoder':
|
||||
for k, v in config.model.items():
|
||||
if isinstance(v, dict) and 'init_cfg' in v:
|
||||
config.model[k].init_cfg = None
|
||||
config.model.pretrained = None
|
||||
config.model.train_cfg = None
|
||||
init_default_scope(config.get('default_scope', 'mmseg'))
|
||||
|
||||
model = MODELS.build(config.model)
|
||||
if checkpoint is not None:
|
||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
|
||||
# save the dataset_meta in the model for convenience
|
||||
if 'dataset_meta' in checkpoint.get('meta', {}):
|
||||
# mmseg 1.x
|
||||
model.dataset_meta = dataset_meta
|
||||
elif 'CLASSES' in checkpoint.get('meta', {}):
|
||||
# < mmseg 1.x
|
||||
classes = checkpoint['meta']['CLASSES']
|
||||
palette = checkpoint['meta']['PALETTE']
|
||||
model.dataset_meta = {'classes': classes, 'palette': palette}
|
||||
else:
|
||||
warnings.simplefilter('once')
|
||||
warnings.warn(
|
||||
'dataset_meta or class names are not saved in the '
|
||||
'checkpoint\'s meta data, classes and palette will be'
|
||||
'set according to num_classes ')
|
||||
num_classes = model.decode_head.num_classes
|
||||
dataset_name = None
|
||||
for name in dataset_aliases.keys():
|
||||
if len(get_classes(name)) == num_classes:
|
||||
dataset_name = name
|
||||
break
|
||||
if dataset_name is None:
|
||||
warnings.warn(
|
||||
'No suitable dataset found, use Cityscapes by default')
|
||||
dataset_name = 'cityscapes'
|
||||
model.dataset_meta = {
|
||||
'classes': get_classes(dataset_name),
|
||||
'palette': get_palette(dataset_name)
|
||||
}
|
||||
model.cfg = config # save the config in the model for convenience
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def inference_model(model: BaseSegmentor,
|
||||
img: ImageType) -> Union[SegDataSample, SampleList]:
|
||||
"""Inference image(s) with the segmentor.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The loaded segmentor.
|
||||
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
|
||||
images.
|
||||
|
||||
Returns:
|
||||
:obj:`SegDataSample` or list[:obj:`SegDataSample`]:
|
||||
If imgs is a list or tuple, the same length list type results
|
||||
will be returned, otherwise return the segmentation results directly.
|
||||
"""
|
||||
# prepare data
|
||||
data, is_batch = _preprare_data(img, model)
|
||||
|
||||
# forward the model
|
||||
with torch.no_grad():
|
||||
results = model.test_step(data)
|
||||
|
||||
return results if is_batch else results[0]
|
||||
|
||||
|
||||
def show_result_pyplot(model: BaseSegmentor,
|
||||
img: Union[str, np.ndarray],
|
||||
result: SegDataSample,
|
||||
opacity: float = 0.5,
|
||||
title: str = '',
|
||||
draw_gt: bool = True,
|
||||
draw_pred: bool = True,
|
||||
wait_time: float = 0,
|
||||
show: bool = True,
|
||||
with_labels: Optional[bool] = True,
|
||||
save_dir=None,
|
||||
out_file=None):
|
||||
"""Visualize the segmentation results on the image.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The loaded segmentor.
|
||||
img (str or np.ndarray): Image filename or loaded image.
|
||||
result (SegDataSample): The prediction SegDataSample result.
|
||||
opacity(float): Opacity of painted segmentation map.
|
||||
Default 0.5. Must be in (0, 1] range.
|
||||
title (str): The title of pyplot figure.
|
||||
Default is ''.
|
||||
draw_gt (bool): Whether to draw GT SegDataSample. Default to True.
|
||||
draw_pred (bool): Whether to draw Prediction SegDataSample.
|
||||
Defaults to True.
|
||||
wait_time (float): The interval of show (s). 0 is the special value
|
||||
that means "forever". Defaults to 0.
|
||||
show (bool): Whether to display the drawn image.
|
||||
Default to True.
|
||||
with_labels(bool, optional): Add semantic labels in visualization
|
||||
result, Default to True.
|
||||
save_dir (str, optional): Save file dir for all storage backends.
|
||||
If it is None, the backend storage will not save any data.
|
||||
out_file (str, optional): Path to output file. Default to None.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
np.ndarray: the drawn image which channel is RGB.
|
||||
"""
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
if isinstance(img, str):
|
||||
image = mmcv.imread(img, channel_order='rgb')
|
||||
else:
|
||||
image = img
|
||||
if save_dir is not None:
|
||||
mkdir_or_exist(save_dir)
|
||||
# init visualizer
|
||||
visualizer = SegLocalVisualizer(
|
||||
vis_backends=[dict(type='LocalVisBackend')],
|
||||
save_dir=save_dir,
|
||||
alpha=opacity)
|
||||
visualizer.dataset_meta = dict(
|
||||
classes=model.dataset_meta['classes'],
|
||||
palette=model.dataset_meta['palette'])
|
||||
visualizer.add_datasample(
|
||||
name=title,
|
||||
image=image,
|
||||
data_sample=result,
|
||||
draw_gt=draw_gt,
|
||||
draw_pred=draw_pred,
|
||||
wait_time=wait_time,
|
||||
out_file=out_file,
|
||||
show=show,
|
||||
with_labels=with_labels)
|
||||
vis_img = visualizer.get_image()
|
||||
|
||||
return vis_img
|
||||
382
finetune/mmseg/apis/mmseg_inferencer.py
Normal file
382
finetune/mmseg/apis/mmseg_inferencer.py
Normal file
@@ -0,0 +1,382 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.transforms import Compose
|
||||
from mmengine.infer.infer import BaseInferencer, ModelType
|
||||
from mmengine.model import revert_sync_batchnorm
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner.checkpoint import _load_checkpoint_to_model
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import ConfigType, SampleList, get_classes, get_palette
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
|
||||
InputType = Union[str, np.ndarray]
|
||||
InputsType = Union[InputType, Sequence[InputType]]
|
||||
PredType = Union[SegDataSample, SampleList]
|
||||
|
||||
|
||||
class MMSegInferencer(BaseInferencer):
|
||||
"""Semantic segmentation inferencer, provides inference and visualization
|
||||
interfaces. Note: MMEngine >= 0.5.0 is required.
|
||||
|
||||
Args:
|
||||
model (str, optional): Path to the config file or the model name
|
||||
defined in metafile. Take the `mmseg metafile <https://github.com/open-mmlab/mmsegmentation/blob/main/configs/fcn/metafile.yaml>`_
|
||||
as an example the `model` could be
|
||||
"fcn_r50-d8_4xb2-40k_cityscapes-512x1024", and the weights of model
|
||||
will be download automatically. If use config file, like
|
||||
"configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py", the
|
||||
`weights` should be defined.
|
||||
weights (str, optional): Path to the checkpoint. If it is not specified
|
||||
and model is a model name of metafile, the weights will be loaded
|
||||
from metafile. Defaults to None.
|
||||
classes (list, optional): Input classes for result rendering, as the
|
||||
prediction of segmentation model is a segment map with label
|
||||
indices, `classes` is a list which includes items responding to the
|
||||
label indices. If classes is not defined, visualizer will take
|
||||
`cityscapes` classes by default. Defaults to None.
|
||||
palette (list, optional): Input palette for result rendering, which is
|
||||
a list of color palette responding to the classes. If palette is
|
||||
not defined, visualizer will take `cityscapes` palette by default.
|
||||
Defaults to None.
|
||||
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/utils/class_names.py#L302-L317>`_
|
||||
visulizer will use the meta information of the dataset i.e. classes
|
||||
and palette, but the `classes` and `palette` have higher priority.
|
||||
Defaults to None.
|
||||
device (str, optional): Device to run inference. If None, the available
|
||||
device will be automatically used. Defaults to None.
|
||||
scope (str, optional): The scope of the model. Defaults to 'mmseg'.
|
||||
""" # noqa
|
||||
|
||||
preprocess_kwargs: set = set()
|
||||
forward_kwargs: set = {'mode', 'out_dir'}
|
||||
visualize_kwargs: set = {
|
||||
'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis',
|
||||
'with_labels'
|
||||
}
|
||||
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
|
||||
|
||||
def __init__(self,
|
||||
model: Union[ModelType, str],
|
||||
weights: Optional[str] = None,
|
||||
classes: Optional[Union[str, List]] = None,
|
||||
palette: Optional[Union[str, List]] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
scope: Optional[str] = 'mmseg') -> None:
|
||||
# A global counter tracking the number of images processes, for
|
||||
# naming of the output images
|
||||
self.num_visualized_imgs = 0
|
||||
self.num_pred_imgs = 0
|
||||
init_default_scope(scope if scope else 'mmseg')
|
||||
super().__init__(
|
||||
model=model, weights=weights, device=device, scope=scope)
|
||||
|
||||
if device == 'cpu' or not torch.cuda.is_available():
|
||||
self.model = revert_sync_batchnorm(self.model)
|
||||
|
||||
assert isinstance(self.visualizer, SegLocalVisualizer)
|
||||
self.visualizer.set_dataset_meta(classes, palette, dataset_name)
|
||||
|
||||
def _load_weights_to_model(self, model: nn.Module,
|
||||
checkpoint: Optional[dict],
|
||||
cfg: Optional[ConfigType]) -> None:
|
||||
"""Loading model weights and meta information from cfg and checkpoint.
|
||||
|
||||
Subclasses could override this method to load extra meta information
|
||||
from ``checkpoint`` and ``cfg`` to model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model to load weights and meta information.
|
||||
checkpoint (dict, optional): The loaded checkpoint.
|
||||
cfg (Config or ConfigDict, optional): The loaded config.
|
||||
"""
|
||||
|
||||
if checkpoint is not None:
|
||||
_load_checkpoint_to_model(model, checkpoint)
|
||||
checkpoint_meta = checkpoint.get('meta', {})
|
||||
# save the dataset_meta in the model for convenience
|
||||
if 'dataset_meta' in checkpoint_meta:
|
||||
# mmsegmentation 1.x
|
||||
model.dataset_meta = {
|
||||
'classes': checkpoint_meta['dataset_meta'].get('classes'),
|
||||
'palette': checkpoint_meta['dataset_meta'].get('palette')
|
||||
}
|
||||
elif 'CLASSES' in checkpoint_meta:
|
||||
# mmsegmentation 0.x
|
||||
classes = checkpoint_meta['CLASSES']
|
||||
palette = checkpoint_meta.get('PALETTE', None)
|
||||
model.dataset_meta = {'classes': classes, 'palette': palette}
|
||||
else:
|
||||
warnings.warn(
|
||||
'dataset_meta or class names are not saved in the '
|
||||
'checkpoint\'s meta data, use classes of Cityscapes by '
|
||||
'default.')
|
||||
model.dataset_meta = {
|
||||
'classes': get_classes('cityscapes'),
|
||||
'palette': get_palette('cityscapes')
|
||||
}
|
||||
else:
|
||||
warnings.warn('Checkpoint is not loaded, and the inference '
|
||||
'result is calculated by the randomly initialized '
|
||||
'model!')
|
||||
warnings.warn(
|
||||
'weights is None, use cityscapes classes by default.')
|
||||
model.dataset_meta = {
|
||||
'classes': get_classes('cityscapes'),
|
||||
'palette': get_palette('cityscapes')
|
||||
}
|
||||
|
||||
def __call__(self,
|
||||
inputs: InputsType,
|
||||
return_datasamples: bool = False,
|
||||
batch_size: int = 1,
|
||||
return_vis: bool = False,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
out_dir: str = '',
|
||||
img_out_dir: str = 'vis',
|
||||
pred_out_dir: str = 'pred',
|
||||
**kwargs) -> dict:
|
||||
"""Call the inferencer.
|
||||
|
||||
Args:
|
||||
inputs (Union[list, str, np.ndarray]): Inputs for the inferencer.
|
||||
return_datasamples (bool): Whether to return results as
|
||||
:obj:`SegDataSample`. Defaults to False.
|
||||
batch_size (int): Batch size. Defaults to 1.
|
||||
show (bool): Whether to display the rendering color segmentation
|
||||
mask in a popup window. Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
out_dir (str): Output directory of inference results. Defaults
|
||||
to ''.
|
||||
img_out_dir (str): Subdirectory of `out_dir`, used to save
|
||||
rendering color segmentation mask, so `out_dir` must be defined
|
||||
if you would like to save predicted mask. Defaults to 'vis'.
|
||||
pred_out_dir (str): Subdirectory of `out_dir`, used to save
|
||||
predicted mask file, so `out_dir` must be defined if you would
|
||||
like to save predicted mask. Defaults to 'pred'.
|
||||
|
||||
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
|
||||
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
|
||||
Each key in kwargs should be in the corresponding set of
|
||||
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
|
||||
and ``postprocess_kwargs``.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Inference and visualization results.
|
||||
"""
|
||||
|
||||
if out_dir != '':
|
||||
pred_out_dir = osp.join(out_dir, pred_out_dir)
|
||||
img_out_dir = osp.join(out_dir, img_out_dir)
|
||||
else:
|
||||
pred_out_dir = ''
|
||||
img_out_dir = ''
|
||||
|
||||
return super().__call__(
|
||||
inputs=inputs,
|
||||
return_datasamples=return_datasamples,
|
||||
batch_size=batch_size,
|
||||
show=show,
|
||||
wait_time=wait_time,
|
||||
img_out_dir=img_out_dir,
|
||||
pred_out_dir=pred_out_dir,
|
||||
return_vis=return_vis,
|
||||
**kwargs)
|
||||
|
||||
def visualize(self,
|
||||
inputs: list,
|
||||
preds: List[dict],
|
||||
return_vis: bool = False,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
img_out_dir: str = '',
|
||||
opacity: float = 0.8,
|
||||
with_labels: Optional[bool] = True) -> List[np.ndarray]:
|
||||
"""Visualize predictions.
|
||||
|
||||
Args:
|
||||
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
|
||||
preds (Any): Predictions of the model.
|
||||
show (bool): Whether to display the image in a popup window.
|
||||
Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
img_out_dir (str): Output directory of rendering prediction i.e.
|
||||
color segmentation mask. Defaults: ''
|
||||
opacity (int, float): The transparency of segmentation mask.
|
||||
Defaults to 0.8.
|
||||
|
||||
Returns:
|
||||
List[np.ndarray]: Visualization results.
|
||||
"""
|
||||
if not show and img_out_dir == '' and not return_vis:
|
||||
return None
|
||||
if self.visualizer is None:
|
||||
raise ValueError('Visualization needs the "visualizer" term'
|
||||
'defined in the config, but got None.')
|
||||
|
||||
self.visualizer.set_dataset_meta(**self.model.dataset_meta)
|
||||
self.visualizer.alpha = opacity
|
||||
|
||||
results = []
|
||||
|
||||
for single_input, pred in zip(inputs, preds):
|
||||
if isinstance(single_input, str):
|
||||
img_bytes = mmengine.fileio.get(single_input)
|
||||
img = mmcv.imfrombytes(img_bytes)
|
||||
img = img[:, :, ::-1]
|
||||
img_name = osp.basename(single_input)
|
||||
elif isinstance(single_input, np.ndarray):
|
||||
img = single_input.copy()
|
||||
img_num = str(self.num_visualized_imgs).zfill(8) + '_vis'
|
||||
img_name = f'{img_num}.jpg'
|
||||
else:
|
||||
raise ValueError('Unsupported input type:'
|
||||
f'{type(single_input)}')
|
||||
|
||||
out_file = osp.join(img_out_dir, img_name) if img_out_dir != ''\
|
||||
else None
|
||||
|
||||
self.visualizer.add_datasample(
|
||||
img_name,
|
||||
img,
|
||||
pred,
|
||||
show=show,
|
||||
wait_time=wait_time,
|
||||
draw_gt=False,
|
||||
draw_pred=True,
|
||||
out_file=out_file,
|
||||
with_labels=with_labels)
|
||||
if return_vis:
|
||||
results.append(self.visualizer.get_image())
|
||||
self.num_visualized_imgs += 1
|
||||
|
||||
return results if return_vis else None
|
||||
|
||||
def postprocess(self,
|
||||
preds: PredType,
|
||||
visualization: List[np.ndarray],
|
||||
return_datasample: bool = False,
|
||||
pred_out_dir: str = '') -> dict:
|
||||
"""Process the predictions and visualization results from ``forward``
|
||||
and ``visualize``.
|
||||
|
||||
This method should be responsible for the following tasks:
|
||||
|
||||
1. Pack the predictions and visualization results and return them.
|
||||
2. Save the predictions, if it needed.
|
||||
|
||||
Args:
|
||||
preds (List[Dict]): Predictions of the model.
|
||||
visualization (List[np.ndarray]): The list of rendering color
|
||||
segmentation mask.
|
||||
return_datasample (bool): Whether to return results as datasamples.
|
||||
Defaults to False.
|
||||
pred_out_dir: File to save the inference results w/o
|
||||
visualization. If left as empty, no file will be saved.
|
||||
Defaults to ''.
|
||||
|
||||
Returns:
|
||||
dict: Inference and visualization results with key ``predictions``
|
||||
and ``visualization``
|
||||
|
||||
- ``visualization (Any)``: Returned by :meth:`visualize`
|
||||
- ``predictions`` (List[np.ndarray], np.ndarray): Returned by
|
||||
:meth:`forward` and processed in :meth:`postprocess`.
|
||||
If ``return_datasample=False``, it will be the segmentation mask
|
||||
with label indice.
|
||||
"""
|
||||
if return_datasample:
|
||||
if len(preds) == 1:
|
||||
return preds[0]
|
||||
else:
|
||||
return preds
|
||||
|
||||
results_dict = {}
|
||||
|
||||
results_dict['predictions'] = []
|
||||
results_dict['visualization'] = []
|
||||
|
||||
for i, pred in enumerate(preds):
|
||||
pred_data = dict()
|
||||
if 'pred_sem_seg' in pred.keys():
|
||||
pred_data['sem_seg'] = pred.pred_sem_seg.numpy().data[0]
|
||||
elif 'pred_depth_map' in pred.keys():
|
||||
pred_data['depth_map'] = pred.pred_depth_map.numpy().data[0]
|
||||
|
||||
if visualization is not None:
|
||||
vis = visualization[i]
|
||||
results_dict['visualization'].append(vis)
|
||||
if pred_out_dir != '':
|
||||
mmengine.mkdir_or_exist(pred_out_dir)
|
||||
for key, data in pred_data.items():
|
||||
post_fix = '_pred.png' if key == 'sem_seg' else '_pred.npy'
|
||||
img_name = str(self.num_pred_imgs).zfill(8) + post_fix
|
||||
img_path = osp.join(pred_out_dir, img_name)
|
||||
if key == 'sem_seg':
|
||||
output = Image.fromarray(data.astype(np.uint8))
|
||||
output.save(img_path)
|
||||
else:
|
||||
np.save(img_path, data)
|
||||
pred_data = next(iter(pred_data.values()))
|
||||
results_dict['predictions'].append(pred_data)
|
||||
self.num_pred_imgs += 1
|
||||
|
||||
if len(results_dict['predictions']) == 1:
|
||||
results_dict['predictions'] = results_dict['predictions'][0]
|
||||
if visualization is not None:
|
||||
results_dict['visualization'] = \
|
||||
results_dict['visualization'][0]
|
||||
return results_dict
|
||||
|
||||
def _init_pipeline(self, cfg: ConfigType) -> Compose:
|
||||
"""Initialize the test pipeline.
|
||||
|
||||
Return a pipeline to handle various input data, such as ``str``,
|
||||
``np.ndarray``. It is an abstract method in BaseInferencer, and should
|
||||
be implemented in subclasses.
|
||||
|
||||
The returned pipeline will be used to process a single data.
|
||||
It will be used in :meth:`preprocess` like this:
|
||||
|
||||
.. code-block:: python
|
||||
def preprocess(self, inputs, batch_size, **kwargs):
|
||||
...
|
||||
dataset = map(self.pipeline, dataset)
|
||||
...
|
||||
"""
|
||||
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
# Loading annotations is also not applicable
|
||||
for transform in ('LoadAnnotations', 'LoadDepthAnnotation'):
|
||||
idx = self._get_transform_idx(pipeline_cfg, transform)
|
||||
if idx != -1:
|
||||
del pipeline_cfg[idx]
|
||||
|
||||
load_img_idx = self._get_transform_idx(pipeline_cfg,
|
||||
'LoadImageFromFile')
|
||||
if load_img_idx == -1:
|
||||
raise ValueError(
|
||||
'LoadImageFromFile is not found in the test pipeline')
|
||||
pipeline_cfg[load_img_idx]['type'] = 'InferencerLoader'
|
||||
return Compose(pipeline_cfg)
|
||||
|
||||
def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int:
|
||||
"""Returns the index of the transform in a pipeline.
|
||||
|
||||
If the transform is not found, returns -1.
|
||||
"""
|
||||
for i, transform in enumerate(pipeline_cfg):
|
||||
if transform['type'] == name:
|
||||
return i
|
||||
return -1
|
||||
279
finetune/mmseg/apis/remote_sense_inferencer.py
Normal file
279
finetune/mmseg/apis/remote_sense_inferencer.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import threading
|
||||
from queue import Queue
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner import load_checkpoint
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
except ImportError:
|
||||
gdal = None
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import _preprare_data
|
||||
|
||||
|
||||
class RSImage:
|
||||
"""Remote sensing image class.
|
||||
|
||||
Args:
|
||||
img (str or gdal.Dataset): Image file path or gdal.Dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, image):
|
||||
self.dataset = gdal.Open(image, gdal.GA_ReadOnly) if isinstance(
|
||||
image, str) else image
|
||||
assert isinstance(self.dataset, gdal.Dataset), \
|
||||
f'{image} is not a image'
|
||||
self.width = self.dataset.RasterXSize
|
||||
self.height = self.dataset.RasterYSize
|
||||
self.channel = self.dataset.RasterCount
|
||||
self.trans = self.dataset.GetGeoTransform()
|
||||
self.proj = self.dataset.GetProjection()
|
||||
self.band_list = []
|
||||
self.band_list.extend(
|
||||
self.dataset.GetRasterBand(c + 1) for c in range(self.channel))
|
||||
self.grids = []
|
||||
|
||||
def read(self, grid: Optional[List] = None) -> np.ndarray:
|
||||
"""Read image data. If grid is None, read the whole image.
|
||||
|
||||
Args:
|
||||
grid (Optional[List], optional): Grid to read. Defaults to None.
|
||||
Returns:
|
||||
np.ndarray: Image data.
|
||||
"""
|
||||
if grid is None:
|
||||
return np.einsum('ijk->jki', self.dataset.ReadAsArray())
|
||||
assert len(
|
||||
grid) >= 4, 'grid must be a list containing at least 4 elements'
|
||||
data = self.dataset.ReadAsArray(*grid[:4])
|
||||
if data.ndim == 2:
|
||||
data = data[np.newaxis, ...]
|
||||
return np.einsum('ijk->jki', data)
|
||||
|
||||
def write(self, data: Optional[np.ndarray], grid: Optional[List] = None):
|
||||
"""Write image data.
|
||||
|
||||
Args:
|
||||
grid (Optional[List], optional): Grid to write. Defaults to None.
|
||||
data (Optional[np.ndarray], optional): Data to write.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: Either grid or data must be provided.
|
||||
"""
|
||||
if grid is not None:
|
||||
assert len(grid) == 8, 'grid must be a list of 8 elements'
|
||||
for band in self.band_list:
|
||||
band.WriteArray(
|
||||
data[grid[5]:grid[5] + grid[7], grid[4]:grid[4] + grid[6]],
|
||||
grid[0] + grid[4], grid[1] + grid[5])
|
||||
elif data is not None:
|
||||
for i in range(self.channel):
|
||||
self.band_list[i].WriteArray(data[..., i])
|
||||
else:
|
||||
raise ValueError('Either grid or data must be provided.')
|
||||
|
||||
def create_seg_map(self, output_path: Optional[str] = None):
|
||||
if output_path is None:
|
||||
output_path = 'output_label.tif'
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
seg_map = driver.Create(output_path, self.width, self.height, 1,
|
||||
gdal.GDT_Byte)
|
||||
seg_map.SetGeoTransform(self.trans)
|
||||
seg_map.SetProjection(self.proj)
|
||||
seg_map_img = RSImage(seg_map)
|
||||
seg_map_img.path = output_path
|
||||
return seg_map_img
|
||||
|
||||
def create_grids(self,
|
||||
window_size: Tuple[int, int],
|
||||
stride: Tuple[int, int] = (0, 0)):
|
||||
"""Create grids for image inference.
|
||||
|
||||
Args:
|
||||
window_size (Tuple[int, int]): the size of the sliding window.
|
||||
stride (Tuple[int, int], optional): the stride of the sliding
|
||||
window. Defaults to (0, 0).
|
||||
|
||||
Raises:
|
||||
AssertionError: window_size must be a tuple of 2 elements.
|
||||
AssertionError: stride must be a tuple of 2 elements.
|
||||
"""
|
||||
assert len(
|
||||
window_size) == 2, 'window_size must be a tuple of 2 elements'
|
||||
assert len(stride) == 2, 'stride must be a tuple of 2 elements'
|
||||
win_w, win_h = window_size
|
||||
stride_x, stride_y = stride
|
||||
|
||||
stride_x = win_w if stride_x == 0 else stride_x
|
||||
stride_y = win_h if stride_y == 0 else stride_y
|
||||
|
||||
x_half_overlap = (win_w - stride_x + 1) // 2
|
||||
y_half_overlap = (win_h - stride_y + 1) // 2
|
||||
|
||||
for y in range(0, self.height, stride_y):
|
||||
y_end = y + win_h >= self.height
|
||||
y_offset = self.height - win_h if y_end else y
|
||||
y_size = win_h
|
||||
y_crop_off = 0 if y_offset == 0 else y_half_overlap
|
||||
y_crop_size = y_size if y_end else win_h - y_crop_off
|
||||
|
||||
for x in range(0, self.width, stride_x):
|
||||
x_end = x + win_w >= self.width
|
||||
x_offset = self.width - win_w if x_end else x
|
||||
x_size = win_w
|
||||
x_crop_off = 0 if x_offset == 0 else x_half_overlap
|
||||
x_crop_size = x_size if x_end else win_w - x_crop_off
|
||||
|
||||
self.grids.append([
|
||||
x_offset, y_offset, x_size, y_size, x_crop_off, y_crop_off,
|
||||
x_crop_size, y_crop_size
|
||||
])
|
||||
|
||||
|
||||
class RSInferencer:
|
||||
"""Remote sensing inference class.
|
||||
|
||||
Args:
|
||||
model (BaseModel): The loaded model.
|
||||
batch_size (int, optional): Batch size. Defaults to 1.
|
||||
thread (int, optional): Number of threads. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, model: BaseModel, batch_size: int = 1, thread: int = 1):
|
||||
self.model = model
|
||||
self.batch_size = batch_size
|
||||
self.END_FLAG = object()
|
||||
self.read_buffer = Queue(self.batch_size)
|
||||
self.write_buffer = Queue(self.batch_size)
|
||||
self.thread = thread
|
||||
|
||||
@classmethod
|
||||
def from_config_path(cls,
|
||||
config_path: str,
|
||||
checkpoint_path: str,
|
||||
batch_size: int = 1,
|
||||
thread: int = 1,
|
||||
device: Optional[str] = 'cpu'):
|
||||
"""Initialize a segmentor from config file.
|
||||
|
||||
Args:
|
||||
config_path (str): Config file path.
|
||||
checkpoint_path (str): Checkpoint path.
|
||||
batch_size (int, optional): Batch size. Defaults to 1.
|
||||
"""
|
||||
init_default_scope('mmseg')
|
||||
cfg = Config.fromfile(config_path)
|
||||
model = MODELS.build(cfg.model)
|
||||
model.cfg = cfg
|
||||
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return cls(model, batch_size, thread)
|
||||
|
||||
@classmethod
|
||||
def from_model(cls,
|
||||
model: BaseModel,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
batch_size: int = 1,
|
||||
thread: int = 1,
|
||||
device: Optional[str] = 'cpu'):
|
||||
"""Initialize a segmentor from model.
|
||||
|
||||
Args:
|
||||
model (BaseModel): The loaded model.
|
||||
checkpoint_path (Optional[str]): Checkpoint path.
|
||||
batch_size (int, optional): Batch size. Defaults to 1.
|
||||
"""
|
||||
if checkpoint_path is not None:
|
||||
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
||||
model.to(device)
|
||||
return cls(model, batch_size, thread)
|
||||
|
||||
def read(self,
|
||||
image: RSImage,
|
||||
window_size: Tuple[int, int],
|
||||
strides: Tuple[int, int] = (0, 0)):
|
||||
"""Load image data to read buffer.
|
||||
|
||||
Args:
|
||||
image (RSImage): The image to read.
|
||||
window_size (Tuple[int, int]): The size of the sliding window.
|
||||
strides (Tuple[int, int], optional): The stride of the sliding
|
||||
window. Defaults to (0, 0).
|
||||
"""
|
||||
image.create_grids(window_size, strides)
|
||||
for grid in image.grids:
|
||||
self.read_buffer.put([grid, image.read(grid=grid)])
|
||||
self.read_buffer.put(self.END_FLAG)
|
||||
|
||||
def inference(self):
|
||||
"""Inference image data from read buffer and put the result to write
|
||||
buffer."""
|
||||
while True:
|
||||
item = self.read_buffer.get()
|
||||
if item == self.END_FLAG:
|
||||
self.read_buffer.put(self.END_FLAG)
|
||||
self.write_buffer.put(item)
|
||||
break
|
||||
data, _ = _preprare_data(item[1], self.model)
|
||||
with torch.no_grad():
|
||||
result = self.model.test_step(data)
|
||||
item[1] = result[0].pred_sem_seg.cpu().data.numpy()[0]
|
||||
self.write_buffer.put(item)
|
||||
self.read_buffer.task_done()
|
||||
|
||||
def write(self, image: RSImage, output_path: Optional[str] = None):
|
||||
"""Write image data from write buffer.
|
||||
|
||||
Args:
|
||||
image (RSImage): The image to write.
|
||||
output_path (Optional[str], optional): The path to save the
|
||||
segmentation map. Defaults to None.
|
||||
"""
|
||||
seg_map = image.create_seg_map(output_path)
|
||||
while True:
|
||||
item = self.write_buffer.get()
|
||||
if item == self.END_FLAG:
|
||||
break
|
||||
seg_map.write(data=item[1], grid=item[0])
|
||||
self.write_buffer.task_done()
|
||||
|
||||
def run(self,
|
||||
image: RSImage,
|
||||
window_size: Tuple[int, int],
|
||||
strides: Tuple[int, int] = (0, 0),
|
||||
output_path: Optional[str] = None):
|
||||
"""Run inference with multi-threading.
|
||||
|
||||
Args:
|
||||
image (RSImage): The image to inference.
|
||||
window_size (Tuple[int, int]): The size of the sliding window.
|
||||
strides (Tuple[int, int], optional): The stride of the sliding
|
||||
window. Defaults to (0, 0).
|
||||
output_path (Optional[str], optional): The path to save the
|
||||
segmentation map. Defaults to None.
|
||||
"""
|
||||
read_thread = threading.Thread(
|
||||
target=self.read, args=(image, window_size, strides))
|
||||
read_thread.start()
|
||||
inference_threads = []
|
||||
for _ in range(self.thread):
|
||||
inference_thread = threading.Thread(target=self.inference)
|
||||
inference_thread.start()
|
||||
inference_threads.append(inference_thread)
|
||||
write_thread = threading.Thread(
|
||||
target=self.write, args=(image, output_path))
|
||||
write_thread.start()
|
||||
read_thread.join()
|
||||
for inference_thread in inference_threads:
|
||||
inference_thread.join()
|
||||
write_thread.join()
|
||||
41
finetune/mmseg/apis/utils.py
Normal file
41
finetune/mmseg/apis/utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections import defaultdict
|
||||
from typing import Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine.model import BaseModel
|
||||
|
||||
ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
|
||||
|
||||
|
||||
def _preprare_data(imgs: ImageType, model: BaseModel):
|
||||
|
||||
cfg = model.cfg
|
||||
for t in cfg.test_pipeline:
|
||||
if t.get('type') == 'LoadAnnotations':
|
||||
cfg.test_pipeline.remove(t)
|
||||
|
||||
is_batch = True
|
||||
if not isinstance(imgs, (list, tuple)):
|
||||
imgs = [imgs]
|
||||
is_batch = False
|
||||
|
||||
if isinstance(imgs[0], np.ndarray):
|
||||
cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray'
|
||||
|
||||
# TODO: Consider using the singleton pattern to avoid building
|
||||
# a pipeline for each inference
|
||||
pipeline = Compose(cfg.test_pipeline)
|
||||
|
||||
data = defaultdict(list)
|
||||
for img in imgs:
|
||||
if isinstance(img, np.ndarray):
|
||||
data_ = dict(img=img)
|
||||
else:
|
||||
data_ = dict(img_path=img)
|
||||
data_ = pipeline(data_)
|
||||
data['inputs'].append(data_['inputs'])
|
||||
data['data_samples'].append(data_['data_samples'])
|
||||
|
||||
return data, is_batch
|
||||
Reference in New Issue
Block a user