init
This commit is contained in:
107
finetune/mmseg/models/utils/shape_convert.py
Normal file
107
finetune/mmseg/models/utils/shape_convert.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
def nlc_to_nchw(x, hw_shape):
|
||||
"""Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
|
||||
|
||||
Args:
|
||||
x (Tensor): The input tensor of shape [N, L, C] before conversion.
|
||||
hw_shape (Sequence[int]): The height and width of output feature map.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor of shape [N, C, H, W] after conversion.
|
||||
"""
|
||||
H, W = hw_shape
|
||||
assert len(x.shape) == 3
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, 'The seq_len doesn\'t match H, W'
|
||||
return x.transpose(1, 2).reshape(B, C, H, W)
|
||||
|
||||
|
||||
def nchw_to_nlc(x):
|
||||
"""Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
|
||||
|
||||
Args:
|
||||
x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor of shape [N, L, C] after conversion.
|
||||
"""
|
||||
assert len(x.shape) == 4
|
||||
return x.flatten(2).transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
def nchw2nlc2nchw(module, x, contiguous=False, **kwargs):
|
||||
"""Flatten [N, C, H, W] shape tensor `x` to [N, L, C] shape tensor. Use the
|
||||
reshaped tensor as the input of `module`, and the convert the output of
|
||||
`module`, whose shape is.
|
||||
|
||||
[N, L, C], to [N, C, H, W].
|
||||
|
||||
Args:
|
||||
module (Callable): A callable object the takes a tensor
|
||||
with shape [N, L, C] as input.
|
||||
x (Tensor): The input tensor of shape [N, C, H, W].
|
||||
contiguous:
|
||||
contiguous (Bool): Whether to make the tensor contiguous
|
||||
after each shape transform.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor of shape [N, C, H, W].
|
||||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> import torch.nn as nn
|
||||
>>> norm = nn.LayerNorm(4)
|
||||
>>> feature_map = torch.rand(4, 4, 5, 5)
|
||||
>>> output = nchw2nlc2nchw(norm, feature_map)
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
if not contiguous:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = module(x, **kwargs)
|
||||
x = x.transpose(1, 2).reshape(B, C, H, W)
|
||||
else:
|
||||
x = x.flatten(2).transpose(1, 2).contiguous()
|
||||
x = module(x, **kwargs)
|
||||
x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()
|
||||
return x
|
||||
|
||||
|
||||
def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs):
|
||||
"""Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the
|
||||
reshaped tensor as the input of `module`, and convert the output of
|
||||
`module`, whose shape is.
|
||||
|
||||
[N, C, H, W], to [N, L, C].
|
||||
|
||||
Args:
|
||||
module (Callable): A callable object the takes a tensor
|
||||
with shape [N, C, H, W] as input.
|
||||
x (Tensor): The input tensor of shape [N, L, C].
|
||||
hw_shape: (Sequence[int]): The height and width of the
|
||||
feature map with shape [N, C, H, W].
|
||||
contiguous (Bool): Whether to make the tensor contiguous
|
||||
after each shape transform.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor of shape [N, L, C].
|
||||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> import torch.nn as nn
|
||||
>>> conv = nn.Conv2d(16, 16, 3, 1, 1)
|
||||
>>> feature_map = torch.rand(4, 25, 16)
|
||||
>>> output = nlc2nchw2nlc(conv, feature_map, (5, 5))
|
||||
"""
|
||||
H, W = hw_shape
|
||||
assert len(x.shape) == 3
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, 'The seq_len doesn\'t match H, W'
|
||||
if not contiguous:
|
||||
x = x.transpose(1, 2).reshape(B, C, H, W)
|
||||
x = module(x, **kwargs)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
else:
|
||||
x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()
|
||||
x = module(x, **kwargs)
|
||||
x = x.flatten(2).transpose(1, 2).contiguous()
|
||||
return x
|
||||
Reference in New Issue
Block a user