init
This commit is contained in:
13
lib/models/necks/__init__.py
Normal file
13
lib/models/necks/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .transformer_encoder import TransformerEncoder
|
||||
from .modality_completion import ModalityCompletion
|
||||
|
||||
__all__ = ['TransformerEncoder', 'ModalityCompletion']
|
||||
|
||||
type_mapping = {
|
||||
'TransformerEncoder': TransformerEncoder,
|
||||
'ModalityCompletion': ModalityCompletion
|
||||
}
|
||||
|
||||
|
||||
def build_neck(type, **kwargs):
|
||||
return type_mapping[type](**kwargs)
|
||||
212
lib/models/necks/modality_completion.py
Normal file
212
lib/models/necks/modality_completion.py
Normal file
@@ -0,0 +1,212 @@
|
||||
# Copyright (c) AntGroup. All rights reserved.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class BFloat16UpsampleNearest2d(nn.Module):
|
||||
def __init__(self, scale_factor, mode='bilinear'):
|
||||
super().__init__()
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, x):
|
||||
x_float = x.float()
|
||||
upsampled_x = F.interpolate(x_float, scale_factor=self.scale_factor, mode=self.mode)
|
||||
return upsampled_x.to(x.dtype)
|
||||
|
||||
class ConvVQVAEv2(nn.Module):
|
||||
def __init__(self, input_shape, conv_dim, z_dim, num_tokens=8192, temp=0.9):
|
||||
super().__init__()
|
||||
self.z_dim = z_dim
|
||||
self.conv_dim = conv_dim # 256
|
||||
self.input_shape = input_shape # 256
|
||||
self.temp = temp
|
||||
# code book
|
||||
self.codebook = nn.Embedding(num_tokens, z_dim)
|
||||
# encoder
|
||||
self.relu = nn.LeakyReLU()
|
||||
self.pool = nn.AvgPool2d(2)
|
||||
self.conv1 = nn.Conv2d(input_shape[0], conv_dim, 5, stride=1, padding=2)
|
||||
self.enc_block1 = nn.Sequential(
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
self.gamma_1 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
|
||||
self.enc_block2 = nn.Sequential(
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
self.gamma_2 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
|
||||
self.logit_conv = nn.Conv2d(conv_dim, num_tokens, 1)
|
||||
# decoder
|
||||
self.unpool = BFloat16UpsampleNearest2d(scale_factor=2)
|
||||
self.conv2 = nn.Conv2d(z_dim, conv_dim, 3, stride=1, padding=1)
|
||||
self.dec_block1 = nn.Sequential(
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
self.gamma_3 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
|
||||
self.dec_block2 = nn.Sequential(
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
self.gamma_4 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
|
||||
self.rec_conv = nn.Conv2d(conv_dim, input_shape[0], 3, stride=1, padding=1)
|
||||
|
||||
def forward_encoder(self, x):
|
||||
x = self.relu(self.conv1(x))
|
||||
x = x + self.gamma_1 * self.enc_block1(x)
|
||||
x = self.pool(x)
|
||||
x = x + self.gamma_2 * self.enc_block2(x)
|
||||
x = self.pool(x)
|
||||
logits = self.logit_conv(x)
|
||||
return logits
|
||||
|
||||
def forward_decoder(self, logits):
|
||||
soft_one_hot = F.softmax(logits * (self.temp*10), dim=1)
|
||||
sampled = torch.einsum('bnhw,nd->bdhw', soft_one_hot, self.codebook.weight)
|
||||
x = self.relu(self.conv2(sampled))
|
||||
x = self.unpool(x)
|
||||
x = x + self.gamma_3 * self.dec_block1(x)
|
||||
x = self.unpool(x)
|
||||
x = x + self.gamma_4 * self.dec_block2(x)
|
||||
rec_feats = self.rec_conv(x)
|
||||
return rec_feats, soft_one_hot
|
||||
|
||||
def forward(self, x):
|
||||
print(x.shape)
|
||||
logits = self.forward_encoder(x)
|
||||
images_p, soft_one_hot = self.forward_decoder(logits)
|
||||
return [logits, images_p]
|
||||
|
||||
class ModalityCompletion(nn.Module):
|
||||
def __init__(self,
|
||||
input_shape_hr=(2816, 16, 16),
|
||||
input_shape_s2=(2816, 16, 16),
|
||||
input_shape_s1=(2816, 16, 16),
|
||||
conv_dim=256,
|
||||
z_dim=256,
|
||||
n_codebook=8192,
|
||||
init_cfg=None
|
||||
):
|
||||
super(ModalityCompletion, self).__init__()
|
||||
self.vae_hr = ConvVQVAEv2(input_shape=input_shape_hr, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook)
|
||||
self.vae_s2 = ConvVQVAEv2(input_shape=input_shape_s2, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook)
|
||||
self.vae_s1 = ConvVQVAEv2(input_shape=input_shape_s1, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook)
|
||||
self.kl_div_loss = torch.nn.KLDivLoss(reduction="none", log_target=True)
|
||||
self.init_cfg=init_cfg
|
||||
|
||||
def init_weights(self):
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
# Suppress default init if use pretrained model.
|
||||
from mmcls.utils import get_root_logger
|
||||
from mmcv.runner import CheckpointLoader, load_state_dict
|
||||
logger = get_root_logger()
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||
else:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def kl_loss(self, logits_hr, logits_s2, logits_s1, modality_info):
|
||||
prob_hr = F.log_softmax(logits_hr, dim=1)
|
||||
prob_s2 = F.log_softmax(logits_s2, dim=1)
|
||||
prob_s1 = F.log_softmax(logits_s1, dim=1)
|
||||
flag_hr = modality_info[:,0][:, None, None, None]
|
||||
flag_s2 = modality_info[:,1][:, None, None, None]
|
||||
flag_s1 = modality_info[:,2][:, None, None, None]
|
||||
loss_hr_s2 = self.kl_div_loss(prob_hr, prob_s2) + self.kl_div_loss(prob_s2, prob_hr)
|
||||
loss_hr_s2 = (loss_hr_s2 * flag_hr * flag_s2).sum((1, 2, 3)).mean()
|
||||
loss_hr_s1 = self.kl_div_loss(prob_hr, prob_s1) + self.kl_div_loss(prob_s1, prob_hr)
|
||||
loss_hr_s1 = (loss_hr_s1 * flag_hr * flag_s1).sum((1, 2, 3)).mean()
|
||||
loss_s2_s1 = self.kl_div_loss(prob_s2, prob_s1) + self.kl_div_loss(prob_s1, prob_s2)
|
||||
loss_s2_s1 = (loss_s2_s1 * flag_s2 * flag_s1).sum((1, 2, 3)).mean()
|
||||
loss = (loss_hr_s2 + loss_hr_s1 + loss_s2_s1) / 6.0
|
||||
|
||||
return loss
|
||||
|
||||
def forward(self, feat_hr, feat_s2, feat_s1, modality_info):
|
||||
# encoders,add noise
|
||||
# each modality
|
||||
# 2816, 16, 16 => conv 256, 4, 4 => flatten 4096(256*4*4) => linear mu 256, log_var 256
|
||||
B, C, H, W = feat_hr.shape
|
||||
B_M, L_M = modality_info.shape
|
||||
assert B == B_M, f'feat_hr batch: {B}, modality_info batch: {B_M}'
|
||||
|
||||
# quant, emb_loss, info
|
||||
# hr input flow
|
||||
logits_hr = self.vae_hr.forward_encoder(feat_hr)
|
||||
logits_s2 = self.vae_s2.forward_encoder(feat_s2)
|
||||
logits_s1 = self.vae_s1.forward_encoder(feat_s1)
|
||||
modality_hr = modality_info[:,0]
|
||||
modality_s2 = modality_info[:,1]
|
||||
modality_s1 = modality_info[:,2]
|
||||
flag_hr = modality_hr[:, None, None, None] # B => B, C, H, W
|
||||
flag_s2 = modality_s2[:, None, None, None]
|
||||
flag_s1 = modality_s1[:, None, None, None]
|
||||
|
||||
mean_logits_hr_s2 = logits_hr * flag_hr + logits_s2 * flag_s2
|
||||
mean_logits_hr_s1 = logits_hr * flag_hr + logits_s1 * flag_s1
|
||||
mean_logits_s1_s2 = logits_s1 * flag_s1 + logits_s2 * flag_s2
|
||||
|
||||
logits_hr_rec = logits_hr * flag_hr + mean_logits_s1_s2 * (~flag_hr)
|
||||
logits_s2_rec = logits_s2 * flag_s2 + mean_logits_hr_s1 * (~flag_s2)
|
||||
logits_s1_rec = logits_s1 * flag_s1 + mean_logits_hr_s2 * (~flag_s1)
|
||||
g_hr, soft_one_hot_hr = self.vae_hr.forward_decoder(logits_hr_rec)
|
||||
g_s2, soft_one_s2 = self.vae_s2.forward_decoder(logits_s2_rec)
|
||||
g_s1, soft_one_s1 = self.vae_s1.forward_decoder(logits_s1_rec)
|
||||
|
||||
hr_out = feat_hr * flag_hr + g_hr * (~flag_hr)
|
||||
s2_out = feat_s2 * flag_s2 + g_s2 * (~flag_s2)
|
||||
s1_out = feat_s1 * flag_s1 + g_s1 * (~flag_s1)
|
||||
|
||||
output = {}
|
||||
|
||||
output['hr_out'] = hr_out
|
||||
output['s2_out'] = s2_out
|
||||
output['s1_out'] = s1_out
|
||||
|
||||
output['modality_info'] = modality_info
|
||||
|
||||
output['input_hr'] = feat_hr
|
||||
output['input_s2'] = feat_s2
|
||||
output['input_s1'] = feat_s1
|
||||
|
||||
output['logits_hr'] = logits_hr
|
||||
output['logits_s2'] = logits_s2
|
||||
output['logits_s1'] = logits_s1
|
||||
|
||||
output['soft_one_hot_hr'] = soft_one_hot_hr
|
||||
output['soft_one_hot_s2'] = soft_one_s2
|
||||
output['soft_one_hot_s1'] = soft_one_s1
|
||||
|
||||
output['g_hr'] = g_hr
|
||||
output['g_s2'] = g_s2
|
||||
output['g_s1'] = g_s1
|
||||
output['loss_quant'] = self.kl_loss(logits_hr, logits_s2, logits_s1, modality_info)
|
||||
|
||||
return output
|
||||
|
||||
144
lib/models/necks/transformer_encoder.py
Normal file
144
lib/models/necks/transformer_encoder.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright (c) Ant Group. All rights reserved.
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmcv.runner import (CheckpointLoader, load_state_dict)
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from mmseg.models.backbones.vit import TransformerEncoderLayer
|
||||
|
||||
from mmseg.utils import get_root_logger
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_dims=768,
|
||||
embed_dims=768,
|
||||
num_layers=4,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
with_cls_token=True,
|
||||
output_cls_token=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
init_cfg=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super(TransformerEncoder, self).__init__()
|
||||
|
||||
self.porj_linear = nn.Linear(input_dims, embed_dims)
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
|
||||
self.init_cfg = init_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for i in range(num_layers):
|
||||
self.layers.append(
|
||||
TransformerEncoderLayer(embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=mlp_ratio *
|
||||
embed_dims,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
batch_first=True))
|
||||
|
||||
def init_weights(self):
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
logger = get_root_logger()
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
_state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
_state_dict = checkpoint
|
||||
|
||||
state_dict = OrderedDict()
|
||||
for k, v in _state_dict.items():
|
||||
if k.startswith('backbone.'):
|
||||
state_dict[k[9:]] = v
|
||||
else:
|
||||
state_dict[k] = v
|
||||
|
||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def forward(self, inputs, require_feat: bool = False, require_two: bool = False):
|
||||
inputs = self.porj_linear(inputs)
|
||||
B, N, C = inputs.shape
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, inputs), dim=1)
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
# add hidden and atten state
|
||||
block_outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if require_feat:
|
||||
block_outs.append(x)
|
||||
|
||||
if self.output_cls_token:
|
||||
if require_two:
|
||||
x = x[:, :2]
|
||||
else:
|
||||
x = x[:, 0]
|
||||
elif not self.output_cls_token and self.with_cls_token:
|
||||
x = x # [:, :]
|
||||
|
||||
if require_feat:
|
||||
return x, block_outs
|
||||
else:
|
||||
return x
|
||||
|
||||
def train(self, mode=True):
|
||||
super(TransformerEncoder, self).train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.LayerNorm):
|
||||
m.eval()
|
||||
Reference in New Issue
Block a user