init
This commit is contained in:
208
lib/evaluation/segm_eval_base.py
Normal file
208
lib/evaluation/segm_eval_base.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from skimage.transform import resize
|
||||
|
||||
from skimage import io
|
||||
from multiprocessing import Pool
|
||||
from functools import partial
|
||||
import logging
|
||||
import datetime
|
||||
|
||||
def get_confusion_matrix(pred, gt, num_class):
|
||||
assert pred.shape == gt.shape, f"pred.shape: {pred.shape} != gt.shape: {gt.shape}"
|
||||
mask = (gt >= 0) & (gt < num_class) # 去掉为0的背景类别
|
||||
label = num_class * gt[mask] + pred[mask]
|
||||
count = np.bincount(label, minlength=num_class**2)
|
||||
confusion_matrix = count.reshape(num_class, num_class)
|
||||
return confusion_matrix
|
||||
|
||||
def get_miou(confusion_matrix):
|
||||
diagonal_elements = np.diag(confusion_matrix)
|
||||
column_sums = np.sum(confusion_matrix, axis=0)
|
||||
row_sums = np.sum(confusion_matrix, axis=1)
|
||||
ious = diagonal_elements/(column_sums + row_sums - diagonal_elements)
|
||||
m_iou = np.nanmean(ious)
|
||||
return m_iou
|
||||
|
||||
def get_mprecison(confusion_matrix):
|
||||
diagonal_elements = np.diag(confusion_matrix)
|
||||
column_sums = np.sum(confusion_matrix, axis=0)
|
||||
precisions = diagonal_elements / (column_sums + 1e-06)
|
||||
m_precision = np.nanmean(precisions)
|
||||
return m_precision
|
||||
|
||||
def get_mrecall(confusion_matrix):
|
||||
diagonal_elements = np.diag(confusion_matrix)
|
||||
row_sums = np.sum(confusion_matrix, axis=1)
|
||||
recalls= diagonal_elements / (row_sums + 1e-06)
|
||||
m_recall = np.nanmean(recalls)
|
||||
return m_recall
|
||||
|
||||
def get_macc(confusion_matrix):
|
||||
'''
|
||||
acc = tp/tp+fn 就是recall
|
||||
'''
|
||||
m_recall = get_mrecall(confusion_matrix)
|
||||
return m_recall
|
||||
|
||||
def get_per_class_iou(confusion_matrix):
|
||||
intersection = np.diag(confusion_matrix)
|
||||
union = np.sum(confusion_matrix, axis=0) + np.sum(confusion_matrix, axis=1) - intersection
|
||||
iou = intersection / (union.astype(np.float32) + 1e-6)
|
||||
return iou
|
||||
|
||||
def get_per_class_acc(confusion_matrix):
|
||||
total_acc = np.diag(confusion_matrix) / (np.sum(confusion_matrix, axis=1).astype(np.float32) + 1e-6)
|
||||
return total_acc
|
||||
|
||||
|
||||
def post_process_segm_output(segm, colors, dist_type='abs'):
|
||||
"""
|
||||
Post-processing to turn output segm image to class index map using NumPy
|
||||
Args:
|
||||
segm: (H, W, 3)
|
||||
Returns:
|
||||
class_map: (H, W)
|
||||
"""
|
||||
palette = np.array(colors)
|
||||
segm = segm.astype(np.float32) # (h, w, 3)
|
||||
h, w, k = segm.shape[0], segm.shape[1], palette.shape[0]
|
||||
if dist_type == 'abs':
|
||||
dist = np.abs(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3)) # (h, w, k)
|
||||
elif dist_type == 'square':
|
||||
dist = np.power(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3), 2) # (h, w, k)
|
||||
elif dist_type == 'mean':
|
||||
dist_abs = np.abs(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3)) # (h, w, k)
|
||||
dist_square = np.power(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3), 2) # (h, w, k)
|
||||
dist = (dist_abs + dist_square) / 2.
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
dist = np.sum(dist, axis=-1)
|
||||
pred = np.argmin(dist, axis=-1).astype(np.int)
|
||||
return pred
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser('semantic segmentation evaluation', add_help=False)
|
||||
parser.add_argument('--pred_dir', type=str, help='dir to pred', required=True)
|
||||
parser.add_argument('--gt_dir', type=str, help='dir to gt', required=True)
|
||||
parser.add_argument('--gt_list_path', type=str, help='dir to gt_list_path', required=True)
|
||||
parser.add_argument('--gt_suffix', type=str, help='suffix to gt', required=True)
|
||||
parser.add_argument('--dataset_name', type=str, help='dataset name', required=True)
|
||||
parser.add_argument('--model_name', type=str, help='model name', required=True)
|
||||
parser.add_argument('--dist_type', type=str, help='dist type',
|
||||
default='abs', choices=['abs', 'square', 'mean'])
|
||||
return parser.parse_args()
|
||||
|
||||
def process_file(file_dict, pred_dir, gt_dir, args, num_class):
|
||||
filename = file_dict['file_name']
|
||||
file_cls = file_dict['file_cls']
|
||||
|
||||
gt = io.imread(os.path.join(gt_dir, filename))
|
||||
gt_index = gt.copy()
|
||||
gt_index[gt_index != file_cls] = 0
|
||||
gt_index[gt_index == file_cls] = 1
|
||||
|
||||
try:
|
||||
pred = io.imread(os.path.join(pred_dir, filename.replace('.png', f'-{file_cls}.png')))
|
||||
pred = resize(pred, gt.shape[-2:], anti_aliasing=False, mode='reflect', order=0)
|
||||
|
||||
if len(pred.shape) == 3:
|
||||
pred_index = pred[:,:,0].copy()
|
||||
else:
|
||||
pred_index = pred.copy()
|
||||
pred_index[pred_index<=127] = 0
|
||||
pred_index[pred_index>127] = 1
|
||||
except:
|
||||
logging.info(filename.replace('.png', f'_{file_cls}.png'), 'not found!')
|
||||
pred_index = gt_index.copy()
|
||||
|
||||
pred_index = pred_index.flatten()
|
||||
gt_index = gt_index.flatten()
|
||||
confusion_matrix = get_confusion_matrix(pred_index, gt_index, num_class)
|
||||
return file_cls, confusion_matrix
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args_parser()
|
||||
dataset_name = args.dataset_name
|
||||
pred_dir = args.pred_dir
|
||||
gt_dir = args.gt_dir
|
||||
gt_list_path = args.gt_list_path
|
||||
dist_type = args.dist_type
|
||||
model_name = args.model_name
|
||||
|
||||
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
os.makedirs('logs/eval', exist_ok=True)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(f'logs/eval/eval_{model_name}_{dataset_name}_{current_time}.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
|
||||
output_folder = os.path.join(pred_dir, f'eval_{dataset_name}')
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
|
||||
num_class = 2
|
||||
|
||||
with open(gt_list_path, 'r') as f:
|
||||
file_cls_list = f.readlines()
|
||||
|
||||
file_list = []
|
||||
for i in file_cls_list:
|
||||
i = i.strip()
|
||||
file_name = i[:-3]
|
||||
file_cls = i[-2:]
|
||||
file_list.append({'file_name': file_name, 'file_cls': int(file_cls)})
|
||||
|
||||
all_pred_labels = []
|
||||
all_gt_labels = []
|
||||
|
||||
process_file_partial = partial(process_file, pred_dir=pred_dir,gt_dir=gt_dir, args=args, num_class=num_class)
|
||||
|
||||
pool = Pool()
|
||||
|
||||
outputs = pool.map(process_file_partial, file_list)
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
logging.info(f'len outputs: {len(outputs)}')
|
||||
confusion_matrix_dict = {}
|
||||
for cls, confusion_matrix in outputs:
|
||||
if cls in confusion_matrix_dict.keys():
|
||||
confusion_matrix_dict[cls] += confusion_matrix
|
||||
else:
|
||||
confusion_matrix_dict[cls] = confusion_matrix
|
||||
|
||||
class_list = []
|
||||
iou_list = []
|
||||
acc_list = []
|
||||
for cls, confusion_matrix in confusion_matrix_dict.items():
|
||||
ious = get_per_class_iou(confusion_matrix)
|
||||
accs = get_per_class_acc(confusion_matrix)
|
||||
logging.info(f'cls: {cls}, ious: {ious}, accs: {accs}')
|
||||
class_list.append(cls)
|
||||
iou_list.append(ious[1])
|
||||
acc_list.append(accs[1])
|
||||
|
||||
miou = np.mean(iou_list)
|
||||
macc = np.mean(acc_list)
|
||||
|
||||
df_metrics = pd.DataFrame({
|
||||
'Class': class_list + ['Mean'],
|
||||
'IoU': iou_list + [miou],
|
||||
'Accuracy': acc_list + [macc],
|
||||
})
|
||||
pd.set_option('display.float_format', '{:.4f}%'.format)
|
||||
logging.info(df_metrics)
|
||||
pd.reset_option('display.float_format')
|
||||
df_metrics.to_csv(os.path.join(output_folder, 'eval.csv'), index=False, float_format='%.4f')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user