209 lines
7.2 KiB
Python
209 lines
7.2 KiB
Python
import os
|
|
import argparse
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from skimage.transform import resize
|
|
|
|
from skimage import io
|
|
from multiprocessing import Pool
|
|
from functools import partial
|
|
import logging
|
|
import datetime
|
|
|
|
def get_confusion_matrix(pred, gt, num_class):
|
|
assert pred.shape == gt.shape, f"pred.shape: {pred.shape} != gt.shape: {gt.shape}"
|
|
mask = (gt >= 0) & (gt < num_class) # 去掉为0的背景类别
|
|
label = num_class * gt[mask] + pred[mask]
|
|
count = np.bincount(label, minlength=num_class**2)
|
|
confusion_matrix = count.reshape(num_class, num_class)
|
|
return confusion_matrix
|
|
|
|
def get_miou(confusion_matrix):
|
|
diagonal_elements = np.diag(confusion_matrix)
|
|
column_sums = np.sum(confusion_matrix, axis=0)
|
|
row_sums = np.sum(confusion_matrix, axis=1)
|
|
ious = diagonal_elements/(column_sums + row_sums - diagonal_elements)
|
|
m_iou = np.nanmean(ious)
|
|
return m_iou
|
|
|
|
def get_mprecison(confusion_matrix):
|
|
diagonal_elements = np.diag(confusion_matrix)
|
|
column_sums = np.sum(confusion_matrix, axis=0)
|
|
precisions = diagonal_elements / (column_sums + 1e-06)
|
|
m_precision = np.nanmean(precisions)
|
|
return m_precision
|
|
|
|
def get_mrecall(confusion_matrix):
|
|
diagonal_elements = np.diag(confusion_matrix)
|
|
row_sums = np.sum(confusion_matrix, axis=1)
|
|
recalls= diagonal_elements / (row_sums + 1e-06)
|
|
m_recall = np.nanmean(recalls)
|
|
return m_recall
|
|
|
|
def get_macc(confusion_matrix):
|
|
'''
|
|
acc = tp/tp+fn 就是recall
|
|
'''
|
|
m_recall = get_mrecall(confusion_matrix)
|
|
return m_recall
|
|
|
|
def get_per_class_iou(confusion_matrix):
|
|
intersection = np.diag(confusion_matrix)
|
|
union = np.sum(confusion_matrix, axis=0) + np.sum(confusion_matrix, axis=1) - intersection
|
|
iou = intersection / (union.astype(np.float32) + 1e-6)
|
|
return iou
|
|
|
|
def get_per_class_acc(confusion_matrix):
|
|
total_acc = np.diag(confusion_matrix) / (np.sum(confusion_matrix, axis=1).astype(np.float32) + 1e-6)
|
|
return total_acc
|
|
|
|
|
|
def post_process_segm_output(segm, colors, dist_type='abs'):
|
|
"""
|
|
Post-processing to turn output segm image to class index map using NumPy
|
|
Args:
|
|
segm: (H, W, 3)
|
|
Returns:
|
|
class_map: (H, W)
|
|
"""
|
|
palette = np.array(colors)
|
|
segm = segm.astype(np.float32) # (h, w, 3)
|
|
h, w, k = segm.shape[0], segm.shape[1], palette.shape[0]
|
|
if dist_type == 'abs':
|
|
dist = np.abs(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3)) # (h, w, k)
|
|
elif dist_type == 'square':
|
|
dist = np.power(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3), 2) # (h, w, k)
|
|
elif dist_type == 'mean':
|
|
dist_abs = np.abs(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3)) # (h, w, k)
|
|
dist_square = np.power(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3), 2) # (h, w, k)
|
|
dist = (dist_abs + dist_square) / 2.
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
dist = np.sum(dist, axis=-1)
|
|
pred = np.argmin(dist, axis=-1).astype(np.int)
|
|
return pred
|
|
|
|
def get_args_parser():
|
|
parser = argparse.ArgumentParser('semantic segmentation evaluation', add_help=False)
|
|
parser.add_argument('--pred_dir', type=str, help='dir to pred', required=True)
|
|
parser.add_argument('--gt_dir', type=str, help='dir to gt', required=True)
|
|
parser.add_argument('--gt_list_path', type=str, help='dir to gt_list_path', required=True)
|
|
parser.add_argument('--gt_suffix', type=str, help='suffix to gt', required=True)
|
|
parser.add_argument('--dataset_name', type=str, help='dataset name', required=True)
|
|
parser.add_argument('--model_name', type=str, help='model name', required=True)
|
|
parser.add_argument('--dist_type', type=str, help='dist type',
|
|
default='abs', choices=['abs', 'square', 'mean'])
|
|
return parser.parse_args()
|
|
|
|
def process_file(file_dict, pred_dir, gt_dir, args, num_class):
|
|
filename = file_dict['file_name']
|
|
file_cls = file_dict['file_cls']
|
|
|
|
gt = io.imread(os.path.join(gt_dir, filename))
|
|
gt_index = gt.copy()
|
|
gt_index[gt_index != file_cls] = 0
|
|
gt_index[gt_index == file_cls] = 1
|
|
|
|
try:
|
|
pred = io.imread(os.path.join(pred_dir, filename.replace('.png', f'-{file_cls}.png')))
|
|
pred = resize(pred, gt.shape[-2:], anti_aliasing=False, mode='reflect', order=0)
|
|
|
|
if len(pred.shape) == 3:
|
|
pred_index = pred[:,:,0].copy()
|
|
else:
|
|
pred_index = pred.copy()
|
|
pred_index[pred_index<=127] = 0
|
|
pred_index[pred_index>127] = 1
|
|
except:
|
|
logging.info(filename.replace('.png', f'_{file_cls}.png'), 'not found!')
|
|
pred_index = gt_index.copy()
|
|
|
|
pred_index = pred_index.flatten()
|
|
gt_index = gt_index.flatten()
|
|
confusion_matrix = get_confusion_matrix(pred_index, gt_index, num_class)
|
|
return file_cls, confusion_matrix
|
|
|
|
if __name__ == '__main__':
|
|
args = get_args_parser()
|
|
dataset_name = args.dataset_name
|
|
pred_dir = args.pred_dir
|
|
gt_dir = args.gt_dir
|
|
gt_list_path = args.gt_list_path
|
|
dist_type = args.dist_type
|
|
model_name = args.model_name
|
|
|
|
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
os.makedirs('logs/eval', exist_ok=True)
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler(f'logs/eval/eval_{model_name}_{dataset_name}_{current_time}.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
|
|
output_folder = os.path.join(pred_dir, f'eval_{dataset_name}')
|
|
os.makedirs(output_folder, exist_ok=True)
|
|
|
|
|
|
num_class = 2
|
|
|
|
with open(gt_list_path, 'r') as f:
|
|
file_cls_list = f.readlines()
|
|
|
|
file_list = []
|
|
for i in file_cls_list:
|
|
i = i.strip()
|
|
file_name = i[:-3]
|
|
file_cls = i[-2:]
|
|
file_list.append({'file_name': file_name, 'file_cls': int(file_cls)})
|
|
|
|
all_pred_labels = []
|
|
all_gt_labels = []
|
|
|
|
process_file_partial = partial(process_file, pred_dir=pred_dir,gt_dir=gt_dir, args=args, num_class=num_class)
|
|
|
|
pool = Pool()
|
|
|
|
outputs = pool.map(process_file_partial, file_list)
|
|
|
|
pool.close()
|
|
pool.join()
|
|
logging.info(f'len outputs: {len(outputs)}')
|
|
confusion_matrix_dict = {}
|
|
for cls, confusion_matrix in outputs:
|
|
if cls in confusion_matrix_dict.keys():
|
|
confusion_matrix_dict[cls] += confusion_matrix
|
|
else:
|
|
confusion_matrix_dict[cls] = confusion_matrix
|
|
|
|
class_list = []
|
|
iou_list = []
|
|
acc_list = []
|
|
for cls, confusion_matrix in confusion_matrix_dict.items():
|
|
ious = get_per_class_iou(confusion_matrix)
|
|
accs = get_per_class_acc(confusion_matrix)
|
|
logging.info(f'cls: {cls}, ious: {ious}, accs: {accs}')
|
|
class_list.append(cls)
|
|
iou_list.append(ious[1])
|
|
acc_list.append(accs[1])
|
|
|
|
miou = np.mean(iou_list)
|
|
macc = np.mean(acc_list)
|
|
|
|
df_metrics = pd.DataFrame({
|
|
'Class': class_list + ['Mean'],
|
|
'IoU': iou_list + [miou],
|
|
'Accuracy': acc_list + [macc],
|
|
})
|
|
pd.set_option('display.float_format', '{:.4f}%'.format)
|
|
logging.info(df_metrics)
|
|
pd.reset_option('display.float_format')
|
|
df_metrics.to_csv(os.path.join(output_folder, 'eval.csv'), index=False, float_format='%.4f')
|
|
|
|
|