This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

View File

@@ -0,0 +1,130 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/open-
mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py."""
import argparse
import json
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
def plot_curve(log_dicts, args):
if args.backend is not None:
plt.switch_backend(args.backend)
sns.set_style(args.style)
# if legend is None, use {filename}_{key} as legend
legend = args.legend
if legend is None:
legend = []
for json_log in args.json_logs:
for metric in args.keys:
legend.append(f'{json_log}_{metric}')
assert len(legend) == (len(args.json_logs) * len(args.keys))
metrics = args.keys
num_metrics = len(metrics)
for i, log_dict in enumerate(log_dicts):
epochs = list(log_dict.keys())
for j, metric in enumerate(metrics):
print(f'plot curve of {args.json_logs[i]}, metric is {metric}')
plot_epochs = []
plot_iters = []
plot_values = []
# In some log files exist lines of validation,
# `mode` list is used to only collect iter number
# of training line.
for epoch in epochs:
epoch_logs = log_dict[epoch]
if metric not in epoch_logs.keys():
continue
if metric in ['mIoU', 'mAcc', 'aAcc']:
plot_epochs.append(epoch)
plot_values.append(epoch_logs[metric][0])
else:
for idx in range(len(epoch_logs[metric])):
plot_iters.append(epoch_logs['step'][idx])
plot_values.append(epoch_logs[metric][idx])
ax = plt.gca()
label = legend[i * num_metrics + j]
if metric in ['mIoU', 'mAcc', 'aAcc']:
ax.set_xticks(plot_epochs)
plt.xlabel('step')
plt.plot(plot_epochs, plot_values, label=label, marker='o')
else:
plt.xlabel('iter')
plt.plot(plot_iters, plot_values, label=label, linewidth=0.5)
plt.legend()
if args.title is not None:
plt.title(args.title)
if args.out is None:
plt.show()
else:
print(f'save curve to: {args.out}')
plt.savefig(args.out)
plt.cla()
def parse_args():
parser = argparse.ArgumentParser(description='Analyze Json Log')
parser.add_argument(
'json_logs',
type=str,
nargs='+',
help='path of train log in json format')
parser.add_argument(
'--keys',
type=str,
nargs='+',
default=['mIoU'],
help='the metric that you want to plot')
parser.add_argument('--title', type=str, help='title of figure')
parser.add_argument(
'--legend',
type=str,
nargs='+',
default=None,
help='legend of each plot')
parser.add_argument(
'--backend', type=str, default=None, help='backend of plt')
parser.add_argument(
'--style', type=str, default='dark', help='style of plt')
parser.add_argument('--out', type=str, default=None)
args = parser.parse_args()
return args
def load_json_logs(json_logs):
# load and convert json_logs to log_dict, key is step, value is a sub dict
# keys of sub dict is different metrics
# value of sub dict is a list of corresponding values of all iterations
log_dicts = [dict() for _ in json_logs]
prev_step = 0
for json_log, log_dict in zip(json_logs, log_dicts):
with open(json_log) as log_file:
for line in log_file:
log = json.loads(line.strip())
# the final step in json file is 0.
if 'step' in log and log['step'] != 0:
step = log['step']
prev_step = step
else:
step = prev_step
if step not in log_dict:
log_dict[step] = defaultdict(list)
for k, v in log.items():
log_dict[step][k].append(v)
return log_dicts
def main():
args = parse_args()
json_logs = args.json_logs
for json_log in json_logs:
assert json_log.endswith('.json')
log_dicts = load_json_logs(json_logs)
plot_curve(log_dicts, args)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,121 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import time
import numpy as np
import torch
from mmengine import Config
from mmengine.fileio import dump
from mmengine.model.utils import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.runner import Runner, load_checkpoint
from mmengine.utils import mkdir_or_exist
from mmseg.registry import MODELS
def parse_args():
parser = argparse.ArgumentParser(description='MMSeg benchmark a model')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--log-interval', type=int, default=50, help='interval of logging')
parser.add_argument(
'--work-dir',
help=('if specified, the results will be dumped '
'into the directory as json'))
parser.add_argument('--repeat-times', type=int, default=1)
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
init_default_scope(cfg.get('default_scope', 'mmseg'))
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
if args.work_dir is not None:
mkdir_or_exist(osp.abspath(args.work_dir))
json_file = osp.join(args.work_dir, f'fps_{timestamp}.json')
else:
# use config filename as default work_dir if cfg.work_dir is None
work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
mkdir_or_exist(osp.abspath(work_dir))
json_file = osp.join(work_dir, f'fps_{timestamp}.json')
repeat_times = args.repeat_times
# set cudnn_benchmark
torch.backends.cudnn.benchmark = False
cfg.model.pretrained = None
benchmark_dict = dict(config=args.config, unit='img / s')
overall_fps_list = []
cfg.test_dataloader.batch_size = 1
for time_index in range(repeat_times):
print(f'Run {time_index + 1}:')
# build the dataloader
data_loader = Runner.build_dataloader(cfg.test_dataloader)
# build the model and load checkpoint
cfg.model.train_cfg = None
model = MODELS.build(cfg.model)
if 'checkpoint' in args and osp.exists(args.checkpoint):
load_checkpoint(model, args.checkpoint, map_location='cpu')
if torch.cuda.is_available():
model = model.cuda()
model = revert_sync_batchnorm(model)
model.eval()
# the first several iterations may be very slow so skip them
num_warmup = 5
pure_inf_time = 0
total_iters = 200
# benchmark with 200 batches and take the average
for i, data in enumerate(data_loader):
data = model.data_preprocessor(data, True)
inputs = data['inputs']
data_samples = data['data_samples']
if torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.perf_counter()
with torch.no_grad():
model(inputs, data_samples, mode='predict')
if torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
if i >= num_warmup:
pure_inf_time += elapsed
if (i + 1) % args.log_interval == 0:
fps = (i + 1 - num_warmup) / pure_inf_time
print(f'Done image [{i + 1:<3}/ {total_iters}], '
f'fps: {fps:.2f} img / s')
if (i + 1) == total_iters:
fps = (i + 1 - num_warmup) / pure_inf_time
print(f'Overall fps: {fps:.2f} img / s\n')
benchmark_dict[f'overall_fps_{time_index + 1}'] = round(fps, 2)
overall_fps_list.append(fps)
break
benchmark_dict['average_fps'] = round(np.mean(overall_fps_list), 2)
benchmark_dict['fps_variance'] = round(np.var(overall_fps_list), 4)
print(f'Average fps of {repeat_times} evaluations: '
f'{benchmark_dict["average_fps"]}')
print(f'The variance of {repeat_times} evaluations: '
f'{benchmark_dict["fps_variance"]}')
dump(benchmark_dict, json_file, indent=4)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,77 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from mmengine.config import Config, DictAction
from mmengine.utils import ProgressBar
from mmseg.registry import DATASETS, VISUALIZERS
from mmseg.utils import register_all_modules
def parse_args():
parser = argparse.ArgumentParser(description='Browse a dataset')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--output-dir',
default=None,
type=str,
help='If there is no display interface, you can save it')
parser.add_argument('--not-show', default=False, action='store_true')
parser.add_argument(
'--show-interval',
type=float,
default=2,
help='the interval of show (s)')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# register all modules in mmdet into the registries
register_all_modules()
dataset = DATASETS.build(cfg.train_dataloader.dataset)
visualizer = VISUALIZERS.build(cfg.visualizer)
visualizer.dataset_meta = dataset.metainfo
progress_bar = ProgressBar(len(dataset))
for item in dataset:
img = item['inputs'].permute(1, 2, 0).numpy()
img = img[..., [2, 1, 0]] # bgr to rgb
data_sample = item['data_samples'].numpy()
img_path = osp.basename(item['data_samples'].img_path)
out_file = osp.join(
args.output_dir,
osp.basename(img_path)) if args.output_dir is not None else None
visualizer.add_datasample(
name=osp.basename(img_path),
image=img,
data_sample=data_sample,
draw_gt=True,
draw_pred=False,
wait_time=args.show_interval,
out_file=out_file,
show=not args.not_show)
progress_bar.update()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,197 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
from mmengine.config import Config, DictAction
from mmengine.registry import init_default_scope
from mmengine.utils import mkdir_or_exist, progressbar
from PIL import Image
from mmseg.registry import DATASETS
init_default_scope('mmseg')
def parse_args():
parser = argparse.ArgumentParser(
description='Generate confusion matrix from segmentation results')
parser.add_argument('config', help='test config file path')
parser.add_argument(
'prediction_path', help='prediction path where test folder result')
parser.add_argument(
'save_dir', help='directory where confusion matrix will be saved')
parser.add_argument(
'--show', action='store_true', help='show confusion matrix')
parser.add_argument(
'--color-theme',
default='winter',
help='theme of the matrix color map')
parser.add_argument(
'--title',
default='Normalized Confusion Matrix',
help='title of the matrix color map')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args
def calculate_confusion_matrix(dataset, results):
"""Calculate the confusion matrix.
Args:
dataset (Dataset): Test or val dataset.
results (list[ndarray]): A list of segmentation results in each image.
"""
n = len(dataset.METAINFO['classes'])
confusion_matrix = np.zeros(shape=[n, n])
assert len(dataset) == len(results)
ignore_index = dataset.ignore_index
reduce_zero_label = dataset.reduce_zero_label
prog_bar = progressbar.ProgressBar(len(results))
for idx, per_img_res in enumerate(results):
res_segm = per_img_res
gt_segm = dataset[idx]['data_samples'] \
.gt_sem_seg.data.squeeze().numpy().astype(np.uint8)
gt_segm, res_segm = gt_segm.flatten(), res_segm.flatten()
if reduce_zero_label:
gt_segm = gt_segm - 1
to_ignore = gt_segm == ignore_index
gt_segm, res_segm = gt_segm[~to_ignore], res_segm[~to_ignore]
inds = n * gt_segm + res_segm
mat = np.bincount(inds, minlength=n**2).reshape(n, n)
confusion_matrix += mat
prog_bar.update()
return confusion_matrix
def plot_confusion_matrix(confusion_matrix,
labels,
save_dir=None,
show=True,
title='Normalized Confusion Matrix',
color_theme='OrRd'):
"""Draw confusion matrix with matplotlib.
Args:
confusion_matrix (ndarray): The confusion matrix.
labels (list[str]): List of class names.
save_dir (str|optional): If set, save the confusion matrix plot to the
given path. Default: None.
show (bool): Whether to show the plot. Default: True.
title (str): Title of the plot. Default: `Normalized Confusion Matrix`.
color_theme (str): Theme of the matrix color map. Default: `winter`.
"""
# normalize the confusion matrix
per_label_sums = confusion_matrix.sum(axis=1)[:, np.newaxis]
confusion_matrix = \
confusion_matrix.astype(np.float32) / per_label_sums * 100
num_classes = len(labels)
fig, ax = plt.subplots(
figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=300)
cmap = plt.get_cmap(color_theme)
im = ax.imshow(confusion_matrix, cmap=cmap)
colorbar = plt.colorbar(mappable=im, ax=ax)
colorbar.ax.tick_params(labelsize=20) # 设置 colorbar 标签的字体大小
title_font = {'weight': 'bold', 'size': 20}
ax.set_title(title, fontdict=title_font)
label_font = {'size': 40}
plt.ylabel('Ground Truth Label', fontdict=label_font)
plt.xlabel('Prediction Label', fontdict=label_font)
# draw locator
xmajor_locator = MultipleLocator(1)
xminor_locator = MultipleLocator(0.5)
ax.xaxis.set_major_locator(xmajor_locator)
ax.xaxis.set_minor_locator(xminor_locator)
ymajor_locator = MultipleLocator(1)
yminor_locator = MultipleLocator(0.5)
ax.yaxis.set_major_locator(ymajor_locator)
ax.yaxis.set_minor_locator(yminor_locator)
# draw grid
ax.grid(True, which='minor', linestyle='-')
# draw label
ax.set_xticks(np.arange(num_classes))
ax.set_yticks(np.arange(num_classes))
ax.set_xticklabels(labels, fontsize=20)
ax.set_yticklabels(labels, fontsize=20)
ax.tick_params(
axis='x', bottom=False, top=True, labelbottom=False, labeltop=True)
plt.setp(
ax.get_xticklabels(), rotation=45, ha='left', rotation_mode='anchor')
# draw confusion matrix value
for i in range(num_classes):
for j in range(num_classes):
ax.text(
j,
i,
'{}%'.format(
round(confusion_matrix[i, j], 2
) if not np.isnan(confusion_matrix[i, j]) else -1),
ha='center',
va='center',
color='k',
size=20)
ax.set_ylim(len(confusion_matrix) - 0.5, -0.5) # matplotlib>3.1.1
fig.tight_layout()
if save_dir is not None:
mkdir_or_exist(save_dir)
plt.savefig(
os.path.join(save_dir, 'confusion_matrix.png'), format='png')
if show:
plt.show()
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
results = []
for img in sorted(os.listdir(args.prediction_path)):
img = os.path.join(args.prediction_path, img)
image = Image.open(img)
image = np.copy(image)
results.append(image)
assert isinstance(results, list)
if isinstance(results[0], np.ndarray):
pass
else:
raise TypeError('invalid type of prediction results')
dataset = DATASETS.build(cfg.test_dataloader.dataset)
confusion_matrix = calculate_confusion_matrix(dataset, results)
plot_confusion_matrix(
confusion_matrix,
dataset.METAINFO['classes'],
save_dir=args.save_dir,
show=args.show,
title=args.title,
color_theme=args.color_theme)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,124 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import tempfile
from pathlib import Path
import torch
from mmengine import Config, DictAction
from mmengine.logging import MMLogger
from mmengine.model import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmseg.models import BaseSegmentor
from mmseg.registry import MODELS
from mmseg.structures import SegDataSample
try:
from mmengine.analysis import get_model_complexity_info
from mmengine.analysis.print_helper import _format_size
except ImportError:
raise ImportError('Please upgrade mmengine >= 0.6.0 to use this script.')
def parse_args():
parser = argparse.ArgumentParser(
description='Get the FLOPs of a segmentor')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[2048, 1024],
help='input image size')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args
def inference(args: argparse.Namespace, logger: MMLogger) -> dict:
config_name = Path(args.config)
if not config_name.exists():
logger.error(f'Config file {config_name} does not exist')
cfg: Config = Config.fromfile(config_name)
cfg.work_dir = tempfile.TemporaryDirectory().name
cfg.log_level = 'WARN'
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
init_default_scope(cfg.get('scope', 'mmseg'))
if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
result = {}
model: BaseSegmentor = MODELS.build(cfg.model)
if hasattr(model, 'auxiliary_head'):
model.auxiliary_head = None
if torch.cuda.is_available():
model.cuda()
model = revert_sync_batchnorm(model)
result['ori_shape'] = input_shape[-2:]
result['pad_shape'] = input_shape[-2:]
data_batch = {
'inputs': [torch.rand(input_shape)],
'data_samples': [SegDataSample(metainfo=result)]
}
data = model.data_preprocessor(data_batch)
model.eval()
if cfg.model.decode_head.type in ['MaskFormerHead', 'Mask2FormerHead']:
# TODO: Support MaskFormer and Mask2Former
raise NotImplementedError('MaskFormer and Mask2Former are not '
'supported yet.')
outputs = get_model_complexity_info(
model,
input_shape=None,
inputs=data['inputs'],
show_table=False,
show_arch=False)
result['flops'] = _format_size(outputs['flops'])
result['params'] = _format_size(outputs['params'])
result['compute_type'] = 'direct: randomly generate a picture'
return result
def main():
args = parse_args()
logger = MMLogger.get_instance(name='MMLogger')
result = inference(args, logger)
split_line = '=' * 30
ori_shape = result['ori_shape']
pad_shape = result['pad_shape']
flops = result['flops']
params = result['params']
compute_type = result['compute_type']
if pad_shape != ori_shape:
print(f'{split_line}\nUse size divisor set input shape '
f'from {ori_shape} to {pad_shape}')
print(f'{split_line}\nCompute type: {compute_type}\n'
f'Input shape: {pad_shape}\nFlops: {flops}\n'
f'Params: {params}\n{split_line}')
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify '
'that the flops computation is correct.')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,127 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Use the pytorch-grad-cam tool to visualize Class Activation Maps (CAM).
requirement: pip install grad-cam
"""
from argparse import ArgumentParser
import numpy as np
import torch
import torch.nn.functional as F
from mmengine import Config
from mmengine.model import revert_sync_batchnorm
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
from mmseg.apis import inference_model, init_model, show_result_pyplot
from mmseg.utils import register_all_modules
class SemanticSegmentationTarget:
"""wrap the model.
requirement: pip install grad-cam
Args:
category (int): Visualization class.
mask (ndarray): Mask of class.
size (tuple): Image size.
"""
def __init__(self, category, mask, size):
self.category = category
self.mask = torch.from_numpy(mask)
self.size = size
if torch.cuda.is_available():
self.mask = self.mask.cuda()
def __call__(self, model_output):
model_output = torch.unsqueeze(model_output, dim=0)
model_output = F.interpolate(
model_output, size=self.size, mode='bilinear')
model_output = torch.squeeze(model_output, dim=0)
return (model_output[self.category, :, :] * self.mask).sum()
def main():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--out-file',
default='prediction.png',
help='Path to output prediction file')
parser.add_argument(
'--cam-file', default='vis_cam.png', help='Path to output cam file')
parser.add_argument(
'--target-layers',
default='backbone.layer4[2]',
help='Target layers to visualize CAM')
parser.add_argument(
'--category-index', default='7', help='Category to visualize CAM')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
# build the model from a config file and a checkpoint file
register_all_modules()
model = init_model(args.config, args.checkpoint, device=args.device)
if args.device == 'cpu':
model = revert_sync_batchnorm(model)
# test a single image
result = inference_model(model, args.img)
# show the results
show_result_pyplot(
model,
args.img,
result,
draw_gt=False,
show=False if args.out_file is not None else True,
out_file=args.out_file)
# result data conversion
prediction_data = result.pred_sem_seg.data
pre_np_data = prediction_data.cpu().numpy().squeeze(0)
target_layers = args.target_layers
target_layers = [eval(f'model.{target_layers}')]
category = int(args.category_index)
mask_float = np.float32(pre_np_data == category)
# data processing
image = np.array(Image.open(args.img).convert('RGB'))
height, width = image.shape[0], image.shape[1]
rgb_img = np.float32(image) / 255
config = Config.fromfile(args.config)
image_mean = config.data_preprocessor['mean']
image_std = config.data_preprocessor['std']
input_tensor = preprocess_image(
rgb_img,
mean=[x / 255 for x in image_mean],
std=[x / 255 for x in image_std])
# Grad CAM(Class Activation Maps)
# Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
targets = [
SemanticSegmentationTarget(category, mask_float, (height, width))
]
with GradCAM(
model=model,
target_layers=target_layers,
use_cuda=torch.cuda.is_available()) as cam:
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
# save cam file
Image.fromarray(cam_image).save(args.cam_file)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,89 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import tempfile
import zipfile
import mmcv
from mmengine.utils import mkdir_or_exist
CHASE_DB1_LEN = 28 * 3
TRAINING_LEN = 60
def parse_args():
parser = argparse.ArgumentParser(
description='Convert CHASE_DB1 dataset to mmsegmentation format')
parser.add_argument('dataset_path', help='path of CHASEDB1.zip')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument('-o', '--out_dir', help='output path')
args = parser.parse_args()
return args
def main():
args = parse_args()
dataset_path = args.dataset_path
if args.out_dir is None:
out_dir = osp.join('data', 'CHASE_DB1')
else:
out_dir = args.out_dir
print('Making directories...')
mkdir_or_exist(out_dir)
mkdir_or_exist(osp.join(out_dir, 'images'))
mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
mkdir_or_exist(osp.join(out_dir, 'annotations'))
mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
print('Extracting CHASEDB1.zip...')
zip_file = zipfile.ZipFile(dataset_path)
zip_file.extractall(tmp_dir)
print('Generating training dataset...')
assert len(os.listdir(tmp_dir)) == CHASE_DB1_LEN, \
f'len(os.listdir(tmp_dir)) != {CHASE_DB1_LEN}'
for img_name in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(tmp_dir, img_name))
if osp.splitext(img_name)[1] == '.jpg':
mmcv.imwrite(
img,
osp.join(out_dir, 'images', 'training',
osp.splitext(img_name)[0] + '.png'))
else:
# The annotation img should be divided by 128, because some of
# the annotation imgs are not standard. We should set a
# threshold to convert the nonstandard annotation imgs. The
# value divided by 128 is equivalent to '1 if value >= 128
# else 0'
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'training',
osp.splitext(img_name)[0] + '.png'))
for img_name in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
img = mmcv.imread(osp.join(tmp_dir, img_name))
if osp.splitext(img_name)[1] == '.jpg':
mmcv.imwrite(
img,
osp.join(out_dir, 'images', 'validation',
osp.splitext(img_name)[0] + '.png'))
else:
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'validation',
osp.splitext(img_name)[0] + '.png'))
print('Removing the temporary files...')
print('Done!')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,56 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from cityscapesscripts.preparation.json2labelImg import json2labelImg
from mmengine.utils import (mkdir_or_exist, scandir, track_parallel_progress,
track_progress)
def convert_json_to_label(json_file):
label_file = json_file.replace('_polygons.json', '_labelTrainIds.png')
json2labelImg(json_file, label_file, 'trainIds')
def parse_args():
parser = argparse.ArgumentParser(
description='Convert Cityscapes annotations to TrainIds')
parser.add_argument('cityscapes_path', help='cityscapes data path')
parser.add_argument('--gt-dir', default='gtFine', type=str)
parser.add_argument('-o', '--out-dir', help='output path')
parser.add_argument(
'--nproc', default=1, type=int, help='number of process')
args = parser.parse_args()
return args
def main():
args = parse_args()
cityscapes_path = args.cityscapes_path
out_dir = args.out_dir if args.out_dir else cityscapes_path
mkdir_or_exist(out_dir)
gt_dir = osp.join(cityscapes_path, args.gt_dir)
poly_files = []
for poly in scandir(gt_dir, '_polygons.json', recursive=True):
poly_file = osp.join(gt_dir, poly)
poly_files.append(poly_file)
if args.nproc > 1:
track_parallel_progress(convert_json_to_label, poly_files, args.nproc)
else:
track_progress(convert_json_to_label, poly_files)
split_names = ['train', 'val', 'test']
for split in split_names:
filenames = []
for poly in scandir(
osp.join(gt_dir, split), '_polygons.json', recursive=True):
filenames.append(poly.replace('_gtFine_polygons.json', ''))
with open(osp.join(out_dir, f'{split}.txt'), 'w') as f:
f.writelines(f + '\n' for f in filenames)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,308 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import shutil
from functools import partial
import numpy as np
from mmengine.utils import (mkdir_or_exist, track_parallel_progress,
track_progress)
from PIL import Image
from scipy.io import loadmat
COCO_LEN = 10000
clsID_to_trID = {
0: 0,
1: 1,
2: 2,
3: 3,
4: 4,
5: 5,
6: 6,
7: 7,
8: 8,
9: 9,
10: 10,
11: 11,
13: 12,
14: 13,
15: 14,
16: 15,
17: 16,
18: 17,
19: 18,
20: 19,
21: 20,
22: 21,
23: 22,
24: 23,
25: 24,
27: 25,
28: 26,
31: 27,
32: 28,
33: 29,
34: 30,
35: 31,
36: 32,
37: 33,
38: 34,
39: 35,
40: 36,
41: 37,
42: 38,
43: 39,
44: 40,
46: 41,
47: 42,
48: 43,
49: 44,
50: 45,
51: 46,
52: 47,
53: 48,
54: 49,
55: 50,
56: 51,
57: 52,
58: 53,
59: 54,
60: 55,
61: 56,
62: 57,
63: 58,
64: 59,
65: 60,
67: 61,
70: 62,
72: 63,
73: 64,
74: 65,
75: 66,
76: 67,
77: 68,
78: 69,
79: 70,
80: 71,
81: 72,
82: 73,
84: 74,
85: 75,
86: 76,
87: 77,
88: 78,
89: 79,
90: 80,
92: 81,
93: 82,
94: 83,
95: 84,
96: 85,
97: 86,
98: 87,
99: 88,
100: 89,
101: 90,
102: 91,
103: 92,
104: 93,
105: 94,
106: 95,
107: 96,
108: 97,
109: 98,
110: 99,
111: 100,
112: 101,
113: 102,
114: 103,
115: 104,
116: 105,
117: 106,
118: 107,
119: 108,
120: 109,
121: 110,
122: 111,
123: 112,
124: 113,
125: 114,
126: 115,
127: 116,
128: 117,
129: 118,
130: 119,
131: 120,
132: 121,
133: 122,
134: 123,
135: 124,
136: 125,
137: 126,
138: 127,
139: 128,
140: 129,
141: 130,
142: 131,
143: 132,
144: 133,
145: 134,
146: 135,
147: 136,
148: 137,
149: 138,
150: 139,
151: 140,
152: 141,
153: 142,
154: 143,
155: 144,
156: 145,
157: 146,
158: 147,
159: 148,
160: 149,
161: 150,
162: 151,
163: 152,
164: 153,
165: 154,
166: 155,
167: 156,
168: 157,
169: 158,
170: 159,
171: 160,
172: 161,
173: 162,
174: 163,
175: 164,
176: 165,
177: 166,
178: 167,
179: 168,
180: 169,
181: 170,
182: 171
}
def convert_to_trainID(tuple_path, in_img_dir, in_ann_dir, out_img_dir,
out_mask_dir, is_train):
imgpath, maskpath = tuple_path
shutil.copyfile(
osp.join(in_img_dir, imgpath),
osp.join(out_img_dir, 'train2014', imgpath) if is_train else osp.join(
out_img_dir, 'test2014', imgpath))
annotate = loadmat(osp.join(in_ann_dir, maskpath))
mask = annotate['S'].astype(np.uint8)
mask_copy = mask.copy()
for clsID, trID in clsID_to_trID.items():
mask_copy[mask == clsID] = trID
seg_filename = osp.join(out_mask_dir, 'train2014',
maskpath.split('.')[0] +
'_labelTrainIds.png') if is_train else osp.join(
out_mask_dir, 'test2014',
maskpath.split('.')[0] + '_labelTrainIds.png')
Image.fromarray(mask_copy).save(seg_filename, 'PNG')
def generate_coco_list(folder):
train_list = osp.join(folder, 'imageLists', 'train.txt')
test_list = osp.join(folder, 'imageLists', 'test.txt')
train_paths = []
test_paths = []
with open(train_list) as f:
for filename in f:
basename = filename.strip()
imgpath = basename + '.jpg'
maskpath = basename + '.mat'
train_paths.append((imgpath, maskpath))
with open(test_list) as f:
for filename in f:
basename = filename.strip()
imgpath = basename + '.jpg'
maskpath = basename + '.mat'
test_paths.append((imgpath, maskpath))
return train_paths, test_paths
def parse_args():
parser = argparse.ArgumentParser(
description=\
'Convert COCO Stuff 10k annotations to mmsegmentation format') # noqa
parser.add_argument('coco_path', help='coco stuff path')
parser.add_argument('-o', '--out_dir', help='output path')
parser.add_argument(
'--nproc', default=16, type=int, help='number of process')
args = parser.parse_args()
return args
def main():
args = parse_args()
coco_path = args.coco_path
nproc = args.nproc
out_dir = args.out_dir or coco_path
out_img_dir = osp.join(out_dir, 'images')
out_mask_dir = osp.join(out_dir, 'annotations')
mkdir_or_exist(osp.join(out_img_dir, 'train2014'))
mkdir_or_exist(osp.join(out_img_dir, 'test2014'))
mkdir_or_exist(osp.join(out_mask_dir, 'train2014'))
mkdir_or_exist(osp.join(out_mask_dir, 'test2014'))
train_list, test_list = generate_coco_list(coco_path)
assert (len(train_list) +
len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(
len(train_list), len(test_list))
if args.nproc > 1:
track_parallel_progress(
partial(
convert_to_trainID,
in_img_dir=osp.join(coco_path, 'images'),
in_ann_dir=osp.join(coco_path, 'annotations'),
out_img_dir=out_img_dir,
out_mask_dir=out_mask_dir,
is_train=True),
train_list,
nproc=nproc)
track_parallel_progress(
partial(
convert_to_trainID,
in_img_dir=osp.join(coco_path, 'images'),
in_ann_dir=osp.join(coco_path, 'annotations'),
out_img_dir=out_img_dir,
out_mask_dir=out_mask_dir,
is_train=False),
test_list,
nproc=nproc)
else:
track_progress(
partial(
convert_to_trainID,
in_img_dir=osp.join(coco_path, 'images'),
in_ann_dir=osp.join(coco_path, 'annotations'),
out_img_dir=out_img_dir,
out_mask_dir=out_mask_dir,
is_train=True), train_list)
track_progress(
partial(
convert_to_trainID,
in_img_dir=osp.join(coco_path, 'images'),
in_ann_dir=osp.join(coco_path, 'annotations'),
out_img_dir=out_img_dir,
out_mask_dir=out_mask_dir,
is_train=False), test_list)
print('Done!')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,265 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import shutil
from functools import partial
from glob import glob
import numpy as np
from mmengine.utils import (mkdir_or_exist, track_parallel_progress,
track_progress)
from PIL import Image
COCO_LEN = 123287
clsID_to_trID = {
0: 0,
1: 1,
2: 2,
3: 3,
4: 4,
5: 5,
6: 6,
7: 7,
8: 8,
9: 9,
10: 10,
12: 11,
13: 12,
14: 13,
15: 14,
16: 15,
17: 16,
18: 17,
19: 18,
20: 19,
21: 20,
22: 21,
23: 22,
24: 23,
26: 24,
27: 25,
30: 26,
31: 27,
32: 28,
33: 29,
34: 30,
35: 31,
36: 32,
37: 33,
38: 34,
39: 35,
40: 36,
41: 37,
42: 38,
43: 39,
45: 40,
46: 41,
47: 42,
48: 43,
49: 44,
50: 45,
51: 46,
52: 47,
53: 48,
54: 49,
55: 50,
56: 51,
57: 52,
58: 53,
59: 54,
60: 55,
61: 56,
62: 57,
63: 58,
64: 59,
66: 60,
69: 61,
71: 62,
72: 63,
73: 64,
74: 65,
75: 66,
76: 67,
77: 68,
78: 69,
79: 70,
80: 71,
81: 72,
83: 73,
84: 74,
85: 75,
86: 76,
87: 77,
88: 78,
89: 79,
91: 80,
92: 81,
93: 82,
94: 83,
95: 84,
96: 85,
97: 86,
98: 87,
99: 88,
100: 89,
101: 90,
102: 91,
103: 92,
104: 93,
105: 94,
106: 95,
107: 96,
108: 97,
109: 98,
110: 99,
111: 100,
112: 101,
113: 102,
114: 103,
115: 104,
116: 105,
117: 106,
118: 107,
119: 108,
120: 109,
121: 110,
122: 111,
123: 112,
124: 113,
125: 114,
126: 115,
127: 116,
128: 117,
129: 118,
130: 119,
131: 120,
132: 121,
133: 122,
134: 123,
135: 124,
136: 125,
137: 126,
138: 127,
139: 128,
140: 129,
141: 130,
142: 131,
143: 132,
144: 133,
145: 134,
146: 135,
147: 136,
148: 137,
149: 138,
150: 139,
151: 140,
152: 141,
153: 142,
154: 143,
155: 144,
156: 145,
157: 146,
158: 147,
159: 148,
160: 149,
161: 150,
162: 151,
163: 152,
164: 153,
165: 154,
166: 155,
167: 156,
168: 157,
169: 158,
170: 159,
171: 160,
172: 161,
173: 162,
174: 163,
175: 164,
176: 165,
177: 166,
178: 167,
179: 168,
180: 169,
181: 170,
255: 255
}
def convert_to_trainID(maskpath, out_mask_dir, is_train):
mask = np.array(Image.open(maskpath))
mask_copy = mask.copy()
for clsID, trID in clsID_to_trID.items():
mask_copy[mask == clsID] = trID
seg_filename = osp.join(
out_mask_dir, 'train2017',
osp.basename(maskpath).split('.')[0] +
'_labelTrainIds.png') if is_train else osp.join(
out_mask_dir, 'val2017',
osp.basename(maskpath).split('.')[0] + '_labelTrainIds.png')
Image.fromarray(mask_copy).save(seg_filename, 'PNG')
def parse_args():
parser = argparse.ArgumentParser(
description=\
'Convert COCO Stuff 164k annotations to mmsegmentation format') # noqa
parser.add_argument('coco_path', help='coco stuff path')
parser.add_argument('-o', '--out_dir', help='output path')
parser.add_argument(
'--nproc', default=16, type=int, help='number of process')
args = parser.parse_args()
return args
def main():
args = parse_args()
coco_path = args.coco_path
nproc = args.nproc
out_dir = args.out_dir or coco_path
out_img_dir = osp.join(out_dir, 'images')
out_mask_dir = osp.join(out_dir, 'annotations')
mkdir_or_exist(osp.join(out_mask_dir, 'train2017'))
mkdir_or_exist(osp.join(out_mask_dir, 'val2017'))
if out_dir != coco_path:
shutil.copytree(osp.join(coco_path, 'images'), out_img_dir)
train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png'))
train_list = [file for file in train_list if '_labelTrainIds' not in file]
test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png'))
test_list = [file for file in test_list if '_labelTrainIds' not in file]
assert (len(train_list) +
len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(
len(train_list), len(test_list))
if args.nproc > 1:
track_parallel_progress(
partial(
convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
train_list,
nproc=nproc)
track_parallel_progress(
partial(
convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
test_list,
nproc=nproc)
else:
track_progress(
partial(
convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
train_list)
track_progress(
partial(
convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
test_list)
print('Done!')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,114 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import tempfile
import zipfile
import cv2
import mmcv
from mmengine.utils import mkdir_or_exist
def parse_args():
parser = argparse.ArgumentParser(
description='Convert DRIVE dataset to mmsegmentation format')
parser.add_argument(
'training_path', help='the training part of DRIVE dataset')
parser.add_argument(
'testing_path', help='the testing part of DRIVE dataset')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument('-o', '--out_dir', help='output path')
args = parser.parse_args()
return args
def main():
args = parse_args()
training_path = args.training_path
testing_path = args.testing_path
if args.out_dir is None:
out_dir = osp.join('data', 'DRIVE')
else:
out_dir = args.out_dir
print('Making directories...')
mkdir_or_exist(out_dir)
mkdir_or_exist(osp.join(out_dir, 'images'))
mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
mkdir_or_exist(osp.join(out_dir, 'annotations'))
mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
print('Extracting training.zip...')
zip_file = zipfile.ZipFile(training_path)
zip_file.extractall(tmp_dir)
print('Generating training dataset...')
now_dir = osp.join(tmp_dir, 'training', 'images')
for img_name in os.listdir(now_dir):
img = mmcv.imread(osp.join(now_dir, img_name))
mmcv.imwrite(
img,
osp.join(
out_dir, 'images', 'training',
osp.splitext(img_name)[0].replace('_training', '') +
'.png'))
now_dir = osp.join(tmp_dir, 'training', '1st_manual')
for img_name in os.listdir(now_dir):
cap = cv2.VideoCapture(osp.join(now_dir, img_name))
ret, img = cap.read()
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'training',
osp.splitext(img_name)[0] + '.png'))
print('Extracting test.zip...')
zip_file = zipfile.ZipFile(testing_path)
zip_file.extractall(tmp_dir)
print('Generating validation dataset...')
now_dir = osp.join(tmp_dir, 'test', 'images')
for img_name in os.listdir(now_dir):
img = mmcv.imread(osp.join(now_dir, img_name))
mmcv.imwrite(
img,
osp.join(
out_dir, 'images', 'validation',
osp.splitext(img_name)[0].replace('_test', '') + '.png'))
now_dir = osp.join(tmp_dir, 'test', '1st_manual')
if osp.exists(now_dir):
for img_name in os.listdir(now_dir):
cap = cv2.VideoCapture(osp.join(now_dir, img_name))
ret, img = cap.read()
# The annotation img should be divided by 128, because some of
# the annotation imgs are not standard. We should set a
# threshold to convert the nonstandard annotation imgs. The
# value divided by 128 is equivalent to '1 if value >= 128
# else 0'
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'validation',
osp.splitext(img_name)[0] + '.png'))
now_dir = osp.join(tmp_dir, 'test', '2nd_manual')
if osp.exists(now_dir):
for img_name in os.listdir(now_dir):
cap = cv2.VideoCapture(osp.join(now_dir, img_name))
ret, img = cap.read()
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'validation',
osp.splitext(img_name)[0] + '.png'))
print('Removing the temporary files...')
print('Done!')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,112 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import tempfile
import zipfile
import mmcv
from mmengine.utils import mkdir_or_exist
HRF_LEN = 15
TRAINING_LEN = 5
def parse_args():
parser = argparse.ArgumentParser(
description='Convert HRF dataset to mmsegmentation format')
parser.add_argument('healthy_path', help='the path of healthy.zip')
parser.add_argument(
'healthy_manualsegm_path', help='the path of healthy_manualsegm.zip')
parser.add_argument('glaucoma_path', help='the path of glaucoma.zip')
parser.add_argument(
'glaucoma_manualsegm_path', help='the path of glaucoma_manualsegm.zip')
parser.add_argument(
'diabetic_retinopathy_path',
help='the path of diabetic_retinopathy.zip')
parser.add_argument(
'diabetic_retinopathy_manualsegm_path',
help='the path of diabetic_retinopathy_manualsegm.zip')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument('-o', '--out_dir', help='output path')
args = parser.parse_args()
return args
def main():
args = parse_args()
images_path = [
args.healthy_path, args.glaucoma_path, args.diabetic_retinopathy_path
]
annotations_path = [
args.healthy_manualsegm_path, args.glaucoma_manualsegm_path,
args.diabetic_retinopathy_manualsegm_path
]
if args.out_dir is None:
out_dir = osp.join('data', 'HRF')
else:
out_dir = args.out_dir
print('Making directories...')
mkdir_or_exist(out_dir)
mkdir_or_exist(osp.join(out_dir, 'images'))
mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
mkdir_or_exist(osp.join(out_dir, 'annotations'))
mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
print('Generating images...')
for now_path in images_path:
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
zip_file = zipfile.ZipFile(now_path)
zip_file.extractall(tmp_dir)
assert len(os.listdir(tmp_dir)) == HRF_LEN, \
f'len(os.listdir(tmp_dir)) != {HRF_LEN}'
for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(tmp_dir, filename))
mmcv.imwrite(
img,
osp.join(out_dir, 'images', 'training',
osp.splitext(filename)[0] + '.png'))
for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
img = mmcv.imread(osp.join(tmp_dir, filename))
mmcv.imwrite(
img,
osp.join(out_dir, 'images', 'validation',
osp.splitext(filename)[0] + '.png'))
print('Generating annotations...')
for now_path in annotations_path:
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
zip_file = zipfile.ZipFile(now_path)
zip_file.extractall(tmp_dir)
assert len(os.listdir(tmp_dir)) == HRF_LEN, \
f'len(os.listdir(tmp_dir)) != {HRF_LEN}'
for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(tmp_dir, filename))
# The annotation img should be divided by 128, because some of
# the annotation imgs are not standard. We should set a
# threshold to convert the nonstandard annotation imgs. The
# value divided by 128 is equivalent to '1 if value >= 128
# else 0'
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'training',
osp.splitext(filename)[0] + '.png'))
for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
img = mmcv.imread(osp.join(tmp_dir, filename))
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'validation',
osp.splitext(filename)[0] + '.png'))
print('Done!')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,246 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import glob
import os
import os.path as osp
import shutil
import tempfile
import zipfile
import mmcv
import numpy as np
from mmengine.utils import ProgressBar, mkdir_or_exist
from PIL import Image
iSAID_palette = \
{
0: (0, 0, 0),
1: (0, 0, 63),
2: (0, 63, 63),
3: (0, 63, 0),
4: (0, 63, 127),
5: (0, 63, 191),
6: (0, 63, 255),
7: (0, 127, 63),
8: (0, 127, 127),
9: (0, 0, 127),
10: (0, 0, 191),
11: (0, 0, 255),
12: (0, 191, 127),
13: (0, 127, 191),
14: (0, 127, 255),
15: (0, 100, 155)
}
iSAID_invert_palette = {v: k for k, v in iSAID_palette.items()}
def iSAID_convert_from_color(arr_3d, palette=iSAID_invert_palette):
"""RGB-color encoding to grayscale labels."""
arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)
for c, i in palette.items():
m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
arr_2d[m] = i
return arr_2d
def slide_crop_image(src_path, out_dir, mode, patch_H, patch_W, overlap):
img = np.asarray(Image.open(src_path).convert('RGB'))
img_H, img_W, _ = img.shape
if img_H < patch_H and img_W > patch_W:
img = mmcv.impad(img, shape=(patch_H, img_W), pad_val=0)
img_H, img_W, _ = img.shape
elif img_H > patch_H and img_W < patch_W:
img = mmcv.impad(img, shape=(img_H, patch_W), pad_val=0)
img_H, img_W, _ = img.shape
elif img_H < patch_H and img_W < patch_W:
img = mmcv.impad(img, shape=(patch_H, patch_W), pad_val=0)
img_H, img_W, _ = img.shape
for x in range(0, img_W, patch_W - overlap):
for y in range(0, img_H, patch_H - overlap):
x_str = x
x_end = x + patch_W
if x_end > img_W:
diff_x = x_end - img_W
x_str -= diff_x
x_end = img_W
y_str = y
y_end = y + patch_H
if y_end > img_H:
diff_y = y_end - img_H
y_str -= diff_y
y_end = img_H
img_patch = img[y_str:y_end, x_str:x_end, :]
img_patch = Image.fromarray(img_patch.astype(np.uint8))
image = osp.basename(src_path).split('.')[0] + '_' + str(
y_str) + '_' + str(y_end) + '_' + str(x_str) + '_' + str(
x_end) + '.png'
# print(image)
save_path_image = osp.join(out_dir, 'img_dir', mode, str(image))
img_patch.save(save_path_image, format='BMP')
def slide_crop_label(src_path, out_dir, mode, patch_H, patch_W, overlap):
label = mmcv.imread(src_path, channel_order='rgb')
label = iSAID_convert_from_color(label)
img_H, img_W = label.shape
if img_H < patch_H and img_W > patch_W:
label = mmcv.impad(label, shape=(patch_H, img_W), pad_val=255)
img_H = patch_H
elif img_H > patch_H and img_W < patch_W:
label = mmcv.impad(label, shape=(img_H, patch_W), pad_val=255)
img_W = patch_W
elif img_H < patch_H and img_W < patch_W:
label = mmcv.impad(label, shape=(patch_H, patch_W), pad_val=255)
img_H = patch_H
img_W = patch_W
for x in range(0, img_W, patch_W - overlap):
for y in range(0, img_H, patch_H - overlap):
x_str = x
x_end = x + patch_W
if x_end > img_W:
diff_x = x_end - img_W
x_str -= diff_x
x_end = img_W
y_str = y
y_end = y + patch_H
if y_end > img_H:
diff_y = y_end - img_H
y_str -= diff_y
y_end = img_H
lab_patch = label[y_str:y_end, x_str:x_end]
lab_patch = Image.fromarray(lab_patch.astype(np.uint8), mode='P')
image = osp.basename(src_path).split('.')[0].split(
'_')[0] + '_' + str(y_str) + '_' + str(y_end) + '_' + str(
x_str) + '_' + str(x_end) + '_instance_color_RGB' + '.png'
lab_patch.save(osp.join(out_dir, 'ann_dir', mode, str(image)))
def parse_args():
parser = argparse.ArgumentParser(
description='Convert iSAID dataset to mmsegmentation format')
parser.add_argument('dataset_path', help='iSAID folder path')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument('-o', '--out_dir', help='output path')
parser.add_argument(
'--patch_width',
default=896,
type=int,
help='Width of the cropped image patch')
parser.add_argument(
'--patch_height',
default=896,
type=int,
help='Height of the cropped image patch')
parser.add_argument(
'--overlap_area', default=384, type=int, help='Overlap area')
args = parser.parse_args()
return args
def main():
args = parse_args()
dataset_path = args.dataset_path
# image patch width and height
patch_H, patch_W = args.patch_width, args.patch_height
overlap = args.overlap_area # overlap area
if args.out_dir is None:
out_dir = osp.join('data', 'iSAID')
else:
out_dir = args.out_dir
print('Making directories...')
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test'))
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'test'))
assert os.path.exists(os.path.join(dataset_path, 'train')), \
f'train is not in {dataset_path}'
assert os.path.exists(os.path.join(dataset_path, 'val')), \
f'val is not in {dataset_path}'
assert os.path.exists(os.path.join(dataset_path, 'test')), \
f'test is not in {dataset_path}'
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
for dataset_mode in ['train', 'val', 'test']:
# for dataset_mode in [ 'test']:
print(f'Extracting {dataset_mode}ing.zip...')
img_zipp_list = glob.glob(
os.path.join(dataset_path, dataset_mode, 'images', '*.zip'))
print('Find the data', img_zipp_list)
for img_zipp in img_zipp_list:
zip_file = zipfile.ZipFile(img_zipp)
zip_file.extractall(os.path.join(tmp_dir, dataset_mode, 'img'))
src_path_list = glob.glob(
os.path.join(tmp_dir, dataset_mode, 'img', 'images', '*.png'))
src_prog_bar = ProgressBar(len(src_path_list))
for i, img_path in enumerate(src_path_list):
if dataset_mode != 'test':
slide_crop_image(img_path, out_dir, dataset_mode, patch_H,
patch_W, overlap)
else:
shutil.move(img_path,
os.path.join(out_dir, 'img_dir', dataset_mode))
src_prog_bar.update()
if dataset_mode != 'test':
label_zipp_list = glob.glob(
os.path.join(dataset_path, dataset_mode, 'Semantic_masks',
'*.zip'))
for label_zipp in label_zipp_list:
zip_file = zipfile.ZipFile(label_zipp)
zip_file.extractall(
os.path.join(tmp_dir, dataset_mode, 'lab'))
lab_path_list = glob.glob(
os.path.join(tmp_dir, dataset_mode, 'lab', 'images',
'*.png'))
lab_prog_bar = ProgressBar(len(lab_path_list))
for i, lab_path in enumerate(lab_path_list):
slide_crop_label(lab_path, out_dir, dataset_mode, patch_H,
patch_W, overlap)
lab_prog_bar.update()
print('Removing the temporary files...')
print('Done!')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,99 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import glob
import math
import os
import os.path as osp
import mmcv
import numpy as np
from mmengine.utils import ProgressBar
def parse_args():
parser = argparse.ArgumentParser(
description='Convert levir-cd dataset to mmsegmentation format')
parser.add_argument('--dataset_path', help='potsdam folder path')
parser.add_argument('-o', '--out_dir', help='output path')
parser.add_argument(
'--clip_size',
type=int,
help='clipped size of image after preparation',
default=256)
parser.add_argument(
'--stride_size',
type=int,
help='stride of clipping original images',
default=256)
args = parser.parse_args()
return args
def main():
args = parse_args()
input_folder = args.dataset_path
png_files = glob.glob(
os.path.join(input_folder, '**/*.png'), recursive=True)
output_folder = args.out_dir
prog_bar = ProgressBar(len(png_files))
for png_file in png_files:
new_path = os.path.join(
output_folder,
os.path.relpath(os.path.dirname(png_file), input_folder))
os.makedirs(os.path.dirname(new_path), exist_ok=True)
label = False
if 'label' in png_file:
label = True
clip_big_image(png_file, new_path, args, label)
prog_bar.update()
def clip_big_image(image_path, clip_save_dir, args, to_label=False):
image = mmcv.imread(image_path)
h, w, c = image.shape
clip_size = args.clip_size
stride_size = args.stride_size
num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
(h - clip_size) /
stride_size) * stride_size + clip_size >= h else math.ceil(
(h - clip_size) / stride_size) + 1
num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
(w - clip_size) /
stride_size) * stride_size + clip_size >= w else math.ceil(
(w - clip_size) / stride_size) + 1
x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
xmin = x * clip_size
ymin = y * clip_size
xmin = xmin.ravel()
ymin = ymin.ravel()
xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
np.zeros_like(xmin))
ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
np.zeros_like(ymin))
boxes = np.stack([
xmin + xmin_offset, ymin + ymin_offset,
np.minimum(xmin + clip_size, w),
np.minimum(ymin + clip_size, h)
],
axis=1)
if to_label:
image[image == 255] = 1
image = image[:, :, 0]
for box in boxes:
start_x, start_y, end_x, end_y = box
clipped_image = image[start_y:end_y, start_x:end_x] \
if to_label else image[start_y:end_y, start_x:end_x, :]
idx = osp.basename(image_path).split('.')[0]
mmcv.imwrite(
clipped_image.astype(np.uint8),
osp.join(clip_save_dir,
f'{idx}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,73 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import shutil
import tempfile
import zipfile
from mmengine.utils import mkdir_or_exist
def parse_args():
parser = argparse.ArgumentParser(
description='Convert LoveDA dataset to mmsegmentation format')
parser.add_argument('dataset_path', help='LoveDA folder path')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument('-o', '--out_dir', help='output path')
args = parser.parse_args()
return args
def main():
args = parse_args()
dataset_path = args.dataset_path
if args.out_dir is None:
out_dir = osp.join('data', 'loveDA')
else:
out_dir = args.out_dir
print('Making directories...')
mkdir_or_exist(out_dir)
mkdir_or_exist(osp.join(out_dir, 'img_dir'))
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test'))
mkdir_or_exist(osp.join(out_dir, 'ann_dir'))
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
assert 'Train.zip' in os.listdir(dataset_path), \
f'Train.zip is not in {dataset_path}'
assert 'Val.zip' in os.listdir(dataset_path), \
f'Val.zip is not in {dataset_path}'
assert 'Test.zip' in os.listdir(dataset_path), \
f'Test.zip is not in {dataset_path}'
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
for dataset in ['Train', 'Val', 'Test']:
zip_file = zipfile.ZipFile(
os.path.join(dataset_path, dataset + '.zip'))
zip_file.extractall(tmp_dir)
data_type = dataset.lower()
for location in ['Rural', 'Urban']:
for image_type in ['images_png', 'masks_png']:
if image_type == 'images_png':
dst = osp.join(out_dir, 'img_dir', data_type)
else:
dst = osp.join(out_dir, 'ann_dir', data_type)
if dataset == 'Test' and image_type == 'masks_png':
continue
else:
src_dir = osp.join(tmp_dir, dataset, location,
image_type)
src_lst = os.listdir(src_dir)
for file in src_lst:
shutil.move(osp.join(src_dir, file), dst)
print('Removing the temporary files...')
print('Done!')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,89 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import shutil
import tempfile
import zipfile
from mmengine.utils import mkdir_or_exist
def parse_args():
parser = argparse.ArgumentParser(
description='Convert NYU Depth dataset to mmsegmentation format')
parser.add_argument('raw_data', help='the path of raw data')
parser.add_argument(
'-o', '--out_dir', help='output path', default='./data/nyu')
args = parser.parse_args()
return args
def reorganize(raw_data_dir: str, out_dir: str):
"""Reorganize NYU Depth dataset files into the required directory
structure.
Args:
raw_data_dir (str): Path to the raw data directory.
out_dir (str): Output directory for the organized dataset.
"""
def move_data(data_list, dst_prefix, fname_func):
"""Move data files from source to destination directory.
Args:
data_list (list): List of data file paths.
dst_prefix (str): Prefix to be added to destination paths.
fname_func (callable): Function to process file names
"""
for data_item in data_list:
data_item = data_item.strip().strip('/')
new_item = fname_func(data_item)
shutil.move(
osp.join(raw_data_dir, data_item),
osp.join(out_dir, dst_prefix, new_item))
def process_phase(phase):
"""Process a dataset phase (e.g., 'train' or 'test')."""
with open(osp.join(raw_data_dir, f'nyu_{phase}.txt')) as f:
data = filter(lambda x: len(x.strip()) > 0, f.readlines())
data = map(lambda x: x.split()[:2], data)
images, annos = zip(*data)
move_data(images, f'images/{phase}',
lambda x: x.replace('/rgb', ''))
move_data(annos, f'annotations/{phase}',
lambda x: x.replace('/sync_depth', ''))
process_phase('train')
process_phase('test')
def main():
args = parse_args()
print('Making directories...')
mkdir_or_exist(args.out_dir)
for subdir in [
'images/train', 'images/test', 'annotations/train',
'annotations/test'
]:
mkdir_or_exist(osp.join(args.out_dir, subdir))
print('Generating images and annotations...')
if args.raw_data.endswith('.zip'):
with tempfile.TemporaryDirectory() as tmp_dir:
zip_file = zipfile.ZipFile(args.raw_data)
zip_file.extractall(tmp_dir)
reorganize(osp.join(tmp_dir, 'nyu'), args.out_dir)
else:
assert osp.isdir(
args.raw_data
), 'the argument --raw-data should be either a zip file or directory.'
reorganize(args.raw_data, args.out_dir)
print('Done!')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,87 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from functools import partial
import numpy as np
from detail import Detail
from mmengine.utils import mkdir_or_exist, track_progress
from PIL import Image
_mapping = np.sort(
np.array([
0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284,
158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59,
440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355,
85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115
]))
_key = np.array(range(len(_mapping))).astype('uint8')
def generate_labels(img_id, detail, out_dir):
def _class_to_index(mask, _mapping, _key):
# assert the values
values = np.unique(mask)
for i in range(len(values)):
assert (values[i] in _mapping)
index = np.digitize(mask.ravel(), _mapping, right=True)
return _key[index].reshape(mask.shape)
mask = Image.fromarray(
_class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key))
filename = img_id['file_name']
mask.save(osp.join(out_dir, filename.replace('jpg', 'png')))
return osp.splitext(osp.basename(filename))[0]
def parse_args():
parser = argparse.ArgumentParser(
description='Convert PASCAL VOC annotations to mmsegmentation format')
parser.add_argument('devkit_path', help='pascal voc devkit path')
parser.add_argument('json_path', help='annoation json filepath')
parser.add_argument('-o', '--out_dir', help='output path')
args = parser.parse_args()
return args
def main():
args = parse_args()
devkit_path = args.devkit_path
if args.out_dir is None:
out_dir = osp.join(devkit_path, 'VOC2010', 'SegmentationClassContext')
else:
out_dir = args.out_dir
json_path = args.json_path
mkdir_or_exist(out_dir)
img_dir = osp.join(devkit_path, 'VOC2010', 'JPEGImages')
train_detail = Detail(json_path, img_dir, 'train')
train_ids = train_detail.getImgs()
val_detail = Detail(json_path, img_dir, 'val')
val_ids = val_detail.getImgs()
mkdir_or_exist(
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext'))
train_list = track_progress(
partial(generate_labels, detail=train_detail, out_dir=out_dir),
train_ids)
with open(
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
'train.txt'), 'w') as f:
f.writelines(line + '\n' for line in sorted(train_list))
val_list = track_progress(
partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids)
with open(
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
'val.txt'), 'w') as f:
f.writelines(line + '\n' for line in sorted(val_list))
print('Done!')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,158 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import glob
import math
import os
import os.path as osp
import tempfile
import zipfile
import mmcv
import numpy as np
from mmengine.utils import ProgressBar, mkdir_or_exist
def parse_args():
parser = argparse.ArgumentParser(
description='Convert potsdam dataset to mmsegmentation format')
parser.add_argument('dataset_path', help='potsdam folder path')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument('-o', '--out_dir', help='output path')
parser.add_argument(
'--clip_size',
type=int,
help='clipped size of image after preparation',
default=512)
parser.add_argument(
'--stride_size',
type=int,
help='stride of clipping original images',
default=256)
args = parser.parse_args()
return args
def clip_big_image(image_path, clip_save_dir, args, to_label=False):
# Original image of Potsdam dataset is very large, thus pre-processing
# of them is adopted. Given fixed clip size and stride size to generate
# clipped image, the intersection of width and height is determined.
# For example, given one 5120 x 5120 original image, the clip size is
# 512 and stride size is 256, thus it would generate 20x20 = 400 images
# whose size are all 512x512.
image = mmcv.imread(image_path)
h, w, c = image.shape
clip_size = args.clip_size
stride_size = args.stride_size
num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
(h - clip_size) /
stride_size) * stride_size + clip_size >= h else math.ceil(
(h - clip_size) / stride_size) + 1
num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
(w - clip_size) /
stride_size) * stride_size + clip_size >= w else math.ceil(
(w - clip_size) / stride_size) + 1
x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
xmin = x * clip_size
ymin = y * clip_size
xmin = xmin.ravel()
ymin = ymin.ravel()
xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
np.zeros_like(xmin))
ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
np.zeros_like(ymin))
boxes = np.stack([
xmin + xmin_offset, ymin + ymin_offset,
np.minimum(xmin + clip_size, w),
np.minimum(ymin + clip_size, h)
],
axis=1)
if to_label:
color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0],
[255, 255, 0], [0, 255, 0], [0, 255, 255],
[0, 0, 255]])
flatten_v = np.matmul(
image.reshape(-1, c),
np.array([2, 3, 4]).reshape(3, 1))
out = np.zeros_like(flatten_v)
for idx, class_color in enumerate(color_map):
value_idx = np.matmul(class_color,
np.array([2, 3, 4]).reshape(3, 1))
out[flatten_v == value_idx] = idx
image = out.reshape(h, w)
for box in boxes:
start_x, start_y, end_x, end_y = box
clipped_image = image[start_y:end_y,
start_x:end_x] if to_label else image[
start_y:end_y, start_x:end_x, :]
idx_i, idx_j = osp.basename(image_path).split('_')[2:4]
mmcv.imwrite(
clipped_image.astype(np.uint8),
osp.join(
clip_save_dir,
f'{idx_i}_{idx_j}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
def main():
args = parse_args()
splits = {
'train': [
'2_10', '2_11', '2_12', '3_10', '3_11', '3_12', '4_10', '4_11',
'4_12', '5_10', '5_11', '5_12', '6_10', '6_11', '6_12', '6_7',
'6_8', '6_9', '7_10', '7_11', '7_12', '7_7', '7_8', '7_9'
],
'val': [
'5_15', '6_15', '6_13', '3_13', '4_14', '6_14', '5_14', '2_13',
'4_15', '2_14', '5_13', '4_13', '3_14', '7_13'
]
}
dataset_path = args.dataset_path
if args.out_dir is None:
out_dir = osp.join('data', 'potsdam')
else:
out_dir = args.out_dir
print('Making directories...')
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
zipp_list = glob.glob(os.path.join(dataset_path, '*.zip'))
print('Find the data', zipp_list)
for zipp in zipp_list:
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
zip_file = zipfile.ZipFile(zipp)
zip_file.extractall(tmp_dir)
src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif'))
if not len(src_path_list):
sub_tmp_dir = os.path.join(tmp_dir, os.listdir(tmp_dir)[0])
src_path_list = glob.glob(os.path.join(sub_tmp_dir, '*.tif'))
prog_bar = ProgressBar(len(src_path_list))
for i, src_path in enumerate(src_path_list):
idx_i, idx_j = osp.basename(src_path).split('_')[2:4]
data_type = 'train' if f'{idx_i}_{idx_j}' in splits[
'train'] else 'val'
if 'label' in src_path:
dst_dir = osp.join(out_dir, 'ann_dir', data_type)
clip_big_image(src_path, dst_dir, args, to_label=True)
else:
dst_dir = osp.join(out_dir, 'img_dir', data_type)
clip_big_image(src_path, dst_dir, args, to_label=False)
prog_bar.update()
print('Removing the temporary files...')
print('Done!')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,110 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import tempfile
import zipfile
import mmcv
import numpy as np
from mmengine.utils import mkdir_or_exist
def parse_args():
parser = argparse.ArgumentParser(
description='Convert REFUGE dataset to mmsegmentation format')
parser.add_argument('--raw_data_root', help='the root path of raw data')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument('-o', '--out_dir', help='output path')
args = parser.parse_args()
return args
def extract_img(root: str,
cur_dir: str,
out_dir: str,
mode: str = 'train',
file_type: str = 'img') -> None:
"""_summary_
Args:
Args:
root (str): root where the extracted data is saved
cur_dir (cur_dir): dir where the zip_file exists
out_dir (str): root dir where the data is saved
mode (str, optional): Defaults to 'train'.
file_type (str, optional): Defaults to 'img',else to 'mask'.
"""
zip_file = zipfile.ZipFile(cur_dir)
zip_file.extractall(root)
for cur_dir, dirs, files in os.walk(root):
# filter child dirs and directories with "Illustration" and "MACOSX"
if len(dirs) == 0 and \
cur_dir.split('\\')[-1].find('Illustration') == -1 and \
cur_dir.find('MACOSX') == -1:
file_names = [
file for file in files
if file.endswith('.jpg') or file.endswith('.bmp')
]
for filename in sorted(file_names):
img = mmcv.imread(osp.join(cur_dir, filename))
if file_type == 'annotations':
img = img[:, :, 0]
img[np.where(img == 0)] = 1
img[np.where(img == 128)] = 2
img[np.where(img == 255)] = 0
mmcv.imwrite(
img,
osp.join(out_dir, file_type, mode,
osp.splitext(filename)[0] + '.png'))
def main():
args = parse_args()
raw_data_root = args.raw_data_root
if args.out_dir is None:
out_dir = osp.join('./data', 'REFUGE')
else:
out_dir = args.out_dir
print('Making directories...')
mkdir_or_exist(out_dir)
mkdir_or_exist(osp.join(out_dir, 'images'))
mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
mkdir_or_exist(osp.join(out_dir, 'images', 'test'))
mkdir_or_exist(osp.join(out_dir, 'annotations'))
mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
mkdir_or_exist(osp.join(out_dir, 'annotations', 'test'))
print('Generating images and annotations...')
# process data from the child dir on the first rank
cur_dir, dirs, files = list(os.walk(raw_data_root))[0]
print('====================')
files = list(filter(lambda x: x.endswith('.zip'), files))
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
for file in files:
# search data folders for training,validation,test
mode = list(
filter(lambda x: file.lower().find(x) != -1,
['training', 'test', 'validation']))[0]
file_root = osp.join(tmp_dir, file[:-4])
file_type = 'images' if file.find('Anno') == -1 and file.find(
'GT') == -1 else 'annotations'
extract_img(file_root, osp.join(cur_dir, file), out_dir, mode,
file_type)
print('Done!')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,167 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import gzip
import os
import os.path as osp
import tarfile
import tempfile
import mmcv
from mmengine.utils import mkdir_or_exist
STARE_LEN = 20
TRAINING_LEN = 10
def un_gz(src, dst):
g_file = gzip.GzipFile(src)
with open(dst, 'wb+') as f:
f.write(g_file.read())
g_file.close()
def parse_args():
parser = argparse.ArgumentParser(
description='Convert STARE dataset to mmsegmentation format')
parser.add_argument('image_path', help='the path of stare-images.tar')
parser.add_argument('labels_ah', help='the path of labels-ah.tar')
parser.add_argument('labels_vk', help='the path of labels-vk.tar')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument('-o', '--out_dir', help='output path')
args = parser.parse_args()
return args
def main():
args = parse_args()
image_path = args.image_path
labels_ah = args.labels_ah
labels_vk = args.labels_vk
if args.out_dir is None:
out_dir = osp.join('data', 'STARE')
else:
out_dir = args.out_dir
print('Making directories...')
mkdir_or_exist(out_dir)
mkdir_or_exist(osp.join(out_dir, 'images'))
mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
mkdir_or_exist(osp.join(out_dir, 'annotations'))
mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
mkdir_or_exist(osp.join(tmp_dir, 'gz'))
mkdir_or_exist(osp.join(tmp_dir, 'files'))
print('Extracting stare-images.tar...')
with tarfile.open(image_path) as f:
f.extractall(osp.join(tmp_dir, 'gz'))
for filename in os.listdir(osp.join(tmp_dir, 'gz')):
un_gz(
osp.join(tmp_dir, 'gz', filename),
osp.join(tmp_dir, 'files',
osp.splitext(filename)[0]))
now_dir = osp.join(tmp_dir, 'files')
assert len(os.listdir(now_dir)) == STARE_LEN, \
f'len(os.listdir(now_dir)) != {STARE_LEN}'
for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(now_dir, filename))
mmcv.imwrite(
img,
osp.join(out_dir, 'images', 'training',
osp.splitext(filename)[0] + '.png'))
for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
img = mmcv.imread(osp.join(now_dir, filename))
mmcv.imwrite(
img,
osp.join(out_dir, 'images', 'validation',
osp.splitext(filename)[0] + '.png'))
print('Removing the temporary files...')
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
mkdir_or_exist(osp.join(tmp_dir, 'gz'))
mkdir_or_exist(osp.join(tmp_dir, 'files'))
print('Extracting labels-ah.tar...')
with tarfile.open(labels_ah) as f:
f.extractall(osp.join(tmp_dir, 'gz'))
for filename in os.listdir(osp.join(tmp_dir, 'gz')):
un_gz(
osp.join(tmp_dir, 'gz', filename),
osp.join(tmp_dir, 'files',
osp.splitext(filename)[0]))
now_dir = osp.join(tmp_dir, 'files')
assert len(os.listdir(now_dir)) == STARE_LEN, \
f'len(os.listdir(now_dir)) != {STARE_LEN}'
for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(now_dir, filename))
# The annotation img should be divided by 128, because some of
# the annotation imgs are not standard. We should set a threshold
# to convert the nonstandard annotation imgs. The value divided by
# 128 equivalent to '1 if value >= 128 else 0'
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'training',
osp.splitext(filename)[0] + '.png'))
for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
img = mmcv.imread(osp.join(now_dir, filename))
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'validation',
osp.splitext(filename)[0] + '.png'))
print('Removing the temporary files...')
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
mkdir_or_exist(osp.join(tmp_dir, 'gz'))
mkdir_or_exist(osp.join(tmp_dir, 'files'))
print('Extracting labels-vk.tar...')
with tarfile.open(labels_vk) as f:
f.extractall(osp.join(tmp_dir, 'gz'))
for filename in os.listdir(osp.join(tmp_dir, 'gz')):
un_gz(
osp.join(tmp_dir, 'gz', filename),
osp.join(tmp_dir, 'files',
osp.splitext(filename)[0]))
now_dir = osp.join(tmp_dir, 'files')
assert len(os.listdir(now_dir)) == STARE_LEN, \
f'len(os.listdir(now_dir)) != {STARE_LEN}'
for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(now_dir, filename))
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'training',
osp.splitext(filename)[0] + '.png'))
for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
img = mmcv.imread(osp.join(now_dir, filename))
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'validation',
osp.splitext(filename)[0] + '.png'))
print('Removing the temporary files...')
print('Done!')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,155 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import nibabel as nib
import numpy as np
from mmengine.utils import mkdir_or_exist
from PIL import Image
def read_files_from_txt(txt_path):
with open(txt_path) as f:
files = f.readlines()
files = [file.strip() for file in files]
return files
def read_nii_file(nii_path):
img = nib.load(nii_path).get_fdata()
return img
def split_3d_image(img):
c, _, _ = img.shape
res = []
for i in range(c):
res.append(img[i, :, :])
return res
def label_mapping(label):
"""Label mapping from TransUNet paper setting. It only has 9 classes, which
are 'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney',
'liver', 'pancreas', 'spleen', 'stomach', respectively. Other foreground
classes in original dataset are all set to background.
More details could be found here: https://arxiv.org/abs/2102.04306
"""
maped_label = np.zeros_like(label)
maped_label[label == 8] = 1
maped_label[label == 4] = 2
maped_label[label == 3] = 3
maped_label[label == 2] = 4
maped_label[label == 6] = 5
maped_label[label == 11] = 6
maped_label[label == 1] = 7
maped_label[label == 7] = 8
return maped_label
def pares_args():
parser = argparse.ArgumentParser(
description='Convert synapse dataset to mmsegmentation format')
parser.add_argument(
'--dataset-path', type=str, help='synapse dataset path.')
parser.add_argument(
'--save-path',
default='data/synapse',
type=str,
help='save path of the dataset.')
args = parser.parse_args()
return args
def main():
args = pares_args()
dataset_path = args.dataset_path
save_path = args.save_path
if not osp.exists(dataset_path):
raise ValueError('The dataset path does not exist. '
'Please enter a correct dataset path.')
if not osp.exists(osp.join(dataset_path, 'img')) \
or not osp.exists(osp.join(dataset_path, 'label')):
raise FileNotFoundError('The dataset structure is incorrect. '
'Please check your dataset.')
train_id = read_files_from_txt(osp.join(dataset_path, 'train.txt'))
train_id = [idx[3:7] for idx in train_id]
test_id = read_files_from_txt(osp.join(dataset_path, 'val.txt'))
test_id = [idx[3:7] for idx in test_id]
mkdir_or_exist(osp.join(save_path, 'img_dir/train'))
mkdir_or_exist(osp.join(save_path, 'img_dir/val'))
mkdir_or_exist(osp.join(save_path, 'ann_dir/train'))
mkdir_or_exist(osp.join(save_path, 'ann_dir/val'))
# It follows data preparation pipeline from here:
# https://github.com/Beckschen/TransUNet/tree/main/datasets
for i, idx in enumerate(train_id):
img_3d = read_nii_file(
osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz'))
label_3d = read_nii_file(
osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz'))
img_3d = np.clip(img_3d, -125, 275)
img_3d = (img_3d + 125) / 400
img_3d *= 255
img_3d = np.transpose(img_3d, [2, 0, 1])
img_3d = np.flip(img_3d, 2)
label_3d = np.transpose(label_3d, [2, 0, 1])
label_3d = np.flip(label_3d, 2)
label_3d = label_mapping(label_3d)
for c in range(img_3d.shape[0]):
img = img_3d[c]
label = label_3d[c]
img = Image.fromarray(img).convert('RGB')
label = Image.fromarray(label).convert('L')
img.save(
osp.join(
save_path, 'img_dir/train', 'case' + idx.zfill(4) +
'_slice' + str(c).zfill(3) + '.jpg'))
label.save(
osp.join(
save_path, 'ann_dir/train', 'case' + idx.zfill(4) +
'_slice' + str(c).zfill(3) + '.png'))
for i, idx in enumerate(test_id):
img_3d = read_nii_file(
osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz'))
label_3d = read_nii_file(
osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz'))
img_3d = np.clip(img_3d, -125, 275)
img_3d = (img_3d + 125) / 400
img_3d *= 255
img_3d = np.transpose(img_3d, [2, 0, 1])
img_3d = np.flip(img_3d, 2)
label_3d = np.transpose(label_3d, [2, 0, 1])
label_3d = np.flip(label_3d, 2)
label_3d = label_mapping(label_3d)
for c in range(img_3d.shape[0]):
img = img_3d[c]
label = label_3d[c]
img = Image.fromarray(img).convert('RGB')
label = Image.fromarray(label).convert('L')
img.save(
osp.join(
save_path, 'img_dir/val', 'case' + idx.zfill(4) +
'_slice' + str(c).zfill(3) + '.jpg'))
label.save(
osp.join(
save_path, 'ann_dir/val', 'case' + idx.zfill(4) +
'_slice' + str(c).zfill(3) + '.png'))
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,156 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import glob
import math
import os
import os.path as osp
import tempfile
import zipfile
import mmcv
import numpy as np
from mmengine.utils import ProgressBar, mkdir_or_exist
def parse_args():
parser = argparse.ArgumentParser(
description='Convert vaihingen dataset to mmsegmentation format')
parser.add_argument('dataset_path', help='vaihingen folder path')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument('-o', '--out_dir', help='output path')
parser.add_argument(
'--clip_size',
type=int,
help='clipped size of image after preparation',
default=512)
parser.add_argument(
'--stride_size',
type=int,
help='stride of clipping original images',
default=256)
args = parser.parse_args()
return args
def clip_big_image(image_path, clip_save_dir, to_label=False):
# Original image of Vaihingen dataset is very large, thus pre-processing
# of them is adopted. Given fixed clip size and stride size to generate
# clipped image, the intersection of width and height is determined.
# For example, given one 5120 x 5120 original image, the clip size is
# 512 and stride size is 256, thus it would generate 20x20 = 400 images
# whose size are all 512x512.
image = mmcv.imread(image_path)
h, w, c = image.shape
cs = args.clip_size
ss = args.stride_size
num_rows = math.ceil((h - cs) / ss) if math.ceil(
(h - cs) / ss) * ss + cs >= h else math.ceil((h - cs) / ss) + 1
num_cols = math.ceil((w - cs) / ss) if math.ceil(
(w - cs) / ss) * ss + cs >= w else math.ceil((w - cs) / ss) + 1
x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
xmin = x * cs
ymin = y * cs
xmin = xmin.ravel()
ymin = ymin.ravel()
xmin_offset = np.where(xmin + cs > w, w - xmin - cs, np.zeros_like(xmin))
ymin_offset = np.where(ymin + cs > h, h - ymin - cs, np.zeros_like(ymin))
boxes = np.stack([
xmin + xmin_offset, ymin + ymin_offset,
np.minimum(xmin + cs, w),
np.minimum(ymin + cs, h)
],
axis=1)
if to_label:
color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0],
[255, 255, 0], [0, 255, 0], [0, 255, 255],
[0, 0, 255]])
flatten_v = np.matmul(
image.reshape(-1, c),
np.array([2, 3, 4]).reshape(3, 1))
out = np.zeros_like(flatten_v)
for idx, class_color in enumerate(color_map):
value_idx = np.matmul(class_color,
np.array([2, 3, 4]).reshape(3, 1))
out[flatten_v == value_idx] = idx
image = out.reshape(h, w)
for box in boxes:
start_x, start_y, end_x, end_y = box
clipped_image = image[start_y:end_y,
start_x:end_x] if to_label else image[
start_y:end_y, start_x:end_x, :]
area_idx = osp.basename(image_path).split('_')[3].strip('.tif')
mmcv.imwrite(
clipped_image.astype(np.uint8),
osp.join(clip_save_dir,
f'{area_idx}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
def main():
splits = {
'train': [
'area1', 'area11', 'area13', 'area15', 'area17', 'area21',
'area23', 'area26', 'area28', 'area3', 'area30', 'area32',
'area34', 'area37', 'area5', 'area7'
],
'val': [
'area6', 'area24', 'area35', 'area16', 'area14', 'area22',
'area10', 'area4', 'area2', 'area20', 'area8', 'area31', 'area33',
'area27', 'area38', 'area12', 'area29'
],
}
dataset_path = args.dataset_path
if args.out_dir is None:
out_dir = osp.join('data', 'vaihingen')
else:
out_dir = args.out_dir
print('Making directories...')
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
zipp_list = glob.glob(os.path.join(dataset_path, '*.zip'))
print('Find the data', zipp_list)
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
for zipp in zipp_list:
zip_file = zipfile.ZipFile(zipp)
zip_file.extractall(tmp_dir)
src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif'))
if 'ISPRS_semantic_labeling_Vaihingen' in zipp:
src_path_list = glob.glob(
os.path.join(os.path.join(tmp_dir, 'top'), '*.tif'))
if 'ISPRS_semantic_labeling_Vaihingen_ground_truth_eroded_COMPLETE' in zipp: # noqa
src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif'))
# delete unused area9 ground truth
for area_ann in src_path_list:
if 'area9' in area_ann:
src_path_list.remove(area_ann)
prog_bar = ProgressBar(len(src_path_list))
for i, src_path in enumerate(src_path_list):
area_idx = osp.basename(src_path).split('_')[3].strip('.tif')
data_type = 'train' if area_idx in splits['train'] else 'val'
if 'noBoundary' in src_path:
dst_dir = osp.join(out_dir, 'ann_dir', data_type)
clip_big_image(src_path, dst_dir, to_label=True)
else:
dst_dir = osp.join(out_dir, 'img_dir', data_type)
clip_big_image(src_path, dst_dir, to_label=False)
prog_bar.update()
print('Removing the temporary files...')
print('Done!')
if __name__ == '__main__':
args = parse_args()
main()

View File

@@ -0,0 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from functools import partial
import numpy as np
from mmengine.utils import mkdir_or_exist, scandir, track_parallel_progress
from PIL import Image
from scipy.io import loadmat
AUG_LEN = 10582
def convert_mat(mat_file, in_dir, out_dir):
data = loadmat(osp.join(in_dir, mat_file))
mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8)
seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png'))
Image.fromarray(mask).save(seg_filename, 'PNG')
def generate_aug_list(merged_list, excluded_list):
return list(set(merged_list) - set(excluded_list))
def parse_args():
parser = argparse.ArgumentParser(
description='Convert PASCAL VOC annotations to mmsegmentation format')
parser.add_argument('devkit_path', help='pascal voc devkit path')
parser.add_argument('aug_path', help='pascal voc aug path')
parser.add_argument('-o', '--out_dir', help='output path')
parser.add_argument(
'--nproc', default=1, type=int, help='number of process')
args = parser.parse_args()
return args
def main():
args = parse_args()
devkit_path = args.devkit_path
aug_path = args.aug_path
nproc = args.nproc
if args.out_dir is None:
out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug')
else:
out_dir = args.out_dir
mkdir_or_exist(out_dir)
in_dir = osp.join(aug_path, 'dataset', 'cls')
track_parallel_progress(
partial(convert_mat, in_dir=in_dir, out_dir=out_dir),
list(scandir(in_dir, suffix='.mat')),
nproc=nproc)
full_aug_list = []
with open(osp.join(aug_path, 'dataset', 'train.txt')) as f:
full_aug_list += [line.strip() for line in f]
with open(osp.join(aug_path, 'dataset', 'val.txt')) as f:
full_aug_list += [line.strip() for line in f]
with open(
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
'train.txt')) as f:
ori_train_list = [line.strip() for line in f]
with open(
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
'val.txt')) as f:
val_list = [line.strip() for line in f]
aug_train_list = generate_aug_list(ori_train_list + full_aug_list,
val_list)
assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format(
AUG_LEN)
with open(
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
'trainaug.txt'), 'w') as f:
f.writelines(line + '\n' for line in aug_train_list)
aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list)
assert len(aug_list) == AUG_LEN - len(
ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN -
len(ori_train_list))
with open(
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'),
'w') as f:
f.writelines(line + '\n' for line in aug_list)
print('Done!')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,185 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import numpy as np
import torch
import torch._C
import torch.serialization
from mmengine import Config
from mmengine.runner import load_checkpoint
from torch import nn
from mmseg.models import build_segmentor
torch.manual_seed(3)
def digit_version(version_str):
digit_version = []
for x in version_str.split('.'):
if x.isdigit():
digit_version.append(int(x))
elif x.find('rc') != -1:
patch_version = x.split('rc')
digit_version.append(int(patch_version[0]) - 1)
digit_version.append(int(patch_version[1]))
return digit_version
def check_torch_version():
torch_minimum_version = '1.8.0'
torch_version = digit_version(torch.__version__)
assert (torch_version >= digit_version(torch_minimum_version)), \
f'Torch=={torch.__version__} is not support for converting to ' \
f'torchscript. Please install pytorch>={torch_minimum_version}.'
def _convert_batchnorm(module):
module_output = module
if isinstance(module, torch.nn.SyncBatchNorm):
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output
def _demo_mm_inputs(input_shape, num_classes):
"""Create a superset of inputs needed to run test or train batches.
Args:
input_shape (tuple):
input batch dimensions
num_classes (int):
number of semantic classes
"""
(N, C, H, W) = input_shape
rng = np.random.RandomState(0)
imgs = rng.rand(*input_shape)
segs = rng.randint(
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
img_metas = [{
'img_shape': (H, W, C),
'ori_shape': (H, W, C),
'pad_shape': (H, W, C),
'filename': '<demo>.png',
'scale_factor': 1.0,
'flip': False,
} for _ in range(N)]
mm_inputs = {
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
'img_metas': img_metas,
'gt_semantic_seg': torch.LongTensor(segs)
}
return mm_inputs
def pytorch2libtorch(model,
input_shape,
show=False,
output_file='tmp.pt',
verify=False):
"""Export Pytorch model to TorchScript model and verify the outputs are
same between Pytorch and TorchScript.
Args:
model (nn.Module): Pytorch model we want to export.
input_shape (tuple): Use this input shape to construct
the corresponding dummy input and execute the model.
show (bool): Whether print the computation graph. Default: False.
output_file (string): The path to where we store the
output TorchScript model. Default: `tmp.pt`.
verify (bool): Whether compare the outputs between
Pytorch and TorchScript. Default: False.
"""
if isinstance(model.decode_head, nn.ModuleList):
num_classes = model.decode_head[-1].num_classes
else:
num_classes = model.decode_head.num_classes
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
imgs = mm_inputs.pop('imgs')
# replace the original forword with forward_dummy
model.forward = model.forward_dummy
model.eval()
traced_model = torch.jit.trace(
model,
example_inputs=imgs,
check_trace=verify,
)
if show:
print(traced_model.graph)
traced_model.save(output_file)
print(f'Successfully exported TorchScript model: {output_file}')
def parse_args():
parser = argparse.ArgumentParser(
description='Convert MMSeg to TorchScript')
parser.add_argument('config', help='test config file path')
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
parser.add_argument(
'--show', action='store_true', help='show TorchScript graph')
parser.add_argument(
'--verify', action='store_true', help='verify the TorchScript model')
parser.add_argument('--output-file', type=str, default='tmp.pt')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[512, 512],
help='input image size (height, width)')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
check_torch_version()
if len(args.shape) == 1:
input_shape = (1, 3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (
1,
3,
) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
cfg = Config.fromfile(args.config)
cfg.model.pretrained = None
# build the model and load checkpoint
cfg.model.train_cfg = None
segmentor = build_segmentor(
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
# convert SyncBN to BN
segmentor = _convert_batchnorm(segmentor)
if args.checkpoint:
load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
# convert the PyTorch model to LibTorch model
pytorch2libtorch(
segmentor,
input_shape,
show=args.show,
output_file=args.output_file,
verify=args.verify)

View File

@@ -0,0 +1,20 @@
CONFIG=$1
CHECKPOINT=$2
GPUS=$3
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch \
--nnodes=$NNODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--nproc_per_node=$GPUS \
--master_port=$PORT \
$(dirname "$0")/test.py \
$CONFIG \
$CHECKPOINT \
--launcher pytorch \
${@:4}

View File

@@ -0,0 +1,17 @@
CONFIG=$1
GPUS=$2
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch \
--nnodes=$NNODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--nproc_per_node=$GPUS \
--master_port=$PORT \
$(dirname "$0")/train.py \
$CONFIG \
--launcher pytorch ${@:3}

View File

@@ -0,0 +1,73 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from mmengine import Config, DictAction
from mmengine.registry import init_default_scope
from mmengine.utils import ProgressBar
from mmseg.registry import DATASETS, VISUALIZERS
def parse_args():
parser = argparse.ArgumentParser(description='Browse a dataset')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--output-dir',
default=None,
type=str,
help='If there is no display interface, you can save it')
parser.add_argument('--not-show', default=False, action='store_true')
parser.add_argument(
'--show-interval',
type=float,
default=2,
help='the interval of show (s)')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# register all modules in mmseg into the registries
init_default_scope('mmseg')
dataset = DATASETS.build(cfg.train_dataloader.dataset)
cfg.visualizer['save_dir'] = args.output_dir
visualizer = VISUALIZERS.build(cfg.visualizer)
visualizer.dataset_meta = dataset.METAINFO
progress_bar = ProgressBar(len(dataset))
for item in dataset:
img = item['inputs'].permute(1, 2, 0).numpy()
data_sample = item['data_samples'].numpy()
img_path = osp.basename(item['data_samples'].img_path)
img = img[..., [2, 1, 0]] # bgr to rgb
visualizer.add_datasample(
osp.basename(img_path),
img,
data_sample,
show=not args.not_show,
wait_time=args.show_interval)
progress_bar.update()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,69 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import warnings
from mmengine import Config, DictAction
from mmseg.apis import init_model
def parse_args():
parser = argparse.ArgumentParser(description='Print the whole config')
parser.add_argument('config', help='config file path')
parser.add_argument(
'--graph', action='store_true', help='print the models graph')
parser.add_argument(
'--options',
nargs='+',
action=DictAction,
help="--options is deprecated in favor of --cfg_options' and it will "
'not be supported in version v0.22.0. Override some settings in the '
'used config, the key-value pair in xxx=yyy format will be merged '
'into config file. If the value to be overwritten is a list, it '
'should be like key="[a,b]" or key=a,b It also allows nested '
'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation '
'marks are necessary and that no white space is allowed.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
if args.options and args.cfg_options:
raise ValueError(
'--options and --cfg-options cannot be both '
'specified, --options is deprecated in favor of --cfg-options. '
'--options will not be supported in version v0.22.0.')
if args.options:
warnings.warn('--options is deprecated in favor of --cfg-options, '
'--options will not be supported in version v0.22.0.')
args.cfg_options = args.options
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
print(f'Config:\n{cfg.pretty_text}')
# dump config
cfg.dump('example.py')
# dump models graph
if args.graph:
model = init_model(args.config, device='cpu')
print(f'Model graph:\n{str(model)}')
with open('example-graph.txt', 'w') as f:
f.writelines(str(model))
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,50 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import subprocess
from hashlib import sha256
import torch
BLOCK_SIZE = 128 * 1024
def parse_args():
parser = argparse.ArgumentParser(
description='Process a checkpoint to be published')
parser.add_argument('in_file', help='input checkpoint filename')
parser.add_argument('out_file', help='output checkpoint filename')
args = parser.parse_args()
return args
def sha256sum(filename: str) -> str:
"""Compute SHA256 message digest from a file."""
hash_func = sha256()
byte_array = bytearray(BLOCK_SIZE)
memory_view = memoryview(byte_array)
with open(filename, 'rb', buffering=0) as file:
for block in iter(lambda: file.readinto(memory_view), 0):
hash_func.update(memory_view[:block])
return hash_func.hexdigest()
def process_checkpoint(in_file, out_file):
checkpoint = torch.load(in_file, map_location='cpu')
# remove optimizer for smaller file size
if 'optimizer' in checkpoint:
del checkpoint['optimizer']
# if it is necessary to remove some sensitive data in checkpoint['meta'],
# add the code here.
torch.save(checkpoint, out_file)
sha = sha256sum(in_file)
final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth'
subprocess.Popen(['mv', out_file, final_file])
def main():
args = parse_args()
process_checkpoint(args.in_file, args.out_file)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,56 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_beit(ckpt):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
if k.startswith('patch_embed'):
new_key = k.replace('patch_embed.proj', 'patch_embed.projection')
new_ckpt[new_key] = v
if k.startswith('blocks'):
new_key = k.replace('blocks', 'layers')
if 'norm' in new_key:
new_key = new_key.replace('norm', 'ln')
elif 'mlp.fc1' in new_key:
new_key = new_key.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in new_key:
new_key = new_key.replace('mlp.fc2', 'ffn.layers.1')
new_ckpt[new_key] = v
else:
new_key = k
new_ckpt[new_key] = v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in official pretrained beit models to'
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_beit(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,163 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_vitlayer(paras):
new_para_name = ''
if paras[0] == 'ln_1':
new_para_name = '.'.join(['ln1'] + paras[1:])
elif paras[0] == 'attn':
new_para_name = '.'.join(['attn.attn'] + paras[1:])
elif paras[0] == 'ln_2':
new_para_name = '.'.join(['ln2'] + paras[1:])
elif paras[0] == 'mlp':
if paras[1] == 'c_fc':
new_para_name = '.'.join(['ffn.layers.0.0'] + paras[-1:])
else:
new_para_name = '.'.join(['ffn.layers.1'] + paras[-1:])
else:
print(f'Wrong for {paras}')
return new_para_name
def convert_translayer(paras):
new_para_name = ''
if paras[0] == 'attn':
new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
elif paras[0] == 'ln_1':
new_para_name = '.'.join(['norms.0'] + paras[1:])
elif paras[0] == 'ln_2':
new_para_name = '.'.join(['norms.1'] + paras[1:])
elif paras[0] == 'mlp':
if paras[1] == 'c_fc':
new_para_name = '.'.join(['ffns.0.layers.0.0'] + paras[2:])
elif paras[1] == 'c_proj':
new_para_name = '.'.join(['ffns.0.layers.1'] + paras[2:])
else:
print(f'Wrong for {paras}')
else:
print(f'Wrong for {paras}')
return new_para_name
def convert_key_name(ckpt, visual_split):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
key_list = k.split('.')
if key_list[0] == 'visual':
new_transform_name = 'image_encoder'
if key_list[1] == 'class_embedding':
new_name = '.'.join([new_transform_name, 'cls_token'])
elif key_list[1] == 'positional_embedding':
new_name = '.'.join([new_transform_name, 'pos_embed'])
elif key_list[1] == 'conv1':
new_name = '.'.join([
new_transform_name, 'patch_embed.projection', key_list[2]
])
elif key_list[1] == 'ln_pre':
new_name = '.'.join(
[new_transform_name, key_list[1], key_list[2]])
elif key_list[1] == 'transformer':
new_layer_name = 'layers'
layer_index = key_list[3]
paras = key_list[4:]
if int(layer_index) < visual_split:
new_para_name = convert_vitlayer(paras)
new_name = '.'.join([
new_transform_name, new_layer_name, layer_index,
new_para_name
])
else:
new_para_name = convert_translayer(paras)
new_transform_name = 'decode_head.rec_with_attnbias'
new_layer_name = 'layers'
layer_index = str(int(layer_index) - visual_split)
new_name = '.'.join([
new_transform_name, new_layer_name, layer_index,
new_para_name
])
elif key_list[1] == 'proj':
new_name = 'decode_head.rec_with_attnbias.proj.weight'
elif key_list[1] == 'ln_post':
new_name = k.replace('visual', 'decode_head.rec_with_attnbias')
else:
print(f'pop parameter: {k}')
continue
else:
text_encoder_name = 'text_encoder'
if key_list[0] == 'transformer':
layer_name = 'transformer'
layer_index = key_list[2]
paras = key_list[3:]
new_para_name = convert_translayer(paras)
new_name = '.'.join([
text_encoder_name, layer_name, layer_index, new_para_name
])
elif key_list[0] in [
'positional_embedding', 'text_projection', 'bg_embed',
'attn_mask', 'logit_scale', 'token_embedding', 'ln_final'
]:
new_name = 'text_encoder.' + k
else:
print(f'pop parameter: {k}')
continue
new_ckpt[new_name] = v
return new_ckpt
def convert_tensor(ckpt):
cls_token = ckpt['image_encoder.cls_token']
new_cls_token = cls_token.unsqueeze(0).unsqueeze(0)
ckpt['image_encoder.cls_token'] = new_cls_token
pos_embed = ckpt['image_encoder.pos_embed']
new_pos_embed = pos_embed.unsqueeze(0)
ckpt['image_encoder.pos_embed'] = new_pos_embed
proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight']
new_proj_weight = proj_weight.transpose(1, 0)
ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight
return ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in timm pretrained vit models to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
if any([s in args.src for s in ['B-16', 'b16', 'base_patch16']]):
visual_split = 9
elif any([s in args.src for s in ['L-14', 'l14', 'large_patch14']]):
visual_split = 18
else:
print('Make sure the clip model is ViT-B/16 or ViT-L/14!')
visual_split = -1
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if isinstance(checkpoint, torch.jit.RecursiveScriptModule):
state_dict = checkpoint.state_dict()
else:
if 'state_dict' in checkpoint:
# timm checkpoint
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
# deit checkpoint
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_key_name(state_dict, visual_split)
weight = convert_tensor(weight)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,82 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_mit(ckpt):
new_ckpt = OrderedDict()
# Process the concat between q linear weights and kv linear weights
for k, v in ckpt.items():
if k.startswith('head'):
continue
# patch embedding conversion
elif k.startswith('patch_embed'):
stage_i = int(k.split('.')[0].replace('patch_embed', ''))
new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0')
new_v = v
if 'proj.' in new_k:
new_k = new_k.replace('proj.', 'projection.')
# transformer encoder layer conversion
elif k.startswith('block'):
stage_i = int(k.split('.')[0].replace('block', ''))
new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1')
new_v = v
if 'attn.q.' in new_k:
sub_item_k = k.replace('q.', 'kv.')
new_k = new_k.replace('q.', 'attn.in_proj_')
new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)
elif 'attn.kv.' in new_k:
continue
elif 'attn.proj.' in new_k:
new_k = new_k.replace('proj.', 'attn.out_proj.')
elif 'attn.sr.' in new_k:
new_k = new_k.replace('sr.', 'sr.')
elif 'mlp.' in new_k:
string = f'{new_k}-'
new_k = new_k.replace('mlp.', 'ffn.layers.')
if 'fc1.weight' in new_k or 'fc2.weight' in new_k:
new_v = v.reshape((*v.shape, 1, 1))
new_k = new_k.replace('fc1.', '0.')
new_k = new_k.replace('dwconv.dwconv.', '1.')
new_k = new_k.replace('fc2.', '4.')
string += f'{new_k} {v.shape}-{new_v.shape}'
# norm layer conversion
elif k.startswith('norm'):
stage_i = int(k.split('.')[0].replace('norm', ''))
new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2')
new_v = v
else:
new_k = k
new_v = v
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in official pretrained segformer to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_mit(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,220 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_key_name(ckpt):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
key_list = k.split('.')
if key_list[0] == 'clip_visual_extractor':
new_transform_name = 'image_encoder'
if key_list[1] == 'class_embedding':
new_name = '.'.join([new_transform_name, 'cls_token'])
elif key_list[1] == 'positional_embedding':
new_name = '.'.join([new_transform_name, 'pos_embed'])
elif key_list[1] == 'conv1':
new_name = '.'.join([
new_transform_name, 'patch_embed.projection', key_list[2]
])
elif key_list[1] == 'ln_pre':
new_name = '.'.join(
[new_transform_name, key_list[1], key_list[2]])
elif key_list[1] == 'resblocks':
new_layer_name = 'layers'
layer_index = key_list[2]
paras = key_list[3:]
if paras[0] == 'ln_1':
new_para_name = '.'.join(['ln1'] + key_list[4:])
elif paras[0] == 'attn':
new_para_name = '.'.join(['attn.attn'] + key_list[4:])
elif paras[0] == 'ln_2':
new_para_name = '.'.join(['ln2'] + key_list[4:])
elif paras[0] == 'mlp':
if paras[1] == 'c_fc':
new_para_name = '.'.join(['ffn.layers.0.0'] +
key_list[-1:])
else:
new_para_name = '.'.join(['ffn.layers.1'] +
key_list[-1:])
new_name = '.'.join([
new_transform_name, new_layer_name, layer_index,
new_para_name
])
elif key_list[0] == 'side_adapter_network':
decode_head_name = 'decode_head'
module_name = 'side_adapter_network'
if key_list[1] == 'vit_model':
if key_list[2] == 'blocks':
layer_name = 'encode_layers'
layer_index = key_list[3]
paras = key_list[4:]
if paras[0] == 'norm1':
new_para_name = '.'.join(['ln1'] + key_list[5:])
elif paras[0] == 'attn':
new_para_name = '.'.join(key_list[4:])
new_para_name = new_para_name.replace(
'attn.qkv.', 'attn.attn.in_proj_')
new_para_name = new_para_name.replace(
'attn.proj', 'attn.attn.out_proj')
elif paras[0] == 'norm2':
new_para_name = '.'.join(['ln2'] + key_list[5:])
elif paras[0] == 'mlp':
new_para_name = '.'.join(['ffn'] + key_list[5:])
new_para_name = new_para_name.replace(
'fc1', 'layers.0.0')
new_para_name = new_para_name.replace(
'fc2', 'layers.1')
else:
print(f'Wrong for {k}')
new_name = '.'.join([
decode_head_name, module_name, layer_name, layer_index,
new_para_name
])
elif key_list[2] == 'pos_embed':
new_name = '.'.join(
[decode_head_name, module_name, 'pos_embed'])
elif key_list[2] == 'patch_embed':
new_name = '.'.join([
decode_head_name, module_name, 'patch_embed',
'projection', key_list[4]
])
else:
print(f'Wrong for {k}')
elif key_list[1] == 'query_embed' or key_list[
1] == 'query_pos_embed':
new_name = '.'.join(
[decode_head_name, module_name, key_list[1]])
elif key_list[1] == 'fusion_layers':
layer_name = 'conv_clips'
layer_index = key_list[2][-1]
paras = '.'.join(key_list[3:])
new_para_name = paras.replace('input_proj.0', '0')
new_para_name = new_para_name.replace('input_proj.1', '1.conv')
new_name = '.'.join([
decode_head_name, module_name, layer_name, layer_index,
new_para_name
])
elif key_list[1] == 'mask_decoder':
new_name = 'decode_head.' + k
else:
print(f'Wrong for {k}')
elif key_list[0] == 'clip_rec_head':
module_name = 'rec_with_attnbias'
if key_list[1] == 'proj':
new_name = '.'.join(
[decode_head_name, module_name, 'proj.weight'])
elif key_list[1] == 'ln_post':
new_name = '.'.join(
[decode_head_name, module_name, 'ln_post', key_list[2]])
elif key_list[1] == 'resblocks':
new_layer_name = 'layers'
layer_index = key_list[2]
paras = key_list[3:]
if paras[0] == 'ln_1':
new_para_name = '.'.join(['norms.0'] + paras[1:])
elif paras[0] == 'attn':
new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
elif paras[0] == 'ln_2':
new_para_name = '.'.join(['norms.1'] + paras[1:])
elif paras[0] == 'mlp':
if paras[1] == 'c_fc':
new_para_name = '.'.join(['ffns.0.layers.0.0'] +
paras[2:])
elif paras[1] == 'c_proj':
new_para_name = '.'.join(['ffns.0.layers.1'] +
paras[2:])
else:
print(f'Wrong for {k}')
new_name = '.'.join([
decode_head_name, module_name, new_layer_name, layer_index,
new_para_name
])
else:
print(f'Wrong for {k}')
elif key_list[0] == 'ov_classifier':
text_encoder_name = 'text_encoder'
if key_list[1] == 'transformer':
layer_name = 'transformer'
layer_index = key_list[3]
paras = key_list[4:]
if paras[0] == 'attn':
new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
elif paras[0] == 'ln_1':
new_para_name = '.'.join(['norms.0'] + paras[1:])
elif paras[0] == 'ln_2':
new_para_name = '.'.join(['norms.1'] + paras[1:])
elif paras[0] == 'mlp':
if paras[1] == 'c_fc':
new_para_name = '.'.join(['ffns.0.layers.0.0'] +
paras[2:])
elif paras[1] == 'c_proj':
new_para_name = '.'.join(['ffns.0.layers.1'] +
paras[2:])
else:
print(f'Wrong for {k}')
else:
print(f'Wrong for {k}')
new_name = '.'.join([
text_encoder_name, layer_name, layer_index, new_para_name
])
elif key_list[1] in [
'positional_embedding', 'text_projection', 'bg_embed',
'attn_mask', 'logit_scale', 'token_embedding', 'ln_final'
]:
new_name = k.replace('ov_classifier', 'text_encoder')
else:
print(f'Wrong for {k}')
elif key_list[0] == 'criterion':
new_name = k
else:
print(f'Wrong for {k}')
new_ckpt[new_name] = v
return new_ckpt
def convert_tensor(ckpt):
cls_token = ckpt['image_encoder.cls_token']
new_cls_token = cls_token.unsqueeze(0).unsqueeze(0)
ckpt['image_encoder.cls_token'] = new_cls_token
pos_embed = ckpt['image_encoder.pos_embed']
new_pos_embed = pos_embed.unsqueeze(0)
ckpt['image_encoder.pos_embed'] = new_pos_embed
proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight']
new_proj_weight = proj_weight.transpose(1, 0)
ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight
return ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in timm pretrained vit models to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
# timm checkpoint
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
# deit checkpoint
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_key_name(state_dict)
weight = convert_tensor(weight)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,71 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_stdc(ckpt, stdc_type):
new_state_dict = {}
if stdc_type == 'STDC1':
stage_lst = ['0', '1', '2.0', '2.1', '3.0', '3.1', '4.0', '4.1']
else:
stage_lst = [
'0', '1', '2.0', '2.1', '2.2', '2.3', '3.0', '3.1', '3.2', '3.3',
'3.4', '4.0', '4.1', '4.2'
]
for k, v in ckpt.items():
ori_k = k
flag = False
if 'cp.' in k:
k = k.replace('cp.', '')
if 'features.' in k:
num_layer = int(k.split('.')[1])
feature_key_lst = 'features.' + str(num_layer) + '.'
stages_key_lst = 'stages.' + stage_lst[num_layer] + '.'
k = k.replace(feature_key_lst, stages_key_lst)
flag = True
if 'conv_list' in k:
k = k.replace('conv_list', 'layers')
flag = True
if 'avd_layer.' in k:
if 'avd_layer.0' in k:
k = k.replace('avd_layer.0', 'downsample.conv')
elif 'avd_layer.1' in k:
k = k.replace('avd_layer.1', 'downsample.bn')
flag = True
if flag:
new_state_dict[k] = ckpt[ori_k]
return new_state_dict
def main():
parser = argparse.ArgumentParser(
description='Convert keys in official pretrained STDC1/2 to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
parser.add_argument('type', help='model type: STDC1 or STDC2')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
assert args.type in ['STDC1',
'STDC2'], 'STD type should be STDC1 or STDC2!'
weight = convert_stdc(state_dict, args.type)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,87 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_swin(ckpt):
new_ckpt = OrderedDict()
def correct_unfold_reduction_order(x):
out_channel, in_channel = x.shape
x = x.reshape(out_channel, 4, in_channel // 4)
x = x[:, [0, 2, 1, 3], :].transpose(1,
2).reshape(out_channel, in_channel)
return x
def correct_unfold_norm_order(x):
in_channel = x.shape[0]
x = x.reshape(4, in_channel // 4)
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
return x
for k, v in ckpt.items():
if k.startswith('head'):
continue
elif k.startswith('layers'):
new_v = v
if 'attn.' in k:
new_k = k.replace('attn.', 'attn.w_msa.')
elif 'mlp.' in k:
if 'mlp.fc1.' in k:
new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
elif 'mlp.fc2.' in k:
new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
else:
new_k = k.replace('mlp.', 'ffn.')
elif 'downsample' in k:
new_k = k
if 'reduction.' in k:
new_v = correct_unfold_reduction_order(v)
elif 'norm.' in k:
new_v = correct_unfold_norm_order(v)
else:
new_k = k
new_k = new_k.replace('layers', 'stages', 1)
elif k.startswith('patch_embed'):
new_v = v
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
else:
new_v = v
new_k = k
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in official pretrained swin models to'
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_swin(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,87 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_twins(args, ckpt):
new_ckpt = OrderedDict()
for k, v in list(ckpt.items()):
new_v = v
if k.startswith('head'):
continue
elif k.startswith('patch_embeds'):
if 'proj.' in k:
new_k = k.replace('proj.', 'projection.')
else:
new_k = k
elif k.startswith('blocks'):
# Union
if 'attn.q.' in k:
new_k = k.replace('q.', 'attn.in_proj_')
new_v = torch.cat([v, ckpt[k.replace('attn.q.', 'attn.kv.')]],
dim=0)
elif 'mlp.fc1' in k:
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in k:
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
# Only pcpvt
elif args.model == 'pcpvt':
if 'attn.proj.' in k:
new_k = k.replace('proj.', 'attn.out_proj.')
else:
new_k = k
# Only svt
else:
if 'attn.proj.' in k:
k_lst = k.split('.')
if int(k_lst[2]) % 2 == 1:
new_k = k.replace('proj.', 'attn.out_proj.')
else:
new_k = k
else:
new_k = k
new_k = new_k.replace('blocks.', 'layers.')
elif k.startswith('pos_block'):
new_k = k.replace('pos_block', 'position_encodings')
if 'proj.0.' in new_k:
new_k = new_k.replace('proj.0.', 'proj.')
else:
new_k = k
if 'attn.kv.' not in k:
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in timm pretrained vit models to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
parser.add_argument('model', help='model: pcpvt or svt')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
# timm checkpoint
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
weight = convert_twins(args, state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,70 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_vit(ckpt):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
if k.startswith('head'):
continue
if k.startswith('norm'):
new_k = k.replace('norm.', 'ln1.')
elif k.startswith('patch_embed'):
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
elif k.startswith('blocks'):
if 'norm' in k:
new_k = k.replace('norm', 'ln')
elif 'mlp.fc1' in k:
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in k:
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
elif 'attn.qkv' in k:
new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_')
elif 'attn.proj' in k:
new_k = k.replace('attn.proj', 'attn.attn.out_proj')
else:
new_k = k
new_k = new_k.replace('blocks.', 'layers.')
else:
new_k = k
new_ckpt[new_k] = v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in timm pretrained vit models to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
# timm checkpoint
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
# deit checkpoint
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_vit(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,123 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import mmengine
import numpy as np
import torch
def vit_jax_to_torch(jax_weights, num_layer=12):
torch_weights = dict()
# patch embedding
conv_filters = jax_weights['embedding/kernel']
conv_filters = conv_filters.permute(3, 2, 0, 1)
torch_weights['patch_embed.projection.weight'] = conv_filters
torch_weights['patch_embed.projection.bias'] = jax_weights[
'embedding/bias']
# pos embedding
torch_weights['pos_embed'] = jax_weights[
'Transformer/posembed_input/pos_embedding']
# cls token
torch_weights['cls_token'] = jax_weights['cls']
# head
torch_weights['ln1.weight'] = jax_weights['Transformer/encoder_norm/scale']
torch_weights['ln1.bias'] = jax_weights['Transformer/encoder_norm/bias']
# transformer blocks
for i in range(num_layer):
jax_block = f'Transformer/encoderblock_{i}'
torch_block = f'layers.{i}'
# attention norm
torch_weights[f'{torch_block}.ln1.weight'] = jax_weights[
f'{jax_block}/LayerNorm_0/scale']
torch_weights[f'{torch_block}.ln1.bias'] = jax_weights[
f'{jax_block}/LayerNorm_0/bias']
# attention
query_weight = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/query/kernel']
query_bias = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/query/bias']
key_weight = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/key/kernel']
key_bias = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/key/bias']
value_weight = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/value/kernel']
value_bias = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/value/bias']
qkv_weight = torch.from_numpy(
np.stack((query_weight, key_weight, value_weight), 1))
qkv_weight = torch.flatten(qkv_weight, start_dim=1)
qkv_bias = torch.from_numpy(
np.stack((query_bias, key_bias, value_bias), 0))
qkv_bias = torch.flatten(qkv_bias, start_dim=0)
torch_weights[f'{torch_block}.attn.attn.in_proj_weight'] = qkv_weight
torch_weights[f'{torch_block}.attn.attn.in_proj_bias'] = qkv_bias
to_out_weight = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/out/kernel']
to_out_weight = torch.flatten(to_out_weight, start_dim=0, end_dim=1)
torch_weights[
f'{torch_block}.attn.attn.out_proj.weight'] = to_out_weight
torch_weights[f'{torch_block}.attn.attn.out_proj.bias'] = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/out/bias']
# mlp norm
torch_weights[f'{torch_block}.ln2.weight'] = jax_weights[
f'{jax_block}/LayerNorm_2/scale']
torch_weights[f'{torch_block}.ln2.bias'] = jax_weights[
f'{jax_block}/LayerNorm_2/bias']
# mlp
torch_weights[f'{torch_block}.ffn.layers.0.0.weight'] = jax_weights[
f'{jax_block}/MlpBlock_3/Dense_0/kernel']
torch_weights[f'{torch_block}.ffn.layers.0.0.bias'] = jax_weights[
f'{jax_block}/MlpBlock_3/Dense_0/bias']
torch_weights[f'{torch_block}.ffn.layers.1.weight'] = jax_weights[
f'{jax_block}/MlpBlock_3/Dense_1/kernel']
torch_weights[f'{torch_block}.ffn.layers.1.bias'] = jax_weights[
f'{jax_block}/MlpBlock_3/Dense_1/bias']
# transpose weights
for k, v in torch_weights.items():
if 'weight' in k and 'patch_embed' not in k and 'ln' not in k:
v = v.permute(1, 0)
torch_weights[k] = v
return torch_weights
def main():
# stole refactoring code from Robin Strudel, thanks
parser = argparse.ArgumentParser(
description='Convert keys from jax official pretrained vit models to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
jax_weights = np.load(args.src)
jax_weights_tensor = {}
for key in jax_weights.files:
value = torch.from_numpy(jax_weights[key])
jax_weights_tensor[key] = value
if 'L_16-i21k' in args.src:
num_layer = 24
else:
num_layer = 12
torch_weights = vit_jax_to_torch(jax_weights_tensor, num_layer)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(torch_weights, args.dst)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,24 @@
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
CHECKPOINT=$4
GPUS=${GPUS:-4}
GPUS_PER_NODE=${GPUS_PER_NODE:-4}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
PY_ARGS=${@:5}
SRUN_ARGS=${SRUN_ARGS:-""}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS}

View File

@@ -0,0 +1,23 @@
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
GPUS=${GPUS:-4}
GPUS_PER_NODE=${GPUS_PER_NODE:-4}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
SRUN_ARGS=${SRUN_ARGS:-""}
PY_ARGS=${@:4}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS}

123
finetune/tools/test.py Normal file
View File

@@ -0,0 +1,123 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
from mmengine.config import Config, DictAction
from mmengine.runner import Runner
# TODO: support fuse_conv_bn, visualization, and format_only
def parse_args():
parser = argparse.ArgumentParser(
description='MMSeg test (and eval) a model')
parser.add_argument('config', help='train config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--work-dir',
help=('if specified, the evaluation metric results will be dumped'
'into the directory as json'))
parser.add_argument(
'--out',
type=str,
help='The directory to save output prediction for offline evaluation')
parser.add_argument(
'--show', action='store_true', help='show prediction results')
parser.add_argument(
'--show-dir',
help='directory where painted images will be saved. '
'If specified, it will be automatically saved '
'to the work_dir/timestamp/show_dir')
parser.add_argument(
'--wait-time', type=float, default=2, help='the interval of show (s)')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument(
'--tta', action='store_true', help='Test time augmentation')
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
# will pass the `--local-rank` parameter to `tools/train.py` instead
# of `--local_rank`.
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def trigger_visualization_hook(cfg, args):
default_hooks = cfg.default_hooks
if 'visualization' in default_hooks:
visualization_hook = default_hooks['visualization']
# Turn on visualization
visualization_hook['draw'] = True
if args.show:
visualization_hook['show'] = True
visualization_hook['wait_time'] = args.wait_time
if args.show_dir:
visualizer = cfg.visualizer
visualizer['save_dir'] = args.show_dir
else:
raise RuntimeError(
'VisualizationHook must be included in default_hooks.'
'refer to usage '
'"visualization=dict(type=\'VisualizationHook\')"')
return cfg
def main():
args = parse_args()
# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
cfg.load_from = args.checkpoint
if args.show or args.show_dir:
cfg = trigger_visualization_hook(cfg, args)
if args.tta:
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
cfg.tta_model.module = cfg.model
cfg.model = cfg.tta_model
# add output_dir in metric
if args.out is not None:
cfg.test_evaluator['output_dir'] = args.out
cfg.test_evaluator['keep_results'] = True
# build the runner from config
runner = Runner.from_cfg(cfg)
# start testing
runner.test()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,112 @@
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser, Namespace
from pathlib import Path
from tempfile import TemporaryDirectory
from mmengine import Config
from mmengine.utils import mkdir_or_exist
try:
from model_archiver.model_packaging import package_model
from model_archiver.model_packaging_utils import ModelExportUtils
except ImportError:
package_model = None
def mmseg2torchserve(
config_file: str,
checkpoint_file: str,
output_folder: str,
model_name: str,
model_version: str = '1.0',
force: bool = False,
):
"""Converts mmsegmentation model (config + checkpoint) to TorchServe
`.mar`.
Args:
config_file:
In MMSegmentation config format.
The contents vary for each task repository.
checkpoint_file:
In MMSegmentation checkpoint format.
The contents vary for each task repository.
output_folder:
Folder where `{model_name}.mar` will be created.
The file created will be in TorchServe archive format.
model_name:
If not None, used for naming the `{model_name}.mar` file
that will be created under `output_folder`.
If None, `{Path(checkpoint_file).stem}` will be used.
model_version:
Model's version.
force:
If True, if there is an existing `{model_name}.mar`
file under `output_folder` it will be overwritten.
"""
mkdir_or_exist(output_folder)
config = Config.fromfile(config_file)
with TemporaryDirectory() as tmpdir:
config.dump(f'{tmpdir}/config.py')
args = Namespace(
**{
'model_file': f'{tmpdir}/config.py',
'serialized_file': checkpoint_file,
'handler': f'{Path(__file__).parent}/mmseg_handler.py',
'model_name': model_name or Path(checkpoint_file).stem,
'version': model_version,
'export_path': output_folder,
'force': force,
'requirements_file': None,
'extra_files': None,
'runtime': 'python',
'archive_format': 'default'
})
manifest = ModelExportUtils.generate_manifest_json(args)
package_model(args, manifest)
def parse_args():
parser = ArgumentParser(
description='Convert mmseg models to TorchServe `.mar` format.')
parser.add_argument('config', type=str, help='config file path')
parser.add_argument('checkpoint', type=str, help='checkpoint file path')
parser.add_argument(
'--output-folder',
type=str,
required=True,
help='Folder where `{model_name}.mar` will be created.')
parser.add_argument(
'--model-name',
type=str,
default=None,
help='If not None, used for naming the `{model_name}.mar`'
'file that will be created under `output_folder`.'
'If None, `{Path(checkpoint_file).stem}` will be used.')
parser.add_argument(
'--model-version',
type=str,
default='1.0',
help='Number used for versioning.')
parser.add_argument(
'-f',
'--force',
action='store_true',
help='overwrite the existing `{model_name}.mar`')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if package_model is None:
raise ImportError('`torch-model-archiver` is required.'
'Try: pip install torch-model-archiver')
mmseg2torchserve(args.config, args.checkpoint, args.output_folder,
args.model_name, args.model_version, args.force)

View File

@@ -0,0 +1,56 @@
# Copyright (c) OpenMMLab. All rights reserved.
import base64
import os
import cv2
import mmcv
import torch
from mmengine.model.utils import revert_sync_batchnorm
from ts.torch_handler.base_handler import BaseHandler
from mmseg.apis import inference_model, init_model
class MMsegHandler(BaseHandler):
def initialize(self, context):
properties = context.system_properties
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(self.map_location + ':' +
str(properties.get('gpu_id')) if torch.cuda.
is_available() else self.map_location)
self.manifest = context.manifest
model_dir = properties.get('model_dir')
serialized_file = self.manifest['model']['serializedFile']
checkpoint = os.path.join(model_dir, serialized_file)
self.config_file = os.path.join(model_dir, 'config.py')
self.model = init_model(self.config_file, checkpoint, self.device)
self.model = revert_sync_batchnorm(self.model)
self.initialized = True
def preprocess(self, data):
images = []
for row in data:
image = row.get('data') or row.get('body')
if isinstance(image, str):
image = base64.b64decode(image)
image = mmcv.imfrombytes(image)
images.append(image)
return images
def inference(self, data, *args, **kwargs):
results = [inference_model(self.model, img) for img in data]
return results
def postprocess(self, data):
output = []
for image_result in data:
_, buffer = cv2.imencode('.png', image_result[0].astype('uint8'))
content = buffer.tobytes()
output.append(content)
return output

View File

@@ -0,0 +1,58 @@
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser
from io import BytesIO
import matplotlib.pyplot as plt
import mmcv
import requests
from mmseg.apis import inference_model, init_model
def parse_args():
parser = ArgumentParser(
description='Compare result of torchserve and pytorch,'
'and visualize them.')
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('model_name', help='The model name in the server')
parser.add_argument(
'--inference-addr',
default='127.0.0.1:8080',
help='Address and port of the inference server')
parser.add_argument(
'--result-image',
type=str,
default=None,
help='save server output in result-image')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
return args
def main(args):
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name
with open(args.img, 'rb') as image:
tmp_res = requests.post(url, image)
content = tmp_res.content
if args.result_image:
with open(args.result_image, 'wb') as out_image:
out_image.write(content)
plt.imshow(mmcv.imread(args.result_image, 'grayscale'))
plt.show()
else:
plt.imshow(plt.imread(BytesIO(content)))
plt.show()
model = init_model(args.config, args.checkpoint, args.device)
image = mmcv.imread(args.img)
result = inference_model(model, image)
plt.imshow(result[0])
plt.show()
if __name__ == '__main__':
args = parse_args()
main(args)

104
finetune/tools/train.py Normal file
View File

@@ -0,0 +1,104 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging
import os
import os.path as osp
from mmengine.config import Config, DictAction
from mmengine.logging import print_log
from mmengine.runner import Runner
from mmseg.registry import RUNNERS
def parse_args():
parser = argparse.ArgumentParser(description='Train a segmentor')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume',
action='store_true',
default=False,
help='resume from the latest checkpoint in the work_dir automatically')
parser.add_argument(
'--amp',
action='store_true',
default=False,
help='enable automatic-mixed-precision training')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
# will pass the `--local-rank` parameter to `tools/train.py` instead
# of `--local_rank`.
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
# enable automatic-mixed-precision training
if args.amp is True:
optim_wrapper = cfg.optim_wrapper.type
if optim_wrapper == 'AmpOptimWrapper':
print_log(
'AMP training is already enabled in your config.',
logger='current',
level=logging.WARNING)
else:
assert optim_wrapper == 'OptimWrapper', (
'`--amp` is only supported when the optimizer wrapper type is '
f'`OptimWrapper` but got {optim_wrapper}.')
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.loss_scale = 'dynamic'
# resume training
cfg.resume = args.resume
# build the runner from config
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
# start training
runner.train()
if __name__ == '__main__':
main()