init
This commit is contained in:
89
lib/models/losses/recon_anno_loss.py
Normal file
89
lib/models/losses/recon_anno_loss.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# Copyright (c) Ant Group and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from antmmf.common.registry import registry
|
||||
import torch.nn.functional as F
|
||||
|
||||
@registry.register_loss("RecLoss")
|
||||
class RecLoss(nn.Module):
|
||||
def __init__(self, **params):
|
||||
super().__init__()
|
||||
self.weight = params.pop("weight")
|
||||
self.patch_size = params.pop("patch_size")
|
||||
self.eps = torch.finfo(torch.bfloat16).eps
|
||||
self.pred_key = params.pop("pred_key")
|
||||
self.vocabulary_size = params.pop("vocabulary_size") + 1
|
||||
self.mask_key = params.pop("mask_key")
|
||||
self.target_key = params.pop("target_key")
|
||||
self.feature_merged = params.pop("feature_merged")
|
||||
self.cnt_train = 0
|
||||
self.cnt_val = 0
|
||||
self.use_bg = params.pop("use_bg")
|
||||
if "use_all_patch" in params:
|
||||
self.use_all_patch = params.pop("use_all_patch")
|
||||
else:
|
||||
self.use_all_patch = False
|
||||
if "balance" in params:
|
||||
self.balance = params.pop("balance")
|
||||
else:
|
||||
self.balance = False
|
||||
if "sim_regularization" in params:
|
||||
self.sim_regularization = params.pop("sim_regularization")
|
||||
else:
|
||||
self.sim_regularization = False
|
||||
|
||||
def unpatchify(self, x):
|
||||
"""
|
||||
x: (N, L, patch_size**2 *3)
|
||||
imgs: (N, 3, H, W)
|
||||
"""
|
||||
p = self.patch_size
|
||||
w = int((x.shape[1]*0.5)**.5)
|
||||
h = w * 2
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p))
|
||||
x = torch.einsum('nhwpq->nhpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], h * p, w * p))
|
||||
return imgs
|
||||
|
||||
|
||||
def forward(self, sample_list, output, *args, **kwargs):
|
||||
pred = output[self.pred_key] # B, C, H, W
|
||||
target = output[self.target_key] # B, H, W
|
||||
mask = output[self.mask_key]
|
||||
b_mask, h_mask, w_mask = mask.shape
|
||||
mask = mask.reshape((b_mask, h_mask*w_mask))
|
||||
mask = mask[:, :, None].repeat(1, 1, self.patch_size**2)
|
||||
mask = self.unpatchify(mask)
|
||||
|
||||
if not self.use_bg:
|
||||
valid = sample_list['valid']
|
||||
mask = mask * valid
|
||||
|
||||
loss = F.cross_entropy(pred, target, reduction="none")
|
||||
|
||||
if self.balance:
|
||||
if self.use_all_patch:
|
||||
loss_pos = loss[target > 0].sum() / ((target > 0).sum() + 1e-6)
|
||||
loss_neg = loss[target == 0].sum() / ((target == 0).sum() + 1e-6)
|
||||
loss = (loss_pos + loss_neg) * 0.5
|
||||
else:
|
||||
loss_pos = loss[(target > 0) & (mask == 1)].sum() / (((target > 0) & (mask == 1)).sum() + 1e-6)
|
||||
loss_neg = loss[(target == 0) & (mask == 1)].sum() / (((target == 0) & (mask == 1)).sum() + 1e-6)
|
||||
loss = (loss_pos + loss_neg) * 0.5
|
||||
else:
|
||||
if self.use_all_patch:
|
||||
loss = loss.mean()
|
||||
else:
|
||||
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
||||
if self.sim_regularization:
|
||||
vocabulary_token = output['vocabulary_token']
|
||||
voca_normed = F.normalize(vocabulary_token, 2, 1)
|
||||
similarity_matrix = 1 + torch.einsum('nd,md->nm', voca_normed, voca_normed)
|
||||
num = voca_normed.shape[0]
|
||||
index = torch.triu(voca_normed.new_ones(num, num), diagonal=1).type(torch.bool)
|
||||
loss_reg = similarity_matrix[index].mean()
|
||||
return loss * self.weight + loss_reg * 0.05
|
||||
return loss * self.weight
|
||||
|
||||
Reference in New Issue
Block a user