init
This commit is contained in:
161
finetune/mmseg/models/utils/self_attention_block.py
Normal file
161
finetune/mmseg/models/utils/self_attention_block.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model.weight_init import constant_init
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class SelfAttentionBlock(nn.Module):
|
||||
"""General self-attention block/non-local block.
|
||||
|
||||
Please refer to https://arxiv.org/abs/1706.03762 for details about key,
|
||||
query and value.
|
||||
|
||||
Args:
|
||||
key_in_channels (int): Input channels of key feature.
|
||||
query_in_channels (int): Input channels of query feature.
|
||||
channels (int): Output channels of key/query transform.
|
||||
out_channels (int): Output channels.
|
||||
share_key_query (bool): Whether share projection weight between key
|
||||
and query projection.
|
||||
query_downsample (nn.Module): Query downsample module.
|
||||
key_downsample (nn.Module): Key downsample module.
|
||||
key_query_num_convs (int): Number of convs for key/query projection.
|
||||
value_num_convs (int): Number of convs for value projection.
|
||||
matmul_norm (bool): Whether normalize attention map with sqrt of
|
||||
channels
|
||||
with_out (bool): Whether use out projection.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict|None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, key_in_channels, query_in_channels, channels,
|
||||
out_channels, share_key_query, query_downsample,
|
||||
key_downsample, key_query_num_convs, value_out_num_convs,
|
||||
key_query_norm, value_out_norm, matmul_norm, with_out,
|
||||
conv_cfg, norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
if share_key_query:
|
||||
assert key_in_channels == query_in_channels
|
||||
self.key_in_channels = key_in_channels
|
||||
self.query_in_channels = query_in_channels
|
||||
self.out_channels = out_channels
|
||||
self.channels = channels
|
||||
self.share_key_query = share_key_query
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.key_project = self.build_project(
|
||||
key_in_channels,
|
||||
channels,
|
||||
num_convs=key_query_num_convs,
|
||||
use_conv_module=key_query_norm,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
if share_key_query:
|
||||
self.query_project = self.key_project
|
||||
else:
|
||||
self.query_project = self.build_project(
|
||||
query_in_channels,
|
||||
channels,
|
||||
num_convs=key_query_num_convs,
|
||||
use_conv_module=key_query_norm,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.value_project = self.build_project(
|
||||
key_in_channels,
|
||||
channels if with_out else out_channels,
|
||||
num_convs=value_out_num_convs,
|
||||
use_conv_module=value_out_norm,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
if with_out:
|
||||
self.out_project = self.build_project(
|
||||
channels,
|
||||
out_channels,
|
||||
num_convs=value_out_num_convs,
|
||||
use_conv_module=value_out_norm,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
else:
|
||||
self.out_project = None
|
||||
|
||||
self.query_downsample = query_downsample
|
||||
self.key_downsample = key_downsample
|
||||
self.matmul_norm = matmul_norm
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize weight of later layer."""
|
||||
if self.out_project is not None:
|
||||
if not isinstance(self.out_project, ConvModule):
|
||||
constant_init(self.out_project, 0)
|
||||
|
||||
def build_project(self, in_channels, channels, num_convs, use_conv_module,
|
||||
conv_cfg, norm_cfg, act_cfg):
|
||||
"""Build projection layer for key/query/value/out."""
|
||||
if use_conv_module:
|
||||
convs = [
|
||||
ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
]
|
||||
for _ in range(num_convs - 1):
|
||||
convs.append(
|
||||
ConvModule(
|
||||
channels,
|
||||
channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
else:
|
||||
convs = [nn.Conv2d(in_channels, channels, 1)]
|
||||
for _ in range(num_convs - 1):
|
||||
convs.append(nn.Conv2d(channels, channels, 1))
|
||||
if len(convs) > 1:
|
||||
convs = nn.Sequential(*convs)
|
||||
else:
|
||||
convs = convs[0]
|
||||
return convs
|
||||
|
||||
def forward(self, query_feats, key_feats):
|
||||
"""Forward function."""
|
||||
batch_size = query_feats.size(0)
|
||||
query = self.query_project(query_feats)
|
||||
if self.query_downsample is not None:
|
||||
query = self.query_downsample(query)
|
||||
query = query.reshape(*query.shape[:2], -1)
|
||||
query = query.permute(0, 2, 1).contiguous()
|
||||
|
||||
key = self.key_project(key_feats)
|
||||
value = self.value_project(key_feats)
|
||||
if self.key_downsample is not None:
|
||||
key = self.key_downsample(key)
|
||||
value = self.key_downsample(value)
|
||||
key = key.reshape(*key.shape[:2], -1)
|
||||
value = value.reshape(*value.shape[:2], -1)
|
||||
value = value.permute(0, 2, 1).contiguous()
|
||||
|
||||
sim_map = torch.matmul(query, key)
|
||||
if self.matmul_norm:
|
||||
sim_map = (self.channels**-.5) * sim_map
|
||||
sim_map = F.softmax(sim_map, dim=-1)
|
||||
|
||||
context = torch.matmul(sim_map, value)
|
||||
context = context.permute(0, 2, 1).contiguous()
|
||||
context = context.reshape(batch_size, -1, *query_feats.shape[2:])
|
||||
if self.out_project is not None:
|
||||
context = self.out_project(context)
|
||||
return context
|
||||
Reference in New Issue
Block a user