Files
SkySensePlusPlus/lib/predictors/flood3i_1shot.py
esenke 01adcfdf60 init
2025-12-08 22:16:31 +08:00

245 lines
8.5 KiB
Python

import os
import glob
import numpy as np
import yaml
import argparse
import oss2
import torch
from PIL import Image
from tqdm import tqdm
import concurrent.futures
from torchvision.transforms import functional as F
import random
from antmmf.common.registry import registry
from antmmf.common.report import Report, default_result_formater
from antmmf.structures import Sample, SampleList
from antmmf.predictors.base_predictor import BasePredictor
from antmmf.utils.timer import Timer
from antmmf.predictors.build import build_predictor
from antmmf.common.task_loader import build_collate_fn
from antmmf.datasets.samplers import SequentialSampler
from antmmf.common.build import build_config
from lib.utils.checkpoint import SegCheckpoint
from lib.datasets.loader.few_shot_flood3i_loader import FewShotFloodLoader
def seed_everything(seed=0):
# 为了确保CUDA卷积的确定性
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
@registry.register_predictor("OneshotPredictor")
class OneshotPredictor(BasePredictor, FewShotFloodLoader):
def __init__(self, config):
self.config = config
self.predictor_parameters = self.config.predictor_parameters
def _predict(self, sample_list):
# with torch.no_grad():
if True:
sample_list = sample_list.to(self.device)
report = self._forward_pass(sample_list)
return report, sample_list
def _forward_pass(self, samplelist):
autocast_dtype = torch.bfloat16
with torch.cuda.amp.autocast(enabled=True, dtype=autocast_dtype):
model_output = self.model(samplelist)
report = Report(samplelist, model_output)
return report
def predict(self, data=None):
if data is None:
data = self.dummy_request()
sample = self._build_sample(data)
if not isinstance(sample, Sample):
raise Exception(
f"Method _build_sample is expected to return a instance of antmmf.structures.sample.Sample,"
f"but got type {type(sample)} instead.")
result, sample_list = self._predict(SampleList([sample]))
np_result = default_result_formater(result)
result = self.format_result(np_result)
assert isinstance(
result, dict
), f"Result should be instance of Dict,but got f{type(result)} instead"
return result, sample_list
def load_checkpoint(self):
self.resume_file = self.config.predictor_parameters.model_dir
self.checkpoint = SegCheckpoint(self, load_only=True)
self.checkpoint.load_model_weights(self.resume_file, force=True)
def covert_speedup_op(self):
if self.config.predictor_parameters.replace_speedup_op:
from lib.utils.optim_utils import replace_speedup_op
self.model = replace_speedup_op(self.model)
def save_image(output_path, image_np_array):
image = Image.fromarray(image_np_array)
image.save(output_path)
def build_predictor_from_args(args, *rest, **kwargs):
config = build_config(
args.config,
config_override=args.config_override,
opts_override=args.opts,
specific_override=args,
)
predictor_obj = build_predictor(config)
setattr(predictor_obj, "args", args)
return predictor_obj
def build_online_predictor(model_dir=None, config_yaml=None):
assert model_dir or config_yaml
from antmmf.utils.flags import flags
# if config_yaml not indicated, there must be a `config.yaml` file under `model_dir`
config_path = config_yaml if config_yaml else os.path.join(model_dir, "config.yaml")
input_args = ["--config", config_path]
if model_dir is not None:
input_args += ["predictor_parameters.model_dir", model_dir]
parser = flags.get_parser()
args = parser.parse_args(input_args)
predictor = build_predictor_from_args(args)
return predictor
def profile(profiler, text):
print(f'{text}: {profiler.get_time_since_start()}')
profiler.reset()
def cvt_colors(img_2d, idx_2_color_rgb):
img_rgb = np.zeros((img_2d.shape[0], img_2d.shape[1], 3), dtype=np.uint8)
for idx, color in idx_2_color_rgb.items():
img_rgb[img_2d==idx] = color
return img_rgb
def process_results(preds, targets, input_imgs, img_names, save_dir, save_dir_vis, idx_2_color):
imagenet_std = np.array([0.229, 0.224, 0.225])
imagenet_mean = np.array([0.485, 0.456, 0.406])
idx_2_color_rgb = {}
for idx, color in idx_2_color.items():
r = color // (256 * 256)
g = (color % (256 * 256)) // 256
b = color % 256
idx_2_color_rgb[idx] = (r, g, b)
for i in range(preds.size(0)):
output1 = preds[i].argmax(0) # h, w
output1_total = output1.clone()
output1 = output1[output1.shape[0]//2:, :]
output1 = output1.numpy().astype(np.uint8)
output1 = cvt_colors(output1, idx_2_color_rgb)
output1_total = output1_total.numpy().astype(np.uint8)
output1_total = cvt_colors(output1_total, idx_2_color_rgb)
# for visualization
output2 = targets[i]
output2 = output2.numpy().astype(np.uint8)
output2 = cvt_colors(output2, idx_2_color_rgb)
input_img = torch.einsum('chw->hwc', input_imgs[i])
input_img = torch.clip((input_img * imagenet_std + imagenet_mean) * 255, 0, 255)
input_img = input_img.numpy().astype(np.uint8)
output_comb = np.concatenate((input_img, output1_total, output2), axis=1)
# save result
save_path = os.path.join(save_dir, f'{img_names[i]}.png')
save_image(save_path, output1)
save_path_vis = os.path.join(save_dir_vis, f'{img_names[i]}.png')
save_image(save_path_vis, output_comb)
def test(args):
model_path = args.model_path
config_path = args.config
global_seed = args.seed
predictor = build_online_predictor(model_path, config_path)
seed_everything(global_seed)
dataset = FewShotFloodLoader(
"test", predictor.config.task_attributes.segmentation.dataset_attributes.few_shot_flood_segmentation)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=8,
shuffle=False,
sampler=SequentialSampler(dataset),
collate_fn=build_collate_fn(dataset),
num_workers=16,
pin_memory=True,
drop_last=False,
)
print(len(loader))
predictor.load(with_ckpt=True)
predictor.covert_speedup_op()
save_dir = args.save_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_dir_vis = os.path.join(args.save_dir, 'vis_full')
if not os.path.exists(save_dir_vis):
os.makedirs(save_dir_vis)
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
profiler2 = Timer()
profiler2.reset()
for sample_batched in tqdm(loader):
profile(profiler2, "Build sample time")
result, sample_list = predictor._predict(SampleList(sample_batched))
profile(profiler2, "Infer time")
preds = result["logits_hr"].to(torch.float32).detach().cpu()
targets = result['mapped_targets'].to(torch.float32).detach().cpu()
idx_2_color = result['idx_2_color']
input_imgs = sample_list['hr_img'].to(torch.float32).detach().cpu()
img_names = sample_list["img_name"]
executor.submit(process_results, preds, targets, input_imgs, img_names, save_dir, save_dir_vis, idx_2_color)
profile(profiler2, "Save results time")
try:
del predictor.model
except Exception as e:
print('delete model error: ', e)
def parse_args():
desc = '1-shot predictor'
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--model_path',
required=True,
type=str,
help='model directory')
parser.add_argument('--seed',
default=0,
type=int,
help='seed')
parser.add_argument('--config',
required=True,
type=str,
help='config path')
parser.add_argument('--save_dir',
required=False,
type=str,
help='save directory')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
test(args)