init
This commit is contained in:
229
finetune/mmseg/models/text_encoder/clip_text_encoder.py
Normal file
229
finetune/mmseg/models/text_encoder/clip_text_encoder.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import get_classes, get_predefined_templates, tokenizer
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CLIPTextEncoder(BaseModule):
|
||||
"""A text encoder with transformer architecture to encode the label text.
|
||||
|
||||
Modified from https://github.com/MendelXu/SAN/blob/main/san/model/clip_utils/classifier.py # noqa:E501
|
||||
Copyright (c) 2023 MendelXu.
|
||||
Licensed under the MIT License
|
||||
|
||||
Args:
|
||||
dataset_name: (str|None): The name of the dataset to which
|
||||
the data belongs.
|
||||
vocabulary: (List[str]|None): The list of class names. Default: None.
|
||||
templates: (List[str]|None): The prompt template used for labels.
|
||||
Default: None.
|
||||
total_vocab_size: (int): Number of all words used by the pre-trained
|
||||
model. Default: 49408 (CLIP).
|
||||
context_length: (int): The max length of prompt text.
|
||||
Default: 77 (CLIP).
|
||||
embed_dims: (int): Width of transformer model. Default: 512.
|
||||
num_layers: (int): Depth of transformer. Default: 12,
|
||||
num_heads: (int): Number of attention heads in transformer.
|
||||
Default: 8,
|
||||
mlp_ratio: (int) Ratio of mlp hidden dim to embedding dim in
|
||||
transformer. Default: 4,
|
||||
output_dims: (int) Dim of output text embeddings. Default: 512,
|
||||
cache_feature: (bool) Whether to save class embeddings in cache.
|
||||
Default: True,
|
||||
cat_bg: (bool) Whether to add background embedding. Default: True.
|
||||
norm_cfg (dict|None): Config for norm layer. Default: dict(type='LN')
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_name: str = None,
|
||||
vocabulary: List[str] = None,
|
||||
templates: str = 'vild',
|
||||
total_vocab_size: int = 49408,
|
||||
context_length: int = 77,
|
||||
embed_dims: int = 512,
|
||||
num_layers: int = 12,
|
||||
num_heads: int = 8,
|
||||
mlp_ratio: int = 4,
|
||||
output_dims: int = 512,
|
||||
cache_feature: bool = True,
|
||||
cat_bg: bool = True,
|
||||
norm_cfg: dict = dict(type='LN'),
|
||||
init_cfg: dict = None):
|
||||
super().__init__(init_cfg)
|
||||
if isinstance(templates, List):
|
||||
self.templates = templates
|
||||
else:
|
||||
self.templates = get_predefined_templates(templates)
|
||||
|
||||
assert dataset_name is not None or vocabulary is not None, \
|
||||
"text_encoder required either 'dataset_name' or 'vocabulary'"
|
||||
assert dataset_name is None or vocabulary is None, \
|
||||
"there is conflict between 'dataset_name' and 'vocabulary'"
|
||||
self.dataset_name = dataset_name
|
||||
self.vocabulary = vocabulary
|
||||
self.num_pos = context_length
|
||||
self.token_embedding = nn.Embedding(total_vocab_size, embed_dims)
|
||||
self.positional_embedding = nn.Parameter(
|
||||
torch.empty(context_length, embed_dims))
|
||||
self.text_projection = nn.Parameter(
|
||||
torch.empty(embed_dims, output_dims))
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
self.transformer = ModuleList()
|
||||
self.register_buffer(
|
||||
'attn_mask', self.build_attention_mask(), persistent=False)
|
||||
for i in range(num_layers):
|
||||
self.transformer.append(
|
||||
BaseTransformerLayer(
|
||||
attn_cfgs=dict(
|
||||
type='MultiheadAttention',
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
batch_first=False,
|
||||
bias=True),
|
||||
ffn_cfgs=dict(
|
||||
type='FFN',
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=mlp_ratio * embed_dims,
|
||||
act_cfg=dict(type='QuickGELU')),
|
||||
operation_order=('norm', 'self_attn', 'norm', 'ffn')))
|
||||
self.ln_final = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix='_final')[1]
|
||||
|
||||
self.cache_feature = cache_feature
|
||||
if self.cache_feature:
|
||||
self.cache = {}
|
||||
|
||||
self._freeze()
|
||||
|
||||
self.cat_bg = cat_bg
|
||||
if self.cat_bg:
|
||||
self.bg_embed = nn.Parameter(
|
||||
torch.randn(1, self.text_projection.shape[1]))
|
||||
|
||||
@property
|
||||
def ln_final(self):
|
||||
return getattr(self, self.final_name)
|
||||
|
||||
def build_attention_mask(self):
|
||||
"""lazily create causal attention mask, with full attention between the
|
||||
tokens.
|
||||
|
||||
pytorch uses additive attention mask; fill with -inf
|
||||
"""
|
||||
mask = torch.empty(self.num_pos, self.num_pos)
|
||||
mask.fill_(float('-inf'))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
return mask
|
||||
|
||||
def _freeze(self):
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def init_weights(self):
|
||||
if self.cat_bg:
|
||||
nn.init.normal_(
|
||||
self.bg_embed,
|
||||
std=self.bg_embed.shape[1]**-0.5,
|
||||
)
|
||||
if isinstance(self.init_cfg, dict) and \
|
||||
self.init_cfg.get('type') == 'Pretrained_Part':
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
|
||||
state_dict = checkpoint.copy()
|
||||
para_prefix = 'text_encoder'
|
||||
prefix_len = len(para_prefix) + 1
|
||||
for k, v in checkpoint.items():
|
||||
state_dict.pop(k)
|
||||
if para_prefix in k:
|
||||
state_dict[k[prefix_len:]] = v
|
||||
|
||||
load_state_dict(self, state_dict, strict=False, logger=None)
|
||||
|
||||
else:
|
||||
super().init_weights()
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_text(self, text, normalize=False):
|
||||
"""encode class token."""
|
||||
|
||||
embed_device = self.token_embedding.weight.device
|
||||
x = self.token_embedding(
|
||||
text.to(embed_device)) # [batch_size, n_ctx, d_model]
|
||||
x = x + self.positional_embedding
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
for block in self.transformer:
|
||||
x = block(query=x, attn_masks=self.attn_mask)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding
|
||||
# (eot_token is the highest number in each sequence)
|
||||
x = x[torch.arange(x.shape[0]),
|
||||
text.argmax(dim=-1)] @ self.text_projection
|
||||
return F.normalize(x, dim=-1) if normalize else x
|
||||
|
||||
def template_encode(self, vocabulary):
|
||||
"""Prompt engineering."""
|
||||
text_embed_bucket = []
|
||||
for template in self.templates:
|
||||
text_inputs = tokenizer.tokenize(
|
||||
[template.format(noun) for noun in vocabulary])
|
||||
text_embed = self.encode_text(text_inputs, normalize=True)
|
||||
text_embed_bucket.append(text_embed)
|
||||
text_embed = torch.stack(text_embed_bucket).mean(dim=0)
|
||||
text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
|
||||
return text_embed
|
||||
|
||||
def forward(self):
|
||||
"""Forward function."""
|
||||
if self.dataset_name is None: # encoding vocabulary directly
|
||||
class_names = self.vocabulary
|
||||
if self.cache_feature:
|
||||
new_classes = [
|
||||
word for word in class_names if word not in self.cache
|
||||
]
|
||||
if len(new_classes) > 0:
|
||||
class_embeds = self.template_encode(new_classes)
|
||||
self.cache.update(dict(zip(new_classes, class_embeds)))
|
||||
class_embeds = torch.stack(
|
||||
[self.cache[word] for word in class_names])
|
||||
else:
|
||||
class_embeds = self.template_encode(class_names)
|
||||
|
||||
else: # encoding the classes of the dataset
|
||||
class_names = get_classes(self.dataset_name)
|
||||
if class_names[0] == 'background':
|
||||
class_names = class_names[1:]
|
||||
if self.cache_feature:
|
||||
if self.dataset_name not in self.cache:
|
||||
class_embeds = self.template_encode(class_names)
|
||||
self.cache[self.dataset_name] = class_embeds
|
||||
else:
|
||||
class_embeds = self.cache[self.dataset_name]
|
||||
else:
|
||||
class_embeds = self.template_encode(class_names)
|
||||
|
||||
if self.cat_bg:
|
||||
class_embeds = torch.cat([class_embeds, self.bg_embed])
|
||||
class_embeds = F.normalize(class_embeds, p=2, dim=-1)
|
||||
return self.logit_scale.exp() * class_embeds
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class QuickGELU(nn.Module):
|
||||
# From https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
Reference in New Issue
Block a user