This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
from .modality_vae_loss import ModalityVAELoss
from .recon_anno_loss import RecLoss
__all__ = [ "ModalityVAELoss", "RecLoss" ]

View File

@@ -0,0 +1,46 @@
# Copyright (c) Ant Group and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F
from antmmf.common.registry import registry
@registry.register_loss("ModalityVAELoss")
class ModalityVAELoss(nn.Module):
def __init__(self, **params):
super().__init__()
self.weight = params.pop("weight")
def compute_rec_loss(self, x_in, x_out, modal_flag):
loss_per_pixel = F.mse_loss(x_in, x_out, reduction='none')
loss_b = torch.mean(loss_per_pixel, dim=[1, 2, 3])
return torch.sum(loss_b * modal_flag)/ (modal_flag.sum() + 1e-6)
def forward(self, sample_list, output, *args, **kwargs):
vae_out = output["vae_out"]
feat_hr = vae_out['input_hr']
feat_s2 = vae_out['input_s2']
feat_s1 = vae_out['input_s1']
g_hr = vae_out['g_hr']
g_s2 = vae_out['g_s2']
g_s1 = vae_out['g_s1']
# process modality flags
modality_info = vae_out['modality_info']
B_M, L_M = modality_info.shape
modality_hr = modality_info[:,0]
modality_s2 = modality_info[:,1]
modality_s1 = modality_info[:,2]
######## rec losses ########
loss_xent = self.compute_rec_loss(g_hr, feat_hr, modality_hr) \
+ self.compute_rec_loss(g_s2, feat_s2, modality_s2) \
+ self.compute_rec_loss(g_s1, feat_s1, modality_s1)
loss_quant = vae_out["loss_quant"]
total_loss = loss_xent / 3 + loss_quant
return total_loss * self.weight

View File

@@ -0,0 +1,89 @@
# Copyright (c) Ant Group and its affiliates.
import torch
import torch.nn as nn
from antmmf.common.registry import registry
import torch.nn.functional as F
@registry.register_loss("RecLoss")
class RecLoss(nn.Module):
def __init__(self, **params):
super().__init__()
self.weight = params.pop("weight")
self.patch_size = params.pop("patch_size")
self.eps = torch.finfo(torch.bfloat16).eps
self.pred_key = params.pop("pred_key")
self.vocabulary_size = params.pop("vocabulary_size") + 1
self.mask_key = params.pop("mask_key")
self.target_key = params.pop("target_key")
self.feature_merged = params.pop("feature_merged")
self.cnt_train = 0
self.cnt_val = 0
self.use_bg = params.pop("use_bg")
if "use_all_patch" in params:
self.use_all_patch = params.pop("use_all_patch")
else:
self.use_all_patch = False
if "balance" in params:
self.balance = params.pop("balance")
else:
self.balance = False
if "sim_regularization" in params:
self.sim_regularization = params.pop("sim_regularization")
else:
self.sim_regularization = False
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.patch_size
w = int((x.shape[1]*0.5)**.5)
h = w * 2
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p))
x = torch.einsum('nhwpq->nhpwq', x)
imgs = x.reshape(shape=(x.shape[0], h * p, w * p))
return imgs
def forward(self, sample_list, output, *args, **kwargs):
pred = output[self.pred_key] # B, C, H, W
target = output[self.target_key] # B, H, W
mask = output[self.mask_key]
b_mask, h_mask, w_mask = mask.shape
mask = mask.reshape((b_mask, h_mask*w_mask))
mask = mask[:, :, None].repeat(1, 1, self.patch_size**2)
mask = self.unpatchify(mask)
if not self.use_bg:
valid = sample_list['valid']
mask = mask * valid
loss = F.cross_entropy(pred, target, reduction="none")
if self.balance:
if self.use_all_patch:
loss_pos = loss[target > 0].sum() / ((target > 0).sum() + 1e-6)
loss_neg = loss[target == 0].sum() / ((target == 0).sum() + 1e-6)
loss = (loss_pos + loss_neg) * 0.5
else:
loss_pos = loss[(target > 0) & (mask == 1)].sum() / (((target > 0) & (mask == 1)).sum() + 1e-6)
loss_neg = loss[(target == 0) & (mask == 1)].sum() / (((target == 0) & (mask == 1)).sum() + 1e-6)
loss = (loss_pos + loss_neg) * 0.5
else:
if self.use_all_patch:
loss = loss.mean()
else:
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
if self.sim_regularization:
vocabulary_token = output['vocabulary_token']
voca_normed = F.normalize(vocabulary_token, 2, 1)
similarity_matrix = 1 + torch.einsum('nd,md->nm', voca_normed, voca_normed)
num = voca_normed.shape[0]
index = torch.triu(voca_normed.new_ones(num, num), diagonal=1).type(torch.bool)
loss_reg = similarity_matrix[index].mean()
return loss * self.weight + loss_reg * 0.05
return loss * self.weight