init
This commit is contained in:
94
finetune/mmseg/models/losses/ohem_cross_entropy_loss.py
Normal file
94
finetune/mmseg/models/losses/ohem_cross_entropy_loss.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class OhemCrossEntropy(nn.Module):
|
||||
"""OhemCrossEntropy loss.
|
||||
|
||||
This func is modified from
|
||||
`PIDNet <https://github.com/XuJiacong/PIDNet/blob/main/utils/criterion.py#L43>`_. # noqa
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Args:
|
||||
ignore_label (int): Labels to ignore when computing the loss.
|
||||
Default: 255
|
||||
thresh (float, optional): The threshold for hard example selection.
|
||||
Below which, are prediction with low confidence. If not
|
||||
specified, the hard examples will be pixels of top ``min_kept``
|
||||
loss. Default: 0.7.
|
||||
min_kept (int, optional): The minimum number of predictions to keep.
|
||||
Default: 100000.
|
||||
loss_weight (float): Weight of the loss. Defaults to 1.0.
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_name (str): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_boundary'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ignore_label: int = 255,
|
||||
thres: float = 0.7,
|
||||
min_kept: int = 100000,
|
||||
loss_weight: float = 1.0,
|
||||
class_weight: Optional[Union[List[float], str]] = None,
|
||||
loss_name: str = 'loss_ohem'):
|
||||
super().__init__()
|
||||
self.thresh = thres
|
||||
self.min_kept = max(1, min_kept)
|
||||
self.ignore_label = ignore_label
|
||||
self.loss_weight = loss_weight
|
||||
self.loss_name_ = loss_name
|
||||
self.class_weight = class_weight
|
||||
|
||||
def forward(self, score: Tensor, target: Tensor) -> Tensor:
|
||||
"""Forward function.
|
||||
Args:
|
||||
score (Tensor): Predictions of the segmentation head.
|
||||
target (Tensor): Ground truth of the image.
|
||||
|
||||
Returns:
|
||||
Tensor: Loss tensor.
|
||||
"""
|
||||
# score: (N, C, H, W)
|
||||
pred = F.softmax(score, dim=1)
|
||||
if self.class_weight is not None:
|
||||
class_weight = score.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
|
||||
pixel_losses = F.cross_entropy(
|
||||
score,
|
||||
target,
|
||||
weight=class_weight,
|
||||
ignore_index=self.ignore_label,
|
||||
reduction='none').contiguous().view(-1) # (N*H*W)
|
||||
mask = target.contiguous().view(-1) != self.ignore_label # (N*H*W)
|
||||
|
||||
tmp_target = target.clone() # (N, H, W)
|
||||
tmp_target[tmp_target == self.ignore_label] = 0
|
||||
# pred: (N, C, H, W) -> (N*H*W, C)
|
||||
pred = pred.gather(1, tmp_target.unsqueeze(1))
|
||||
# pred: (N*H*W, C) -> (N*H*W), ind: (N*H*W)
|
||||
pred, ind = pred.contiguous().view(-1, )[mask].contiguous().sort()
|
||||
if pred.numel() > 0:
|
||||
min_value = pred[min(self.min_kept, pred.numel() - 1)]
|
||||
else:
|
||||
return score.new_tensor(0.0)
|
||||
threshold = max(min_value, self.thresh)
|
||||
|
||||
pixel_losses = pixel_losses[mask][ind]
|
||||
pixel_losses = pixel_losses[pred < threshold]
|
||||
return self.loss_weight * pixel_losses.mean()
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
return self.loss_name_
|
||||
Reference in New Issue
Block a user