init
This commit is contained in:
255
finetune/mmseg/models/decode_heads/ham_head.py
Normal file
255
finetune/mmseg/models/decode_heads/ham_head.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Originally from https://github.com/visual-attention-network/segnext
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.device import get_device
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class Matrix_Decomposition_2D_Base(nn.Module):
|
||||
"""Base class of 2D Matrix Decomposition.
|
||||
|
||||
Args:
|
||||
MD_S (int): The number of spatial coefficient in
|
||||
Matrix Decomposition, it may be used for calculation
|
||||
of the number of latent dimension D in Matrix
|
||||
Decomposition. Defaults: 1.
|
||||
MD_R (int): The number of latent dimension R in
|
||||
Matrix Decomposition. Defaults: 64.
|
||||
train_steps (int): The number of iteration steps in
|
||||
Multiplicative Update (MU) rule to solve Non-negative
|
||||
Matrix Factorization (NMF) in training. Defaults: 6.
|
||||
eval_steps (int): The number of iteration steps in
|
||||
Multiplicative Update (MU) rule to solve Non-negative
|
||||
Matrix Factorization (NMF) in evaluation. Defaults: 7.
|
||||
inv_t (int): Inverted multiple number to make coefficient
|
||||
smaller in softmax. Defaults: 100.
|
||||
rand_init (bool): Whether to initialize randomly.
|
||||
Defaults: True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
MD_S=1,
|
||||
MD_R=64,
|
||||
train_steps=6,
|
||||
eval_steps=7,
|
||||
inv_t=100,
|
||||
rand_init=True):
|
||||
super().__init__()
|
||||
|
||||
self.S = MD_S
|
||||
self.R = MD_R
|
||||
|
||||
self.train_steps = train_steps
|
||||
self.eval_steps = eval_steps
|
||||
|
||||
self.inv_t = inv_t
|
||||
|
||||
self.rand_init = rand_init
|
||||
|
||||
def _build_bases(self, B, S, D, R, device=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def local_step(self, x, bases, coef):
|
||||
raise NotImplementedError
|
||||
|
||||
def local_inference(self, x, bases):
|
||||
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
coef = torch.bmm(x.transpose(1, 2), bases)
|
||||
coef = F.softmax(self.inv_t * coef, dim=-1)
|
||||
|
||||
steps = self.train_steps if self.training else self.eval_steps
|
||||
for _ in range(steps):
|
||||
bases, coef = self.local_step(x, bases, coef)
|
||||
|
||||
return bases, coef
|
||||
|
||||
def compute_coef(self, x, bases, coef):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x, return_bases=False):
|
||||
"""Forward Function."""
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# (B, C, H, W) -> (B * S, D, N)
|
||||
D = C // self.S
|
||||
N = H * W
|
||||
x = x.view(B * self.S, D, N)
|
||||
if not self.rand_init and not hasattr(self, 'bases'):
|
||||
bases = self._build_bases(1, self.S, D, self.R, device=x.device)
|
||||
self.register_buffer('bases', bases)
|
||||
|
||||
# (S, D, R) -> (B * S, D, R)
|
||||
if self.rand_init:
|
||||
bases = self._build_bases(B, self.S, D, self.R, device=x.device)
|
||||
else:
|
||||
bases = self.bases.repeat(B, 1, 1)
|
||||
|
||||
bases, coef = self.local_inference(x, bases)
|
||||
|
||||
# (B * S, N, R)
|
||||
coef = self.compute_coef(x, bases, coef)
|
||||
|
||||
# (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
|
||||
x = torch.bmm(bases, coef.transpose(1, 2))
|
||||
|
||||
# (B * S, D, N) -> (B, C, H, W)
|
||||
x = x.view(B, C, H, W)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class NMF2D(Matrix_Decomposition_2D_Base):
|
||||
"""Non-negative Matrix Factorization (NMF) module.
|
||||
|
||||
It is inherited from ``Matrix_Decomposition_2D_Base`` module.
|
||||
"""
|
||||
|
||||
def __init__(self, args=dict()):
|
||||
super().__init__(**args)
|
||||
|
||||
self.inv_t = 1
|
||||
|
||||
def _build_bases(self, B, S, D, R, device=None):
|
||||
"""Build bases in initialization."""
|
||||
if device is None:
|
||||
device = get_device()
|
||||
bases = torch.rand((B * S, D, R)).to(device)
|
||||
bases = F.normalize(bases, dim=1)
|
||||
|
||||
return bases
|
||||
|
||||
def local_step(self, x, bases, coef):
|
||||
"""Local step in iteration to renew bases and coefficient."""
|
||||
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
numerator = torch.bmm(x.transpose(1, 2), bases)
|
||||
# (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
|
||||
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
|
||||
# Multiplicative Update
|
||||
coef = coef * numerator / (denominator + 1e-6)
|
||||
|
||||
# (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
|
||||
numerator = torch.bmm(x, coef)
|
||||
# (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
|
||||
denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
|
||||
# Multiplicative Update
|
||||
bases = bases * numerator / (denominator + 1e-6)
|
||||
|
||||
return bases, coef
|
||||
|
||||
def compute_coef(self, x, bases, coef):
|
||||
"""Compute coefficient."""
|
||||
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
numerator = torch.bmm(x.transpose(1, 2), bases)
|
||||
# (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
|
||||
# multiplication update
|
||||
coef = coef * numerator / (denominator + 1e-6)
|
||||
|
||||
return coef
|
||||
|
||||
|
||||
class Hamburger(nn.Module):
|
||||
"""Hamburger Module. It consists of one slice of "ham" (matrix
|
||||
decomposition) and two slices of "bread" (linear transformation).
|
||||
|
||||
Args:
|
||||
ham_channels (int): Input and output channels of feature.
|
||||
ham_kwargs (dict): Config of matrix decomposition module.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ham_channels=512,
|
||||
ham_kwargs=dict(),
|
||||
norm_cfg=None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.ham_in = ConvModule(
|
||||
ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None)
|
||||
|
||||
self.ham = NMF2D(ham_kwargs)
|
||||
|
||||
self.ham_out = ConvModule(
|
||||
ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
|
||||
def forward(self, x):
|
||||
enjoy = self.ham_in(x)
|
||||
enjoy = F.relu(enjoy, inplace=True)
|
||||
enjoy = self.ham(enjoy)
|
||||
enjoy = self.ham_out(enjoy)
|
||||
ham = F.relu(x + enjoy, inplace=True)
|
||||
|
||||
return ham
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class LightHamHead(BaseDecodeHead):
|
||||
"""SegNeXt decode head.
|
||||
|
||||
This decode head is the implementation of `SegNeXt: Rethinking
|
||||
Convolutional Attention Design for Semantic
|
||||
Segmentation <https://arxiv.org/abs/2209.08575>`_.
|
||||
Inspiration from https://github.com/visual-attention-network/segnext.
|
||||
|
||||
Specifically, LightHamHead is inspired by HamNet from
|
||||
`Is Attention Better Than Matrix Decomposition?
|
||||
<https://arxiv.org/abs/2109.04553>`.
|
||||
|
||||
Args:
|
||||
ham_channels (int): input channels for Hamburger.
|
||||
Defaults: 512.
|
||||
ham_kwargs (int): kwagrs for Ham. Defaults: dict().
|
||||
"""
|
||||
|
||||
def __init__(self, ham_channels=512, ham_kwargs=dict(), **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
self.ham_channels = ham_channels
|
||||
|
||||
self.squeeze = ConvModule(
|
||||
sum(self.in_channels),
|
||||
self.ham_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.hamburger = Hamburger(ham_channels, ham_kwargs, **kwargs)
|
||||
|
||||
self.align = ConvModule(
|
||||
self.ham_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
inputs = self._transform_inputs(inputs)
|
||||
|
||||
inputs = [
|
||||
resize(
|
||||
level,
|
||||
size=inputs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners) for level in inputs
|
||||
]
|
||||
|
||||
inputs = torch.cat(inputs, dim=1)
|
||||
# apply a conv block to squeeze feature map
|
||||
x = self.squeeze(inputs)
|
||||
# apply hamburger module
|
||||
x = self.hamburger(x)
|
||||
|
||||
# apply a conv block to align feature map
|
||||
output = self.align(x)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
Reference in New Issue
Block a user