47 lines
1.5 KiB
Python
47 lines
1.5 KiB
Python
# Copyright (c) Ant Group and its affiliates.
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from antmmf.common.registry import registry
|
|
|
|
|
|
@registry.register_loss("ModalityVAELoss")
|
|
class ModalityVAELoss(nn.Module):
|
|
def __init__(self, **params):
|
|
super().__init__()
|
|
self.weight = params.pop("weight")
|
|
|
|
def compute_rec_loss(self, x_in, x_out, modal_flag):
|
|
loss_per_pixel = F.mse_loss(x_in, x_out, reduction='none')
|
|
loss_b = torch.mean(loss_per_pixel, dim=[1, 2, 3])
|
|
return torch.sum(loss_b * modal_flag)/ (modal_flag.sum() + 1e-6)
|
|
|
|
def forward(self, sample_list, output, *args, **kwargs):
|
|
vae_out = output["vae_out"]
|
|
feat_hr = vae_out['input_hr']
|
|
feat_s2 = vae_out['input_s2']
|
|
feat_s1 = vae_out['input_s1']
|
|
|
|
g_hr = vae_out['g_hr']
|
|
g_s2 = vae_out['g_s2']
|
|
g_s1 = vae_out['g_s1']
|
|
|
|
# process modality flags
|
|
modality_info = vae_out['modality_info']
|
|
B_M, L_M = modality_info.shape
|
|
|
|
modality_hr = modality_info[:,0]
|
|
modality_s2 = modality_info[:,1]
|
|
modality_s1 = modality_info[:,2]
|
|
|
|
######## rec losses ########
|
|
loss_xent = self.compute_rec_loss(g_hr, feat_hr, modality_hr) \
|
|
+ self.compute_rec_loss(g_s2, feat_s2, modality_s2) \
|
|
+ self.compute_rec_loss(g_s1, feat_s1, modality_s1)
|
|
|
|
|
|
loss_quant = vae_out["loss_quant"]
|
|
total_loss = loss_xent / 3 + loss_quant
|
|
return total_loss * self.weight
|