Files
esenke 01adcfdf60 init
2025-12-08 22:16:31 +08:00

52 lines
1.9 KiB
Python

import torch.nn as nn
from collections import OrderedDict
from mmcv.cnn.utils.weight_init import (kaiming_init, trunc_normal_)
from mmcv.runner import (CheckpointLoader, load_state_dict)
from mmseg.utils import get_root_logger
class UPHead(nn.Module):
def __init__(self, in_dim, out_dim, up_scale, init_cfg=None):
super().__init__()
self.decoder = nn.Sequential(
nn.Conv2d(in_channels=in_dim,
out_channels=up_scale**2 * out_dim,
kernel_size=1),
nn.PixelShuffle(up_scale),
)
self.init_cfg = init_cfg
self.apply(self._init_weights)
def _init_weights(self, m):
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
logger = get_root_logger()
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
_state_dict = checkpoint['state_dict']
else:
_state_dict = checkpoint
state_dict = OrderedDict()
for k, v in _state_dict.items():
if k.startswith('backbone.'):
state_dict[k[9:]] = v
else:
state_dict[k] = v
print(f'loading weight: {self.init_cfg["checkpoint"]}')
load_state_dict(self, state_dict, strict=False, logger=logger)
else:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
kaiming_init(m, mode='fan_in', bias=0.)
def forward(self, x):
x = self.decoder(x)
return x