90 lines
3.4 KiB
Python
90 lines
3.4 KiB
Python
# Copyright (c) Ant Group and its affiliates.
|
|
import torch
|
|
import torch.nn as nn
|
|
from antmmf.common.registry import registry
|
|
import torch.nn.functional as F
|
|
|
|
@registry.register_loss("RecLoss")
|
|
class RecLoss(nn.Module):
|
|
def __init__(self, **params):
|
|
super().__init__()
|
|
self.weight = params.pop("weight")
|
|
self.patch_size = params.pop("patch_size")
|
|
self.eps = torch.finfo(torch.bfloat16).eps
|
|
self.pred_key = params.pop("pred_key")
|
|
self.vocabulary_size = params.pop("vocabulary_size") + 1
|
|
self.mask_key = params.pop("mask_key")
|
|
self.target_key = params.pop("target_key")
|
|
self.feature_merged = params.pop("feature_merged")
|
|
self.cnt_train = 0
|
|
self.cnt_val = 0
|
|
self.use_bg = params.pop("use_bg")
|
|
if "use_all_patch" in params:
|
|
self.use_all_patch = params.pop("use_all_patch")
|
|
else:
|
|
self.use_all_patch = False
|
|
if "balance" in params:
|
|
self.balance = params.pop("balance")
|
|
else:
|
|
self.balance = False
|
|
if "sim_regularization" in params:
|
|
self.sim_regularization = params.pop("sim_regularization")
|
|
else:
|
|
self.sim_regularization = False
|
|
|
|
def unpatchify(self, x):
|
|
"""
|
|
x: (N, L, patch_size**2 *3)
|
|
imgs: (N, 3, H, W)
|
|
"""
|
|
p = self.patch_size
|
|
w = int((x.shape[1]*0.5)**.5)
|
|
h = w * 2
|
|
assert h * w == x.shape[1]
|
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p, p))
|
|
x = torch.einsum('nhwpq->nhpwq', x)
|
|
imgs = x.reshape(shape=(x.shape[0], h * p, w * p))
|
|
return imgs
|
|
|
|
|
|
def forward(self, sample_list, output, *args, **kwargs):
|
|
pred = output[self.pred_key] # B, C, H, W
|
|
target = output[self.target_key] # B, H, W
|
|
mask = output[self.mask_key]
|
|
b_mask, h_mask, w_mask = mask.shape
|
|
mask = mask.reshape((b_mask, h_mask*w_mask))
|
|
mask = mask[:, :, None].repeat(1, 1, self.patch_size**2)
|
|
mask = self.unpatchify(mask)
|
|
|
|
if not self.use_bg:
|
|
valid = sample_list['valid']
|
|
mask = mask * valid
|
|
|
|
loss = F.cross_entropy(pred, target, reduction="none")
|
|
|
|
if self.balance:
|
|
if self.use_all_patch:
|
|
loss_pos = loss[target > 0].sum() / ((target > 0).sum() + 1e-6)
|
|
loss_neg = loss[target == 0].sum() / ((target == 0).sum() + 1e-6)
|
|
loss = (loss_pos + loss_neg) * 0.5
|
|
else:
|
|
loss_pos = loss[(target > 0) & (mask == 1)].sum() / (((target > 0) & (mask == 1)).sum() + 1e-6)
|
|
loss_neg = loss[(target == 0) & (mask == 1)].sum() / (((target == 0) & (mask == 1)).sum() + 1e-6)
|
|
loss = (loss_pos + loss_neg) * 0.5
|
|
else:
|
|
if self.use_all_patch:
|
|
loss = loss.mean()
|
|
else:
|
|
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
|
if self.sim_regularization:
|
|
vocabulary_token = output['vocabulary_token']
|
|
voca_normed = F.normalize(vocabulary_token, 2, 1)
|
|
similarity_matrix = 1 + torch.einsum('nd,md->nm', voca_normed, voca_normed)
|
|
num = voca_normed.shape[0]
|
|
index = torch.triu(voca_normed.new_ones(num, num), diagonal=1).type(torch.bool)
|
|
loss_reg = similarity_matrix[index].mean()
|
|
return loss * self.weight + loss_reg * 0.05
|
|
return loss * self.weight
|
|
|