230 lines
9.1 KiB
Python
230 lines
9.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import build_norm_layer
|
|
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
|
from mmengine.model import BaseModule, ModuleList
|
|
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
|
|
from torch.nn import functional as F
|
|
|
|
from mmseg.registry import MODELS
|
|
from mmseg.utils import get_classes, get_predefined_templates, tokenizer
|
|
|
|
|
|
@MODELS.register_module()
|
|
class CLIPTextEncoder(BaseModule):
|
|
"""A text encoder with transformer architecture to encode the label text.
|
|
|
|
Modified from https://github.com/MendelXu/SAN/blob/main/san/model/clip_utils/classifier.py # noqa:E501
|
|
Copyright (c) 2023 MendelXu.
|
|
Licensed under the MIT License
|
|
|
|
Args:
|
|
dataset_name: (str|None): The name of the dataset to which
|
|
the data belongs.
|
|
vocabulary: (List[str]|None): The list of class names. Default: None.
|
|
templates: (List[str]|None): The prompt template used for labels.
|
|
Default: None.
|
|
total_vocab_size: (int): Number of all words used by the pre-trained
|
|
model. Default: 49408 (CLIP).
|
|
context_length: (int): The max length of prompt text.
|
|
Default: 77 (CLIP).
|
|
embed_dims: (int): Width of transformer model. Default: 512.
|
|
num_layers: (int): Depth of transformer. Default: 12,
|
|
num_heads: (int): Number of attention heads in transformer.
|
|
Default: 8,
|
|
mlp_ratio: (int) Ratio of mlp hidden dim to embedding dim in
|
|
transformer. Default: 4,
|
|
output_dims: (int) Dim of output text embeddings. Default: 512,
|
|
cache_feature: (bool) Whether to save class embeddings in cache.
|
|
Default: True,
|
|
cat_bg: (bool) Whether to add background embedding. Default: True.
|
|
norm_cfg (dict|None): Config for norm layer. Default: dict(type='LN')
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
dataset_name: str = None,
|
|
vocabulary: List[str] = None,
|
|
templates: str = 'vild',
|
|
total_vocab_size: int = 49408,
|
|
context_length: int = 77,
|
|
embed_dims: int = 512,
|
|
num_layers: int = 12,
|
|
num_heads: int = 8,
|
|
mlp_ratio: int = 4,
|
|
output_dims: int = 512,
|
|
cache_feature: bool = True,
|
|
cat_bg: bool = True,
|
|
norm_cfg: dict = dict(type='LN'),
|
|
init_cfg: dict = None):
|
|
super().__init__(init_cfg)
|
|
if isinstance(templates, List):
|
|
self.templates = templates
|
|
else:
|
|
self.templates = get_predefined_templates(templates)
|
|
|
|
assert dataset_name is not None or vocabulary is not None, \
|
|
"text_encoder required either 'dataset_name' or 'vocabulary'"
|
|
assert dataset_name is None or vocabulary is None, \
|
|
"there is conflict between 'dataset_name' and 'vocabulary'"
|
|
self.dataset_name = dataset_name
|
|
self.vocabulary = vocabulary
|
|
self.num_pos = context_length
|
|
self.token_embedding = nn.Embedding(total_vocab_size, embed_dims)
|
|
self.positional_embedding = nn.Parameter(
|
|
torch.empty(context_length, embed_dims))
|
|
self.text_projection = nn.Parameter(
|
|
torch.empty(embed_dims, output_dims))
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
|
self.transformer = ModuleList()
|
|
self.register_buffer(
|
|
'attn_mask', self.build_attention_mask(), persistent=False)
|
|
for i in range(num_layers):
|
|
self.transformer.append(
|
|
BaseTransformerLayer(
|
|
attn_cfgs=dict(
|
|
type='MultiheadAttention',
|
|
embed_dims=embed_dims,
|
|
num_heads=num_heads,
|
|
batch_first=False,
|
|
bias=True),
|
|
ffn_cfgs=dict(
|
|
type='FFN',
|
|
embed_dims=embed_dims,
|
|
feedforward_channels=mlp_ratio * embed_dims,
|
|
act_cfg=dict(type='QuickGELU')),
|
|
operation_order=('norm', 'self_attn', 'norm', 'ffn')))
|
|
self.ln_final = build_norm_layer(
|
|
norm_cfg, embed_dims, postfix='_final')[1]
|
|
|
|
self.cache_feature = cache_feature
|
|
if self.cache_feature:
|
|
self.cache = {}
|
|
|
|
self._freeze()
|
|
|
|
self.cat_bg = cat_bg
|
|
if self.cat_bg:
|
|
self.bg_embed = nn.Parameter(
|
|
torch.randn(1, self.text_projection.shape[1]))
|
|
|
|
@property
|
|
def ln_final(self):
|
|
return getattr(self, self.final_name)
|
|
|
|
def build_attention_mask(self):
|
|
"""lazily create causal attention mask, with full attention between the
|
|
tokens.
|
|
|
|
pytorch uses additive attention mask; fill with -inf
|
|
"""
|
|
mask = torch.empty(self.num_pos, self.num_pos)
|
|
mask.fill_(float('-inf'))
|
|
mask.triu_(1) # zero out the lower diagonal
|
|
return mask
|
|
|
|
def _freeze(self):
|
|
for param in self.parameters():
|
|
param.requires_grad = False
|
|
|
|
def init_weights(self):
|
|
if self.cat_bg:
|
|
nn.init.normal_(
|
|
self.bg_embed,
|
|
std=self.bg_embed.shape[1]**-0.5,
|
|
)
|
|
if isinstance(self.init_cfg, dict) and \
|
|
self.init_cfg.get('type') == 'Pretrained_Part':
|
|
checkpoint = CheckpointLoader.load_checkpoint(
|
|
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
|
|
|
state_dict = checkpoint.copy()
|
|
para_prefix = 'text_encoder'
|
|
prefix_len = len(para_prefix) + 1
|
|
for k, v in checkpoint.items():
|
|
state_dict.pop(k)
|
|
if para_prefix in k:
|
|
state_dict[k[prefix_len:]] = v
|
|
|
|
load_state_dict(self, state_dict, strict=False, logger=None)
|
|
|
|
else:
|
|
super().init_weights()
|
|
|
|
@torch.no_grad()
|
|
def encode_text(self, text, normalize=False):
|
|
"""encode class token."""
|
|
|
|
embed_device = self.token_embedding.weight.device
|
|
x = self.token_embedding(
|
|
text.to(embed_device)) # [batch_size, n_ctx, d_model]
|
|
x = x + self.positional_embedding
|
|
x = x.permute(1, 0, 2) # NLD -> LND
|
|
for block in self.transformer:
|
|
x = block(query=x, attn_masks=self.attn_mask)
|
|
x = x.permute(1, 0, 2) # LND -> NLD
|
|
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
|
# take features from the eot embedding
|
|
# (eot_token is the highest number in each sequence)
|
|
x = x[torch.arange(x.shape[0]),
|
|
text.argmax(dim=-1)] @ self.text_projection
|
|
return F.normalize(x, dim=-1) if normalize else x
|
|
|
|
def template_encode(self, vocabulary):
|
|
"""Prompt engineering."""
|
|
text_embed_bucket = []
|
|
for template in self.templates:
|
|
text_inputs = tokenizer.tokenize(
|
|
[template.format(noun) for noun in vocabulary])
|
|
text_embed = self.encode_text(text_inputs, normalize=True)
|
|
text_embed_bucket.append(text_embed)
|
|
text_embed = torch.stack(text_embed_bucket).mean(dim=0)
|
|
text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
|
|
return text_embed
|
|
|
|
def forward(self):
|
|
"""Forward function."""
|
|
if self.dataset_name is None: # encoding vocabulary directly
|
|
class_names = self.vocabulary
|
|
if self.cache_feature:
|
|
new_classes = [
|
|
word for word in class_names if word not in self.cache
|
|
]
|
|
if len(new_classes) > 0:
|
|
class_embeds = self.template_encode(new_classes)
|
|
self.cache.update(dict(zip(new_classes, class_embeds)))
|
|
class_embeds = torch.stack(
|
|
[self.cache[word] for word in class_names])
|
|
else:
|
|
class_embeds = self.template_encode(class_names)
|
|
|
|
else: # encoding the classes of the dataset
|
|
class_names = get_classes(self.dataset_name)
|
|
if class_names[0] == 'background':
|
|
class_names = class_names[1:]
|
|
if self.cache_feature:
|
|
if self.dataset_name not in self.cache:
|
|
class_embeds = self.template_encode(class_names)
|
|
self.cache[self.dataset_name] = class_embeds
|
|
else:
|
|
class_embeds = self.cache[self.dataset_name]
|
|
else:
|
|
class_embeds = self.template_encode(class_names)
|
|
|
|
if self.cat_bg:
|
|
class_embeds = torch.cat([class_embeds, self.bg_embed])
|
|
class_embeds = F.normalize(class_embeds, p=2, dim=-1)
|
|
return self.logit_scale.exp() * class_embeds
|
|
|
|
|
|
@MODELS.register_module()
|
|
class QuickGELU(nn.Module):
|
|
# From https://github.com/openai/CLIP/blob/main/clip/model.py
|
|
def forward(self, x: torch.Tensor):
|
|
return x * torch.sigmoid(1.702 * x)
|