Files
SkySensePlusPlus/lib/models/metrics/sem_metrics.py
esenke 01adcfdf60 init
2025-12-08 22:16:31 +08:00

93 lines
3.2 KiB
Python

# coding: utf-8
# Copyright (c) Ant Group. All rights reserved.
import torch
from torch.distributed import all_reduce, ReduceOp
from antmmf.common.registry import registry
from antmmf.modules.metrics.base_metric import BaseMetric
@registry.register_metric("sem_metric")
class SemMetric(BaseMetric):
"""Segmentation metrics used in evaluation phase.
Args:
name (str): Name of the metric.
eval_type(str): 3 types are supported: 'mIoU', 'mDice', 'mFscore'
result_field(str): key of predicted results in output dict
target_field(str): key of ground truth in output dict
ignore_index(int): class value will be ignored in evaluation
num_cls(int): total number of categories in evaluation
"""
def __init__(self,
name="dummy_metric", **kwargs
):
super().__init__(name)
self.reset()
def calculate(self, sample_list, model_output, *args, **kwargs):
"""Calculate Intersection and Union for a batch.
Args:
sample_list (Sample_List): data which contains ground truth segmentation maps
model_output (dict): data which contains prediction segmentation maps
Returns:
torch.Tensor: The intersection of prediction and ground truth histogram
on all classes.
torch.Tensor: The union of prediction and ground truth histogram on all
classes.
torch.Tensor: The prediction histogram on all classes.
torch.Tensor: The ground truth histogram on all classes.
"""
return torch.tensor(0).float()
def reset(self):
""" initialized all attributes value before evaluation
"""
self.total_mask_mae = 0
self.total_num = torch.tensor(0)
def collect(self, sample_list, model_output, *args, **kwargs):
"""
Args:
sample_list(Sample_List): data which contains ground truth segmentation maps
model_output (Dict): Dict returned by model, that contains two modalities
Returns:
torch.FloatTensor: Accuracy
"""
batch_mask_mae = \
self.calculate(sample_list, model_output, *args, **kwargs)
self.total_mask_mae += batch_mask_mae
self.total_num += 1
def format(self, *args):
""" Format evaluated metrics for profile.
Returns:
dict: dict of all evaluated metrics.
"""
output_metric = dict()
# if self.eval_type == 'mae':
mae = args[0]
output_metric['mae'] = mae.item()
return output_metric
def summarize(self, *args, **kwargs):
"""This method is used to calculate the overall metric.
Returns:
dict: dict of all evaluated metrics.
"""
# if self.eval_type == 'mae':
mae = self.total_mask_mae / (self.total_num)
return self.format(mae)
def all_reduce(self):
total_number = torch.stack([
self.total_mask_mae, self.total_num
]).cuda()
all_reduce(total_number, op=ReduceOp.SUM)
self.total_mask_mae = total_number[0].cpu()
self.total_num = total_number[1].cpu()