init
This commit is contained in:
130
finetune/tools/analysis_tools/analyze_logs.py
Normal file
130
finetune/tools/analysis_tools/analyze_logs.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Modified from https://github.com/open-
|
||||
mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py."""
|
||||
import argparse
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
|
||||
def plot_curve(log_dicts, args):
|
||||
if args.backend is not None:
|
||||
plt.switch_backend(args.backend)
|
||||
sns.set_style(args.style)
|
||||
# if legend is None, use {filename}_{key} as legend
|
||||
legend = args.legend
|
||||
if legend is None:
|
||||
legend = []
|
||||
for json_log in args.json_logs:
|
||||
for metric in args.keys:
|
||||
legend.append(f'{json_log}_{metric}')
|
||||
assert len(legend) == (len(args.json_logs) * len(args.keys))
|
||||
metrics = args.keys
|
||||
|
||||
num_metrics = len(metrics)
|
||||
for i, log_dict in enumerate(log_dicts):
|
||||
epochs = list(log_dict.keys())
|
||||
for j, metric in enumerate(metrics):
|
||||
print(f'plot curve of {args.json_logs[i]}, metric is {metric}')
|
||||
plot_epochs = []
|
||||
plot_iters = []
|
||||
plot_values = []
|
||||
# In some log files exist lines of validation,
|
||||
# `mode` list is used to only collect iter number
|
||||
# of training line.
|
||||
for epoch in epochs:
|
||||
epoch_logs = log_dict[epoch]
|
||||
if metric not in epoch_logs.keys():
|
||||
continue
|
||||
if metric in ['mIoU', 'mAcc', 'aAcc']:
|
||||
plot_epochs.append(epoch)
|
||||
plot_values.append(epoch_logs[metric][0])
|
||||
else:
|
||||
for idx in range(len(epoch_logs[metric])):
|
||||
plot_iters.append(epoch_logs['step'][idx])
|
||||
plot_values.append(epoch_logs[metric][idx])
|
||||
ax = plt.gca()
|
||||
label = legend[i * num_metrics + j]
|
||||
if metric in ['mIoU', 'mAcc', 'aAcc']:
|
||||
ax.set_xticks(plot_epochs)
|
||||
plt.xlabel('step')
|
||||
plt.plot(plot_epochs, plot_values, label=label, marker='o')
|
||||
else:
|
||||
plt.xlabel('iter')
|
||||
plt.plot(plot_iters, plot_values, label=label, linewidth=0.5)
|
||||
plt.legend()
|
||||
if args.title is not None:
|
||||
plt.title(args.title)
|
||||
if args.out is None:
|
||||
plt.show()
|
||||
else:
|
||||
print(f'save curve to: {args.out}')
|
||||
plt.savefig(args.out)
|
||||
plt.cla()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Analyze Json Log')
|
||||
parser.add_argument(
|
||||
'json_logs',
|
||||
type=str,
|
||||
nargs='+',
|
||||
help='path of train log in json format')
|
||||
parser.add_argument(
|
||||
'--keys',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=['mIoU'],
|
||||
help='the metric that you want to plot')
|
||||
parser.add_argument('--title', type=str, help='title of figure')
|
||||
parser.add_argument(
|
||||
'--legend',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=None,
|
||||
help='legend of each plot')
|
||||
parser.add_argument(
|
||||
'--backend', type=str, default=None, help='backend of plt')
|
||||
parser.add_argument(
|
||||
'--style', type=str, default='dark', help='style of plt')
|
||||
parser.add_argument('--out', type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def load_json_logs(json_logs):
|
||||
# load and convert json_logs to log_dict, key is step, value is a sub dict
|
||||
# keys of sub dict is different metrics
|
||||
# value of sub dict is a list of corresponding values of all iterations
|
||||
log_dicts = [dict() for _ in json_logs]
|
||||
prev_step = 0
|
||||
for json_log, log_dict in zip(json_logs, log_dicts):
|
||||
with open(json_log) as log_file:
|
||||
for line in log_file:
|
||||
log = json.loads(line.strip())
|
||||
# the final step in json file is 0.
|
||||
if 'step' in log and log['step'] != 0:
|
||||
step = log['step']
|
||||
prev_step = step
|
||||
else:
|
||||
step = prev_step
|
||||
if step not in log_dict:
|
||||
log_dict[step] = defaultdict(list)
|
||||
for k, v in log.items():
|
||||
log_dict[step][k].append(v)
|
||||
return log_dicts
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
json_logs = args.json_logs
|
||||
for json_log in json_logs:
|
||||
assert json_log.endswith('.json')
|
||||
log_dicts = load_json_logs(json_logs)
|
||||
plot_curve(log_dicts, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
121
finetune/tools/analysis_tools/benchmark.py
Normal file
121
finetune/tools/analysis_tools/benchmark.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.fileio import dump
|
||||
from mmengine.model.utils import revert_sync_batchnorm
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner import Runner, load_checkpoint
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='MMSeg benchmark a model')
|
||||
parser.add_argument('config', help='test config file path')
|
||||
parser.add_argument('checkpoint', help='checkpoint file')
|
||||
parser.add_argument(
|
||||
'--log-interval', type=int, default=50, help='interval of logging')
|
||||
parser.add_argument(
|
||||
'--work-dir',
|
||||
help=('if specified, the results will be dumped '
|
||||
'into the directory as json'))
|
||||
parser.add_argument('--repeat-times', type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
cfg = Config.fromfile(args.config)
|
||||
|
||||
init_default_scope(cfg.get('default_scope', 'mmseg'))
|
||||
|
||||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
if args.work_dir is not None:
|
||||
mkdir_or_exist(osp.abspath(args.work_dir))
|
||||
json_file = osp.join(args.work_dir, f'fps_{timestamp}.json')
|
||||
else:
|
||||
# use config filename as default work_dir if cfg.work_dir is None
|
||||
work_dir = osp.join('./work_dirs',
|
||||
osp.splitext(osp.basename(args.config))[0])
|
||||
mkdir_or_exist(osp.abspath(work_dir))
|
||||
json_file = osp.join(work_dir, f'fps_{timestamp}.json')
|
||||
|
||||
repeat_times = args.repeat_times
|
||||
# set cudnn_benchmark
|
||||
torch.backends.cudnn.benchmark = False
|
||||
cfg.model.pretrained = None
|
||||
|
||||
benchmark_dict = dict(config=args.config, unit='img / s')
|
||||
overall_fps_list = []
|
||||
cfg.test_dataloader.batch_size = 1
|
||||
for time_index in range(repeat_times):
|
||||
print(f'Run {time_index + 1}:')
|
||||
# build the dataloader
|
||||
data_loader = Runner.build_dataloader(cfg.test_dataloader)
|
||||
|
||||
# build the model and load checkpoint
|
||||
cfg.model.train_cfg = None
|
||||
model = MODELS.build(cfg.model)
|
||||
|
||||
if 'checkpoint' in args and osp.exists(args.checkpoint):
|
||||
load_checkpoint(model, args.checkpoint, map_location='cpu')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
|
||||
model = revert_sync_batchnorm(model)
|
||||
|
||||
model.eval()
|
||||
|
||||
# the first several iterations may be very slow so skip them
|
||||
num_warmup = 5
|
||||
pure_inf_time = 0
|
||||
total_iters = 200
|
||||
|
||||
# benchmark with 200 batches and take the average
|
||||
for i, data in enumerate(data_loader):
|
||||
data = model.data_preprocessor(data, True)
|
||||
inputs = data['inputs']
|
||||
data_samples = data['data_samples']
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
model(inputs, data_samples, mode='predict')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start_time
|
||||
|
||||
if i >= num_warmup:
|
||||
pure_inf_time += elapsed
|
||||
if (i + 1) % args.log_interval == 0:
|
||||
fps = (i + 1 - num_warmup) / pure_inf_time
|
||||
print(f'Done image [{i + 1:<3}/ {total_iters}], '
|
||||
f'fps: {fps:.2f} img / s')
|
||||
|
||||
if (i + 1) == total_iters:
|
||||
fps = (i + 1 - num_warmup) / pure_inf_time
|
||||
print(f'Overall fps: {fps:.2f} img / s\n')
|
||||
benchmark_dict[f'overall_fps_{time_index + 1}'] = round(fps, 2)
|
||||
overall_fps_list.append(fps)
|
||||
break
|
||||
benchmark_dict['average_fps'] = round(np.mean(overall_fps_list), 2)
|
||||
benchmark_dict['fps_variance'] = round(np.var(overall_fps_list), 4)
|
||||
print(f'Average fps of {repeat_times} evaluations: '
|
||||
f'{benchmark_dict["average_fps"]}')
|
||||
print(f'The variance of {repeat_times} evaluations: '
|
||||
f'{benchmark_dict["fps_variance"]}')
|
||||
dump(benchmark_dict, json_file, indent=4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
77
finetune/tools/analysis_tools/browse_dataset.py
Normal file
77
finetune/tools/analysis_tools/browse_dataset.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.utils import ProgressBar
|
||||
|
||||
from mmseg.registry import DATASETS, VISUALIZERS
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Browse a dataset')
|
||||
parser.add_argument('config', help='train config file path')
|
||||
parser.add_argument(
|
||||
'--output-dir',
|
||||
default=None,
|
||||
type=str,
|
||||
help='If there is no display interface, you can save it')
|
||||
parser.add_argument('--not-show', default=False, action='store_true')
|
||||
parser.add_argument(
|
||||
'--show-interval',
|
||||
type=float,
|
||||
default=2,
|
||||
help='the interval of show (s)')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
cfg = Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# register all modules in mmdet into the registries
|
||||
register_all_modules()
|
||||
|
||||
dataset = DATASETS.build(cfg.train_dataloader.dataset)
|
||||
visualizer = VISUALIZERS.build(cfg.visualizer)
|
||||
visualizer.dataset_meta = dataset.metainfo
|
||||
|
||||
progress_bar = ProgressBar(len(dataset))
|
||||
for item in dataset:
|
||||
img = item['inputs'].permute(1, 2, 0).numpy()
|
||||
img = img[..., [2, 1, 0]] # bgr to rgb
|
||||
data_sample = item['data_samples'].numpy()
|
||||
img_path = osp.basename(item['data_samples'].img_path)
|
||||
|
||||
out_file = osp.join(
|
||||
args.output_dir,
|
||||
osp.basename(img_path)) if args.output_dir is not None else None
|
||||
|
||||
visualizer.add_datasample(
|
||||
name=osp.basename(img_path),
|
||||
image=img,
|
||||
data_sample=data_sample,
|
||||
draw_gt=True,
|
||||
draw_pred=False,
|
||||
wait_time=args.show_interval,
|
||||
out_file=out_file,
|
||||
show=not args.not_show)
|
||||
progress_bar.update()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
197
finetune/tools/analysis_tools/confusion_matrix.py
Normal file
197
finetune/tools/analysis_tools/confusion_matrix.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from matplotlib.ticker import MultipleLocator
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.utils import mkdir_or_exist, progressbar
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
|
||||
init_default_scope('mmseg')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Generate confusion matrix from segmentation results')
|
||||
parser.add_argument('config', help='test config file path')
|
||||
parser.add_argument(
|
||||
'prediction_path', help='prediction path where test folder result')
|
||||
parser.add_argument(
|
||||
'save_dir', help='directory where confusion matrix will be saved')
|
||||
parser.add_argument(
|
||||
'--show', action='store_true', help='show confusion matrix')
|
||||
parser.add_argument(
|
||||
'--color-theme',
|
||||
default='winter',
|
||||
help='theme of the matrix color map')
|
||||
parser.add_argument(
|
||||
'--title',
|
||||
default='Normalized Confusion Matrix',
|
||||
help='title of the matrix color map')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def calculate_confusion_matrix(dataset, results):
|
||||
"""Calculate the confusion matrix.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): Test or val dataset.
|
||||
results (list[ndarray]): A list of segmentation results in each image.
|
||||
"""
|
||||
n = len(dataset.METAINFO['classes'])
|
||||
confusion_matrix = np.zeros(shape=[n, n])
|
||||
assert len(dataset) == len(results)
|
||||
ignore_index = dataset.ignore_index
|
||||
reduce_zero_label = dataset.reduce_zero_label
|
||||
prog_bar = progressbar.ProgressBar(len(results))
|
||||
for idx, per_img_res in enumerate(results):
|
||||
res_segm = per_img_res
|
||||
gt_segm = dataset[idx]['data_samples'] \
|
||||
.gt_sem_seg.data.squeeze().numpy().astype(np.uint8)
|
||||
gt_segm, res_segm = gt_segm.flatten(), res_segm.flatten()
|
||||
if reduce_zero_label:
|
||||
gt_segm = gt_segm - 1
|
||||
to_ignore = gt_segm == ignore_index
|
||||
|
||||
gt_segm, res_segm = gt_segm[~to_ignore], res_segm[~to_ignore]
|
||||
inds = n * gt_segm + res_segm
|
||||
mat = np.bincount(inds, minlength=n**2).reshape(n, n)
|
||||
confusion_matrix += mat
|
||||
prog_bar.update()
|
||||
return confusion_matrix
|
||||
|
||||
|
||||
def plot_confusion_matrix(confusion_matrix,
|
||||
labels,
|
||||
save_dir=None,
|
||||
show=True,
|
||||
title='Normalized Confusion Matrix',
|
||||
color_theme='OrRd'):
|
||||
"""Draw confusion matrix with matplotlib.
|
||||
|
||||
Args:
|
||||
confusion_matrix (ndarray): The confusion matrix.
|
||||
labels (list[str]): List of class names.
|
||||
save_dir (str|optional): If set, save the confusion matrix plot to the
|
||||
given path. Default: None.
|
||||
show (bool): Whether to show the plot. Default: True.
|
||||
title (str): Title of the plot. Default: `Normalized Confusion Matrix`.
|
||||
color_theme (str): Theme of the matrix color map. Default: `winter`.
|
||||
"""
|
||||
# normalize the confusion matrix
|
||||
per_label_sums = confusion_matrix.sum(axis=1)[:, np.newaxis]
|
||||
confusion_matrix = \
|
||||
confusion_matrix.astype(np.float32) / per_label_sums * 100
|
||||
|
||||
num_classes = len(labels)
|
||||
fig, ax = plt.subplots(
|
||||
figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=300)
|
||||
cmap = plt.get_cmap(color_theme)
|
||||
im = ax.imshow(confusion_matrix, cmap=cmap)
|
||||
colorbar = plt.colorbar(mappable=im, ax=ax)
|
||||
colorbar.ax.tick_params(labelsize=20) # 设置 colorbar 标签的字体大小
|
||||
|
||||
title_font = {'weight': 'bold', 'size': 20}
|
||||
ax.set_title(title, fontdict=title_font)
|
||||
label_font = {'size': 40}
|
||||
plt.ylabel('Ground Truth Label', fontdict=label_font)
|
||||
plt.xlabel('Prediction Label', fontdict=label_font)
|
||||
|
||||
# draw locator
|
||||
xmajor_locator = MultipleLocator(1)
|
||||
xminor_locator = MultipleLocator(0.5)
|
||||
ax.xaxis.set_major_locator(xmajor_locator)
|
||||
ax.xaxis.set_minor_locator(xminor_locator)
|
||||
ymajor_locator = MultipleLocator(1)
|
||||
yminor_locator = MultipleLocator(0.5)
|
||||
ax.yaxis.set_major_locator(ymajor_locator)
|
||||
ax.yaxis.set_minor_locator(yminor_locator)
|
||||
|
||||
# draw grid
|
||||
ax.grid(True, which='minor', linestyle='-')
|
||||
|
||||
# draw label
|
||||
ax.set_xticks(np.arange(num_classes))
|
||||
ax.set_yticks(np.arange(num_classes))
|
||||
ax.set_xticklabels(labels, fontsize=20)
|
||||
ax.set_yticklabels(labels, fontsize=20)
|
||||
|
||||
ax.tick_params(
|
||||
axis='x', bottom=False, top=True, labelbottom=False, labeltop=True)
|
||||
plt.setp(
|
||||
ax.get_xticklabels(), rotation=45, ha='left', rotation_mode='anchor')
|
||||
|
||||
# draw confusion matrix value
|
||||
for i in range(num_classes):
|
||||
for j in range(num_classes):
|
||||
ax.text(
|
||||
j,
|
||||
i,
|
||||
'{}%'.format(
|
||||
round(confusion_matrix[i, j], 2
|
||||
) if not np.isnan(confusion_matrix[i, j]) else -1),
|
||||
ha='center',
|
||||
va='center',
|
||||
color='k',
|
||||
size=20)
|
||||
|
||||
ax.set_ylim(len(confusion_matrix) - 0.5, -0.5) # matplotlib>3.1.1
|
||||
|
||||
fig.tight_layout()
|
||||
if save_dir is not None:
|
||||
mkdir_or_exist(save_dir)
|
||||
plt.savefig(
|
||||
os.path.join(save_dir, 'confusion_matrix.png'), format='png')
|
||||
if show:
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
results = []
|
||||
for img in sorted(os.listdir(args.prediction_path)):
|
||||
img = os.path.join(args.prediction_path, img)
|
||||
image = Image.open(img)
|
||||
image = np.copy(image)
|
||||
results.append(image)
|
||||
|
||||
assert isinstance(results, list)
|
||||
if isinstance(results[0], np.ndarray):
|
||||
pass
|
||||
else:
|
||||
raise TypeError('invalid type of prediction results')
|
||||
|
||||
dataset = DATASETS.build(cfg.test_dataloader.dataset)
|
||||
confusion_matrix = calculate_confusion_matrix(dataset, results)
|
||||
plot_confusion_matrix(
|
||||
confusion_matrix,
|
||||
dataset.METAINFO['classes'],
|
||||
save_dir=args.save_dir,
|
||||
show=args.show,
|
||||
title=args.title,
|
||||
color_theme=args.color_theme)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
124
finetune/tools/analysis_tools/get_flops.py
Normal file
124
finetune/tools/analysis_tools/get_flops.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from mmengine import Config, DictAction
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model import revert_sync_batchnorm
|
||||
from mmengine.registry import init_default_scope
|
||||
|
||||
from mmseg.models import BaseSegmentor
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
try:
|
||||
from mmengine.analysis import get_model_complexity_info
|
||||
from mmengine.analysis.print_helper import _format_size
|
||||
except ImportError:
|
||||
raise ImportError('Please upgrade mmengine >= 0.6.0 to use this script.')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Get the FLOPs of a segmentor')
|
||||
parser.add_argument('config', help='train config file path')
|
||||
parser.add_argument(
|
||||
'--shape',
|
||||
type=int,
|
||||
nargs='+',
|
||||
default=[2048, 1024],
|
||||
help='input image size')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def inference(args: argparse.Namespace, logger: MMLogger) -> dict:
|
||||
config_name = Path(args.config)
|
||||
|
||||
if not config_name.exists():
|
||||
logger.error(f'Config file {config_name} does not exist')
|
||||
|
||||
cfg: Config = Config.fromfile(config_name)
|
||||
cfg.work_dir = tempfile.TemporaryDirectory().name
|
||||
cfg.log_level = 'WARN'
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
init_default_scope(cfg.get('scope', 'mmseg'))
|
||||
|
||||
if len(args.shape) == 1:
|
||||
input_shape = (3, args.shape[0], args.shape[0])
|
||||
elif len(args.shape) == 2:
|
||||
input_shape = (3, ) + tuple(args.shape)
|
||||
else:
|
||||
raise ValueError('invalid input shape')
|
||||
result = {}
|
||||
|
||||
model: BaseSegmentor = MODELS.build(cfg.model)
|
||||
if hasattr(model, 'auxiliary_head'):
|
||||
model.auxiliary_head = None
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
model = revert_sync_batchnorm(model)
|
||||
result['ori_shape'] = input_shape[-2:]
|
||||
result['pad_shape'] = input_shape[-2:]
|
||||
data_batch = {
|
||||
'inputs': [torch.rand(input_shape)],
|
||||
'data_samples': [SegDataSample(metainfo=result)]
|
||||
}
|
||||
data = model.data_preprocessor(data_batch)
|
||||
model.eval()
|
||||
if cfg.model.decode_head.type in ['MaskFormerHead', 'Mask2FormerHead']:
|
||||
# TODO: Support MaskFormer and Mask2Former
|
||||
raise NotImplementedError('MaskFormer and Mask2Former are not '
|
||||
'supported yet.')
|
||||
outputs = get_model_complexity_info(
|
||||
model,
|
||||
input_shape=None,
|
||||
inputs=data['inputs'],
|
||||
show_table=False,
|
||||
show_arch=False)
|
||||
result['flops'] = _format_size(outputs['flops'])
|
||||
result['params'] = _format_size(outputs['params'])
|
||||
result['compute_type'] = 'direct: randomly generate a picture'
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
args = parse_args()
|
||||
logger = MMLogger.get_instance(name='MMLogger')
|
||||
|
||||
result = inference(args, logger)
|
||||
split_line = '=' * 30
|
||||
ori_shape = result['ori_shape']
|
||||
pad_shape = result['pad_shape']
|
||||
flops = result['flops']
|
||||
params = result['params']
|
||||
compute_type = result['compute_type']
|
||||
|
||||
if pad_shape != ori_shape:
|
||||
print(f'{split_line}\nUse size divisor set input shape '
|
||||
f'from {ori_shape} to {pad_shape}')
|
||||
print(f'{split_line}\nCompute type: {compute_type}\n'
|
||||
f'Input shape: {pad_shape}\nFlops: {flops}\n'
|
||||
f'Params: {params}\n{split_line}')
|
||||
print('!!!Please be cautious if you use the results in papers. '
|
||||
'You may need to check if all ops are supported and verify '
|
||||
'that the flops computation is correct.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
127
finetune/tools/analysis_tools/visualization_cam.py
Normal file
127
finetune/tools/analysis_tools/visualization_cam.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Use the pytorch-grad-cam tool to visualize Class Activation Maps (CAM).
|
||||
|
||||
requirement: pip install grad-cam
|
||||
"""
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine import Config
|
||||
from mmengine.model import revert_sync_batchnorm
|
||||
from PIL import Image
|
||||
from pytorch_grad_cam import GradCAM
|
||||
from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
|
||||
|
||||
from mmseg.apis import inference_model, init_model, show_result_pyplot
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
class SemanticSegmentationTarget:
|
||||
"""wrap the model.
|
||||
|
||||
requirement: pip install grad-cam
|
||||
|
||||
Args:
|
||||
category (int): Visualization class.
|
||||
mask (ndarray): Mask of class.
|
||||
size (tuple): Image size.
|
||||
"""
|
||||
|
||||
def __init__(self, category, mask, size):
|
||||
self.category = category
|
||||
self.mask = torch.from_numpy(mask)
|
||||
self.size = size
|
||||
if torch.cuda.is_available():
|
||||
self.mask = self.mask.cuda()
|
||||
|
||||
def __call__(self, model_output):
|
||||
model_output = torch.unsqueeze(model_output, dim=0)
|
||||
model_output = F.interpolate(
|
||||
model_output, size=self.size, mode='bilinear')
|
||||
model_output = torch.squeeze(model_output, dim=0)
|
||||
|
||||
return (model_output[self.category, :, :] * self.mask).sum()
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('img', help='Image file')
|
||||
parser.add_argument('config', help='Config file')
|
||||
parser.add_argument('checkpoint', help='Checkpoint file')
|
||||
parser.add_argument(
|
||||
'--out-file',
|
||||
default='prediction.png',
|
||||
help='Path to output prediction file')
|
||||
parser.add_argument(
|
||||
'--cam-file', default='vis_cam.png', help='Path to output cam file')
|
||||
parser.add_argument(
|
||||
'--target-layers',
|
||||
default='backbone.layer4[2]',
|
||||
help='Target layers to visualize CAM')
|
||||
parser.add_argument(
|
||||
'--category-index', default='7', help='Category to visualize CAM')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
args = parser.parse_args()
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
register_all_modules()
|
||||
model = init_model(args.config, args.checkpoint, device=args.device)
|
||||
if args.device == 'cpu':
|
||||
model = revert_sync_batchnorm(model)
|
||||
|
||||
# test a single image
|
||||
result = inference_model(model, args.img)
|
||||
|
||||
# show the results
|
||||
show_result_pyplot(
|
||||
model,
|
||||
args.img,
|
||||
result,
|
||||
draw_gt=False,
|
||||
show=False if args.out_file is not None else True,
|
||||
out_file=args.out_file)
|
||||
|
||||
# result data conversion
|
||||
prediction_data = result.pred_sem_seg.data
|
||||
pre_np_data = prediction_data.cpu().numpy().squeeze(0)
|
||||
|
||||
target_layers = args.target_layers
|
||||
target_layers = [eval(f'model.{target_layers}')]
|
||||
|
||||
category = int(args.category_index)
|
||||
mask_float = np.float32(pre_np_data == category)
|
||||
|
||||
# data processing
|
||||
image = np.array(Image.open(args.img).convert('RGB'))
|
||||
height, width = image.shape[0], image.shape[1]
|
||||
rgb_img = np.float32(image) / 255
|
||||
config = Config.fromfile(args.config)
|
||||
image_mean = config.data_preprocessor['mean']
|
||||
image_std = config.data_preprocessor['std']
|
||||
input_tensor = preprocess_image(
|
||||
rgb_img,
|
||||
mean=[x / 255 for x in image_mean],
|
||||
std=[x / 255 for x in image_std])
|
||||
|
||||
# Grad CAM(Class Activation Maps)
|
||||
# Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
|
||||
targets = [
|
||||
SemanticSegmentationTarget(category, mask_float, (height, width))
|
||||
]
|
||||
with GradCAM(
|
||||
model=model,
|
||||
target_layers=target_layers,
|
||||
use_cuda=torch.cuda.is_available()) as cam:
|
||||
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
|
||||
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
|
||||
|
||||
# save cam file
|
||||
Image.fromarray(cam_image).save(args.cam_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
89
finetune/tools/dataset_converters/chase_db1.py
Normal file
89
finetune/tools/dataset_converters/chase_db1.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
import mmcv
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
CHASE_DB1_LEN = 28 * 3
|
||||
TRAINING_LEN = 60
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert CHASE_DB1 dataset to mmsegmentation format')
|
||||
parser.add_argument('dataset_path', help='path of CHASEDB1.zip')
|
||||
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
||||
parser.add_argument('-o', '--out_dir', help='output path')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
dataset_path = args.dataset_path
|
||||
if args.out_dir is None:
|
||||
out_dir = osp.join('data', 'CHASE_DB1')
|
||||
else:
|
||||
out_dir = args.out_dir
|
||||
|
||||
print('Making directories...')
|
||||
mkdir_or_exist(out_dir)
|
||||
mkdir_or_exist(osp.join(out_dir, 'images'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'annotations'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
||||
print('Extracting CHASEDB1.zip...')
|
||||
zip_file = zipfile.ZipFile(dataset_path)
|
||||
zip_file.extractall(tmp_dir)
|
||||
|
||||
print('Generating training dataset...')
|
||||
|
||||
assert len(os.listdir(tmp_dir)) == CHASE_DB1_LEN, \
|
||||
f'len(os.listdir(tmp_dir)) != {CHASE_DB1_LEN}'
|
||||
|
||||
for img_name in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
|
||||
img = mmcv.imread(osp.join(tmp_dir, img_name))
|
||||
if osp.splitext(img_name)[1] == '.jpg':
|
||||
mmcv.imwrite(
|
||||
img,
|
||||
osp.join(out_dir, 'images', 'training',
|
||||
osp.splitext(img_name)[0] + '.png'))
|
||||
else:
|
||||
# The annotation img should be divided by 128, because some of
|
||||
# the annotation imgs are not standard. We should set a
|
||||
# threshold to convert the nonstandard annotation imgs. The
|
||||
# value divided by 128 is equivalent to '1 if value >= 128
|
||||
# else 0'
|
||||
mmcv.imwrite(
|
||||
img[:, :, 0] // 128,
|
||||
osp.join(out_dir, 'annotations', 'training',
|
||||
osp.splitext(img_name)[0] + '.png'))
|
||||
|
||||
for img_name in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
|
||||
img = mmcv.imread(osp.join(tmp_dir, img_name))
|
||||
if osp.splitext(img_name)[1] == '.jpg':
|
||||
mmcv.imwrite(
|
||||
img,
|
||||
osp.join(out_dir, 'images', 'validation',
|
||||
osp.splitext(img_name)[0] + '.png'))
|
||||
else:
|
||||
mmcv.imwrite(
|
||||
img[:, :, 0] // 128,
|
||||
osp.join(out_dir, 'annotations', 'validation',
|
||||
osp.splitext(img_name)[0] + '.png'))
|
||||
|
||||
print('Removing the temporary files...')
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
56
finetune/tools/dataset_converters/cityscapes.py
Normal file
56
finetune/tools/dataset_converters/cityscapes.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
|
||||
from cityscapesscripts.preparation.json2labelImg import json2labelImg
|
||||
from mmengine.utils import (mkdir_or_exist, scandir, track_parallel_progress,
|
||||
track_progress)
|
||||
|
||||
|
||||
def convert_json_to_label(json_file):
|
||||
label_file = json_file.replace('_polygons.json', '_labelTrainIds.png')
|
||||
json2labelImg(json_file, label_file, 'trainIds')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert Cityscapes annotations to TrainIds')
|
||||
parser.add_argument('cityscapes_path', help='cityscapes data path')
|
||||
parser.add_argument('--gt-dir', default='gtFine', type=str)
|
||||
parser.add_argument('-o', '--out-dir', help='output path')
|
||||
parser.add_argument(
|
||||
'--nproc', default=1, type=int, help='number of process')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
cityscapes_path = args.cityscapes_path
|
||||
out_dir = args.out_dir if args.out_dir else cityscapes_path
|
||||
mkdir_or_exist(out_dir)
|
||||
|
||||
gt_dir = osp.join(cityscapes_path, args.gt_dir)
|
||||
|
||||
poly_files = []
|
||||
for poly in scandir(gt_dir, '_polygons.json', recursive=True):
|
||||
poly_file = osp.join(gt_dir, poly)
|
||||
poly_files.append(poly_file)
|
||||
if args.nproc > 1:
|
||||
track_parallel_progress(convert_json_to_label, poly_files, args.nproc)
|
||||
else:
|
||||
track_progress(convert_json_to_label, poly_files)
|
||||
|
||||
split_names = ['train', 'val', 'test']
|
||||
|
||||
for split in split_names:
|
||||
filenames = []
|
||||
for poly in scandir(
|
||||
osp.join(gt_dir, split), '_polygons.json', recursive=True):
|
||||
filenames.append(poly.replace('_gtFine_polygons.json', ''))
|
||||
with open(osp.join(out_dir, f'{split}.txt'), 'w') as f:
|
||||
f.writelines(f + '\n' for f in filenames)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
308
finetune/tools/dataset_converters/coco_stuff10k.py
Normal file
308
finetune/tools/dataset_converters/coco_stuff10k.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import shutil
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from mmengine.utils import (mkdir_or_exist, track_parallel_progress,
|
||||
track_progress)
|
||||
from PIL import Image
|
||||
from scipy.io import loadmat
|
||||
|
||||
COCO_LEN = 10000
|
||||
|
||||
clsID_to_trID = {
|
||||
0: 0,
|
||||
1: 1,
|
||||
2: 2,
|
||||
3: 3,
|
||||
4: 4,
|
||||
5: 5,
|
||||
6: 6,
|
||||
7: 7,
|
||||
8: 8,
|
||||
9: 9,
|
||||
10: 10,
|
||||
11: 11,
|
||||
13: 12,
|
||||
14: 13,
|
||||
15: 14,
|
||||
16: 15,
|
||||
17: 16,
|
||||
18: 17,
|
||||
19: 18,
|
||||
20: 19,
|
||||
21: 20,
|
||||
22: 21,
|
||||
23: 22,
|
||||
24: 23,
|
||||
25: 24,
|
||||
27: 25,
|
||||
28: 26,
|
||||
31: 27,
|
||||
32: 28,
|
||||
33: 29,
|
||||
34: 30,
|
||||
35: 31,
|
||||
36: 32,
|
||||
37: 33,
|
||||
38: 34,
|
||||
39: 35,
|
||||
40: 36,
|
||||
41: 37,
|
||||
42: 38,
|
||||
43: 39,
|
||||
44: 40,
|
||||
46: 41,
|
||||
47: 42,
|
||||
48: 43,
|
||||
49: 44,
|
||||
50: 45,
|
||||
51: 46,
|
||||
52: 47,
|
||||
53: 48,
|
||||
54: 49,
|
||||
55: 50,
|
||||
56: 51,
|
||||
57: 52,
|
||||
58: 53,
|
||||
59: 54,
|
||||
60: 55,
|
||||
61: 56,
|
||||
62: 57,
|
||||
63: 58,
|
||||
64: 59,
|
||||
65: 60,
|
||||
67: 61,
|
||||
70: 62,
|
||||
72: 63,
|
||||
73: 64,
|
||||
74: 65,
|
||||
75: 66,
|
||||
76: 67,
|
||||
77: 68,
|
||||
78: 69,
|
||||
79: 70,
|
||||
80: 71,
|
||||
81: 72,
|
||||
82: 73,
|
||||
84: 74,
|
||||
85: 75,
|
||||
86: 76,
|
||||
87: 77,
|
||||
88: 78,
|
||||
89: 79,
|
||||
90: 80,
|
||||
92: 81,
|
||||
93: 82,
|
||||
94: 83,
|
||||
95: 84,
|
||||
96: 85,
|
||||
97: 86,
|
||||
98: 87,
|
||||
99: 88,
|
||||
100: 89,
|
||||
101: 90,
|
||||
102: 91,
|
||||
103: 92,
|
||||
104: 93,
|
||||
105: 94,
|
||||
106: 95,
|
||||
107: 96,
|
||||
108: 97,
|
||||
109: 98,
|
||||
110: 99,
|
||||
111: 100,
|
||||
112: 101,
|
||||
113: 102,
|
||||
114: 103,
|
||||
115: 104,
|
||||
116: 105,
|
||||
117: 106,
|
||||
118: 107,
|
||||
119: 108,
|
||||
120: 109,
|
||||
121: 110,
|
||||
122: 111,
|
||||
123: 112,
|
||||
124: 113,
|
||||
125: 114,
|
||||
126: 115,
|
||||
127: 116,
|
||||
128: 117,
|
||||
129: 118,
|
||||
130: 119,
|
||||
131: 120,
|
||||
132: 121,
|
||||
133: 122,
|
||||
134: 123,
|
||||
135: 124,
|
||||
136: 125,
|
||||
137: 126,
|
||||
138: 127,
|
||||
139: 128,
|
||||
140: 129,
|
||||
141: 130,
|
||||
142: 131,
|
||||
143: 132,
|
||||
144: 133,
|
||||
145: 134,
|
||||
146: 135,
|
||||
147: 136,
|
||||
148: 137,
|
||||
149: 138,
|
||||
150: 139,
|
||||
151: 140,
|
||||
152: 141,
|
||||
153: 142,
|
||||
154: 143,
|
||||
155: 144,
|
||||
156: 145,
|
||||
157: 146,
|
||||
158: 147,
|
||||
159: 148,
|
||||
160: 149,
|
||||
161: 150,
|
||||
162: 151,
|
||||
163: 152,
|
||||
164: 153,
|
||||
165: 154,
|
||||
166: 155,
|
||||
167: 156,
|
||||
168: 157,
|
||||
169: 158,
|
||||
170: 159,
|
||||
171: 160,
|
||||
172: 161,
|
||||
173: 162,
|
||||
174: 163,
|
||||
175: 164,
|
||||
176: 165,
|
||||
177: 166,
|
||||
178: 167,
|
||||
179: 168,
|
||||
180: 169,
|
||||
181: 170,
|
||||
182: 171
|
||||
}
|
||||
|
||||
|
||||
def convert_to_trainID(tuple_path, in_img_dir, in_ann_dir, out_img_dir,
|
||||
out_mask_dir, is_train):
|
||||
imgpath, maskpath = tuple_path
|
||||
shutil.copyfile(
|
||||
osp.join(in_img_dir, imgpath),
|
||||
osp.join(out_img_dir, 'train2014', imgpath) if is_train else osp.join(
|
||||
out_img_dir, 'test2014', imgpath))
|
||||
annotate = loadmat(osp.join(in_ann_dir, maskpath))
|
||||
mask = annotate['S'].astype(np.uint8)
|
||||
mask_copy = mask.copy()
|
||||
for clsID, trID in clsID_to_trID.items():
|
||||
mask_copy[mask == clsID] = trID
|
||||
seg_filename = osp.join(out_mask_dir, 'train2014',
|
||||
maskpath.split('.')[0] +
|
||||
'_labelTrainIds.png') if is_train else osp.join(
|
||||
out_mask_dir, 'test2014',
|
||||
maskpath.split('.')[0] + '_labelTrainIds.png')
|
||||
Image.fromarray(mask_copy).save(seg_filename, 'PNG')
|
||||
|
||||
|
||||
def generate_coco_list(folder):
|
||||
train_list = osp.join(folder, 'imageLists', 'train.txt')
|
||||
test_list = osp.join(folder, 'imageLists', 'test.txt')
|
||||
train_paths = []
|
||||
test_paths = []
|
||||
|
||||
with open(train_list) as f:
|
||||
for filename in f:
|
||||
basename = filename.strip()
|
||||
imgpath = basename + '.jpg'
|
||||
maskpath = basename + '.mat'
|
||||
train_paths.append((imgpath, maskpath))
|
||||
|
||||
with open(test_list) as f:
|
||||
for filename in f:
|
||||
basename = filename.strip()
|
||||
imgpath = basename + '.jpg'
|
||||
maskpath = basename + '.mat'
|
||||
test_paths.append((imgpath, maskpath))
|
||||
|
||||
return train_paths, test_paths
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=\
|
||||
'Convert COCO Stuff 10k annotations to mmsegmentation format') # noqa
|
||||
parser.add_argument('coco_path', help='coco stuff path')
|
||||
parser.add_argument('-o', '--out_dir', help='output path')
|
||||
parser.add_argument(
|
||||
'--nproc', default=16, type=int, help='number of process')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
coco_path = args.coco_path
|
||||
nproc = args.nproc
|
||||
|
||||
out_dir = args.out_dir or coco_path
|
||||
out_img_dir = osp.join(out_dir, 'images')
|
||||
out_mask_dir = osp.join(out_dir, 'annotations')
|
||||
|
||||
mkdir_or_exist(osp.join(out_img_dir, 'train2014'))
|
||||
mkdir_or_exist(osp.join(out_img_dir, 'test2014'))
|
||||
mkdir_or_exist(osp.join(out_mask_dir, 'train2014'))
|
||||
mkdir_or_exist(osp.join(out_mask_dir, 'test2014'))
|
||||
|
||||
train_list, test_list = generate_coco_list(coco_path)
|
||||
assert (len(train_list) +
|
||||
len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(
|
||||
len(train_list), len(test_list))
|
||||
|
||||
if args.nproc > 1:
|
||||
track_parallel_progress(
|
||||
partial(
|
||||
convert_to_trainID,
|
||||
in_img_dir=osp.join(coco_path, 'images'),
|
||||
in_ann_dir=osp.join(coco_path, 'annotations'),
|
||||
out_img_dir=out_img_dir,
|
||||
out_mask_dir=out_mask_dir,
|
||||
is_train=True),
|
||||
train_list,
|
||||
nproc=nproc)
|
||||
track_parallel_progress(
|
||||
partial(
|
||||
convert_to_trainID,
|
||||
in_img_dir=osp.join(coco_path, 'images'),
|
||||
in_ann_dir=osp.join(coco_path, 'annotations'),
|
||||
out_img_dir=out_img_dir,
|
||||
out_mask_dir=out_mask_dir,
|
||||
is_train=False),
|
||||
test_list,
|
||||
nproc=nproc)
|
||||
else:
|
||||
track_progress(
|
||||
partial(
|
||||
convert_to_trainID,
|
||||
in_img_dir=osp.join(coco_path, 'images'),
|
||||
in_ann_dir=osp.join(coco_path, 'annotations'),
|
||||
out_img_dir=out_img_dir,
|
||||
out_mask_dir=out_mask_dir,
|
||||
is_train=True), train_list)
|
||||
track_progress(
|
||||
partial(
|
||||
convert_to_trainID,
|
||||
in_img_dir=osp.join(coco_path, 'images'),
|
||||
in_ann_dir=osp.join(coco_path, 'annotations'),
|
||||
out_img_dir=out_img_dir,
|
||||
out_mask_dir=out_mask_dir,
|
||||
is_train=False), test_list)
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
265
finetune/tools/dataset_converters/coco_stuff164k.py
Normal file
265
finetune/tools/dataset_converters/coco_stuff164k.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import shutil
|
||||
from functools import partial
|
||||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
from mmengine.utils import (mkdir_or_exist, track_parallel_progress,
|
||||
track_progress)
|
||||
from PIL import Image
|
||||
|
||||
COCO_LEN = 123287
|
||||
|
||||
clsID_to_trID = {
|
||||
0: 0,
|
||||
1: 1,
|
||||
2: 2,
|
||||
3: 3,
|
||||
4: 4,
|
||||
5: 5,
|
||||
6: 6,
|
||||
7: 7,
|
||||
8: 8,
|
||||
9: 9,
|
||||
10: 10,
|
||||
12: 11,
|
||||
13: 12,
|
||||
14: 13,
|
||||
15: 14,
|
||||
16: 15,
|
||||
17: 16,
|
||||
18: 17,
|
||||
19: 18,
|
||||
20: 19,
|
||||
21: 20,
|
||||
22: 21,
|
||||
23: 22,
|
||||
24: 23,
|
||||
26: 24,
|
||||
27: 25,
|
||||
30: 26,
|
||||
31: 27,
|
||||
32: 28,
|
||||
33: 29,
|
||||
34: 30,
|
||||
35: 31,
|
||||
36: 32,
|
||||
37: 33,
|
||||
38: 34,
|
||||
39: 35,
|
||||
40: 36,
|
||||
41: 37,
|
||||
42: 38,
|
||||
43: 39,
|
||||
45: 40,
|
||||
46: 41,
|
||||
47: 42,
|
||||
48: 43,
|
||||
49: 44,
|
||||
50: 45,
|
||||
51: 46,
|
||||
52: 47,
|
||||
53: 48,
|
||||
54: 49,
|
||||
55: 50,
|
||||
56: 51,
|
||||
57: 52,
|
||||
58: 53,
|
||||
59: 54,
|
||||
60: 55,
|
||||
61: 56,
|
||||
62: 57,
|
||||
63: 58,
|
||||
64: 59,
|
||||
66: 60,
|
||||
69: 61,
|
||||
71: 62,
|
||||
72: 63,
|
||||
73: 64,
|
||||
74: 65,
|
||||
75: 66,
|
||||
76: 67,
|
||||
77: 68,
|
||||
78: 69,
|
||||
79: 70,
|
||||
80: 71,
|
||||
81: 72,
|
||||
83: 73,
|
||||
84: 74,
|
||||
85: 75,
|
||||
86: 76,
|
||||
87: 77,
|
||||
88: 78,
|
||||
89: 79,
|
||||
91: 80,
|
||||
92: 81,
|
||||
93: 82,
|
||||
94: 83,
|
||||
95: 84,
|
||||
96: 85,
|
||||
97: 86,
|
||||
98: 87,
|
||||
99: 88,
|
||||
100: 89,
|
||||
101: 90,
|
||||
102: 91,
|
||||
103: 92,
|
||||
104: 93,
|
||||
105: 94,
|
||||
106: 95,
|
||||
107: 96,
|
||||
108: 97,
|
||||
109: 98,
|
||||
110: 99,
|
||||
111: 100,
|
||||
112: 101,
|
||||
113: 102,
|
||||
114: 103,
|
||||
115: 104,
|
||||
116: 105,
|
||||
117: 106,
|
||||
118: 107,
|
||||
119: 108,
|
||||
120: 109,
|
||||
121: 110,
|
||||
122: 111,
|
||||
123: 112,
|
||||
124: 113,
|
||||
125: 114,
|
||||
126: 115,
|
||||
127: 116,
|
||||
128: 117,
|
||||
129: 118,
|
||||
130: 119,
|
||||
131: 120,
|
||||
132: 121,
|
||||
133: 122,
|
||||
134: 123,
|
||||
135: 124,
|
||||
136: 125,
|
||||
137: 126,
|
||||
138: 127,
|
||||
139: 128,
|
||||
140: 129,
|
||||
141: 130,
|
||||
142: 131,
|
||||
143: 132,
|
||||
144: 133,
|
||||
145: 134,
|
||||
146: 135,
|
||||
147: 136,
|
||||
148: 137,
|
||||
149: 138,
|
||||
150: 139,
|
||||
151: 140,
|
||||
152: 141,
|
||||
153: 142,
|
||||
154: 143,
|
||||
155: 144,
|
||||
156: 145,
|
||||
157: 146,
|
||||
158: 147,
|
||||
159: 148,
|
||||
160: 149,
|
||||
161: 150,
|
||||
162: 151,
|
||||
163: 152,
|
||||
164: 153,
|
||||
165: 154,
|
||||
166: 155,
|
||||
167: 156,
|
||||
168: 157,
|
||||
169: 158,
|
||||
170: 159,
|
||||
171: 160,
|
||||
172: 161,
|
||||
173: 162,
|
||||
174: 163,
|
||||
175: 164,
|
||||
176: 165,
|
||||
177: 166,
|
||||
178: 167,
|
||||
179: 168,
|
||||
180: 169,
|
||||
181: 170,
|
||||
255: 255
|
||||
}
|
||||
|
||||
|
||||
def convert_to_trainID(maskpath, out_mask_dir, is_train):
|
||||
mask = np.array(Image.open(maskpath))
|
||||
mask_copy = mask.copy()
|
||||
for clsID, trID in clsID_to_trID.items():
|
||||
mask_copy[mask == clsID] = trID
|
||||
seg_filename = osp.join(
|
||||
out_mask_dir, 'train2017',
|
||||
osp.basename(maskpath).split('.')[0] +
|
||||
'_labelTrainIds.png') if is_train else osp.join(
|
||||
out_mask_dir, 'val2017',
|
||||
osp.basename(maskpath).split('.')[0] + '_labelTrainIds.png')
|
||||
Image.fromarray(mask_copy).save(seg_filename, 'PNG')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=\
|
||||
'Convert COCO Stuff 164k annotations to mmsegmentation format') # noqa
|
||||
parser.add_argument('coco_path', help='coco stuff path')
|
||||
parser.add_argument('-o', '--out_dir', help='output path')
|
||||
parser.add_argument(
|
||||
'--nproc', default=16, type=int, help='number of process')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
coco_path = args.coco_path
|
||||
nproc = args.nproc
|
||||
|
||||
out_dir = args.out_dir or coco_path
|
||||
out_img_dir = osp.join(out_dir, 'images')
|
||||
out_mask_dir = osp.join(out_dir, 'annotations')
|
||||
|
||||
mkdir_or_exist(osp.join(out_mask_dir, 'train2017'))
|
||||
mkdir_or_exist(osp.join(out_mask_dir, 'val2017'))
|
||||
|
||||
if out_dir != coco_path:
|
||||
shutil.copytree(osp.join(coco_path, 'images'), out_img_dir)
|
||||
|
||||
train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png'))
|
||||
train_list = [file for file in train_list if '_labelTrainIds' not in file]
|
||||
test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png'))
|
||||
test_list = [file for file in test_list if '_labelTrainIds' not in file]
|
||||
assert (len(train_list) +
|
||||
len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(
|
||||
len(train_list), len(test_list))
|
||||
|
||||
if args.nproc > 1:
|
||||
track_parallel_progress(
|
||||
partial(
|
||||
convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
|
||||
train_list,
|
||||
nproc=nproc)
|
||||
track_parallel_progress(
|
||||
partial(
|
||||
convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
|
||||
test_list,
|
||||
nproc=nproc)
|
||||
else:
|
||||
track_progress(
|
||||
partial(
|
||||
convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
|
||||
train_list)
|
||||
track_progress(
|
||||
partial(
|
||||
convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
|
||||
test_list)
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
114
finetune/tools/dataset_converters/drive.py
Normal file
114
finetune/tools/dataset_converters/drive.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert DRIVE dataset to mmsegmentation format')
|
||||
parser.add_argument(
|
||||
'training_path', help='the training part of DRIVE dataset')
|
||||
parser.add_argument(
|
||||
'testing_path', help='the testing part of DRIVE dataset')
|
||||
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
||||
parser.add_argument('-o', '--out_dir', help='output path')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
training_path = args.training_path
|
||||
testing_path = args.testing_path
|
||||
if args.out_dir is None:
|
||||
out_dir = osp.join('data', 'DRIVE')
|
||||
else:
|
||||
out_dir = args.out_dir
|
||||
|
||||
print('Making directories...')
|
||||
mkdir_or_exist(out_dir)
|
||||
mkdir_or_exist(osp.join(out_dir, 'images'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'annotations'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
||||
print('Extracting training.zip...')
|
||||
zip_file = zipfile.ZipFile(training_path)
|
||||
zip_file.extractall(tmp_dir)
|
||||
|
||||
print('Generating training dataset...')
|
||||
now_dir = osp.join(tmp_dir, 'training', 'images')
|
||||
for img_name in os.listdir(now_dir):
|
||||
img = mmcv.imread(osp.join(now_dir, img_name))
|
||||
mmcv.imwrite(
|
||||
img,
|
||||
osp.join(
|
||||
out_dir, 'images', 'training',
|
||||
osp.splitext(img_name)[0].replace('_training', '') +
|
||||
'.png'))
|
||||
|
||||
now_dir = osp.join(tmp_dir, 'training', '1st_manual')
|
||||
for img_name in os.listdir(now_dir):
|
||||
cap = cv2.VideoCapture(osp.join(now_dir, img_name))
|
||||
ret, img = cap.read()
|
||||
mmcv.imwrite(
|
||||
img[:, :, 0] // 128,
|
||||
osp.join(out_dir, 'annotations', 'training',
|
||||
osp.splitext(img_name)[0] + '.png'))
|
||||
|
||||
print('Extracting test.zip...')
|
||||
zip_file = zipfile.ZipFile(testing_path)
|
||||
zip_file.extractall(tmp_dir)
|
||||
|
||||
print('Generating validation dataset...')
|
||||
now_dir = osp.join(tmp_dir, 'test', 'images')
|
||||
for img_name in os.listdir(now_dir):
|
||||
img = mmcv.imread(osp.join(now_dir, img_name))
|
||||
mmcv.imwrite(
|
||||
img,
|
||||
osp.join(
|
||||
out_dir, 'images', 'validation',
|
||||
osp.splitext(img_name)[0].replace('_test', '') + '.png'))
|
||||
|
||||
now_dir = osp.join(tmp_dir, 'test', '1st_manual')
|
||||
if osp.exists(now_dir):
|
||||
for img_name in os.listdir(now_dir):
|
||||
cap = cv2.VideoCapture(osp.join(now_dir, img_name))
|
||||
ret, img = cap.read()
|
||||
# The annotation img should be divided by 128, because some of
|
||||
# the annotation imgs are not standard. We should set a
|
||||
# threshold to convert the nonstandard annotation imgs. The
|
||||
# value divided by 128 is equivalent to '1 if value >= 128
|
||||
# else 0'
|
||||
mmcv.imwrite(
|
||||
img[:, :, 0] // 128,
|
||||
osp.join(out_dir, 'annotations', 'validation',
|
||||
osp.splitext(img_name)[0] + '.png'))
|
||||
|
||||
now_dir = osp.join(tmp_dir, 'test', '2nd_manual')
|
||||
if osp.exists(now_dir):
|
||||
for img_name in os.listdir(now_dir):
|
||||
cap = cv2.VideoCapture(osp.join(now_dir, img_name))
|
||||
ret, img = cap.read()
|
||||
mmcv.imwrite(
|
||||
img[:, :, 0] // 128,
|
||||
osp.join(out_dir, 'annotations', 'validation',
|
||||
osp.splitext(img_name)[0] + '.png'))
|
||||
|
||||
print('Removing the temporary files...')
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
112
finetune/tools/dataset_converters/hrf.py
Normal file
112
finetune/tools/dataset_converters/hrf.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
import mmcv
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
HRF_LEN = 15
|
||||
TRAINING_LEN = 5
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert HRF dataset to mmsegmentation format')
|
||||
parser.add_argument('healthy_path', help='the path of healthy.zip')
|
||||
parser.add_argument(
|
||||
'healthy_manualsegm_path', help='the path of healthy_manualsegm.zip')
|
||||
parser.add_argument('glaucoma_path', help='the path of glaucoma.zip')
|
||||
parser.add_argument(
|
||||
'glaucoma_manualsegm_path', help='the path of glaucoma_manualsegm.zip')
|
||||
parser.add_argument(
|
||||
'diabetic_retinopathy_path',
|
||||
help='the path of diabetic_retinopathy.zip')
|
||||
parser.add_argument(
|
||||
'diabetic_retinopathy_manualsegm_path',
|
||||
help='the path of diabetic_retinopathy_manualsegm.zip')
|
||||
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
||||
parser.add_argument('-o', '--out_dir', help='output path')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
images_path = [
|
||||
args.healthy_path, args.glaucoma_path, args.diabetic_retinopathy_path
|
||||
]
|
||||
annotations_path = [
|
||||
args.healthy_manualsegm_path, args.glaucoma_manualsegm_path,
|
||||
args.diabetic_retinopathy_manualsegm_path
|
||||
]
|
||||
if args.out_dir is None:
|
||||
out_dir = osp.join('data', 'HRF')
|
||||
else:
|
||||
out_dir = args.out_dir
|
||||
|
||||
print('Making directories...')
|
||||
mkdir_or_exist(out_dir)
|
||||
mkdir_or_exist(osp.join(out_dir, 'images'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'annotations'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
|
||||
|
||||
print('Generating images...')
|
||||
for now_path in images_path:
|
||||
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
||||
zip_file = zipfile.ZipFile(now_path)
|
||||
zip_file.extractall(tmp_dir)
|
||||
|
||||
assert len(os.listdir(tmp_dir)) == HRF_LEN, \
|
||||
f'len(os.listdir(tmp_dir)) != {HRF_LEN}'
|
||||
|
||||
for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
|
||||
img = mmcv.imread(osp.join(tmp_dir, filename))
|
||||
mmcv.imwrite(
|
||||
img,
|
||||
osp.join(out_dir, 'images', 'training',
|
||||
osp.splitext(filename)[0] + '.png'))
|
||||
for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
|
||||
img = mmcv.imread(osp.join(tmp_dir, filename))
|
||||
mmcv.imwrite(
|
||||
img,
|
||||
osp.join(out_dir, 'images', 'validation',
|
||||
osp.splitext(filename)[0] + '.png'))
|
||||
|
||||
print('Generating annotations...')
|
||||
for now_path in annotations_path:
|
||||
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
||||
zip_file = zipfile.ZipFile(now_path)
|
||||
zip_file.extractall(tmp_dir)
|
||||
|
||||
assert len(os.listdir(tmp_dir)) == HRF_LEN, \
|
||||
f'len(os.listdir(tmp_dir)) != {HRF_LEN}'
|
||||
|
||||
for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
|
||||
img = mmcv.imread(osp.join(tmp_dir, filename))
|
||||
# The annotation img should be divided by 128, because some of
|
||||
# the annotation imgs are not standard. We should set a
|
||||
# threshold to convert the nonstandard annotation imgs. The
|
||||
# value divided by 128 is equivalent to '1 if value >= 128
|
||||
# else 0'
|
||||
mmcv.imwrite(
|
||||
img[:, :, 0] // 128,
|
||||
osp.join(out_dir, 'annotations', 'training',
|
||||
osp.splitext(filename)[0] + '.png'))
|
||||
for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
|
||||
img = mmcv.imread(osp.join(tmp_dir, filename))
|
||||
mmcv.imwrite(
|
||||
img[:, :, 0] // 128,
|
||||
osp.join(out_dir, 'annotations', 'validation',
|
||||
osp.splitext(filename)[0] + '.png'))
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
246
finetune/tools/dataset_converters/isaid.py
Normal file
246
finetune/tools/dataset_converters/isaid.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmengine.utils import ProgressBar, mkdir_or_exist
|
||||
from PIL import Image
|
||||
|
||||
iSAID_palette = \
|
||||
{
|
||||
0: (0, 0, 0),
|
||||
1: (0, 0, 63),
|
||||
2: (0, 63, 63),
|
||||
3: (0, 63, 0),
|
||||
4: (0, 63, 127),
|
||||
5: (0, 63, 191),
|
||||
6: (0, 63, 255),
|
||||
7: (0, 127, 63),
|
||||
8: (0, 127, 127),
|
||||
9: (0, 0, 127),
|
||||
10: (0, 0, 191),
|
||||
11: (0, 0, 255),
|
||||
12: (0, 191, 127),
|
||||
13: (0, 127, 191),
|
||||
14: (0, 127, 255),
|
||||
15: (0, 100, 155)
|
||||
}
|
||||
|
||||
iSAID_invert_palette = {v: k for k, v in iSAID_palette.items()}
|
||||
|
||||
|
||||
def iSAID_convert_from_color(arr_3d, palette=iSAID_invert_palette):
|
||||
"""RGB-color encoding to grayscale labels."""
|
||||
arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)
|
||||
|
||||
for c, i in palette.items():
|
||||
m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
|
||||
arr_2d[m] = i
|
||||
|
||||
return arr_2d
|
||||
|
||||
|
||||
def slide_crop_image(src_path, out_dir, mode, patch_H, patch_W, overlap):
|
||||
img = np.asarray(Image.open(src_path).convert('RGB'))
|
||||
|
||||
img_H, img_W, _ = img.shape
|
||||
|
||||
if img_H < patch_H and img_W > patch_W:
|
||||
|
||||
img = mmcv.impad(img, shape=(patch_H, img_W), pad_val=0)
|
||||
|
||||
img_H, img_W, _ = img.shape
|
||||
|
||||
elif img_H > patch_H and img_W < patch_W:
|
||||
|
||||
img = mmcv.impad(img, shape=(img_H, patch_W), pad_val=0)
|
||||
|
||||
img_H, img_W, _ = img.shape
|
||||
|
||||
elif img_H < patch_H and img_W < patch_W:
|
||||
|
||||
img = mmcv.impad(img, shape=(patch_H, patch_W), pad_val=0)
|
||||
|
||||
img_H, img_W, _ = img.shape
|
||||
|
||||
for x in range(0, img_W, patch_W - overlap):
|
||||
for y in range(0, img_H, patch_H - overlap):
|
||||
x_str = x
|
||||
x_end = x + patch_W
|
||||
if x_end > img_W:
|
||||
diff_x = x_end - img_W
|
||||
x_str -= diff_x
|
||||
x_end = img_W
|
||||
y_str = y
|
||||
y_end = y + patch_H
|
||||
if y_end > img_H:
|
||||
diff_y = y_end - img_H
|
||||
y_str -= diff_y
|
||||
y_end = img_H
|
||||
|
||||
img_patch = img[y_str:y_end, x_str:x_end, :]
|
||||
img_patch = Image.fromarray(img_patch.astype(np.uint8))
|
||||
image = osp.basename(src_path).split('.')[0] + '_' + str(
|
||||
y_str) + '_' + str(y_end) + '_' + str(x_str) + '_' + str(
|
||||
x_end) + '.png'
|
||||
# print(image)
|
||||
save_path_image = osp.join(out_dir, 'img_dir', mode, str(image))
|
||||
img_patch.save(save_path_image, format='BMP')
|
||||
|
||||
|
||||
def slide_crop_label(src_path, out_dir, mode, patch_H, patch_W, overlap):
|
||||
label = mmcv.imread(src_path, channel_order='rgb')
|
||||
label = iSAID_convert_from_color(label)
|
||||
img_H, img_W = label.shape
|
||||
|
||||
if img_H < patch_H and img_W > patch_W:
|
||||
|
||||
label = mmcv.impad(label, shape=(patch_H, img_W), pad_val=255)
|
||||
|
||||
img_H = patch_H
|
||||
|
||||
elif img_H > patch_H and img_W < patch_W:
|
||||
|
||||
label = mmcv.impad(label, shape=(img_H, patch_W), pad_val=255)
|
||||
|
||||
img_W = patch_W
|
||||
|
||||
elif img_H < patch_H and img_W < patch_W:
|
||||
|
||||
label = mmcv.impad(label, shape=(patch_H, patch_W), pad_val=255)
|
||||
|
||||
img_H = patch_H
|
||||
img_W = patch_W
|
||||
|
||||
for x in range(0, img_W, patch_W - overlap):
|
||||
for y in range(0, img_H, patch_H - overlap):
|
||||
x_str = x
|
||||
x_end = x + patch_W
|
||||
if x_end > img_W:
|
||||
diff_x = x_end - img_W
|
||||
x_str -= diff_x
|
||||
x_end = img_W
|
||||
y_str = y
|
||||
y_end = y + patch_H
|
||||
if y_end > img_H:
|
||||
diff_y = y_end - img_H
|
||||
y_str -= diff_y
|
||||
y_end = img_H
|
||||
|
||||
lab_patch = label[y_str:y_end, x_str:x_end]
|
||||
lab_patch = Image.fromarray(lab_patch.astype(np.uint8), mode='P')
|
||||
|
||||
image = osp.basename(src_path).split('.')[0].split(
|
||||
'_')[0] + '_' + str(y_str) + '_' + str(y_end) + '_' + str(
|
||||
x_str) + '_' + str(x_end) + '_instance_color_RGB' + '.png'
|
||||
lab_patch.save(osp.join(out_dir, 'ann_dir', mode, str(image)))
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert iSAID dataset to mmsegmentation format')
|
||||
parser.add_argument('dataset_path', help='iSAID folder path')
|
||||
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
||||
parser.add_argument('-o', '--out_dir', help='output path')
|
||||
|
||||
parser.add_argument(
|
||||
'--patch_width',
|
||||
default=896,
|
||||
type=int,
|
||||
help='Width of the cropped image patch')
|
||||
parser.add_argument(
|
||||
'--patch_height',
|
||||
default=896,
|
||||
type=int,
|
||||
help='Height of the cropped image patch')
|
||||
parser.add_argument(
|
||||
'--overlap_area', default=384, type=int, help='Overlap area')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
dataset_path = args.dataset_path
|
||||
# image patch width and height
|
||||
patch_H, patch_W = args.patch_width, args.patch_height
|
||||
|
||||
overlap = args.overlap_area # overlap area
|
||||
|
||||
if args.out_dir is None:
|
||||
out_dir = osp.join('data', 'iSAID')
|
||||
else:
|
||||
out_dir = args.out_dir
|
||||
|
||||
print('Making directories...')
|
||||
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test'))
|
||||
|
||||
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'test'))
|
||||
|
||||
assert os.path.exists(os.path.join(dataset_path, 'train')), \
|
||||
f'train is not in {dataset_path}'
|
||||
assert os.path.exists(os.path.join(dataset_path, 'val')), \
|
||||
f'val is not in {dataset_path}'
|
||||
assert os.path.exists(os.path.join(dataset_path, 'test')), \
|
||||
f'test is not in {dataset_path}'
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
||||
for dataset_mode in ['train', 'val', 'test']:
|
||||
|
||||
# for dataset_mode in [ 'test']:
|
||||
print(f'Extracting {dataset_mode}ing.zip...')
|
||||
img_zipp_list = glob.glob(
|
||||
os.path.join(dataset_path, dataset_mode, 'images', '*.zip'))
|
||||
print('Find the data', img_zipp_list)
|
||||
for img_zipp in img_zipp_list:
|
||||
zip_file = zipfile.ZipFile(img_zipp)
|
||||
zip_file.extractall(os.path.join(tmp_dir, dataset_mode, 'img'))
|
||||
src_path_list = glob.glob(
|
||||
os.path.join(tmp_dir, dataset_mode, 'img', 'images', '*.png'))
|
||||
|
||||
src_prog_bar = ProgressBar(len(src_path_list))
|
||||
for i, img_path in enumerate(src_path_list):
|
||||
if dataset_mode != 'test':
|
||||
slide_crop_image(img_path, out_dir, dataset_mode, patch_H,
|
||||
patch_W, overlap)
|
||||
|
||||
else:
|
||||
shutil.move(img_path,
|
||||
os.path.join(out_dir, 'img_dir', dataset_mode))
|
||||
src_prog_bar.update()
|
||||
|
||||
if dataset_mode != 'test':
|
||||
label_zipp_list = glob.glob(
|
||||
os.path.join(dataset_path, dataset_mode, 'Semantic_masks',
|
||||
'*.zip'))
|
||||
for label_zipp in label_zipp_list:
|
||||
zip_file = zipfile.ZipFile(label_zipp)
|
||||
zip_file.extractall(
|
||||
os.path.join(tmp_dir, dataset_mode, 'lab'))
|
||||
|
||||
lab_path_list = glob.glob(
|
||||
os.path.join(tmp_dir, dataset_mode, 'lab', 'images',
|
||||
'*.png'))
|
||||
lab_prog_bar = ProgressBar(len(lab_path_list))
|
||||
for i, lab_path in enumerate(lab_path_list):
|
||||
slide_crop_label(lab_path, out_dir, dataset_mode, patch_H,
|
||||
patch_W, overlap)
|
||||
lab_prog_bar.update()
|
||||
|
||||
print('Removing the temporary files...')
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
99
finetune/tools/dataset_converters/levircd.py
Normal file
99
finetune/tools/dataset_converters/levircd.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmengine.utils import ProgressBar
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert levir-cd dataset to mmsegmentation format')
|
||||
parser.add_argument('--dataset_path', help='potsdam folder path')
|
||||
parser.add_argument('-o', '--out_dir', help='output path')
|
||||
parser.add_argument(
|
||||
'--clip_size',
|
||||
type=int,
|
||||
help='clipped size of image after preparation',
|
||||
default=256)
|
||||
parser.add_argument(
|
||||
'--stride_size',
|
||||
type=int,
|
||||
help='stride of clipping original images',
|
||||
default=256)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
input_folder = args.dataset_path
|
||||
png_files = glob.glob(
|
||||
os.path.join(input_folder, '**/*.png'), recursive=True)
|
||||
output_folder = args.out_dir
|
||||
prog_bar = ProgressBar(len(png_files))
|
||||
for png_file in png_files:
|
||||
new_path = os.path.join(
|
||||
output_folder,
|
||||
os.path.relpath(os.path.dirname(png_file), input_folder))
|
||||
os.makedirs(os.path.dirname(new_path), exist_ok=True)
|
||||
label = False
|
||||
if 'label' in png_file:
|
||||
label = True
|
||||
clip_big_image(png_file, new_path, args, label)
|
||||
prog_bar.update()
|
||||
|
||||
|
||||
def clip_big_image(image_path, clip_save_dir, args, to_label=False):
|
||||
image = mmcv.imread(image_path)
|
||||
|
||||
h, w, c = image.shape
|
||||
clip_size = args.clip_size
|
||||
stride_size = args.stride_size
|
||||
|
||||
num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
|
||||
(h - clip_size) /
|
||||
stride_size) * stride_size + clip_size >= h else math.ceil(
|
||||
(h - clip_size) / stride_size) + 1
|
||||
num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
|
||||
(w - clip_size) /
|
||||
stride_size) * stride_size + clip_size >= w else math.ceil(
|
||||
(w - clip_size) / stride_size) + 1
|
||||
|
||||
x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
|
||||
xmin = x * clip_size
|
||||
ymin = y * clip_size
|
||||
|
||||
xmin = xmin.ravel()
|
||||
ymin = ymin.ravel()
|
||||
xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
|
||||
np.zeros_like(xmin))
|
||||
ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
|
||||
np.zeros_like(ymin))
|
||||
boxes = np.stack([
|
||||
xmin + xmin_offset, ymin + ymin_offset,
|
||||
np.minimum(xmin + clip_size, w),
|
||||
np.minimum(ymin + clip_size, h)
|
||||
],
|
||||
axis=1)
|
||||
|
||||
if to_label:
|
||||
image[image == 255] = 1
|
||||
image = image[:, :, 0]
|
||||
for box in boxes:
|
||||
start_x, start_y, end_x, end_y = box
|
||||
clipped_image = image[start_y:end_y, start_x:end_x] \
|
||||
if to_label else image[start_y:end_y, start_x:end_x, :]
|
||||
idx = osp.basename(image_path).split('.')[0]
|
||||
mmcv.imwrite(
|
||||
clipped_image.astype(np.uint8),
|
||||
osp.join(clip_save_dir,
|
||||
f'{idx}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
73
finetune/tools/dataset_converters/loveda.py
Normal file
73
finetune/tools/dataset_converters/loveda.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert LoveDA dataset to mmsegmentation format')
|
||||
parser.add_argument('dataset_path', help='LoveDA folder path')
|
||||
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
||||
parser.add_argument('-o', '--out_dir', help='output path')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
dataset_path = args.dataset_path
|
||||
if args.out_dir is None:
|
||||
out_dir = osp.join('data', 'loveDA')
|
||||
else:
|
||||
out_dir = args.out_dir
|
||||
|
||||
print('Making directories...')
|
||||
mkdir_or_exist(out_dir)
|
||||
mkdir_or_exist(osp.join(out_dir, 'img_dir'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'ann_dir'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
|
||||
|
||||
assert 'Train.zip' in os.listdir(dataset_path), \
|
||||
f'Train.zip is not in {dataset_path}'
|
||||
assert 'Val.zip' in os.listdir(dataset_path), \
|
||||
f'Val.zip is not in {dataset_path}'
|
||||
assert 'Test.zip' in os.listdir(dataset_path), \
|
||||
f'Test.zip is not in {dataset_path}'
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
||||
for dataset in ['Train', 'Val', 'Test']:
|
||||
zip_file = zipfile.ZipFile(
|
||||
os.path.join(dataset_path, dataset + '.zip'))
|
||||
zip_file.extractall(tmp_dir)
|
||||
data_type = dataset.lower()
|
||||
for location in ['Rural', 'Urban']:
|
||||
for image_type in ['images_png', 'masks_png']:
|
||||
if image_type == 'images_png':
|
||||
dst = osp.join(out_dir, 'img_dir', data_type)
|
||||
else:
|
||||
dst = osp.join(out_dir, 'ann_dir', data_type)
|
||||
if dataset == 'Test' and image_type == 'masks_png':
|
||||
continue
|
||||
else:
|
||||
src_dir = osp.join(tmp_dir, dataset, location,
|
||||
image_type)
|
||||
src_lst = os.listdir(src_dir)
|
||||
for file in src_lst:
|
||||
shutil.move(osp.join(src_dir, file), dst)
|
||||
print('Removing the temporary files...')
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
89
finetune/tools/dataset_converters/nyu.py
Normal file
89
finetune/tools/dataset_converters/nyu.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert NYU Depth dataset to mmsegmentation format')
|
||||
parser.add_argument('raw_data', help='the path of raw data')
|
||||
parser.add_argument(
|
||||
'-o', '--out_dir', help='output path', default='./data/nyu')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def reorganize(raw_data_dir: str, out_dir: str):
|
||||
"""Reorganize NYU Depth dataset files into the required directory
|
||||
structure.
|
||||
|
||||
Args:
|
||||
raw_data_dir (str): Path to the raw data directory.
|
||||
out_dir (str): Output directory for the organized dataset.
|
||||
"""
|
||||
|
||||
def move_data(data_list, dst_prefix, fname_func):
|
||||
"""Move data files from source to destination directory.
|
||||
|
||||
Args:
|
||||
data_list (list): List of data file paths.
|
||||
dst_prefix (str): Prefix to be added to destination paths.
|
||||
fname_func (callable): Function to process file names
|
||||
"""
|
||||
for data_item in data_list:
|
||||
data_item = data_item.strip().strip('/')
|
||||
new_item = fname_func(data_item)
|
||||
shutil.move(
|
||||
osp.join(raw_data_dir, data_item),
|
||||
osp.join(out_dir, dst_prefix, new_item))
|
||||
|
||||
def process_phase(phase):
|
||||
"""Process a dataset phase (e.g., 'train' or 'test')."""
|
||||
with open(osp.join(raw_data_dir, f'nyu_{phase}.txt')) as f:
|
||||
data = filter(lambda x: len(x.strip()) > 0, f.readlines())
|
||||
data = map(lambda x: x.split()[:2], data)
|
||||
images, annos = zip(*data)
|
||||
|
||||
move_data(images, f'images/{phase}',
|
||||
lambda x: x.replace('/rgb', ''))
|
||||
move_data(annos, f'annotations/{phase}',
|
||||
lambda x: x.replace('/sync_depth', ''))
|
||||
|
||||
process_phase('train')
|
||||
process_phase('test')
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
print('Making directories...')
|
||||
mkdir_or_exist(args.out_dir)
|
||||
for subdir in [
|
||||
'images/train', 'images/test', 'annotations/train',
|
||||
'annotations/test'
|
||||
]:
|
||||
mkdir_or_exist(osp.join(args.out_dir, subdir))
|
||||
|
||||
print('Generating images and annotations...')
|
||||
|
||||
if args.raw_data.endswith('.zip'):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
zip_file = zipfile.ZipFile(args.raw_data)
|
||||
zip_file.extractall(tmp_dir)
|
||||
reorganize(osp.join(tmp_dir, 'nyu'), args.out_dir)
|
||||
else:
|
||||
assert osp.isdir(
|
||||
args.raw_data
|
||||
), 'the argument --raw-data should be either a zip file or directory.'
|
||||
reorganize(args.raw_data, args.out_dir)
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
87
finetune/tools/dataset_converters/pascal_context.py
Normal file
87
finetune/tools/dataset_converters/pascal_context.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from detail import Detail
|
||||
from mmengine.utils import mkdir_or_exist, track_progress
|
||||
from PIL import Image
|
||||
|
||||
_mapping = np.sort(
|
||||
np.array([
|
||||
0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284,
|
||||
158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59,
|
||||
440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355,
|
||||
85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115
|
||||
]))
|
||||
_key = np.array(range(len(_mapping))).astype('uint8')
|
||||
|
||||
|
||||
def generate_labels(img_id, detail, out_dir):
|
||||
|
||||
def _class_to_index(mask, _mapping, _key):
|
||||
# assert the values
|
||||
values = np.unique(mask)
|
||||
for i in range(len(values)):
|
||||
assert (values[i] in _mapping)
|
||||
index = np.digitize(mask.ravel(), _mapping, right=True)
|
||||
return _key[index].reshape(mask.shape)
|
||||
|
||||
mask = Image.fromarray(
|
||||
_class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key))
|
||||
filename = img_id['file_name']
|
||||
mask.save(osp.join(out_dir, filename.replace('jpg', 'png')))
|
||||
return osp.splitext(osp.basename(filename))[0]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert PASCAL VOC annotations to mmsegmentation format')
|
||||
parser.add_argument('devkit_path', help='pascal voc devkit path')
|
||||
parser.add_argument('json_path', help='annoation json filepath')
|
||||
parser.add_argument('-o', '--out_dir', help='output path')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
devkit_path = args.devkit_path
|
||||
if args.out_dir is None:
|
||||
out_dir = osp.join(devkit_path, 'VOC2010', 'SegmentationClassContext')
|
||||
else:
|
||||
out_dir = args.out_dir
|
||||
json_path = args.json_path
|
||||
mkdir_or_exist(out_dir)
|
||||
img_dir = osp.join(devkit_path, 'VOC2010', 'JPEGImages')
|
||||
|
||||
train_detail = Detail(json_path, img_dir, 'train')
|
||||
train_ids = train_detail.getImgs()
|
||||
|
||||
val_detail = Detail(json_path, img_dir, 'val')
|
||||
val_ids = val_detail.getImgs()
|
||||
|
||||
mkdir_or_exist(
|
||||
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext'))
|
||||
|
||||
train_list = track_progress(
|
||||
partial(generate_labels, detail=train_detail, out_dir=out_dir),
|
||||
train_ids)
|
||||
with open(
|
||||
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
|
||||
'train.txt'), 'w') as f:
|
||||
f.writelines(line + '\n' for line in sorted(train_list))
|
||||
|
||||
val_list = track_progress(
|
||||
partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids)
|
||||
with open(
|
||||
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
|
||||
'val.txt'), 'w') as f:
|
||||
f.writelines(line + '\n' for line in sorted(val_list))
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
158
finetune/tools/dataset_converters/potsdam.py
Normal file
158
finetune/tools/dataset_converters/potsdam.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmengine.utils import ProgressBar, mkdir_or_exist
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert potsdam dataset to mmsegmentation format')
|
||||
parser.add_argument('dataset_path', help='potsdam folder path')
|
||||
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
||||
parser.add_argument('-o', '--out_dir', help='output path')
|
||||
parser.add_argument(
|
||||
'--clip_size',
|
||||
type=int,
|
||||
help='clipped size of image after preparation',
|
||||
default=512)
|
||||
parser.add_argument(
|
||||
'--stride_size',
|
||||
type=int,
|
||||
help='stride of clipping original images',
|
||||
default=256)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def clip_big_image(image_path, clip_save_dir, args, to_label=False):
|
||||
# Original image of Potsdam dataset is very large, thus pre-processing
|
||||
# of them is adopted. Given fixed clip size and stride size to generate
|
||||
# clipped image, the intersection of width and height is determined.
|
||||
# For example, given one 5120 x 5120 original image, the clip size is
|
||||
# 512 and stride size is 256, thus it would generate 20x20 = 400 images
|
||||
# whose size are all 512x512.
|
||||
image = mmcv.imread(image_path)
|
||||
|
||||
h, w, c = image.shape
|
||||
clip_size = args.clip_size
|
||||
stride_size = args.stride_size
|
||||
|
||||
num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
|
||||
(h - clip_size) /
|
||||
stride_size) * stride_size + clip_size >= h else math.ceil(
|
||||
(h - clip_size) / stride_size) + 1
|
||||
num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
|
||||
(w - clip_size) /
|
||||
stride_size) * stride_size + clip_size >= w else math.ceil(
|
||||
(w - clip_size) / stride_size) + 1
|
||||
|
||||
x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
|
||||
xmin = x * clip_size
|
||||
ymin = y * clip_size
|
||||
|
||||
xmin = xmin.ravel()
|
||||
ymin = ymin.ravel()
|
||||
xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
|
||||
np.zeros_like(xmin))
|
||||
ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
|
||||
np.zeros_like(ymin))
|
||||
boxes = np.stack([
|
||||
xmin + xmin_offset, ymin + ymin_offset,
|
||||
np.minimum(xmin + clip_size, w),
|
||||
np.minimum(ymin + clip_size, h)
|
||||
],
|
||||
axis=1)
|
||||
|
||||
if to_label:
|
||||
color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0],
|
||||
[255, 255, 0], [0, 255, 0], [0, 255, 255],
|
||||
[0, 0, 255]])
|
||||
flatten_v = np.matmul(
|
||||
image.reshape(-1, c),
|
||||
np.array([2, 3, 4]).reshape(3, 1))
|
||||
out = np.zeros_like(flatten_v)
|
||||
for idx, class_color in enumerate(color_map):
|
||||
value_idx = np.matmul(class_color,
|
||||
np.array([2, 3, 4]).reshape(3, 1))
|
||||
out[flatten_v == value_idx] = idx
|
||||
image = out.reshape(h, w)
|
||||
|
||||
for box in boxes:
|
||||
start_x, start_y, end_x, end_y = box
|
||||
clipped_image = image[start_y:end_y,
|
||||
start_x:end_x] if to_label else image[
|
||||
start_y:end_y, start_x:end_x, :]
|
||||
idx_i, idx_j = osp.basename(image_path).split('_')[2:4]
|
||||
mmcv.imwrite(
|
||||
clipped_image.astype(np.uint8),
|
||||
osp.join(
|
||||
clip_save_dir,
|
||||
f'{idx_i}_{idx_j}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
splits = {
|
||||
'train': [
|
||||
'2_10', '2_11', '2_12', '3_10', '3_11', '3_12', '4_10', '4_11',
|
||||
'4_12', '5_10', '5_11', '5_12', '6_10', '6_11', '6_12', '6_7',
|
||||
'6_8', '6_9', '7_10', '7_11', '7_12', '7_7', '7_8', '7_9'
|
||||
],
|
||||
'val': [
|
||||
'5_15', '6_15', '6_13', '3_13', '4_14', '6_14', '5_14', '2_13',
|
||||
'4_15', '2_14', '5_13', '4_13', '3_14', '7_13'
|
||||
]
|
||||
}
|
||||
|
||||
dataset_path = args.dataset_path
|
||||
if args.out_dir is None:
|
||||
out_dir = osp.join('data', 'potsdam')
|
||||
else:
|
||||
out_dir = args.out_dir
|
||||
|
||||
print('Making directories...')
|
||||
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
|
||||
|
||||
zipp_list = glob.glob(os.path.join(dataset_path, '*.zip'))
|
||||
print('Find the data', zipp_list)
|
||||
|
||||
for zipp in zipp_list:
|
||||
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
||||
zip_file = zipfile.ZipFile(zipp)
|
||||
zip_file.extractall(tmp_dir)
|
||||
src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif'))
|
||||
if not len(src_path_list):
|
||||
sub_tmp_dir = os.path.join(tmp_dir, os.listdir(tmp_dir)[0])
|
||||
src_path_list = glob.glob(os.path.join(sub_tmp_dir, '*.tif'))
|
||||
|
||||
prog_bar = ProgressBar(len(src_path_list))
|
||||
for i, src_path in enumerate(src_path_list):
|
||||
idx_i, idx_j = osp.basename(src_path).split('_')[2:4]
|
||||
data_type = 'train' if f'{idx_i}_{idx_j}' in splits[
|
||||
'train'] else 'val'
|
||||
if 'label' in src_path:
|
||||
dst_dir = osp.join(out_dir, 'ann_dir', data_type)
|
||||
clip_big_image(src_path, dst_dir, args, to_label=True)
|
||||
else:
|
||||
dst_dir = osp.join(out_dir, 'img_dir', data_type)
|
||||
clip_big_image(src_path, dst_dir, args, to_label=False)
|
||||
prog_bar.update()
|
||||
|
||||
print('Removing the temporary files...')
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
110
finetune/tools/dataset_converters/refuge.py
Normal file
110
finetune/tools/dataset_converters/refuge.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert REFUGE dataset to mmsegmentation format')
|
||||
parser.add_argument('--raw_data_root', help='the root path of raw data')
|
||||
|
||||
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
||||
parser.add_argument('-o', '--out_dir', help='output path')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def extract_img(root: str,
|
||||
cur_dir: str,
|
||||
out_dir: str,
|
||||
mode: str = 'train',
|
||||
file_type: str = 'img') -> None:
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
Args:
|
||||
root (str): root where the extracted data is saved
|
||||
cur_dir (cur_dir): dir where the zip_file exists
|
||||
out_dir (str): root dir where the data is saved
|
||||
|
||||
mode (str, optional): Defaults to 'train'.
|
||||
file_type (str, optional): Defaults to 'img',else to 'mask'.
|
||||
"""
|
||||
zip_file = zipfile.ZipFile(cur_dir)
|
||||
zip_file.extractall(root)
|
||||
for cur_dir, dirs, files in os.walk(root):
|
||||
# filter child dirs and directories with "Illustration" and "MACOSX"
|
||||
if len(dirs) == 0 and \
|
||||
cur_dir.split('\\')[-1].find('Illustration') == -1 and \
|
||||
cur_dir.find('MACOSX') == -1:
|
||||
|
||||
file_names = [
|
||||
file for file in files
|
||||
if file.endswith('.jpg') or file.endswith('.bmp')
|
||||
]
|
||||
for filename in sorted(file_names):
|
||||
img = mmcv.imread(osp.join(cur_dir, filename))
|
||||
|
||||
if file_type == 'annotations':
|
||||
img = img[:, :, 0]
|
||||
img[np.where(img == 0)] = 1
|
||||
img[np.where(img == 128)] = 2
|
||||
img[np.where(img == 255)] = 0
|
||||
mmcv.imwrite(
|
||||
img,
|
||||
osp.join(out_dir, file_type, mode,
|
||||
osp.splitext(filename)[0] + '.png'))
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
raw_data_root = args.raw_data_root
|
||||
if args.out_dir is None:
|
||||
out_dir = osp.join('./data', 'REFUGE')
|
||||
|
||||
else:
|
||||
out_dir = args.out_dir
|
||||
|
||||
print('Making directories...')
|
||||
mkdir_or_exist(out_dir)
|
||||
mkdir_or_exist(osp.join(out_dir, 'images'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'images', 'test'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'annotations'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'annotations', 'test'))
|
||||
|
||||
print('Generating images and annotations...')
|
||||
# process data from the child dir on the first rank
|
||||
cur_dir, dirs, files = list(os.walk(raw_data_root))[0]
|
||||
print('====================')
|
||||
|
||||
files = list(filter(lambda x: x.endswith('.zip'), files))
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
||||
for file in files:
|
||||
# search data folders for training,validation,test
|
||||
mode = list(
|
||||
filter(lambda x: file.lower().find(x) != -1,
|
||||
['training', 'test', 'validation']))[0]
|
||||
file_root = osp.join(tmp_dir, file[:-4])
|
||||
file_type = 'images' if file.find('Anno') == -1 and file.find(
|
||||
'GT') == -1 else 'annotations'
|
||||
extract_img(file_root, osp.join(cur_dir, file), out_dir, mode,
|
||||
file_type)
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
167
finetune/tools/dataset_converters/stare.py
Normal file
167
finetune/tools/dataset_converters/stare.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import gzip
|
||||
import os
|
||||
import os.path as osp
|
||||
import tarfile
|
||||
import tempfile
|
||||
|
||||
import mmcv
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
STARE_LEN = 20
|
||||
TRAINING_LEN = 10
|
||||
|
||||
|
||||
def un_gz(src, dst):
|
||||
g_file = gzip.GzipFile(src)
|
||||
with open(dst, 'wb+') as f:
|
||||
f.write(g_file.read())
|
||||
g_file.close()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert STARE dataset to mmsegmentation format')
|
||||
parser.add_argument('image_path', help='the path of stare-images.tar')
|
||||
parser.add_argument('labels_ah', help='the path of labels-ah.tar')
|
||||
parser.add_argument('labels_vk', help='the path of labels-vk.tar')
|
||||
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
||||
parser.add_argument('-o', '--out_dir', help='output path')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
image_path = args.image_path
|
||||
labels_ah = args.labels_ah
|
||||
labels_vk = args.labels_vk
|
||||
if args.out_dir is None:
|
||||
out_dir = osp.join('data', 'STARE')
|
||||
else:
|
||||
out_dir = args.out_dir
|
||||
|
||||
print('Making directories...')
|
||||
mkdir_or_exist(out_dir)
|
||||
mkdir_or_exist(osp.join(out_dir, 'images'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'annotations'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
||||
mkdir_or_exist(osp.join(tmp_dir, 'gz'))
|
||||
mkdir_or_exist(osp.join(tmp_dir, 'files'))
|
||||
|
||||
print('Extracting stare-images.tar...')
|
||||
with tarfile.open(image_path) as f:
|
||||
f.extractall(osp.join(tmp_dir, 'gz'))
|
||||
|
||||
for filename in os.listdir(osp.join(tmp_dir, 'gz')):
|
||||
un_gz(
|
||||
osp.join(tmp_dir, 'gz', filename),
|
||||
osp.join(tmp_dir, 'files',
|
||||
osp.splitext(filename)[0]))
|
||||
|
||||
now_dir = osp.join(tmp_dir, 'files')
|
||||
|
||||
assert len(os.listdir(now_dir)) == STARE_LEN, \
|
||||
f'len(os.listdir(now_dir)) != {STARE_LEN}'
|
||||
|
||||
for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
|
||||
img = mmcv.imread(osp.join(now_dir, filename))
|
||||
mmcv.imwrite(
|
||||
img,
|
||||
osp.join(out_dir, 'images', 'training',
|
||||
osp.splitext(filename)[0] + '.png'))
|
||||
|
||||
for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
|
||||
img = mmcv.imread(osp.join(now_dir, filename))
|
||||
mmcv.imwrite(
|
||||
img,
|
||||
osp.join(out_dir, 'images', 'validation',
|
||||
osp.splitext(filename)[0] + '.png'))
|
||||
|
||||
print('Removing the temporary files...')
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
||||
mkdir_or_exist(osp.join(tmp_dir, 'gz'))
|
||||
mkdir_or_exist(osp.join(tmp_dir, 'files'))
|
||||
|
||||
print('Extracting labels-ah.tar...')
|
||||
with tarfile.open(labels_ah) as f:
|
||||
f.extractall(osp.join(tmp_dir, 'gz'))
|
||||
|
||||
for filename in os.listdir(osp.join(tmp_dir, 'gz')):
|
||||
un_gz(
|
||||
osp.join(tmp_dir, 'gz', filename),
|
||||
osp.join(tmp_dir, 'files',
|
||||
osp.splitext(filename)[0]))
|
||||
|
||||
now_dir = osp.join(tmp_dir, 'files')
|
||||
|
||||
assert len(os.listdir(now_dir)) == STARE_LEN, \
|
||||
f'len(os.listdir(now_dir)) != {STARE_LEN}'
|
||||
|
||||
for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
|
||||
img = mmcv.imread(osp.join(now_dir, filename))
|
||||
# The annotation img should be divided by 128, because some of
|
||||
# the annotation imgs are not standard. We should set a threshold
|
||||
# to convert the nonstandard annotation imgs. The value divided by
|
||||
# 128 equivalent to '1 if value >= 128 else 0'
|
||||
mmcv.imwrite(
|
||||
img[:, :, 0] // 128,
|
||||
osp.join(out_dir, 'annotations', 'training',
|
||||
osp.splitext(filename)[0] + '.png'))
|
||||
|
||||
for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
|
||||
img = mmcv.imread(osp.join(now_dir, filename))
|
||||
mmcv.imwrite(
|
||||
img[:, :, 0] // 128,
|
||||
osp.join(out_dir, 'annotations', 'validation',
|
||||
osp.splitext(filename)[0] + '.png'))
|
||||
|
||||
print('Removing the temporary files...')
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
||||
mkdir_or_exist(osp.join(tmp_dir, 'gz'))
|
||||
mkdir_or_exist(osp.join(tmp_dir, 'files'))
|
||||
|
||||
print('Extracting labels-vk.tar...')
|
||||
with tarfile.open(labels_vk) as f:
|
||||
f.extractall(osp.join(tmp_dir, 'gz'))
|
||||
|
||||
for filename in os.listdir(osp.join(tmp_dir, 'gz')):
|
||||
un_gz(
|
||||
osp.join(tmp_dir, 'gz', filename),
|
||||
osp.join(tmp_dir, 'files',
|
||||
osp.splitext(filename)[0]))
|
||||
|
||||
now_dir = osp.join(tmp_dir, 'files')
|
||||
|
||||
assert len(os.listdir(now_dir)) == STARE_LEN, \
|
||||
f'len(os.listdir(now_dir)) != {STARE_LEN}'
|
||||
|
||||
for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
|
||||
img = mmcv.imread(osp.join(now_dir, filename))
|
||||
mmcv.imwrite(
|
||||
img[:, :, 0] // 128,
|
||||
osp.join(out_dir, 'annotations', 'training',
|
||||
osp.splitext(filename)[0] + '.png'))
|
||||
|
||||
for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
|
||||
img = mmcv.imread(osp.join(now_dir, filename))
|
||||
mmcv.imwrite(
|
||||
img[:, :, 0] // 128,
|
||||
osp.join(out_dir, 'annotations', 'validation',
|
||||
osp.splitext(filename)[0] + '.png'))
|
||||
|
||||
print('Removing the temporary files...')
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
155
finetune/tools/dataset_converters/synapse.py
Normal file
155
finetune/tools/dataset_converters/synapse.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
|
||||
import nibabel as nib
|
||||
import numpy as np
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def read_files_from_txt(txt_path):
|
||||
with open(txt_path) as f:
|
||||
files = f.readlines()
|
||||
files = [file.strip() for file in files]
|
||||
return files
|
||||
|
||||
|
||||
def read_nii_file(nii_path):
|
||||
img = nib.load(nii_path).get_fdata()
|
||||
return img
|
||||
|
||||
|
||||
def split_3d_image(img):
|
||||
c, _, _ = img.shape
|
||||
res = []
|
||||
for i in range(c):
|
||||
res.append(img[i, :, :])
|
||||
return res
|
||||
|
||||
|
||||
def label_mapping(label):
|
||||
"""Label mapping from TransUNet paper setting. It only has 9 classes, which
|
||||
are 'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney',
|
||||
'liver', 'pancreas', 'spleen', 'stomach', respectively. Other foreground
|
||||
classes in original dataset are all set to background.
|
||||
|
||||
More details could be found here: https://arxiv.org/abs/2102.04306
|
||||
"""
|
||||
maped_label = np.zeros_like(label)
|
||||
maped_label[label == 8] = 1
|
||||
maped_label[label == 4] = 2
|
||||
maped_label[label == 3] = 3
|
||||
maped_label[label == 2] = 4
|
||||
maped_label[label == 6] = 5
|
||||
maped_label[label == 11] = 6
|
||||
maped_label[label == 1] = 7
|
||||
maped_label[label == 7] = 8
|
||||
return maped_label
|
||||
|
||||
|
||||
def pares_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert synapse dataset to mmsegmentation format')
|
||||
parser.add_argument(
|
||||
'--dataset-path', type=str, help='synapse dataset path.')
|
||||
parser.add_argument(
|
||||
'--save-path',
|
||||
default='data/synapse',
|
||||
type=str,
|
||||
help='save path of the dataset.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = pares_args()
|
||||
dataset_path = args.dataset_path
|
||||
save_path = args.save_path
|
||||
|
||||
if not osp.exists(dataset_path):
|
||||
raise ValueError('The dataset path does not exist. '
|
||||
'Please enter a correct dataset path.')
|
||||
if not osp.exists(osp.join(dataset_path, 'img')) \
|
||||
or not osp.exists(osp.join(dataset_path, 'label')):
|
||||
raise FileNotFoundError('The dataset structure is incorrect. '
|
||||
'Please check your dataset.')
|
||||
|
||||
train_id = read_files_from_txt(osp.join(dataset_path, 'train.txt'))
|
||||
train_id = [idx[3:7] for idx in train_id]
|
||||
|
||||
test_id = read_files_from_txt(osp.join(dataset_path, 'val.txt'))
|
||||
test_id = [idx[3:7] for idx in test_id]
|
||||
|
||||
mkdir_or_exist(osp.join(save_path, 'img_dir/train'))
|
||||
mkdir_or_exist(osp.join(save_path, 'img_dir/val'))
|
||||
mkdir_or_exist(osp.join(save_path, 'ann_dir/train'))
|
||||
mkdir_or_exist(osp.join(save_path, 'ann_dir/val'))
|
||||
|
||||
# It follows data preparation pipeline from here:
|
||||
# https://github.com/Beckschen/TransUNet/tree/main/datasets
|
||||
for i, idx in enumerate(train_id):
|
||||
img_3d = read_nii_file(
|
||||
osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz'))
|
||||
label_3d = read_nii_file(
|
||||
osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz'))
|
||||
|
||||
img_3d = np.clip(img_3d, -125, 275)
|
||||
img_3d = (img_3d + 125) / 400
|
||||
img_3d *= 255
|
||||
img_3d = np.transpose(img_3d, [2, 0, 1])
|
||||
img_3d = np.flip(img_3d, 2)
|
||||
|
||||
label_3d = np.transpose(label_3d, [2, 0, 1])
|
||||
label_3d = np.flip(label_3d, 2)
|
||||
label_3d = label_mapping(label_3d)
|
||||
|
||||
for c in range(img_3d.shape[0]):
|
||||
img = img_3d[c]
|
||||
label = label_3d[c]
|
||||
|
||||
img = Image.fromarray(img).convert('RGB')
|
||||
label = Image.fromarray(label).convert('L')
|
||||
img.save(
|
||||
osp.join(
|
||||
save_path, 'img_dir/train', 'case' + idx.zfill(4) +
|
||||
'_slice' + str(c).zfill(3) + '.jpg'))
|
||||
label.save(
|
||||
osp.join(
|
||||
save_path, 'ann_dir/train', 'case' + idx.zfill(4) +
|
||||
'_slice' + str(c).zfill(3) + '.png'))
|
||||
|
||||
for i, idx in enumerate(test_id):
|
||||
img_3d = read_nii_file(
|
||||
osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz'))
|
||||
label_3d = read_nii_file(
|
||||
osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz'))
|
||||
|
||||
img_3d = np.clip(img_3d, -125, 275)
|
||||
img_3d = (img_3d + 125) / 400
|
||||
img_3d *= 255
|
||||
img_3d = np.transpose(img_3d, [2, 0, 1])
|
||||
img_3d = np.flip(img_3d, 2)
|
||||
|
||||
label_3d = np.transpose(label_3d, [2, 0, 1])
|
||||
label_3d = np.flip(label_3d, 2)
|
||||
label_3d = label_mapping(label_3d)
|
||||
|
||||
for c in range(img_3d.shape[0]):
|
||||
img = img_3d[c]
|
||||
label = label_3d[c]
|
||||
|
||||
img = Image.fromarray(img).convert('RGB')
|
||||
label = Image.fromarray(label).convert('L')
|
||||
img.save(
|
||||
osp.join(
|
||||
save_path, 'img_dir/val', 'case' + idx.zfill(4) +
|
||||
'_slice' + str(c).zfill(3) + '.jpg'))
|
||||
label.save(
|
||||
osp.join(
|
||||
save_path, 'ann_dir/val', 'case' + idx.zfill(4) +
|
||||
'_slice' + str(c).zfill(3) + '.png'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
156
finetune/tools/dataset_converters/vaihingen.py
Normal file
156
finetune/tools/dataset_converters/vaihingen.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmengine.utils import ProgressBar, mkdir_or_exist
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert vaihingen dataset to mmsegmentation format')
|
||||
parser.add_argument('dataset_path', help='vaihingen folder path')
|
||||
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
||||
parser.add_argument('-o', '--out_dir', help='output path')
|
||||
parser.add_argument(
|
||||
'--clip_size',
|
||||
type=int,
|
||||
help='clipped size of image after preparation',
|
||||
default=512)
|
||||
parser.add_argument(
|
||||
'--stride_size',
|
||||
type=int,
|
||||
help='stride of clipping original images',
|
||||
default=256)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def clip_big_image(image_path, clip_save_dir, to_label=False):
|
||||
# Original image of Vaihingen dataset is very large, thus pre-processing
|
||||
# of them is adopted. Given fixed clip size and stride size to generate
|
||||
# clipped image, the intersection of width and height is determined.
|
||||
# For example, given one 5120 x 5120 original image, the clip size is
|
||||
# 512 and stride size is 256, thus it would generate 20x20 = 400 images
|
||||
# whose size are all 512x512.
|
||||
image = mmcv.imread(image_path)
|
||||
|
||||
h, w, c = image.shape
|
||||
cs = args.clip_size
|
||||
ss = args.stride_size
|
||||
|
||||
num_rows = math.ceil((h - cs) / ss) if math.ceil(
|
||||
(h - cs) / ss) * ss + cs >= h else math.ceil((h - cs) / ss) + 1
|
||||
num_cols = math.ceil((w - cs) / ss) if math.ceil(
|
||||
(w - cs) / ss) * ss + cs >= w else math.ceil((w - cs) / ss) + 1
|
||||
|
||||
x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
|
||||
xmin = x * cs
|
||||
ymin = y * cs
|
||||
|
||||
xmin = xmin.ravel()
|
||||
ymin = ymin.ravel()
|
||||
xmin_offset = np.where(xmin + cs > w, w - xmin - cs, np.zeros_like(xmin))
|
||||
ymin_offset = np.where(ymin + cs > h, h - ymin - cs, np.zeros_like(ymin))
|
||||
boxes = np.stack([
|
||||
xmin + xmin_offset, ymin + ymin_offset,
|
||||
np.minimum(xmin + cs, w),
|
||||
np.minimum(ymin + cs, h)
|
||||
],
|
||||
axis=1)
|
||||
|
||||
if to_label:
|
||||
color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0],
|
||||
[255, 255, 0], [0, 255, 0], [0, 255, 255],
|
||||
[0, 0, 255]])
|
||||
flatten_v = np.matmul(
|
||||
image.reshape(-1, c),
|
||||
np.array([2, 3, 4]).reshape(3, 1))
|
||||
out = np.zeros_like(flatten_v)
|
||||
for idx, class_color in enumerate(color_map):
|
||||
value_idx = np.matmul(class_color,
|
||||
np.array([2, 3, 4]).reshape(3, 1))
|
||||
out[flatten_v == value_idx] = idx
|
||||
image = out.reshape(h, w)
|
||||
|
||||
for box in boxes:
|
||||
start_x, start_y, end_x, end_y = box
|
||||
clipped_image = image[start_y:end_y,
|
||||
start_x:end_x] if to_label else image[
|
||||
start_y:end_y, start_x:end_x, :]
|
||||
area_idx = osp.basename(image_path).split('_')[3].strip('.tif')
|
||||
mmcv.imwrite(
|
||||
clipped_image.astype(np.uint8),
|
||||
osp.join(clip_save_dir,
|
||||
f'{area_idx}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
|
||||
|
||||
|
||||
def main():
|
||||
splits = {
|
||||
'train': [
|
||||
'area1', 'area11', 'area13', 'area15', 'area17', 'area21',
|
||||
'area23', 'area26', 'area28', 'area3', 'area30', 'area32',
|
||||
'area34', 'area37', 'area5', 'area7'
|
||||
],
|
||||
'val': [
|
||||
'area6', 'area24', 'area35', 'area16', 'area14', 'area22',
|
||||
'area10', 'area4', 'area2', 'area20', 'area8', 'area31', 'area33',
|
||||
'area27', 'area38', 'area12', 'area29'
|
||||
],
|
||||
}
|
||||
|
||||
dataset_path = args.dataset_path
|
||||
if args.out_dir is None:
|
||||
out_dir = osp.join('data', 'vaihingen')
|
||||
else:
|
||||
out_dir = args.out_dir
|
||||
|
||||
print('Making directories...')
|
||||
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
|
||||
|
||||
zipp_list = glob.glob(os.path.join(dataset_path, '*.zip'))
|
||||
print('Find the data', zipp_list)
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
||||
for zipp in zipp_list:
|
||||
zip_file = zipfile.ZipFile(zipp)
|
||||
zip_file.extractall(tmp_dir)
|
||||
src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif'))
|
||||
if 'ISPRS_semantic_labeling_Vaihingen' in zipp:
|
||||
src_path_list = glob.glob(
|
||||
os.path.join(os.path.join(tmp_dir, 'top'), '*.tif'))
|
||||
if 'ISPRS_semantic_labeling_Vaihingen_ground_truth_eroded_COMPLETE' in zipp: # noqa
|
||||
src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif'))
|
||||
# delete unused area9 ground truth
|
||||
for area_ann in src_path_list:
|
||||
if 'area9' in area_ann:
|
||||
src_path_list.remove(area_ann)
|
||||
prog_bar = ProgressBar(len(src_path_list))
|
||||
for i, src_path in enumerate(src_path_list):
|
||||
area_idx = osp.basename(src_path).split('_')[3].strip('.tif')
|
||||
data_type = 'train' if area_idx in splits['train'] else 'val'
|
||||
if 'noBoundary' in src_path:
|
||||
dst_dir = osp.join(out_dir, 'ann_dir', data_type)
|
||||
clip_big_image(src_path, dst_dir, to_label=True)
|
||||
else:
|
||||
dst_dir = osp.join(out_dir, 'img_dir', data_type)
|
||||
clip_big_image(src_path, dst_dir, to_label=False)
|
||||
prog_bar.update()
|
||||
|
||||
print('Removing the temporary files...')
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main()
|
||||
92
finetune/tools/dataset_converters/voc_aug.py
Normal file
92
finetune/tools/dataset_converters/voc_aug.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from mmengine.utils import mkdir_or_exist, scandir, track_parallel_progress
|
||||
from PIL import Image
|
||||
from scipy.io import loadmat
|
||||
|
||||
AUG_LEN = 10582
|
||||
|
||||
|
||||
def convert_mat(mat_file, in_dir, out_dir):
|
||||
data = loadmat(osp.join(in_dir, mat_file))
|
||||
mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8)
|
||||
seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png'))
|
||||
Image.fromarray(mask).save(seg_filename, 'PNG')
|
||||
|
||||
|
||||
def generate_aug_list(merged_list, excluded_list):
|
||||
return list(set(merged_list) - set(excluded_list))
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert PASCAL VOC annotations to mmsegmentation format')
|
||||
parser.add_argument('devkit_path', help='pascal voc devkit path')
|
||||
parser.add_argument('aug_path', help='pascal voc aug path')
|
||||
parser.add_argument('-o', '--out_dir', help='output path')
|
||||
parser.add_argument(
|
||||
'--nproc', default=1, type=int, help='number of process')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
devkit_path = args.devkit_path
|
||||
aug_path = args.aug_path
|
||||
nproc = args.nproc
|
||||
if args.out_dir is None:
|
||||
out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug')
|
||||
else:
|
||||
out_dir = args.out_dir
|
||||
mkdir_or_exist(out_dir)
|
||||
in_dir = osp.join(aug_path, 'dataset', 'cls')
|
||||
|
||||
track_parallel_progress(
|
||||
partial(convert_mat, in_dir=in_dir, out_dir=out_dir),
|
||||
list(scandir(in_dir, suffix='.mat')),
|
||||
nproc=nproc)
|
||||
|
||||
full_aug_list = []
|
||||
with open(osp.join(aug_path, 'dataset', 'train.txt')) as f:
|
||||
full_aug_list += [line.strip() for line in f]
|
||||
with open(osp.join(aug_path, 'dataset', 'val.txt')) as f:
|
||||
full_aug_list += [line.strip() for line in f]
|
||||
|
||||
with open(
|
||||
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
|
||||
'train.txt')) as f:
|
||||
ori_train_list = [line.strip() for line in f]
|
||||
with open(
|
||||
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
|
||||
'val.txt')) as f:
|
||||
val_list = [line.strip() for line in f]
|
||||
|
||||
aug_train_list = generate_aug_list(ori_train_list + full_aug_list,
|
||||
val_list)
|
||||
assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format(
|
||||
AUG_LEN)
|
||||
|
||||
with open(
|
||||
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
|
||||
'trainaug.txt'), 'w') as f:
|
||||
f.writelines(line + '\n' for line in aug_train_list)
|
||||
|
||||
aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list)
|
||||
assert len(aug_list) == AUG_LEN - len(
|
||||
ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN -
|
||||
len(ori_train_list))
|
||||
with open(
|
||||
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'),
|
||||
'w') as f:
|
||||
f.writelines(line + '\n' for line in aug_list)
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
185
finetune/tools/deployment/pytorch2torchscript.py
Normal file
185
finetune/tools/deployment/pytorch2torchscript.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch._C
|
||||
import torch.serialization
|
||||
from mmengine import Config
|
||||
from mmengine.runner import load_checkpoint
|
||||
from torch import nn
|
||||
|
||||
from mmseg.models import build_segmentor
|
||||
|
||||
torch.manual_seed(3)
|
||||
|
||||
|
||||
def digit_version(version_str):
|
||||
digit_version = []
|
||||
for x in version_str.split('.'):
|
||||
if x.isdigit():
|
||||
digit_version.append(int(x))
|
||||
elif x.find('rc') != -1:
|
||||
patch_version = x.split('rc')
|
||||
digit_version.append(int(patch_version[0]) - 1)
|
||||
digit_version.append(int(patch_version[1]))
|
||||
return digit_version
|
||||
|
||||
|
||||
def check_torch_version():
|
||||
torch_minimum_version = '1.8.0'
|
||||
torch_version = digit_version(torch.__version__)
|
||||
|
||||
assert (torch_version >= digit_version(torch_minimum_version)), \
|
||||
f'Torch=={torch.__version__} is not support for converting to ' \
|
||||
f'torchscript. Please install pytorch>={torch_minimum_version}.'
|
||||
|
||||
|
||||
def _convert_batchnorm(module):
|
||||
module_output = module
|
||||
if isinstance(module, torch.nn.SyncBatchNorm):
|
||||
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
|
||||
module.momentum, module.affine,
|
||||
module.track_running_stats)
|
||||
if module.affine:
|
||||
module_output.weight.data = module.weight.data.clone().detach()
|
||||
module_output.bias.data = module.bias.data.clone().detach()
|
||||
# keep requires_grad unchanged
|
||||
module_output.weight.requires_grad = module.weight.requires_grad
|
||||
module_output.bias.requires_grad = module.bias.requires_grad
|
||||
module_output.running_mean = module.running_mean
|
||||
module_output.running_var = module.running_var
|
||||
module_output.num_batches_tracked = module.num_batches_tracked
|
||||
for name, child in module.named_children():
|
||||
module_output.add_module(name, _convert_batchnorm(child))
|
||||
del module
|
||||
return module_output
|
||||
|
||||
|
||||
def _demo_mm_inputs(input_shape, num_classes):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
input_shape (tuple):
|
||||
input batch dimensions
|
||||
num_classes (int):
|
||||
number of semantic classes
|
||||
"""
|
||||
(N, C, H, W) = input_shape
|
||||
rng = np.random.RandomState(0)
|
||||
imgs = rng.rand(*input_shape)
|
||||
segs = rng.randint(
|
||||
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
|
||||
img_metas = [{
|
||||
'img_shape': (H, W, C),
|
||||
'ori_shape': (H, W, C),
|
||||
'pad_shape': (H, W, C),
|
||||
'filename': '<demo>.png',
|
||||
'scale_factor': 1.0,
|
||||
'flip': False,
|
||||
} for _ in range(N)]
|
||||
mm_inputs = {
|
||||
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
|
||||
'img_metas': img_metas,
|
||||
'gt_semantic_seg': torch.LongTensor(segs)
|
||||
}
|
||||
return mm_inputs
|
||||
|
||||
|
||||
def pytorch2libtorch(model,
|
||||
input_shape,
|
||||
show=False,
|
||||
output_file='tmp.pt',
|
||||
verify=False):
|
||||
"""Export Pytorch model to TorchScript model and verify the outputs are
|
||||
same between Pytorch and TorchScript.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Pytorch model we want to export.
|
||||
input_shape (tuple): Use this input shape to construct
|
||||
the corresponding dummy input and execute the model.
|
||||
show (bool): Whether print the computation graph. Default: False.
|
||||
output_file (string): The path to where we store the
|
||||
output TorchScript model. Default: `tmp.pt`.
|
||||
verify (bool): Whether compare the outputs between
|
||||
Pytorch and TorchScript. Default: False.
|
||||
"""
|
||||
if isinstance(model.decode_head, nn.ModuleList):
|
||||
num_classes = model.decode_head[-1].num_classes
|
||||
else:
|
||||
num_classes = model.decode_head.num_classes
|
||||
|
||||
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
|
||||
|
||||
imgs = mm_inputs.pop('imgs')
|
||||
|
||||
# replace the original forword with forward_dummy
|
||||
model.forward = model.forward_dummy
|
||||
model.eval()
|
||||
traced_model = torch.jit.trace(
|
||||
model,
|
||||
example_inputs=imgs,
|
||||
check_trace=verify,
|
||||
)
|
||||
|
||||
if show:
|
||||
print(traced_model.graph)
|
||||
|
||||
traced_model.save(output_file)
|
||||
print(f'Successfully exported TorchScript model: {output_file}')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert MMSeg to TorchScript')
|
||||
parser.add_argument('config', help='test config file path')
|
||||
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
|
||||
parser.add_argument(
|
||||
'--show', action='store_true', help='show TorchScript graph')
|
||||
parser.add_argument(
|
||||
'--verify', action='store_true', help='verify the TorchScript model')
|
||||
parser.add_argument('--output-file', type=str, default='tmp.pt')
|
||||
parser.add_argument(
|
||||
'--shape',
|
||||
type=int,
|
||||
nargs='+',
|
||||
default=[512, 512],
|
||||
help='input image size (height, width)')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
check_torch_version()
|
||||
|
||||
if len(args.shape) == 1:
|
||||
input_shape = (1, 3, args.shape[0], args.shape[0])
|
||||
elif len(args.shape) == 2:
|
||||
input_shape = (
|
||||
1,
|
||||
3,
|
||||
) + tuple(args.shape)
|
||||
else:
|
||||
raise ValueError('invalid input shape')
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg.model.pretrained = None
|
||||
|
||||
# build the model and load checkpoint
|
||||
cfg.model.train_cfg = None
|
||||
segmentor = build_segmentor(
|
||||
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
|
||||
# convert SyncBN to BN
|
||||
segmentor = _convert_batchnorm(segmentor)
|
||||
|
||||
if args.checkpoint:
|
||||
load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
|
||||
|
||||
# convert the PyTorch model to LibTorch model
|
||||
pytorch2libtorch(
|
||||
segmentor,
|
||||
input_shape,
|
||||
show=args.show,
|
||||
output_file=args.output_file,
|
||||
verify=args.verify)
|
||||
20
finetune/tools/dist_test.sh
Normal file
20
finetune/tools/dist_test.sh
Normal file
@@ -0,0 +1,20 @@
|
||||
CONFIG=$1
|
||||
CHECKPOINT=$2
|
||||
GPUS=$3
|
||||
NNODES=${NNODES:-1}
|
||||
NODE_RANK=${NODE_RANK:-0}
|
||||
PORT=${PORT:-29500}
|
||||
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
|
||||
|
||||
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
|
||||
python -m torch.distributed.launch \
|
||||
--nnodes=$NNODES \
|
||||
--node_rank=$NODE_RANK \
|
||||
--master_addr=$MASTER_ADDR \
|
||||
--nproc_per_node=$GPUS \
|
||||
--master_port=$PORT \
|
||||
$(dirname "$0")/test.py \
|
||||
$CONFIG \
|
||||
$CHECKPOINT \
|
||||
--launcher pytorch \
|
||||
${@:4}
|
||||
17
finetune/tools/dist_train.sh
Normal file
17
finetune/tools/dist_train.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
CONFIG=$1
|
||||
GPUS=$2
|
||||
NNODES=${NNODES:-1}
|
||||
NODE_RANK=${NODE_RANK:-0}
|
||||
PORT=${PORT:-29500}
|
||||
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
|
||||
|
||||
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
|
||||
python -m torch.distributed.launch \
|
||||
--nnodes=$NNODES \
|
||||
--node_rank=$NODE_RANK \
|
||||
--master_addr=$MASTER_ADDR \
|
||||
--nproc_per_node=$GPUS \
|
||||
--master_port=$PORT \
|
||||
$(dirname "$0")/train.py \
|
||||
$CONFIG \
|
||||
--launcher pytorch ${@:3}
|
||||
73
finetune/tools/misc/browse_dataset.py
Normal file
73
finetune/tools/misc/browse_dataset.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
|
||||
from mmengine import Config, DictAction
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.utils import ProgressBar
|
||||
|
||||
from mmseg.registry import DATASETS, VISUALIZERS
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Browse a dataset')
|
||||
parser.add_argument('config', help='train config file path')
|
||||
parser.add_argument(
|
||||
'--output-dir',
|
||||
default=None,
|
||||
type=str,
|
||||
help='If there is no display interface, you can save it')
|
||||
parser.add_argument('--not-show', default=False, action='store_true')
|
||||
parser.add_argument(
|
||||
'--show-interval',
|
||||
type=float,
|
||||
default=2,
|
||||
help='the interval of show (s)')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
cfg = Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# register all modules in mmseg into the registries
|
||||
init_default_scope('mmseg')
|
||||
|
||||
dataset = DATASETS.build(cfg.train_dataloader.dataset)
|
||||
cfg.visualizer['save_dir'] = args.output_dir
|
||||
visualizer = VISUALIZERS.build(cfg.visualizer)
|
||||
visualizer.dataset_meta = dataset.METAINFO
|
||||
|
||||
progress_bar = ProgressBar(len(dataset))
|
||||
for item in dataset:
|
||||
img = item['inputs'].permute(1, 2, 0).numpy()
|
||||
data_sample = item['data_samples'].numpy()
|
||||
img_path = osp.basename(item['data_samples'].img_path)
|
||||
|
||||
img = img[..., [2, 1, 0]] # bgr to rgb
|
||||
|
||||
visualizer.add_datasample(
|
||||
osp.basename(img_path),
|
||||
img,
|
||||
data_sample,
|
||||
show=not args.not_show,
|
||||
wait_time=args.show_interval)
|
||||
|
||||
progress_bar.update()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
69
finetune/tools/misc/print_config.py
Normal file
69
finetune/tools/misc/print_config.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
from mmengine import Config, DictAction
|
||||
|
||||
from mmseg.apis import init_model
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Print the whole config')
|
||||
parser.add_argument('config', help='config file path')
|
||||
parser.add_argument(
|
||||
'--graph', action='store_true', help='print the models graph')
|
||||
parser.add_argument(
|
||||
'--options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help="--options is deprecated in favor of --cfg_options' and it will "
|
||||
'not be supported in version v0.22.0. Override some settings in the '
|
||||
'used config, the key-value pair in xxx=yyy format will be merged '
|
||||
'into config file. If the value to be overwritten is a list, it '
|
||||
'should be like key="[a,b]" or key=a,b It also allows nested '
|
||||
'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation '
|
||||
'marks are necessary and that no white space is allowed.')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.options and args.cfg_options:
|
||||
raise ValueError(
|
||||
'--options and --cfg-options cannot be both '
|
||||
'specified, --options is deprecated in favor of --cfg-options. '
|
||||
'--options will not be supported in version v0.22.0.')
|
||||
if args.options:
|
||||
warnings.warn('--options is deprecated in favor of --cfg-options, '
|
||||
'--options will not be supported in version v0.22.0.')
|
||||
args.cfg_options = args.options
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
print(f'Config:\n{cfg.pretty_text}')
|
||||
# dump config
|
||||
cfg.dump('example.py')
|
||||
# dump models graph
|
||||
if args.graph:
|
||||
model = init_model(args.config, device='cpu')
|
||||
print(f'Model graph:\n{str(model)}')
|
||||
with open('example-graph.txt', 'w') as f:
|
||||
f.writelines(str(model))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
50
finetune/tools/misc/publish_model.py
Normal file
50
finetune/tools/misc/publish_model.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import subprocess
|
||||
from hashlib import sha256
|
||||
|
||||
import torch
|
||||
|
||||
BLOCK_SIZE = 128 * 1024
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Process a checkpoint to be published')
|
||||
parser.add_argument('in_file', help='input checkpoint filename')
|
||||
parser.add_argument('out_file', help='output checkpoint filename')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def sha256sum(filename: str) -> str:
|
||||
"""Compute SHA256 message digest from a file."""
|
||||
hash_func = sha256()
|
||||
byte_array = bytearray(BLOCK_SIZE)
|
||||
memory_view = memoryview(byte_array)
|
||||
with open(filename, 'rb', buffering=0) as file:
|
||||
for block in iter(lambda: file.readinto(memory_view), 0):
|
||||
hash_func.update(memory_view[:block])
|
||||
return hash_func.hexdigest()
|
||||
|
||||
|
||||
def process_checkpoint(in_file, out_file):
|
||||
checkpoint = torch.load(in_file, map_location='cpu')
|
||||
# remove optimizer for smaller file size
|
||||
if 'optimizer' in checkpoint:
|
||||
del checkpoint['optimizer']
|
||||
# if it is necessary to remove some sensitive data in checkpoint['meta'],
|
||||
# add the code here.
|
||||
torch.save(checkpoint, out_file)
|
||||
sha = sha256sum(in_file)
|
||||
final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth'
|
||||
subprocess.Popen(['mv', out_file, final_file])
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
process_checkpoint(args.in_file, args.out_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
56
finetune/tools/model_converters/beit2mmseg.py
Normal file
56
finetune/tools/model_converters/beit2mmseg.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_beit(ckpt):
|
||||
new_ckpt = OrderedDict()
|
||||
|
||||
for k, v in ckpt.items():
|
||||
if k.startswith('patch_embed'):
|
||||
new_key = k.replace('patch_embed.proj', 'patch_embed.projection')
|
||||
new_ckpt[new_key] = v
|
||||
if k.startswith('blocks'):
|
||||
new_key = k.replace('blocks', 'layers')
|
||||
if 'norm' in new_key:
|
||||
new_key = new_key.replace('norm', 'ln')
|
||||
elif 'mlp.fc1' in new_key:
|
||||
new_key = new_key.replace('mlp.fc1', 'ffn.layers.0.0')
|
||||
elif 'mlp.fc2' in new_key:
|
||||
new_key = new_key.replace('mlp.fc2', 'ffn.layers.1')
|
||||
new_ckpt[new_key] = v
|
||||
else:
|
||||
new_key = k
|
||||
new_ckpt[new_key] = v
|
||||
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in official pretrained beit models to'
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
weight = convert_beit(state_dict)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
163
finetune/tools/model_converters/clip2mmseg.py
Normal file
163
finetune/tools/model_converters/clip2mmseg.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_vitlayer(paras):
|
||||
new_para_name = ''
|
||||
if paras[0] == 'ln_1':
|
||||
new_para_name = '.'.join(['ln1'] + paras[1:])
|
||||
elif paras[0] == 'attn':
|
||||
new_para_name = '.'.join(['attn.attn'] + paras[1:])
|
||||
elif paras[0] == 'ln_2':
|
||||
new_para_name = '.'.join(['ln2'] + paras[1:])
|
||||
elif paras[0] == 'mlp':
|
||||
if paras[1] == 'c_fc':
|
||||
new_para_name = '.'.join(['ffn.layers.0.0'] + paras[-1:])
|
||||
else:
|
||||
new_para_name = '.'.join(['ffn.layers.1'] + paras[-1:])
|
||||
else:
|
||||
print(f'Wrong for {paras}')
|
||||
return new_para_name
|
||||
|
||||
|
||||
def convert_translayer(paras):
|
||||
new_para_name = ''
|
||||
if paras[0] == 'attn':
|
||||
new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
|
||||
elif paras[0] == 'ln_1':
|
||||
new_para_name = '.'.join(['norms.0'] + paras[1:])
|
||||
elif paras[0] == 'ln_2':
|
||||
new_para_name = '.'.join(['norms.1'] + paras[1:])
|
||||
elif paras[0] == 'mlp':
|
||||
if paras[1] == 'c_fc':
|
||||
new_para_name = '.'.join(['ffns.0.layers.0.0'] + paras[2:])
|
||||
elif paras[1] == 'c_proj':
|
||||
new_para_name = '.'.join(['ffns.0.layers.1'] + paras[2:])
|
||||
else:
|
||||
print(f'Wrong for {paras}')
|
||||
else:
|
||||
print(f'Wrong for {paras}')
|
||||
return new_para_name
|
||||
|
||||
|
||||
def convert_key_name(ckpt, visual_split):
|
||||
new_ckpt = OrderedDict()
|
||||
for k, v in ckpt.items():
|
||||
key_list = k.split('.')
|
||||
if key_list[0] == 'visual':
|
||||
new_transform_name = 'image_encoder'
|
||||
if key_list[1] == 'class_embedding':
|
||||
new_name = '.'.join([new_transform_name, 'cls_token'])
|
||||
elif key_list[1] == 'positional_embedding':
|
||||
new_name = '.'.join([new_transform_name, 'pos_embed'])
|
||||
elif key_list[1] == 'conv1':
|
||||
new_name = '.'.join([
|
||||
new_transform_name, 'patch_embed.projection', key_list[2]
|
||||
])
|
||||
elif key_list[1] == 'ln_pre':
|
||||
new_name = '.'.join(
|
||||
[new_transform_name, key_list[1], key_list[2]])
|
||||
elif key_list[1] == 'transformer':
|
||||
new_layer_name = 'layers'
|
||||
layer_index = key_list[3]
|
||||
paras = key_list[4:]
|
||||
if int(layer_index) < visual_split:
|
||||
new_para_name = convert_vitlayer(paras)
|
||||
new_name = '.'.join([
|
||||
new_transform_name, new_layer_name, layer_index,
|
||||
new_para_name
|
||||
])
|
||||
else:
|
||||
new_para_name = convert_translayer(paras)
|
||||
new_transform_name = 'decode_head.rec_with_attnbias'
|
||||
new_layer_name = 'layers'
|
||||
layer_index = str(int(layer_index) - visual_split)
|
||||
new_name = '.'.join([
|
||||
new_transform_name, new_layer_name, layer_index,
|
||||
new_para_name
|
||||
])
|
||||
elif key_list[1] == 'proj':
|
||||
new_name = 'decode_head.rec_with_attnbias.proj.weight'
|
||||
elif key_list[1] == 'ln_post':
|
||||
new_name = k.replace('visual', 'decode_head.rec_with_attnbias')
|
||||
else:
|
||||
print(f'pop parameter: {k}')
|
||||
continue
|
||||
else:
|
||||
text_encoder_name = 'text_encoder'
|
||||
if key_list[0] == 'transformer':
|
||||
layer_name = 'transformer'
|
||||
layer_index = key_list[2]
|
||||
paras = key_list[3:]
|
||||
new_para_name = convert_translayer(paras)
|
||||
new_name = '.'.join([
|
||||
text_encoder_name, layer_name, layer_index, new_para_name
|
||||
])
|
||||
elif key_list[0] in [
|
||||
'positional_embedding', 'text_projection', 'bg_embed',
|
||||
'attn_mask', 'logit_scale', 'token_embedding', 'ln_final'
|
||||
]:
|
||||
new_name = 'text_encoder.' + k
|
||||
else:
|
||||
print(f'pop parameter: {k}')
|
||||
continue
|
||||
new_ckpt[new_name] = v
|
||||
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def convert_tensor(ckpt):
|
||||
cls_token = ckpt['image_encoder.cls_token']
|
||||
new_cls_token = cls_token.unsqueeze(0).unsqueeze(0)
|
||||
ckpt['image_encoder.cls_token'] = new_cls_token
|
||||
pos_embed = ckpt['image_encoder.pos_embed']
|
||||
new_pos_embed = pos_embed.unsqueeze(0)
|
||||
ckpt['image_encoder.pos_embed'] = new_pos_embed
|
||||
proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight']
|
||||
new_proj_weight = proj_weight.transpose(1, 0)
|
||||
ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight
|
||||
return ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in timm pretrained vit models to '
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
if any([s in args.src for s in ['B-16', 'b16', 'base_patch16']]):
|
||||
visual_split = 9
|
||||
elif any([s in args.src for s in ['L-14', 'l14', 'large_patch14']]):
|
||||
visual_split = 18
|
||||
else:
|
||||
print('Make sure the clip model is ViT-B/16 or ViT-L/14!')
|
||||
visual_split = -1
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
if isinstance(checkpoint, torch.jit.RecursiveScriptModule):
|
||||
state_dict = checkpoint.state_dict()
|
||||
else:
|
||||
if 'state_dict' in checkpoint:
|
||||
# timm checkpoint
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
# deit checkpoint
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
weight = convert_key_name(state_dict, visual_split)
|
||||
weight = convert_tensor(weight)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
82
finetune/tools/model_converters/mit2mmseg.py
Normal file
82
finetune/tools/model_converters/mit2mmseg.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_mit(ckpt):
|
||||
new_ckpt = OrderedDict()
|
||||
# Process the concat between q linear weights and kv linear weights
|
||||
for k, v in ckpt.items():
|
||||
if k.startswith('head'):
|
||||
continue
|
||||
# patch embedding conversion
|
||||
elif k.startswith('patch_embed'):
|
||||
stage_i = int(k.split('.')[0].replace('patch_embed', ''))
|
||||
new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0')
|
||||
new_v = v
|
||||
if 'proj.' in new_k:
|
||||
new_k = new_k.replace('proj.', 'projection.')
|
||||
# transformer encoder layer conversion
|
||||
elif k.startswith('block'):
|
||||
stage_i = int(k.split('.')[0].replace('block', ''))
|
||||
new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1')
|
||||
new_v = v
|
||||
if 'attn.q.' in new_k:
|
||||
sub_item_k = k.replace('q.', 'kv.')
|
||||
new_k = new_k.replace('q.', 'attn.in_proj_')
|
||||
new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)
|
||||
elif 'attn.kv.' in new_k:
|
||||
continue
|
||||
elif 'attn.proj.' in new_k:
|
||||
new_k = new_k.replace('proj.', 'attn.out_proj.')
|
||||
elif 'attn.sr.' in new_k:
|
||||
new_k = new_k.replace('sr.', 'sr.')
|
||||
elif 'mlp.' in new_k:
|
||||
string = f'{new_k}-'
|
||||
new_k = new_k.replace('mlp.', 'ffn.layers.')
|
||||
if 'fc1.weight' in new_k or 'fc2.weight' in new_k:
|
||||
new_v = v.reshape((*v.shape, 1, 1))
|
||||
new_k = new_k.replace('fc1.', '0.')
|
||||
new_k = new_k.replace('dwconv.dwconv.', '1.')
|
||||
new_k = new_k.replace('fc2.', '4.')
|
||||
string += f'{new_k} {v.shape}-{new_v.shape}'
|
||||
# norm layer conversion
|
||||
elif k.startswith('norm'):
|
||||
stage_i = int(k.split('.')[0].replace('norm', ''))
|
||||
new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2')
|
||||
new_v = v
|
||||
else:
|
||||
new_k = k
|
||||
new_v = v
|
||||
new_ckpt[new_k] = new_v
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in official pretrained segformer to '
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
weight = convert_mit(state_dict)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
220
finetune/tools/model_converters/san2mmseg.py
Normal file
220
finetune/tools/model_converters/san2mmseg.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_key_name(ckpt):
|
||||
new_ckpt = OrderedDict()
|
||||
|
||||
for k, v in ckpt.items():
|
||||
key_list = k.split('.')
|
||||
if key_list[0] == 'clip_visual_extractor':
|
||||
new_transform_name = 'image_encoder'
|
||||
if key_list[1] == 'class_embedding':
|
||||
new_name = '.'.join([new_transform_name, 'cls_token'])
|
||||
elif key_list[1] == 'positional_embedding':
|
||||
new_name = '.'.join([new_transform_name, 'pos_embed'])
|
||||
elif key_list[1] == 'conv1':
|
||||
new_name = '.'.join([
|
||||
new_transform_name, 'patch_embed.projection', key_list[2]
|
||||
])
|
||||
elif key_list[1] == 'ln_pre':
|
||||
new_name = '.'.join(
|
||||
[new_transform_name, key_list[1], key_list[2]])
|
||||
elif key_list[1] == 'resblocks':
|
||||
new_layer_name = 'layers'
|
||||
layer_index = key_list[2]
|
||||
paras = key_list[3:]
|
||||
if paras[0] == 'ln_1':
|
||||
new_para_name = '.'.join(['ln1'] + key_list[4:])
|
||||
elif paras[0] == 'attn':
|
||||
new_para_name = '.'.join(['attn.attn'] + key_list[4:])
|
||||
elif paras[0] == 'ln_2':
|
||||
new_para_name = '.'.join(['ln2'] + key_list[4:])
|
||||
elif paras[0] == 'mlp':
|
||||
if paras[1] == 'c_fc':
|
||||
new_para_name = '.'.join(['ffn.layers.0.0'] +
|
||||
key_list[-1:])
|
||||
else:
|
||||
new_para_name = '.'.join(['ffn.layers.1'] +
|
||||
key_list[-1:])
|
||||
new_name = '.'.join([
|
||||
new_transform_name, new_layer_name, layer_index,
|
||||
new_para_name
|
||||
])
|
||||
elif key_list[0] == 'side_adapter_network':
|
||||
decode_head_name = 'decode_head'
|
||||
module_name = 'side_adapter_network'
|
||||
if key_list[1] == 'vit_model':
|
||||
if key_list[2] == 'blocks':
|
||||
layer_name = 'encode_layers'
|
||||
layer_index = key_list[3]
|
||||
paras = key_list[4:]
|
||||
if paras[0] == 'norm1':
|
||||
new_para_name = '.'.join(['ln1'] + key_list[5:])
|
||||
elif paras[0] == 'attn':
|
||||
new_para_name = '.'.join(key_list[4:])
|
||||
new_para_name = new_para_name.replace(
|
||||
'attn.qkv.', 'attn.attn.in_proj_')
|
||||
new_para_name = new_para_name.replace(
|
||||
'attn.proj', 'attn.attn.out_proj')
|
||||
elif paras[0] == 'norm2':
|
||||
new_para_name = '.'.join(['ln2'] + key_list[5:])
|
||||
elif paras[0] == 'mlp':
|
||||
new_para_name = '.'.join(['ffn'] + key_list[5:])
|
||||
new_para_name = new_para_name.replace(
|
||||
'fc1', 'layers.0.0')
|
||||
new_para_name = new_para_name.replace(
|
||||
'fc2', 'layers.1')
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
new_name = '.'.join([
|
||||
decode_head_name, module_name, layer_name, layer_index,
|
||||
new_para_name
|
||||
])
|
||||
elif key_list[2] == 'pos_embed':
|
||||
new_name = '.'.join(
|
||||
[decode_head_name, module_name, 'pos_embed'])
|
||||
elif key_list[2] == 'patch_embed':
|
||||
new_name = '.'.join([
|
||||
decode_head_name, module_name, 'patch_embed',
|
||||
'projection', key_list[4]
|
||||
])
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
elif key_list[1] == 'query_embed' or key_list[
|
||||
1] == 'query_pos_embed':
|
||||
new_name = '.'.join(
|
||||
[decode_head_name, module_name, key_list[1]])
|
||||
elif key_list[1] == 'fusion_layers':
|
||||
layer_name = 'conv_clips'
|
||||
layer_index = key_list[2][-1]
|
||||
paras = '.'.join(key_list[3:])
|
||||
new_para_name = paras.replace('input_proj.0', '0')
|
||||
new_para_name = new_para_name.replace('input_proj.1', '1.conv')
|
||||
new_name = '.'.join([
|
||||
decode_head_name, module_name, layer_name, layer_index,
|
||||
new_para_name
|
||||
])
|
||||
elif key_list[1] == 'mask_decoder':
|
||||
new_name = 'decode_head.' + k
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
elif key_list[0] == 'clip_rec_head':
|
||||
module_name = 'rec_with_attnbias'
|
||||
if key_list[1] == 'proj':
|
||||
new_name = '.'.join(
|
||||
[decode_head_name, module_name, 'proj.weight'])
|
||||
elif key_list[1] == 'ln_post':
|
||||
new_name = '.'.join(
|
||||
[decode_head_name, module_name, 'ln_post', key_list[2]])
|
||||
elif key_list[1] == 'resblocks':
|
||||
new_layer_name = 'layers'
|
||||
layer_index = key_list[2]
|
||||
paras = key_list[3:]
|
||||
if paras[0] == 'ln_1':
|
||||
new_para_name = '.'.join(['norms.0'] + paras[1:])
|
||||
elif paras[0] == 'attn':
|
||||
new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
|
||||
elif paras[0] == 'ln_2':
|
||||
new_para_name = '.'.join(['norms.1'] + paras[1:])
|
||||
elif paras[0] == 'mlp':
|
||||
if paras[1] == 'c_fc':
|
||||
new_para_name = '.'.join(['ffns.0.layers.0.0'] +
|
||||
paras[2:])
|
||||
elif paras[1] == 'c_proj':
|
||||
new_para_name = '.'.join(['ffns.0.layers.1'] +
|
||||
paras[2:])
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
new_name = '.'.join([
|
||||
decode_head_name, module_name, new_layer_name, layer_index,
|
||||
new_para_name
|
||||
])
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
elif key_list[0] == 'ov_classifier':
|
||||
text_encoder_name = 'text_encoder'
|
||||
if key_list[1] == 'transformer':
|
||||
layer_name = 'transformer'
|
||||
layer_index = key_list[3]
|
||||
paras = key_list[4:]
|
||||
if paras[0] == 'attn':
|
||||
new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
|
||||
elif paras[0] == 'ln_1':
|
||||
new_para_name = '.'.join(['norms.0'] + paras[1:])
|
||||
elif paras[0] == 'ln_2':
|
||||
new_para_name = '.'.join(['norms.1'] + paras[1:])
|
||||
elif paras[0] == 'mlp':
|
||||
if paras[1] == 'c_fc':
|
||||
new_para_name = '.'.join(['ffns.0.layers.0.0'] +
|
||||
paras[2:])
|
||||
elif paras[1] == 'c_proj':
|
||||
new_para_name = '.'.join(['ffns.0.layers.1'] +
|
||||
paras[2:])
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
new_name = '.'.join([
|
||||
text_encoder_name, layer_name, layer_index, new_para_name
|
||||
])
|
||||
elif key_list[1] in [
|
||||
'positional_embedding', 'text_projection', 'bg_embed',
|
||||
'attn_mask', 'logit_scale', 'token_embedding', 'ln_final'
|
||||
]:
|
||||
new_name = k.replace('ov_classifier', 'text_encoder')
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
elif key_list[0] == 'criterion':
|
||||
new_name = k
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
new_ckpt[new_name] = v
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def convert_tensor(ckpt):
|
||||
cls_token = ckpt['image_encoder.cls_token']
|
||||
new_cls_token = cls_token.unsqueeze(0).unsqueeze(0)
|
||||
ckpt['image_encoder.cls_token'] = new_cls_token
|
||||
pos_embed = ckpt['image_encoder.pos_embed']
|
||||
new_pos_embed = pos_embed.unsqueeze(0)
|
||||
ckpt['image_encoder.pos_embed'] = new_pos_embed
|
||||
proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight']
|
||||
new_proj_weight = proj_weight.transpose(1, 0)
|
||||
ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight
|
||||
return ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in timm pretrained vit models to '
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
# timm checkpoint
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
# deit checkpoint
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
weight = convert_key_name(state_dict)
|
||||
weight = convert_tensor(weight)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
71
finetune/tools/model_converters/stdc2mmseg.py
Normal file
71
finetune/tools/model_converters/stdc2mmseg.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_stdc(ckpt, stdc_type):
|
||||
new_state_dict = {}
|
||||
if stdc_type == 'STDC1':
|
||||
stage_lst = ['0', '1', '2.0', '2.1', '3.0', '3.1', '4.0', '4.1']
|
||||
else:
|
||||
stage_lst = [
|
||||
'0', '1', '2.0', '2.1', '2.2', '2.3', '3.0', '3.1', '3.2', '3.3',
|
||||
'3.4', '4.0', '4.1', '4.2'
|
||||
]
|
||||
for k, v in ckpt.items():
|
||||
ori_k = k
|
||||
flag = False
|
||||
if 'cp.' in k:
|
||||
k = k.replace('cp.', '')
|
||||
if 'features.' in k:
|
||||
num_layer = int(k.split('.')[1])
|
||||
feature_key_lst = 'features.' + str(num_layer) + '.'
|
||||
stages_key_lst = 'stages.' + stage_lst[num_layer] + '.'
|
||||
k = k.replace(feature_key_lst, stages_key_lst)
|
||||
flag = True
|
||||
if 'conv_list' in k:
|
||||
k = k.replace('conv_list', 'layers')
|
||||
flag = True
|
||||
if 'avd_layer.' in k:
|
||||
if 'avd_layer.0' in k:
|
||||
k = k.replace('avd_layer.0', 'downsample.conv')
|
||||
elif 'avd_layer.1' in k:
|
||||
k = k.replace('avd_layer.1', 'downsample.bn')
|
||||
flag = True
|
||||
if flag:
|
||||
new_state_dict[k] = ckpt[ori_k]
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in official pretrained STDC1/2 to '
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
parser.add_argument('type', help='model type: STDC1 or STDC2')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
assert args.type in ['STDC1',
|
||||
'STDC2'], 'STD type should be STDC1 or STDC2!'
|
||||
weight = convert_stdc(state_dict, args.type)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
87
finetune/tools/model_converters/swin2mmseg.py
Normal file
87
finetune/tools/model_converters/swin2mmseg.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_swin(ckpt):
|
||||
new_ckpt = OrderedDict()
|
||||
|
||||
def correct_unfold_reduction_order(x):
|
||||
out_channel, in_channel = x.shape
|
||||
x = x.reshape(out_channel, 4, in_channel // 4)
|
||||
x = x[:, [0, 2, 1, 3], :].transpose(1,
|
||||
2).reshape(out_channel, in_channel)
|
||||
return x
|
||||
|
||||
def correct_unfold_norm_order(x):
|
||||
in_channel = x.shape[0]
|
||||
x = x.reshape(4, in_channel // 4)
|
||||
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
|
||||
return x
|
||||
|
||||
for k, v in ckpt.items():
|
||||
if k.startswith('head'):
|
||||
continue
|
||||
elif k.startswith('layers'):
|
||||
new_v = v
|
||||
if 'attn.' in k:
|
||||
new_k = k.replace('attn.', 'attn.w_msa.')
|
||||
elif 'mlp.' in k:
|
||||
if 'mlp.fc1.' in k:
|
||||
new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
|
||||
elif 'mlp.fc2.' in k:
|
||||
new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
|
||||
else:
|
||||
new_k = k.replace('mlp.', 'ffn.')
|
||||
elif 'downsample' in k:
|
||||
new_k = k
|
||||
if 'reduction.' in k:
|
||||
new_v = correct_unfold_reduction_order(v)
|
||||
elif 'norm.' in k:
|
||||
new_v = correct_unfold_norm_order(v)
|
||||
else:
|
||||
new_k = k
|
||||
new_k = new_k.replace('layers', 'stages', 1)
|
||||
elif k.startswith('patch_embed'):
|
||||
new_v = v
|
||||
if 'proj' in k:
|
||||
new_k = k.replace('proj', 'projection')
|
||||
else:
|
||||
new_k = k
|
||||
else:
|
||||
new_v = v
|
||||
new_k = k
|
||||
|
||||
new_ckpt[new_k] = new_v
|
||||
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in official pretrained swin models to'
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
weight = convert_swin(state_dict)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
87
finetune/tools/model_converters/twins2mmseg.py
Normal file
87
finetune/tools/model_converters/twins2mmseg.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_twins(args, ckpt):
|
||||
|
||||
new_ckpt = OrderedDict()
|
||||
|
||||
for k, v in list(ckpt.items()):
|
||||
new_v = v
|
||||
if k.startswith('head'):
|
||||
continue
|
||||
elif k.startswith('patch_embeds'):
|
||||
if 'proj.' in k:
|
||||
new_k = k.replace('proj.', 'projection.')
|
||||
else:
|
||||
new_k = k
|
||||
elif k.startswith('blocks'):
|
||||
# Union
|
||||
if 'attn.q.' in k:
|
||||
new_k = k.replace('q.', 'attn.in_proj_')
|
||||
new_v = torch.cat([v, ckpt[k.replace('attn.q.', 'attn.kv.')]],
|
||||
dim=0)
|
||||
elif 'mlp.fc1' in k:
|
||||
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
|
||||
elif 'mlp.fc2' in k:
|
||||
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
|
||||
# Only pcpvt
|
||||
elif args.model == 'pcpvt':
|
||||
if 'attn.proj.' in k:
|
||||
new_k = k.replace('proj.', 'attn.out_proj.')
|
||||
else:
|
||||
new_k = k
|
||||
|
||||
# Only svt
|
||||
else:
|
||||
if 'attn.proj.' in k:
|
||||
k_lst = k.split('.')
|
||||
if int(k_lst[2]) % 2 == 1:
|
||||
new_k = k.replace('proj.', 'attn.out_proj.')
|
||||
else:
|
||||
new_k = k
|
||||
else:
|
||||
new_k = k
|
||||
new_k = new_k.replace('blocks.', 'layers.')
|
||||
elif k.startswith('pos_block'):
|
||||
new_k = k.replace('pos_block', 'position_encodings')
|
||||
if 'proj.0.' in new_k:
|
||||
new_k = new_k.replace('proj.0.', 'proj.')
|
||||
else:
|
||||
new_k = k
|
||||
if 'attn.kv.' not in k:
|
||||
new_ckpt[new_k] = new_v
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in timm pretrained vit models to '
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
parser.add_argument('model', help='model: pcpvt or svt')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
# timm checkpoint
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
weight = convert_twins(args, state_dict)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
70
finetune/tools/model_converters/vit2mmseg.py
Normal file
70
finetune/tools/model_converters/vit2mmseg.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_vit(ckpt):
|
||||
|
||||
new_ckpt = OrderedDict()
|
||||
|
||||
for k, v in ckpt.items():
|
||||
if k.startswith('head'):
|
||||
continue
|
||||
if k.startswith('norm'):
|
||||
new_k = k.replace('norm.', 'ln1.')
|
||||
elif k.startswith('patch_embed'):
|
||||
if 'proj' in k:
|
||||
new_k = k.replace('proj', 'projection')
|
||||
else:
|
||||
new_k = k
|
||||
elif k.startswith('blocks'):
|
||||
if 'norm' in k:
|
||||
new_k = k.replace('norm', 'ln')
|
||||
elif 'mlp.fc1' in k:
|
||||
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
|
||||
elif 'mlp.fc2' in k:
|
||||
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
|
||||
elif 'attn.qkv' in k:
|
||||
new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_')
|
||||
elif 'attn.proj' in k:
|
||||
new_k = k.replace('attn.proj', 'attn.attn.out_proj')
|
||||
else:
|
||||
new_k = k
|
||||
new_k = new_k.replace('blocks.', 'layers.')
|
||||
else:
|
||||
new_k = k
|
||||
new_ckpt[new_k] = v
|
||||
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in timm pretrained vit models to '
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
# timm checkpoint
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
# deit checkpoint
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
weight = convert_vit(state_dict)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
123
finetune/tools/model_converters/vitjax2mmseg.py
Normal file
123
finetune/tools/model_converters/vitjax2mmseg.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def vit_jax_to_torch(jax_weights, num_layer=12):
|
||||
torch_weights = dict()
|
||||
|
||||
# patch embedding
|
||||
conv_filters = jax_weights['embedding/kernel']
|
||||
conv_filters = conv_filters.permute(3, 2, 0, 1)
|
||||
torch_weights['patch_embed.projection.weight'] = conv_filters
|
||||
torch_weights['patch_embed.projection.bias'] = jax_weights[
|
||||
'embedding/bias']
|
||||
|
||||
# pos embedding
|
||||
torch_weights['pos_embed'] = jax_weights[
|
||||
'Transformer/posembed_input/pos_embedding']
|
||||
|
||||
# cls token
|
||||
torch_weights['cls_token'] = jax_weights['cls']
|
||||
|
||||
# head
|
||||
torch_weights['ln1.weight'] = jax_weights['Transformer/encoder_norm/scale']
|
||||
torch_weights['ln1.bias'] = jax_weights['Transformer/encoder_norm/bias']
|
||||
|
||||
# transformer blocks
|
||||
for i in range(num_layer):
|
||||
jax_block = f'Transformer/encoderblock_{i}'
|
||||
torch_block = f'layers.{i}'
|
||||
|
||||
# attention norm
|
||||
torch_weights[f'{torch_block}.ln1.weight'] = jax_weights[
|
||||
f'{jax_block}/LayerNorm_0/scale']
|
||||
torch_weights[f'{torch_block}.ln1.bias'] = jax_weights[
|
||||
f'{jax_block}/LayerNorm_0/bias']
|
||||
|
||||
# attention
|
||||
query_weight = jax_weights[
|
||||
f'{jax_block}/MultiHeadDotProductAttention_1/query/kernel']
|
||||
query_bias = jax_weights[
|
||||
f'{jax_block}/MultiHeadDotProductAttention_1/query/bias']
|
||||
key_weight = jax_weights[
|
||||
f'{jax_block}/MultiHeadDotProductAttention_1/key/kernel']
|
||||
key_bias = jax_weights[
|
||||
f'{jax_block}/MultiHeadDotProductAttention_1/key/bias']
|
||||
value_weight = jax_weights[
|
||||
f'{jax_block}/MultiHeadDotProductAttention_1/value/kernel']
|
||||
value_bias = jax_weights[
|
||||
f'{jax_block}/MultiHeadDotProductAttention_1/value/bias']
|
||||
|
||||
qkv_weight = torch.from_numpy(
|
||||
np.stack((query_weight, key_weight, value_weight), 1))
|
||||
qkv_weight = torch.flatten(qkv_weight, start_dim=1)
|
||||
qkv_bias = torch.from_numpy(
|
||||
np.stack((query_bias, key_bias, value_bias), 0))
|
||||
qkv_bias = torch.flatten(qkv_bias, start_dim=0)
|
||||
|
||||
torch_weights[f'{torch_block}.attn.attn.in_proj_weight'] = qkv_weight
|
||||
torch_weights[f'{torch_block}.attn.attn.in_proj_bias'] = qkv_bias
|
||||
to_out_weight = jax_weights[
|
||||
f'{jax_block}/MultiHeadDotProductAttention_1/out/kernel']
|
||||
to_out_weight = torch.flatten(to_out_weight, start_dim=0, end_dim=1)
|
||||
torch_weights[
|
||||
f'{torch_block}.attn.attn.out_proj.weight'] = to_out_weight
|
||||
torch_weights[f'{torch_block}.attn.attn.out_proj.bias'] = jax_weights[
|
||||
f'{jax_block}/MultiHeadDotProductAttention_1/out/bias']
|
||||
|
||||
# mlp norm
|
||||
torch_weights[f'{torch_block}.ln2.weight'] = jax_weights[
|
||||
f'{jax_block}/LayerNorm_2/scale']
|
||||
torch_weights[f'{torch_block}.ln2.bias'] = jax_weights[
|
||||
f'{jax_block}/LayerNorm_2/bias']
|
||||
|
||||
# mlp
|
||||
torch_weights[f'{torch_block}.ffn.layers.0.0.weight'] = jax_weights[
|
||||
f'{jax_block}/MlpBlock_3/Dense_0/kernel']
|
||||
torch_weights[f'{torch_block}.ffn.layers.0.0.bias'] = jax_weights[
|
||||
f'{jax_block}/MlpBlock_3/Dense_0/bias']
|
||||
torch_weights[f'{torch_block}.ffn.layers.1.weight'] = jax_weights[
|
||||
f'{jax_block}/MlpBlock_3/Dense_1/kernel']
|
||||
torch_weights[f'{torch_block}.ffn.layers.1.bias'] = jax_weights[
|
||||
f'{jax_block}/MlpBlock_3/Dense_1/bias']
|
||||
|
||||
# transpose weights
|
||||
for k, v in torch_weights.items():
|
||||
if 'weight' in k and 'patch_embed' not in k and 'ln' not in k:
|
||||
v = v.permute(1, 0)
|
||||
torch_weights[k] = v
|
||||
|
||||
return torch_weights
|
||||
|
||||
|
||||
def main():
|
||||
# stole refactoring code from Robin Strudel, thanks
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys from jax official pretrained vit models to '
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
jax_weights = np.load(args.src)
|
||||
jax_weights_tensor = {}
|
||||
for key in jax_weights.files:
|
||||
value = torch.from_numpy(jax_weights[key])
|
||||
jax_weights_tensor[key] = value
|
||||
if 'L_16-i21k' in args.src:
|
||||
num_layer = 24
|
||||
else:
|
||||
num_layer = 12
|
||||
torch_weights = vit_jax_to_torch(jax_weights_tensor, num_layer)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(torch_weights, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
24
finetune/tools/slurm_test.sh
Normal file
24
finetune/tools/slurm_test.sh
Normal file
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -x
|
||||
|
||||
PARTITION=$1
|
||||
JOB_NAME=$2
|
||||
CONFIG=$3
|
||||
CHECKPOINT=$4
|
||||
GPUS=${GPUS:-4}
|
||||
GPUS_PER_NODE=${GPUS_PER_NODE:-4}
|
||||
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
|
||||
PY_ARGS=${@:5}
|
||||
SRUN_ARGS=${SRUN_ARGS:-""}
|
||||
|
||||
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
|
||||
srun -p ${PARTITION} \
|
||||
--job-name=${JOB_NAME} \
|
||||
--gres=gpu:${GPUS_PER_NODE} \
|
||||
--ntasks=${GPUS} \
|
||||
--ntasks-per-node=${GPUS_PER_NODE} \
|
||||
--cpus-per-task=${CPUS_PER_TASK} \
|
||||
--kill-on-bad-exit=1 \
|
||||
${SRUN_ARGS} \
|
||||
python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS}
|
||||
23
finetune/tools/slurm_train.sh
Normal file
23
finetune/tools/slurm_train.sh
Normal file
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -x
|
||||
|
||||
PARTITION=$1
|
||||
JOB_NAME=$2
|
||||
CONFIG=$3
|
||||
GPUS=${GPUS:-4}
|
||||
GPUS_PER_NODE=${GPUS_PER_NODE:-4}
|
||||
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
|
||||
SRUN_ARGS=${SRUN_ARGS:-""}
|
||||
PY_ARGS=${@:4}
|
||||
|
||||
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
|
||||
srun -p ${PARTITION} \
|
||||
--job-name=${JOB_NAME} \
|
||||
--gres=gpu:${GPUS_PER_NODE} \
|
||||
--ntasks=${GPUS} \
|
||||
--ntasks-per-node=${GPUS_PER_NODE} \
|
||||
--cpus-per-task=${CPUS_PER_TASK} \
|
||||
--kill-on-bad-exit=1 \
|
||||
${SRUN_ARGS} \
|
||||
python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS}
|
||||
123
finetune/tools/test.py
Normal file
123
finetune/tools/test.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.runner import Runner
|
||||
|
||||
|
||||
# TODO: support fuse_conv_bn, visualization, and format_only
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='MMSeg test (and eval) a model')
|
||||
parser.add_argument('config', help='train config file path')
|
||||
parser.add_argument('checkpoint', help='checkpoint file')
|
||||
parser.add_argument(
|
||||
'--work-dir',
|
||||
help=('if specified, the evaluation metric results will be dumped'
|
||||
'into the directory as json'))
|
||||
parser.add_argument(
|
||||
'--out',
|
||||
type=str,
|
||||
help='The directory to save output prediction for offline evaluation')
|
||||
parser.add_argument(
|
||||
'--show', action='store_true', help='show prediction results')
|
||||
parser.add_argument(
|
||||
'--show-dir',
|
||||
help='directory where painted images will be saved. '
|
||||
'If specified, it will be automatically saved '
|
||||
'to the work_dir/timestamp/show_dir')
|
||||
parser.add_argument(
|
||||
'--wait-time', type=float, default=2, help='the interval of show (s)')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
parser.add_argument(
|
||||
'--launcher',
|
||||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument(
|
||||
'--tta', action='store_true', help='Test time augmentation')
|
||||
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
|
||||
# will pass the `--local-rank` parameter to `tools/train.py` instead
|
||||
# of `--local_rank`.
|
||||
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def trigger_visualization_hook(cfg, args):
|
||||
default_hooks = cfg.default_hooks
|
||||
if 'visualization' in default_hooks:
|
||||
visualization_hook = default_hooks['visualization']
|
||||
# Turn on visualization
|
||||
visualization_hook['draw'] = True
|
||||
if args.show:
|
||||
visualization_hook['show'] = True
|
||||
visualization_hook['wait_time'] = args.wait_time
|
||||
if args.show_dir:
|
||||
visualizer = cfg.visualizer
|
||||
visualizer['save_dir'] = args.show_dir
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'VisualizationHook must be included in default_hooks.'
|
||||
'refer to usage '
|
||||
'"visualization=dict(type=\'VisualizationHook\')"')
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg.launcher = args.launcher
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# work_dir is determined in this priority: CLI > segment in file > filename
|
||||
if args.work_dir is not None:
|
||||
# update configs according to CLI args if args.work_dir is not None
|
||||
cfg.work_dir = args.work_dir
|
||||
elif cfg.get('work_dir', None) is None:
|
||||
# use config filename as default work_dir if cfg.work_dir is None
|
||||
cfg.work_dir = osp.join('./work_dirs',
|
||||
osp.splitext(osp.basename(args.config))[0])
|
||||
|
||||
cfg.load_from = args.checkpoint
|
||||
|
||||
if args.show or args.show_dir:
|
||||
cfg = trigger_visualization_hook(cfg, args)
|
||||
|
||||
if args.tta:
|
||||
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
|
||||
cfg.tta_model.module = cfg.model
|
||||
cfg.model = cfg.tta_model
|
||||
|
||||
# add output_dir in metric
|
||||
if args.out is not None:
|
||||
cfg.test_evaluator['output_dir'] = args.out
|
||||
cfg.test_evaluator['keep_results'] = True
|
||||
|
||||
# build the runner from config
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
# start testing
|
||||
runner.test()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
112
finetune/tools/torchserve/mmseg2torchserve.py
Normal file
112
finetune/tools/torchserve/mmseg2torchserve.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from mmengine import Config
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
try:
|
||||
from model_archiver.model_packaging import package_model
|
||||
from model_archiver.model_packaging_utils import ModelExportUtils
|
||||
except ImportError:
|
||||
package_model = None
|
||||
|
||||
|
||||
def mmseg2torchserve(
|
||||
config_file: str,
|
||||
checkpoint_file: str,
|
||||
output_folder: str,
|
||||
model_name: str,
|
||||
model_version: str = '1.0',
|
||||
force: bool = False,
|
||||
):
|
||||
"""Converts mmsegmentation model (config + checkpoint) to TorchServe
|
||||
`.mar`.
|
||||
|
||||
Args:
|
||||
config_file:
|
||||
In MMSegmentation config format.
|
||||
The contents vary for each task repository.
|
||||
checkpoint_file:
|
||||
In MMSegmentation checkpoint format.
|
||||
The contents vary for each task repository.
|
||||
output_folder:
|
||||
Folder where `{model_name}.mar` will be created.
|
||||
The file created will be in TorchServe archive format.
|
||||
model_name:
|
||||
If not None, used for naming the `{model_name}.mar` file
|
||||
that will be created under `output_folder`.
|
||||
If None, `{Path(checkpoint_file).stem}` will be used.
|
||||
model_version:
|
||||
Model's version.
|
||||
force:
|
||||
If True, if there is an existing `{model_name}.mar`
|
||||
file under `output_folder` it will be overwritten.
|
||||
"""
|
||||
mkdir_or_exist(output_folder)
|
||||
|
||||
config = Config.fromfile(config_file)
|
||||
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
config.dump(f'{tmpdir}/config.py')
|
||||
|
||||
args = Namespace(
|
||||
**{
|
||||
'model_file': f'{tmpdir}/config.py',
|
||||
'serialized_file': checkpoint_file,
|
||||
'handler': f'{Path(__file__).parent}/mmseg_handler.py',
|
||||
'model_name': model_name or Path(checkpoint_file).stem,
|
||||
'version': model_version,
|
||||
'export_path': output_folder,
|
||||
'force': force,
|
||||
'requirements_file': None,
|
||||
'extra_files': None,
|
||||
'runtime': 'python',
|
||||
'archive_format': 'default'
|
||||
})
|
||||
manifest = ModelExportUtils.generate_manifest_json(args)
|
||||
package_model(args, manifest)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(
|
||||
description='Convert mmseg models to TorchServe `.mar` format.')
|
||||
parser.add_argument('config', type=str, help='config file path')
|
||||
parser.add_argument('checkpoint', type=str, help='checkpoint file path')
|
||||
parser.add_argument(
|
||||
'--output-folder',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Folder where `{model_name}.mar` will be created.')
|
||||
parser.add_argument(
|
||||
'--model-name',
|
||||
type=str,
|
||||
default=None,
|
||||
help='If not None, used for naming the `{model_name}.mar`'
|
||||
'file that will be created under `output_folder`.'
|
||||
'If None, `{Path(checkpoint_file).stem}` will be used.')
|
||||
parser.add_argument(
|
||||
'--model-version',
|
||||
type=str,
|
||||
default='1.0',
|
||||
help='Number used for versioning.')
|
||||
parser.add_argument(
|
||||
'-f',
|
||||
'--force',
|
||||
action='store_true',
|
||||
help='overwrite the existing `{model_name}.mar`')
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
if package_model is None:
|
||||
raise ImportError('`torch-model-archiver` is required.'
|
||||
'Try: pip install torch-model-archiver')
|
||||
|
||||
mmseg2torchserve(args.config, args.checkpoint, args.output_folder,
|
||||
args.model_name, args.model_version, args.force)
|
||||
56
finetune/tools/torchserve/mmseg_handler.py
Normal file
56
finetune/tools/torchserve/mmseg_handler.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import base64
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import torch
|
||||
from mmengine.model.utils import revert_sync_batchnorm
|
||||
from ts.torch_handler.base_handler import BaseHandler
|
||||
|
||||
from mmseg.apis import inference_model, init_model
|
||||
|
||||
|
||||
class MMsegHandler(BaseHandler):
|
||||
|
||||
def initialize(self, context):
|
||||
properties = context.system_properties
|
||||
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
self.device = torch.device(self.map_location + ':' +
|
||||
str(properties.get('gpu_id')) if torch.cuda.
|
||||
is_available() else self.map_location)
|
||||
self.manifest = context.manifest
|
||||
|
||||
model_dir = properties.get('model_dir')
|
||||
serialized_file = self.manifest['model']['serializedFile']
|
||||
checkpoint = os.path.join(model_dir, serialized_file)
|
||||
self.config_file = os.path.join(model_dir, 'config.py')
|
||||
|
||||
self.model = init_model(self.config_file, checkpoint, self.device)
|
||||
self.model = revert_sync_batchnorm(self.model)
|
||||
self.initialized = True
|
||||
|
||||
def preprocess(self, data):
|
||||
images = []
|
||||
|
||||
for row in data:
|
||||
image = row.get('data') or row.get('body')
|
||||
if isinstance(image, str):
|
||||
image = base64.b64decode(image)
|
||||
image = mmcv.imfrombytes(image)
|
||||
images.append(image)
|
||||
|
||||
return images
|
||||
|
||||
def inference(self, data, *args, **kwargs):
|
||||
results = [inference_model(self.model, img) for img in data]
|
||||
return results
|
||||
|
||||
def postprocess(self, data):
|
||||
output = []
|
||||
|
||||
for image_result in data:
|
||||
_, buffer = cv2.imencode('.png', image_result[0].astype('uint8'))
|
||||
content = buffer.tobytes()
|
||||
output.append(content)
|
||||
return output
|
||||
58
finetune/tools/torchserve/test_torchserve.py
Normal file
58
finetune/tools/torchserve/test_torchserve.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from argparse import ArgumentParser
|
||||
from io import BytesIO
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mmcv
|
||||
import requests
|
||||
|
||||
from mmseg.apis import inference_model, init_model
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(
|
||||
description='Compare result of torchserve and pytorch,'
|
||||
'and visualize them.')
|
||||
parser.add_argument('img', help='Image file')
|
||||
parser.add_argument('config', help='Config file')
|
||||
parser.add_argument('checkpoint', help='Checkpoint file')
|
||||
parser.add_argument('model_name', help='The model name in the server')
|
||||
parser.add_argument(
|
||||
'--inference-addr',
|
||||
default='127.0.0.1:8080',
|
||||
help='Address and port of the inference server')
|
||||
parser.add_argument(
|
||||
'--result-image',
|
||||
type=str,
|
||||
default=None,
|
||||
help='save server output in result-image')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name
|
||||
with open(args.img, 'rb') as image:
|
||||
tmp_res = requests.post(url, image)
|
||||
content = tmp_res.content
|
||||
if args.result_image:
|
||||
with open(args.result_image, 'wb') as out_image:
|
||||
out_image.write(content)
|
||||
plt.imshow(mmcv.imread(args.result_image, 'grayscale'))
|
||||
plt.show()
|
||||
else:
|
||||
plt.imshow(plt.imread(BytesIO(content)))
|
||||
plt.show()
|
||||
model = init_model(args.config, args.checkpoint, args.device)
|
||||
image = mmcv.imread(args.img)
|
||||
result = inference_model(model, image)
|
||||
plt.imshow(result[0])
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main(args)
|
||||
104
finetune/tools/train.py
Normal file
104
finetune/tools/train.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmseg.registry import RUNNERS
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Train a segmentor')
|
||||
parser.add_argument('config', help='train config file path')
|
||||
parser.add_argument('--work-dir', help='the dir to save logs and models')
|
||||
parser.add_argument(
|
||||
'--resume',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='resume from the latest checkpoint in the work_dir automatically')
|
||||
parser.add_argument(
|
||||
'--amp',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='enable automatic-mixed-precision training')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
parser.add_argument(
|
||||
'--launcher',
|
||||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
default='none',
|
||||
help='job launcher')
|
||||
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
|
||||
# will pass the `--local-rank` parameter to `tools/train.py` instead
|
||||
# of `--local_rank`.
|
||||
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg.launcher = args.launcher
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# work_dir is determined in this priority: CLI > segment in file > filename
|
||||
if args.work_dir is not None:
|
||||
# update configs according to CLI args if args.work_dir is not None
|
||||
cfg.work_dir = args.work_dir
|
||||
elif cfg.get('work_dir', None) is None:
|
||||
# use config filename as default work_dir if cfg.work_dir is None
|
||||
cfg.work_dir = osp.join('./work_dirs',
|
||||
osp.splitext(osp.basename(args.config))[0])
|
||||
|
||||
# enable automatic-mixed-precision training
|
||||
if args.amp is True:
|
||||
optim_wrapper = cfg.optim_wrapper.type
|
||||
if optim_wrapper == 'AmpOptimWrapper':
|
||||
print_log(
|
||||
'AMP training is already enabled in your config.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
else:
|
||||
assert optim_wrapper == 'OptimWrapper', (
|
||||
'`--amp` is only supported when the optimizer wrapper type is '
|
||||
f'`OptimWrapper` but got {optim_wrapper}.')
|
||||
cfg.optim_wrapper.type = 'AmpOptimWrapper'
|
||||
cfg.optim_wrapper.loss_scale = 'dynamic'
|
||||
|
||||
# resume training
|
||||
cfg.resume = args.resume
|
||||
|
||||
# build the runner from config
|
||||
if 'runner_type' not in cfg:
|
||||
# build the default runner
|
||||
runner = Runner.from_cfg(cfg)
|
||||
else:
|
||||
# build customized runner from the registry
|
||||
# if 'runner_type' is set in the cfg
|
||||
runner = RUNNERS.build(cfg)
|
||||
|
||||
# start training
|
||||
runner.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user