init
This commit is contained in:
47
finetune/mmseg/models/segmentors/seg_tta.py
Normal file
47
finetune/mmseg/models/segmentors/seg_tta.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseTTAModel
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SegTTAModel(BaseTTAModel):
|
||||
|
||||
def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList:
|
||||
"""Merge predictions of enhanced data to one prediction.
|
||||
|
||||
Args:
|
||||
data_samples_list (List[SampleList]): List of predictions
|
||||
of all enhanced data.
|
||||
|
||||
Returns:
|
||||
SampleList: Merged prediction.
|
||||
"""
|
||||
predictions = []
|
||||
for data_samples in data_samples_list:
|
||||
seg_logits = data_samples[0].seg_logits.data
|
||||
logits = torch.zeros(seg_logits.shape).to(seg_logits)
|
||||
for data_sample in data_samples:
|
||||
seg_logit = data_sample.seg_logits.data
|
||||
if self.module.out_channels > 1:
|
||||
logits += seg_logit.softmax(dim=0)
|
||||
else:
|
||||
logits += seg_logit.sigmoid()
|
||||
logits /= len(data_samples)
|
||||
if self.module.out_channels == 1:
|
||||
seg_pred = (logits > self.module.decode_head.threshold
|
||||
).to(logits).squeeze(1)
|
||||
else:
|
||||
seg_pred = logits.argmax(dim=0)
|
||||
data_sample.set_data({'pred_sem_seg': PixelData(data=seg_pred)})
|
||||
if hasattr(data_samples[0], 'gt_sem_seg'):
|
||||
data_sample.set_data(
|
||||
{'gt_sem_seg': data_samples[0].gt_sem_seg})
|
||||
data_sample.set_metainfo({'img_path': data_samples[0].img_path})
|
||||
predictions.append(data_sample)
|
||||
return predictions
|
||||
Reference in New Issue
Block a user