This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

View File

@@ -0,0 +1,46 @@
# Copyright (c) Ant Group and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F
from antmmf.common.registry import registry
@registry.register_loss("ModalityVAELoss")
class ModalityVAELoss(nn.Module):
def __init__(self, **params):
super().__init__()
self.weight = params.pop("weight")
def compute_rec_loss(self, x_in, x_out, modal_flag):
loss_per_pixel = F.mse_loss(x_in, x_out, reduction='none')
loss_b = torch.mean(loss_per_pixel, dim=[1, 2, 3])
return torch.sum(loss_b * modal_flag)/ (modal_flag.sum() + 1e-6)
def forward(self, sample_list, output, *args, **kwargs):
vae_out = output["vae_out"]
feat_hr = vae_out['input_hr']
feat_s2 = vae_out['input_s2']
feat_s1 = vae_out['input_s1']
g_hr = vae_out['g_hr']
g_s2 = vae_out['g_s2']
g_s1 = vae_out['g_s1']
# process modality flags
modality_info = vae_out['modality_info']
B_M, L_M = modality_info.shape
modality_hr = modality_info[:,0]
modality_s2 = modality_info[:,1]
modality_s1 = modality_info[:,2]
######## rec losses ########
loss_xent = self.compute_rec_loss(g_hr, feat_hr, modality_hr) \
+ self.compute_rec_loss(g_s2, feat_s2, modality_s2) \
+ self.compute_rec_loss(g_s1, feat_s1, modality_s1)
loss_quant = vae_out["loss_quant"]
total_loss = loss_xent / 3 + loss_quant
return total_loss * self.weight