This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_model, init_model, show_result_pyplot
from .mmseg_inferencer import MMSegInferencer
from .remote_sense_inferencer import RSImage, RSInferencer
__all__ = [
'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer',
'RSInferencer', 'RSImage'
]

View File

@@ -0,0 +1,189 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from pathlib import Path
from typing import Optional, Union
import mmcv
import numpy as np
import torch
from mmengine import Config
from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint
from mmengine.utils import mkdir_or_exist
from mmseg.models import BaseSegmentor
from mmseg.registry import MODELS
from mmseg.structures import SegDataSample
from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette
from mmseg.visualization import SegLocalVisualizer
from .utils import ImageType, _preprare_data
def init_model(config: Union[str, Path, Config],
checkpoint: Optional[str] = None,
device: str = 'cuda:0',
cfg_options: Optional[dict] = None):
"""Initialize a segmentor from config file.
Args:
config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
:obj:`Path`, or the config object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
Use 'cpu' for loading model on CPU.
cfg_options (dict, optional): Options to override some settings in
the used config.
Returns:
nn.Module: The constructed segmentor.
"""
if isinstance(config, (str, Path)):
config = Config.fromfile(config)
elif not isinstance(config, Config):
raise TypeError('config must be a filename or Config object, '
'but got {}'.format(type(config)))
if cfg_options is not None:
config.merge_from_dict(cfg_options)
if config.model.type == 'EncoderDecoder':
if 'init_cfg' in config.model.backbone:
config.model.backbone.init_cfg = None
elif config.model.type == 'MultimodalEncoderDecoder':
for k, v in config.model.items():
if isinstance(v, dict) and 'init_cfg' in v:
config.model[k].init_cfg = None
config.model.pretrained = None
config.model.train_cfg = None
init_default_scope(config.get('default_scope', 'mmseg'))
model = MODELS.build(config.model)
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
# save the dataset_meta in the model for convenience
if 'dataset_meta' in checkpoint.get('meta', {}):
# mmseg 1.x
model.dataset_meta = dataset_meta
elif 'CLASSES' in checkpoint.get('meta', {}):
# < mmseg 1.x
classes = checkpoint['meta']['CLASSES']
palette = checkpoint['meta']['PALETTE']
model.dataset_meta = {'classes': classes, 'palette': palette}
else:
warnings.simplefilter('once')
warnings.warn(
'dataset_meta or class names are not saved in the '
'checkpoint\'s meta data, classes and palette will be'
'set according to num_classes ')
num_classes = model.decode_head.num_classes
dataset_name = None
for name in dataset_aliases.keys():
if len(get_classes(name)) == num_classes:
dataset_name = name
break
if dataset_name is None:
warnings.warn(
'No suitable dataset found, use Cityscapes by default')
dataset_name = 'cityscapes'
model.dataset_meta = {
'classes': get_classes(dataset_name),
'palette': get_palette(dataset_name)
}
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
def inference_model(model: BaseSegmentor,
img: ImageType) -> Union[SegDataSample, SampleList]:
"""Inference image(s) with the segmentor.
Args:
model (nn.Module): The loaded segmentor.
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
images.
Returns:
:obj:`SegDataSample` or list[:obj:`SegDataSample`]:
If imgs is a list or tuple, the same length list type results
will be returned, otherwise return the segmentation results directly.
"""
# prepare data
data, is_batch = _preprare_data(img, model)
# forward the model
with torch.no_grad():
results = model.test_step(data)
return results if is_batch else results[0]
def show_result_pyplot(model: BaseSegmentor,
img: Union[str, np.ndarray],
result: SegDataSample,
opacity: float = 0.5,
title: str = '',
draw_gt: bool = True,
draw_pred: bool = True,
wait_time: float = 0,
show: bool = True,
with_labels: Optional[bool] = True,
save_dir=None,
out_file=None):
"""Visualize the segmentation results on the image.
Args:
model (nn.Module): The loaded segmentor.
img (str or np.ndarray): Image filename or loaded image.
result (SegDataSample): The prediction SegDataSample result.
opacity(float): Opacity of painted segmentation map.
Default 0.5. Must be in (0, 1] range.
title (str): The title of pyplot figure.
Default is ''.
draw_gt (bool): Whether to draw GT SegDataSample. Default to True.
draw_pred (bool): Whether to draw Prediction SegDataSample.
Defaults to True.
wait_time (float): The interval of show (s). 0 is the special value
that means "forever". Defaults to 0.
show (bool): Whether to display the drawn image.
Default to True.
with_labels(bool, optional): Add semantic labels in visualization
result, Default to True.
save_dir (str, optional): Save file dir for all storage backends.
If it is None, the backend storage will not save any data.
out_file (str, optional): Path to output file. Default to None.
Returns:
np.ndarray: the drawn image which channel is RGB.
"""
if hasattr(model, 'module'):
model = model.module
if isinstance(img, str):
image = mmcv.imread(img, channel_order='rgb')
else:
image = img
if save_dir is not None:
mkdir_or_exist(save_dir)
# init visualizer
visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')],
save_dir=save_dir,
alpha=opacity)
visualizer.dataset_meta = dict(
classes=model.dataset_meta['classes'],
palette=model.dataset_meta['palette'])
visualizer.add_datasample(
name=title,
image=image,
data_sample=result,
draw_gt=draw_gt,
draw_pred=draw_pred,
wait_time=wait_time,
out_file=out_file,
show=show,
with_labels=with_labels)
vis_img = visualizer.get_image()
return vis_img

View File

@@ -0,0 +1,382 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import List, Optional, Sequence, Union
import mmcv
import mmengine
import numpy as np
import torch
import torch.nn as nn
from mmcv.transforms import Compose
from mmengine.infer.infer import BaseInferencer, ModelType
from mmengine.model import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.runner.checkpoint import _load_checkpoint_to_model
from PIL import Image
from mmseg.structures import SegDataSample
from mmseg.utils import ConfigType, SampleList, get_classes, get_palette
from mmseg.visualization import SegLocalVisualizer
InputType = Union[str, np.ndarray]
InputsType = Union[InputType, Sequence[InputType]]
PredType = Union[SegDataSample, SampleList]
class MMSegInferencer(BaseInferencer):
"""Semantic segmentation inferencer, provides inference and visualization
interfaces. Note: MMEngine >= 0.5.0 is required.
Args:
model (str, optional): Path to the config file or the model name
defined in metafile. Take the `mmseg metafile <https://github.com/open-mmlab/mmsegmentation/blob/main/configs/fcn/metafile.yaml>`_
as an example the `model` could be
"fcn_r50-d8_4xb2-40k_cityscapes-512x1024", and the weights of model
will be download automatically. If use config file, like
"configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py", the
`weights` should be defined.
weights (str, optional): Path to the checkpoint. If it is not specified
and model is a model name of metafile, the weights will be loaded
from metafile. Defaults to None.
classes (list, optional): Input classes for result rendering, as the
prediction of segmentation model is a segment map with label
indices, `classes` is a list which includes items responding to the
label indices. If classes is not defined, visualizer will take
`cityscapes` classes by default. Defaults to None.
palette (list, optional): Input palette for result rendering, which is
a list of color palette responding to the classes. If palette is
not defined, visualizer will take `cityscapes` palette by default.
Defaults to None.
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/utils/class_names.py#L302-L317>`_
visulizer will use the meta information of the dataset i.e. classes
and palette, but the `classes` and `palette` have higher priority.
Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
scope (str, optional): The scope of the model. Defaults to 'mmseg'.
""" # noqa
preprocess_kwargs: set = set()
forward_kwargs: set = {'mode', 'out_dir'}
visualize_kwargs: set = {
'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis',
'with_labels'
}
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
def __init__(self,
model: Union[ModelType, str],
weights: Optional[str] = None,
classes: Optional[Union[str, List]] = None,
palette: Optional[Union[str, List]] = None,
dataset_name: Optional[str] = None,
device: Optional[str] = None,
scope: Optional[str] = 'mmseg') -> None:
# A global counter tracking the number of images processes, for
# naming of the output images
self.num_visualized_imgs = 0
self.num_pred_imgs = 0
init_default_scope(scope if scope else 'mmseg')
super().__init__(
model=model, weights=weights, device=device, scope=scope)
if device == 'cpu' or not torch.cuda.is_available():
self.model = revert_sync_batchnorm(self.model)
assert isinstance(self.visualizer, SegLocalVisualizer)
self.visualizer.set_dataset_meta(classes, palette, dataset_name)
def _load_weights_to_model(self, model: nn.Module,
checkpoint: Optional[dict],
cfg: Optional[ConfigType]) -> None:
"""Loading model weights and meta information from cfg and checkpoint.
Subclasses could override this method to load extra meta information
from ``checkpoint`` and ``cfg`` to model.
Args:
model (nn.Module): Model to load weights and meta information.
checkpoint (dict, optional): The loaded checkpoint.
cfg (Config or ConfigDict, optional): The loaded config.
"""
if checkpoint is not None:
_load_checkpoint_to_model(model, checkpoint)
checkpoint_meta = checkpoint.get('meta', {})
# save the dataset_meta in the model for convenience
if 'dataset_meta' in checkpoint_meta:
# mmsegmentation 1.x
model.dataset_meta = {
'classes': checkpoint_meta['dataset_meta'].get('classes'),
'palette': checkpoint_meta['dataset_meta'].get('palette')
}
elif 'CLASSES' in checkpoint_meta:
# mmsegmentation 0.x
classes = checkpoint_meta['CLASSES']
palette = checkpoint_meta.get('PALETTE', None)
model.dataset_meta = {'classes': classes, 'palette': palette}
else:
warnings.warn(
'dataset_meta or class names are not saved in the '
'checkpoint\'s meta data, use classes of Cityscapes by '
'default.')
model.dataset_meta = {
'classes': get_classes('cityscapes'),
'palette': get_palette('cityscapes')
}
else:
warnings.warn('Checkpoint is not loaded, and the inference '
'result is calculated by the randomly initialized '
'model!')
warnings.warn(
'weights is None, use cityscapes classes by default.')
model.dataset_meta = {
'classes': get_classes('cityscapes'),
'palette': get_palette('cityscapes')
}
def __call__(self,
inputs: InputsType,
return_datasamples: bool = False,
batch_size: int = 1,
return_vis: bool = False,
show: bool = False,
wait_time: int = 0,
out_dir: str = '',
img_out_dir: str = 'vis',
pred_out_dir: str = 'pred',
**kwargs) -> dict:
"""Call the inferencer.
Args:
inputs (Union[list, str, np.ndarray]): Inputs for the inferencer.
return_datasamples (bool): Whether to return results as
:obj:`SegDataSample`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
show (bool): Whether to display the rendering color segmentation
mask in a popup window. Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
out_dir (str): Output directory of inference results. Defaults
to ''.
img_out_dir (str): Subdirectory of `out_dir`, used to save
rendering color segmentation mask, so `out_dir` must be defined
if you would like to save predicted mask. Defaults to 'vis'.
pred_out_dir (str): Subdirectory of `out_dir`, used to save
predicted mask file, so `out_dir` must be defined if you would
like to save predicted mask. Defaults to 'pred'.
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
Each key in kwargs should be in the corresponding set of
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
and ``postprocess_kwargs``.
Returns:
dict: Inference and visualization results.
"""
if out_dir != '':
pred_out_dir = osp.join(out_dir, pred_out_dir)
img_out_dir = osp.join(out_dir, img_out_dir)
else:
pred_out_dir = ''
img_out_dir = ''
return super().__call__(
inputs=inputs,
return_datasamples=return_datasamples,
batch_size=batch_size,
show=show,
wait_time=wait_time,
img_out_dir=img_out_dir,
pred_out_dir=pred_out_dir,
return_vis=return_vis,
**kwargs)
def visualize(self,
inputs: list,
preds: List[dict],
return_vis: bool = False,
show: bool = False,
wait_time: int = 0,
img_out_dir: str = '',
opacity: float = 0.8,
with_labels: Optional[bool] = True) -> List[np.ndarray]:
"""Visualize predictions.
Args:
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
preds (Any): Predictions of the model.
show (bool): Whether to display the image in a popup window.
Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
img_out_dir (str): Output directory of rendering prediction i.e.
color segmentation mask. Defaults: ''
opacity (int, float): The transparency of segmentation mask.
Defaults to 0.8.
Returns:
List[np.ndarray]: Visualization results.
"""
if not show and img_out_dir == '' and not return_vis:
return None
if self.visualizer is None:
raise ValueError('Visualization needs the "visualizer" term'
'defined in the config, but got None.')
self.visualizer.set_dataset_meta(**self.model.dataset_meta)
self.visualizer.alpha = opacity
results = []
for single_input, pred in zip(inputs, preds):
if isinstance(single_input, str):
img_bytes = mmengine.fileio.get(single_input)
img = mmcv.imfrombytes(img_bytes)
img = img[:, :, ::-1]
img_name = osp.basename(single_input)
elif isinstance(single_input, np.ndarray):
img = single_input.copy()
img_num = str(self.num_visualized_imgs).zfill(8) + '_vis'
img_name = f'{img_num}.jpg'
else:
raise ValueError('Unsupported input type:'
f'{type(single_input)}')
out_file = osp.join(img_out_dir, img_name) if img_out_dir != ''\
else None
self.visualizer.add_datasample(
img_name,
img,
pred,
show=show,
wait_time=wait_time,
draw_gt=False,
draw_pred=True,
out_file=out_file,
with_labels=with_labels)
if return_vis:
results.append(self.visualizer.get_image())
self.num_visualized_imgs += 1
return results if return_vis else None
def postprocess(self,
preds: PredType,
visualization: List[np.ndarray],
return_datasample: bool = False,
pred_out_dir: str = '') -> dict:
"""Process the predictions and visualization results from ``forward``
and ``visualize``.
This method should be responsible for the following tasks:
1. Pack the predictions and visualization results and return them.
2. Save the predictions, if it needed.
Args:
preds (List[Dict]): Predictions of the model.
visualization (List[np.ndarray]): The list of rendering color
segmentation mask.
return_datasample (bool): Whether to return results as datasamples.
Defaults to False.
pred_out_dir: File to save the inference results w/o
visualization. If left as empty, no file will be saved.
Defaults to ''.
Returns:
dict: Inference and visualization results with key ``predictions``
and ``visualization``
- ``visualization (Any)``: Returned by :meth:`visualize`
- ``predictions`` (List[np.ndarray], np.ndarray): Returned by
:meth:`forward` and processed in :meth:`postprocess`.
If ``return_datasample=False``, it will be the segmentation mask
with label indice.
"""
if return_datasample:
if len(preds) == 1:
return preds[0]
else:
return preds
results_dict = {}
results_dict['predictions'] = []
results_dict['visualization'] = []
for i, pred in enumerate(preds):
pred_data = dict()
if 'pred_sem_seg' in pred.keys():
pred_data['sem_seg'] = pred.pred_sem_seg.numpy().data[0]
elif 'pred_depth_map' in pred.keys():
pred_data['depth_map'] = pred.pred_depth_map.numpy().data[0]
if visualization is not None:
vis = visualization[i]
results_dict['visualization'].append(vis)
if pred_out_dir != '':
mmengine.mkdir_or_exist(pred_out_dir)
for key, data in pred_data.items():
post_fix = '_pred.png' if key == 'sem_seg' else '_pred.npy'
img_name = str(self.num_pred_imgs).zfill(8) + post_fix
img_path = osp.join(pred_out_dir, img_name)
if key == 'sem_seg':
output = Image.fromarray(data.astype(np.uint8))
output.save(img_path)
else:
np.save(img_path, data)
pred_data = next(iter(pred_data.values()))
results_dict['predictions'].append(pred_data)
self.num_pred_imgs += 1
if len(results_dict['predictions']) == 1:
results_dict['predictions'] = results_dict['predictions'][0]
if visualization is not None:
results_dict['visualization'] = \
results_dict['visualization'][0]
return results_dict
def _init_pipeline(self, cfg: ConfigType) -> Compose:
"""Initialize the test pipeline.
Return a pipeline to handle various input data, such as ``str``,
``np.ndarray``. It is an abstract method in BaseInferencer, and should
be implemented in subclasses.
The returned pipeline will be used to process a single data.
It will be used in :meth:`preprocess` like this:
.. code-block:: python
def preprocess(self, inputs, batch_size, **kwargs):
...
dataset = map(self.pipeline, dataset)
...
"""
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
# Loading annotations is also not applicable
for transform in ('LoadAnnotations', 'LoadDepthAnnotation'):
idx = self._get_transform_idx(pipeline_cfg, transform)
if idx != -1:
del pipeline_cfg[idx]
load_img_idx = self._get_transform_idx(pipeline_cfg,
'LoadImageFromFile')
if load_img_idx == -1:
raise ValueError(
'LoadImageFromFile is not found in the test pipeline')
pipeline_cfg[load_img_idx]['type'] = 'InferencerLoader'
return Compose(pipeline_cfg)
def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int:
"""Returns the index of the transform in a pipeline.
If the transform is not found, returns -1.
"""
for i, transform in enumerate(pipeline_cfg):
if transform['type'] == name:
return i
return -1

View File

@@ -0,0 +1,279 @@
# Copyright (c) OpenMMLab. All rights reserved.
import threading
from queue import Queue
from typing import List, Optional, Tuple
import numpy as np
import torch
from mmengine import Config
from mmengine.model import BaseModel
from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint
try:
from osgeo import gdal
except ImportError:
gdal = None
from mmseg.registry import MODELS
from .utils import _preprare_data
class RSImage:
"""Remote sensing image class.
Args:
img (str or gdal.Dataset): Image file path or gdal.Dataset.
"""
def __init__(self, image):
self.dataset = gdal.Open(image, gdal.GA_ReadOnly) if isinstance(
image, str) else image
assert isinstance(self.dataset, gdal.Dataset), \
f'{image} is not a image'
self.width = self.dataset.RasterXSize
self.height = self.dataset.RasterYSize
self.channel = self.dataset.RasterCount
self.trans = self.dataset.GetGeoTransform()
self.proj = self.dataset.GetProjection()
self.band_list = []
self.band_list.extend(
self.dataset.GetRasterBand(c + 1) for c in range(self.channel))
self.grids = []
def read(self, grid: Optional[List] = None) -> np.ndarray:
"""Read image data. If grid is None, read the whole image.
Args:
grid (Optional[List], optional): Grid to read. Defaults to None.
Returns:
np.ndarray: Image data.
"""
if grid is None:
return np.einsum('ijk->jki', self.dataset.ReadAsArray())
assert len(
grid) >= 4, 'grid must be a list containing at least 4 elements'
data = self.dataset.ReadAsArray(*grid[:4])
if data.ndim == 2:
data = data[np.newaxis, ...]
return np.einsum('ijk->jki', data)
def write(self, data: Optional[np.ndarray], grid: Optional[List] = None):
"""Write image data.
Args:
grid (Optional[List], optional): Grid to write. Defaults to None.
data (Optional[np.ndarray], optional): Data to write.
Defaults to None.
Raises:
ValueError: Either grid or data must be provided.
"""
if grid is not None:
assert len(grid) == 8, 'grid must be a list of 8 elements'
for band in self.band_list:
band.WriteArray(
data[grid[5]:grid[5] + grid[7], grid[4]:grid[4] + grid[6]],
grid[0] + grid[4], grid[1] + grid[5])
elif data is not None:
for i in range(self.channel):
self.band_list[i].WriteArray(data[..., i])
else:
raise ValueError('Either grid or data must be provided.')
def create_seg_map(self, output_path: Optional[str] = None):
if output_path is None:
output_path = 'output_label.tif'
driver = gdal.GetDriverByName('GTiff')
seg_map = driver.Create(output_path, self.width, self.height, 1,
gdal.GDT_Byte)
seg_map.SetGeoTransform(self.trans)
seg_map.SetProjection(self.proj)
seg_map_img = RSImage(seg_map)
seg_map_img.path = output_path
return seg_map_img
def create_grids(self,
window_size: Tuple[int, int],
stride: Tuple[int, int] = (0, 0)):
"""Create grids for image inference.
Args:
window_size (Tuple[int, int]): the size of the sliding window.
stride (Tuple[int, int], optional): the stride of the sliding
window. Defaults to (0, 0).
Raises:
AssertionError: window_size must be a tuple of 2 elements.
AssertionError: stride must be a tuple of 2 elements.
"""
assert len(
window_size) == 2, 'window_size must be a tuple of 2 elements'
assert len(stride) == 2, 'stride must be a tuple of 2 elements'
win_w, win_h = window_size
stride_x, stride_y = stride
stride_x = win_w if stride_x == 0 else stride_x
stride_y = win_h if stride_y == 0 else stride_y
x_half_overlap = (win_w - stride_x + 1) // 2
y_half_overlap = (win_h - stride_y + 1) // 2
for y in range(0, self.height, stride_y):
y_end = y + win_h >= self.height
y_offset = self.height - win_h if y_end else y
y_size = win_h
y_crop_off = 0 if y_offset == 0 else y_half_overlap
y_crop_size = y_size if y_end else win_h - y_crop_off
for x in range(0, self.width, stride_x):
x_end = x + win_w >= self.width
x_offset = self.width - win_w if x_end else x
x_size = win_w
x_crop_off = 0 if x_offset == 0 else x_half_overlap
x_crop_size = x_size if x_end else win_w - x_crop_off
self.grids.append([
x_offset, y_offset, x_size, y_size, x_crop_off, y_crop_off,
x_crop_size, y_crop_size
])
class RSInferencer:
"""Remote sensing inference class.
Args:
model (BaseModel): The loaded model.
batch_size (int, optional): Batch size. Defaults to 1.
thread (int, optional): Number of threads. Defaults to 1.
"""
def __init__(self, model: BaseModel, batch_size: int = 1, thread: int = 1):
self.model = model
self.batch_size = batch_size
self.END_FLAG = object()
self.read_buffer = Queue(self.batch_size)
self.write_buffer = Queue(self.batch_size)
self.thread = thread
@classmethod
def from_config_path(cls,
config_path: str,
checkpoint_path: str,
batch_size: int = 1,
thread: int = 1,
device: Optional[str] = 'cpu'):
"""Initialize a segmentor from config file.
Args:
config_path (str): Config file path.
checkpoint_path (str): Checkpoint path.
batch_size (int, optional): Batch size. Defaults to 1.
"""
init_default_scope('mmseg')
cfg = Config.fromfile(config_path)
model = MODELS.build(cfg.model)
model.cfg = cfg
load_checkpoint(model, checkpoint_path, map_location='cpu')
model.to(device)
model.eval()
return cls(model, batch_size, thread)
@classmethod
def from_model(cls,
model: BaseModel,
checkpoint_path: Optional[str] = None,
batch_size: int = 1,
thread: int = 1,
device: Optional[str] = 'cpu'):
"""Initialize a segmentor from model.
Args:
model (BaseModel): The loaded model.
checkpoint_path (Optional[str]): Checkpoint path.
batch_size (int, optional): Batch size. Defaults to 1.
"""
if checkpoint_path is not None:
load_checkpoint(model, checkpoint_path, map_location='cpu')
model.to(device)
return cls(model, batch_size, thread)
def read(self,
image: RSImage,
window_size: Tuple[int, int],
strides: Tuple[int, int] = (0, 0)):
"""Load image data to read buffer.
Args:
image (RSImage): The image to read.
window_size (Tuple[int, int]): The size of the sliding window.
strides (Tuple[int, int], optional): The stride of the sliding
window. Defaults to (0, 0).
"""
image.create_grids(window_size, strides)
for grid in image.grids:
self.read_buffer.put([grid, image.read(grid=grid)])
self.read_buffer.put(self.END_FLAG)
def inference(self):
"""Inference image data from read buffer and put the result to write
buffer."""
while True:
item = self.read_buffer.get()
if item == self.END_FLAG:
self.read_buffer.put(self.END_FLAG)
self.write_buffer.put(item)
break
data, _ = _preprare_data(item[1], self.model)
with torch.no_grad():
result = self.model.test_step(data)
item[1] = result[0].pred_sem_seg.cpu().data.numpy()[0]
self.write_buffer.put(item)
self.read_buffer.task_done()
def write(self, image: RSImage, output_path: Optional[str] = None):
"""Write image data from write buffer.
Args:
image (RSImage): The image to write.
output_path (Optional[str], optional): The path to save the
segmentation map. Defaults to None.
"""
seg_map = image.create_seg_map(output_path)
while True:
item = self.write_buffer.get()
if item == self.END_FLAG:
break
seg_map.write(data=item[1], grid=item[0])
self.write_buffer.task_done()
def run(self,
image: RSImage,
window_size: Tuple[int, int],
strides: Tuple[int, int] = (0, 0),
output_path: Optional[str] = None):
"""Run inference with multi-threading.
Args:
image (RSImage): The image to inference.
window_size (Tuple[int, int]): The size of the sliding window.
strides (Tuple[int, int], optional): The stride of the sliding
window. Defaults to (0, 0).
output_path (Optional[str], optional): The path to save the
segmentation map. Defaults to None.
"""
read_thread = threading.Thread(
target=self.read, args=(image, window_size, strides))
read_thread.start()
inference_threads = []
for _ in range(self.thread):
inference_thread = threading.Thread(target=self.inference)
inference_thread.start()
inference_threads.append(inference_thread)
write_thread = threading.Thread(
target=self.write, args=(image, output_path))
write_thread.start()
read_thread.join()
for inference_thread in inference_threads:
inference_thread.join()
write_thread.join()

View File

@@ -0,0 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import defaultdict
from typing import Sequence, Union
import numpy as np
from mmengine.dataset import Compose
from mmengine.model import BaseModel
ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
def _preprare_data(imgs: ImageType, model: BaseModel):
cfg = model.cfg
for t in cfg.test_pipeline:
if t.get('type') == 'LoadAnnotations':
cfg.test_pipeline.remove(t)
is_batch = True
if not isinstance(imgs, (list, tuple)):
imgs = [imgs]
is_batch = False
if isinstance(imgs[0], np.ndarray):
cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray'
# TODO: Consider using the singleton pattern to avoid building
# a pipeline for each inference
pipeline = Compose(cfg.test_pipeline)
data = defaultdict(list)
for img in imgs:
if isinstance(img, np.ndarray):
data_ = dict(img=img)
else:
data_ = dict(img_path=img)
data_ = pipeline(data_)
data['inputs'].append(data_['inputs'])
data['data_samples'].append(data_['data_samples'])
return data, is_batch