This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

View File

@@ -0,0 +1,56 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_beit(ckpt):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
if k.startswith('patch_embed'):
new_key = k.replace('patch_embed.proj', 'patch_embed.projection')
new_ckpt[new_key] = v
if k.startswith('blocks'):
new_key = k.replace('blocks', 'layers')
if 'norm' in new_key:
new_key = new_key.replace('norm', 'ln')
elif 'mlp.fc1' in new_key:
new_key = new_key.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in new_key:
new_key = new_key.replace('mlp.fc2', 'ffn.layers.1')
new_ckpt[new_key] = v
else:
new_key = k
new_ckpt[new_key] = v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in official pretrained beit models to'
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_beit(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,163 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_vitlayer(paras):
new_para_name = ''
if paras[0] == 'ln_1':
new_para_name = '.'.join(['ln1'] + paras[1:])
elif paras[0] == 'attn':
new_para_name = '.'.join(['attn.attn'] + paras[1:])
elif paras[0] == 'ln_2':
new_para_name = '.'.join(['ln2'] + paras[1:])
elif paras[0] == 'mlp':
if paras[1] == 'c_fc':
new_para_name = '.'.join(['ffn.layers.0.0'] + paras[-1:])
else:
new_para_name = '.'.join(['ffn.layers.1'] + paras[-1:])
else:
print(f'Wrong for {paras}')
return new_para_name
def convert_translayer(paras):
new_para_name = ''
if paras[0] == 'attn':
new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
elif paras[0] == 'ln_1':
new_para_name = '.'.join(['norms.0'] + paras[1:])
elif paras[0] == 'ln_2':
new_para_name = '.'.join(['norms.1'] + paras[1:])
elif paras[0] == 'mlp':
if paras[1] == 'c_fc':
new_para_name = '.'.join(['ffns.0.layers.0.0'] + paras[2:])
elif paras[1] == 'c_proj':
new_para_name = '.'.join(['ffns.0.layers.1'] + paras[2:])
else:
print(f'Wrong for {paras}')
else:
print(f'Wrong for {paras}')
return new_para_name
def convert_key_name(ckpt, visual_split):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
key_list = k.split('.')
if key_list[0] == 'visual':
new_transform_name = 'image_encoder'
if key_list[1] == 'class_embedding':
new_name = '.'.join([new_transform_name, 'cls_token'])
elif key_list[1] == 'positional_embedding':
new_name = '.'.join([new_transform_name, 'pos_embed'])
elif key_list[1] == 'conv1':
new_name = '.'.join([
new_transform_name, 'patch_embed.projection', key_list[2]
])
elif key_list[1] == 'ln_pre':
new_name = '.'.join(
[new_transform_name, key_list[1], key_list[2]])
elif key_list[1] == 'transformer':
new_layer_name = 'layers'
layer_index = key_list[3]
paras = key_list[4:]
if int(layer_index) < visual_split:
new_para_name = convert_vitlayer(paras)
new_name = '.'.join([
new_transform_name, new_layer_name, layer_index,
new_para_name
])
else:
new_para_name = convert_translayer(paras)
new_transform_name = 'decode_head.rec_with_attnbias'
new_layer_name = 'layers'
layer_index = str(int(layer_index) - visual_split)
new_name = '.'.join([
new_transform_name, new_layer_name, layer_index,
new_para_name
])
elif key_list[1] == 'proj':
new_name = 'decode_head.rec_with_attnbias.proj.weight'
elif key_list[1] == 'ln_post':
new_name = k.replace('visual', 'decode_head.rec_with_attnbias')
else:
print(f'pop parameter: {k}')
continue
else:
text_encoder_name = 'text_encoder'
if key_list[0] == 'transformer':
layer_name = 'transformer'
layer_index = key_list[2]
paras = key_list[3:]
new_para_name = convert_translayer(paras)
new_name = '.'.join([
text_encoder_name, layer_name, layer_index, new_para_name
])
elif key_list[0] in [
'positional_embedding', 'text_projection', 'bg_embed',
'attn_mask', 'logit_scale', 'token_embedding', 'ln_final'
]:
new_name = 'text_encoder.' + k
else:
print(f'pop parameter: {k}')
continue
new_ckpt[new_name] = v
return new_ckpt
def convert_tensor(ckpt):
cls_token = ckpt['image_encoder.cls_token']
new_cls_token = cls_token.unsqueeze(0).unsqueeze(0)
ckpt['image_encoder.cls_token'] = new_cls_token
pos_embed = ckpt['image_encoder.pos_embed']
new_pos_embed = pos_embed.unsqueeze(0)
ckpt['image_encoder.pos_embed'] = new_pos_embed
proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight']
new_proj_weight = proj_weight.transpose(1, 0)
ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight
return ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in timm pretrained vit models to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
if any([s in args.src for s in ['B-16', 'b16', 'base_patch16']]):
visual_split = 9
elif any([s in args.src for s in ['L-14', 'l14', 'large_patch14']]):
visual_split = 18
else:
print('Make sure the clip model is ViT-B/16 or ViT-L/14!')
visual_split = -1
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if isinstance(checkpoint, torch.jit.RecursiveScriptModule):
state_dict = checkpoint.state_dict()
else:
if 'state_dict' in checkpoint:
# timm checkpoint
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
# deit checkpoint
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_key_name(state_dict, visual_split)
weight = convert_tensor(weight)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,82 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_mit(ckpt):
new_ckpt = OrderedDict()
# Process the concat between q linear weights and kv linear weights
for k, v in ckpt.items():
if k.startswith('head'):
continue
# patch embedding conversion
elif k.startswith('patch_embed'):
stage_i = int(k.split('.')[0].replace('patch_embed', ''))
new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0')
new_v = v
if 'proj.' in new_k:
new_k = new_k.replace('proj.', 'projection.')
# transformer encoder layer conversion
elif k.startswith('block'):
stage_i = int(k.split('.')[0].replace('block', ''))
new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1')
new_v = v
if 'attn.q.' in new_k:
sub_item_k = k.replace('q.', 'kv.')
new_k = new_k.replace('q.', 'attn.in_proj_')
new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)
elif 'attn.kv.' in new_k:
continue
elif 'attn.proj.' in new_k:
new_k = new_k.replace('proj.', 'attn.out_proj.')
elif 'attn.sr.' in new_k:
new_k = new_k.replace('sr.', 'sr.')
elif 'mlp.' in new_k:
string = f'{new_k}-'
new_k = new_k.replace('mlp.', 'ffn.layers.')
if 'fc1.weight' in new_k or 'fc2.weight' in new_k:
new_v = v.reshape((*v.shape, 1, 1))
new_k = new_k.replace('fc1.', '0.')
new_k = new_k.replace('dwconv.dwconv.', '1.')
new_k = new_k.replace('fc2.', '4.')
string += f'{new_k} {v.shape}-{new_v.shape}'
# norm layer conversion
elif k.startswith('norm'):
stage_i = int(k.split('.')[0].replace('norm', ''))
new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2')
new_v = v
else:
new_k = k
new_v = v
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in official pretrained segformer to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_mit(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,220 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_key_name(ckpt):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
key_list = k.split('.')
if key_list[0] == 'clip_visual_extractor':
new_transform_name = 'image_encoder'
if key_list[1] == 'class_embedding':
new_name = '.'.join([new_transform_name, 'cls_token'])
elif key_list[1] == 'positional_embedding':
new_name = '.'.join([new_transform_name, 'pos_embed'])
elif key_list[1] == 'conv1':
new_name = '.'.join([
new_transform_name, 'patch_embed.projection', key_list[2]
])
elif key_list[1] == 'ln_pre':
new_name = '.'.join(
[new_transform_name, key_list[1], key_list[2]])
elif key_list[1] == 'resblocks':
new_layer_name = 'layers'
layer_index = key_list[2]
paras = key_list[3:]
if paras[0] == 'ln_1':
new_para_name = '.'.join(['ln1'] + key_list[4:])
elif paras[0] == 'attn':
new_para_name = '.'.join(['attn.attn'] + key_list[4:])
elif paras[0] == 'ln_2':
new_para_name = '.'.join(['ln2'] + key_list[4:])
elif paras[0] == 'mlp':
if paras[1] == 'c_fc':
new_para_name = '.'.join(['ffn.layers.0.0'] +
key_list[-1:])
else:
new_para_name = '.'.join(['ffn.layers.1'] +
key_list[-1:])
new_name = '.'.join([
new_transform_name, new_layer_name, layer_index,
new_para_name
])
elif key_list[0] == 'side_adapter_network':
decode_head_name = 'decode_head'
module_name = 'side_adapter_network'
if key_list[1] == 'vit_model':
if key_list[2] == 'blocks':
layer_name = 'encode_layers'
layer_index = key_list[3]
paras = key_list[4:]
if paras[0] == 'norm1':
new_para_name = '.'.join(['ln1'] + key_list[5:])
elif paras[0] == 'attn':
new_para_name = '.'.join(key_list[4:])
new_para_name = new_para_name.replace(
'attn.qkv.', 'attn.attn.in_proj_')
new_para_name = new_para_name.replace(
'attn.proj', 'attn.attn.out_proj')
elif paras[0] == 'norm2':
new_para_name = '.'.join(['ln2'] + key_list[5:])
elif paras[0] == 'mlp':
new_para_name = '.'.join(['ffn'] + key_list[5:])
new_para_name = new_para_name.replace(
'fc1', 'layers.0.0')
new_para_name = new_para_name.replace(
'fc2', 'layers.1')
else:
print(f'Wrong for {k}')
new_name = '.'.join([
decode_head_name, module_name, layer_name, layer_index,
new_para_name
])
elif key_list[2] == 'pos_embed':
new_name = '.'.join(
[decode_head_name, module_name, 'pos_embed'])
elif key_list[2] == 'patch_embed':
new_name = '.'.join([
decode_head_name, module_name, 'patch_embed',
'projection', key_list[4]
])
else:
print(f'Wrong for {k}')
elif key_list[1] == 'query_embed' or key_list[
1] == 'query_pos_embed':
new_name = '.'.join(
[decode_head_name, module_name, key_list[1]])
elif key_list[1] == 'fusion_layers':
layer_name = 'conv_clips'
layer_index = key_list[2][-1]
paras = '.'.join(key_list[3:])
new_para_name = paras.replace('input_proj.0', '0')
new_para_name = new_para_name.replace('input_proj.1', '1.conv')
new_name = '.'.join([
decode_head_name, module_name, layer_name, layer_index,
new_para_name
])
elif key_list[1] == 'mask_decoder':
new_name = 'decode_head.' + k
else:
print(f'Wrong for {k}')
elif key_list[0] == 'clip_rec_head':
module_name = 'rec_with_attnbias'
if key_list[1] == 'proj':
new_name = '.'.join(
[decode_head_name, module_name, 'proj.weight'])
elif key_list[1] == 'ln_post':
new_name = '.'.join(
[decode_head_name, module_name, 'ln_post', key_list[2]])
elif key_list[1] == 'resblocks':
new_layer_name = 'layers'
layer_index = key_list[2]
paras = key_list[3:]
if paras[0] == 'ln_1':
new_para_name = '.'.join(['norms.0'] + paras[1:])
elif paras[0] == 'attn':
new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
elif paras[0] == 'ln_2':
new_para_name = '.'.join(['norms.1'] + paras[1:])
elif paras[0] == 'mlp':
if paras[1] == 'c_fc':
new_para_name = '.'.join(['ffns.0.layers.0.0'] +
paras[2:])
elif paras[1] == 'c_proj':
new_para_name = '.'.join(['ffns.0.layers.1'] +
paras[2:])
else:
print(f'Wrong for {k}')
new_name = '.'.join([
decode_head_name, module_name, new_layer_name, layer_index,
new_para_name
])
else:
print(f'Wrong for {k}')
elif key_list[0] == 'ov_classifier':
text_encoder_name = 'text_encoder'
if key_list[1] == 'transformer':
layer_name = 'transformer'
layer_index = key_list[3]
paras = key_list[4:]
if paras[0] == 'attn':
new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
elif paras[0] == 'ln_1':
new_para_name = '.'.join(['norms.0'] + paras[1:])
elif paras[0] == 'ln_2':
new_para_name = '.'.join(['norms.1'] + paras[1:])
elif paras[0] == 'mlp':
if paras[1] == 'c_fc':
new_para_name = '.'.join(['ffns.0.layers.0.0'] +
paras[2:])
elif paras[1] == 'c_proj':
new_para_name = '.'.join(['ffns.0.layers.1'] +
paras[2:])
else:
print(f'Wrong for {k}')
else:
print(f'Wrong for {k}')
new_name = '.'.join([
text_encoder_name, layer_name, layer_index, new_para_name
])
elif key_list[1] in [
'positional_embedding', 'text_projection', 'bg_embed',
'attn_mask', 'logit_scale', 'token_embedding', 'ln_final'
]:
new_name = k.replace('ov_classifier', 'text_encoder')
else:
print(f'Wrong for {k}')
elif key_list[0] == 'criterion':
new_name = k
else:
print(f'Wrong for {k}')
new_ckpt[new_name] = v
return new_ckpt
def convert_tensor(ckpt):
cls_token = ckpt['image_encoder.cls_token']
new_cls_token = cls_token.unsqueeze(0).unsqueeze(0)
ckpt['image_encoder.cls_token'] = new_cls_token
pos_embed = ckpt['image_encoder.pos_embed']
new_pos_embed = pos_embed.unsqueeze(0)
ckpt['image_encoder.pos_embed'] = new_pos_embed
proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight']
new_proj_weight = proj_weight.transpose(1, 0)
ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight
return ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in timm pretrained vit models to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
# timm checkpoint
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
# deit checkpoint
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_key_name(state_dict)
weight = convert_tensor(weight)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,71 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_stdc(ckpt, stdc_type):
new_state_dict = {}
if stdc_type == 'STDC1':
stage_lst = ['0', '1', '2.0', '2.1', '3.0', '3.1', '4.0', '4.1']
else:
stage_lst = [
'0', '1', '2.0', '2.1', '2.2', '2.3', '3.0', '3.1', '3.2', '3.3',
'3.4', '4.0', '4.1', '4.2'
]
for k, v in ckpt.items():
ori_k = k
flag = False
if 'cp.' in k:
k = k.replace('cp.', '')
if 'features.' in k:
num_layer = int(k.split('.')[1])
feature_key_lst = 'features.' + str(num_layer) + '.'
stages_key_lst = 'stages.' + stage_lst[num_layer] + '.'
k = k.replace(feature_key_lst, stages_key_lst)
flag = True
if 'conv_list' in k:
k = k.replace('conv_list', 'layers')
flag = True
if 'avd_layer.' in k:
if 'avd_layer.0' in k:
k = k.replace('avd_layer.0', 'downsample.conv')
elif 'avd_layer.1' in k:
k = k.replace('avd_layer.1', 'downsample.bn')
flag = True
if flag:
new_state_dict[k] = ckpt[ori_k]
return new_state_dict
def main():
parser = argparse.ArgumentParser(
description='Convert keys in official pretrained STDC1/2 to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
parser.add_argument('type', help='model type: STDC1 or STDC2')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
assert args.type in ['STDC1',
'STDC2'], 'STD type should be STDC1 or STDC2!'
weight = convert_stdc(state_dict, args.type)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,87 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_swin(ckpt):
new_ckpt = OrderedDict()
def correct_unfold_reduction_order(x):
out_channel, in_channel = x.shape
x = x.reshape(out_channel, 4, in_channel // 4)
x = x[:, [0, 2, 1, 3], :].transpose(1,
2).reshape(out_channel, in_channel)
return x
def correct_unfold_norm_order(x):
in_channel = x.shape[0]
x = x.reshape(4, in_channel // 4)
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
return x
for k, v in ckpt.items():
if k.startswith('head'):
continue
elif k.startswith('layers'):
new_v = v
if 'attn.' in k:
new_k = k.replace('attn.', 'attn.w_msa.')
elif 'mlp.' in k:
if 'mlp.fc1.' in k:
new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
elif 'mlp.fc2.' in k:
new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
else:
new_k = k.replace('mlp.', 'ffn.')
elif 'downsample' in k:
new_k = k
if 'reduction.' in k:
new_v = correct_unfold_reduction_order(v)
elif 'norm.' in k:
new_v = correct_unfold_norm_order(v)
else:
new_k = k
new_k = new_k.replace('layers', 'stages', 1)
elif k.startswith('patch_embed'):
new_v = v
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
else:
new_v = v
new_k = k
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in official pretrained swin models to'
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_swin(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,87 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_twins(args, ckpt):
new_ckpt = OrderedDict()
for k, v in list(ckpt.items()):
new_v = v
if k.startswith('head'):
continue
elif k.startswith('patch_embeds'):
if 'proj.' in k:
new_k = k.replace('proj.', 'projection.')
else:
new_k = k
elif k.startswith('blocks'):
# Union
if 'attn.q.' in k:
new_k = k.replace('q.', 'attn.in_proj_')
new_v = torch.cat([v, ckpt[k.replace('attn.q.', 'attn.kv.')]],
dim=0)
elif 'mlp.fc1' in k:
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in k:
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
# Only pcpvt
elif args.model == 'pcpvt':
if 'attn.proj.' in k:
new_k = k.replace('proj.', 'attn.out_proj.')
else:
new_k = k
# Only svt
else:
if 'attn.proj.' in k:
k_lst = k.split('.')
if int(k_lst[2]) % 2 == 1:
new_k = k.replace('proj.', 'attn.out_proj.')
else:
new_k = k
else:
new_k = k
new_k = new_k.replace('blocks.', 'layers.')
elif k.startswith('pos_block'):
new_k = k.replace('pos_block', 'position_encodings')
if 'proj.0.' in new_k:
new_k = new_k.replace('proj.0.', 'proj.')
else:
new_k = k
if 'attn.kv.' not in k:
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in timm pretrained vit models to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
parser.add_argument('model', help='model: pcpvt or svt')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
# timm checkpoint
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
weight = convert_twins(args, state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,70 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_vit(ckpt):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
if k.startswith('head'):
continue
if k.startswith('norm'):
new_k = k.replace('norm.', 'ln1.')
elif k.startswith('patch_embed'):
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
elif k.startswith('blocks'):
if 'norm' in k:
new_k = k.replace('norm', 'ln')
elif 'mlp.fc1' in k:
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in k:
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
elif 'attn.qkv' in k:
new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_')
elif 'attn.proj' in k:
new_k = k.replace('attn.proj', 'attn.attn.out_proj')
else:
new_k = k
new_k = new_k.replace('blocks.', 'layers.')
else:
new_k = k
new_ckpt[new_k] = v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in timm pretrained vit models to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
# timm checkpoint
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
# deit checkpoint
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_vit(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,123 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import mmengine
import numpy as np
import torch
def vit_jax_to_torch(jax_weights, num_layer=12):
torch_weights = dict()
# patch embedding
conv_filters = jax_weights['embedding/kernel']
conv_filters = conv_filters.permute(3, 2, 0, 1)
torch_weights['patch_embed.projection.weight'] = conv_filters
torch_weights['patch_embed.projection.bias'] = jax_weights[
'embedding/bias']
# pos embedding
torch_weights['pos_embed'] = jax_weights[
'Transformer/posembed_input/pos_embedding']
# cls token
torch_weights['cls_token'] = jax_weights['cls']
# head
torch_weights['ln1.weight'] = jax_weights['Transformer/encoder_norm/scale']
torch_weights['ln1.bias'] = jax_weights['Transformer/encoder_norm/bias']
# transformer blocks
for i in range(num_layer):
jax_block = f'Transformer/encoderblock_{i}'
torch_block = f'layers.{i}'
# attention norm
torch_weights[f'{torch_block}.ln1.weight'] = jax_weights[
f'{jax_block}/LayerNorm_0/scale']
torch_weights[f'{torch_block}.ln1.bias'] = jax_weights[
f'{jax_block}/LayerNorm_0/bias']
# attention
query_weight = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/query/kernel']
query_bias = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/query/bias']
key_weight = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/key/kernel']
key_bias = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/key/bias']
value_weight = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/value/kernel']
value_bias = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/value/bias']
qkv_weight = torch.from_numpy(
np.stack((query_weight, key_weight, value_weight), 1))
qkv_weight = torch.flatten(qkv_weight, start_dim=1)
qkv_bias = torch.from_numpy(
np.stack((query_bias, key_bias, value_bias), 0))
qkv_bias = torch.flatten(qkv_bias, start_dim=0)
torch_weights[f'{torch_block}.attn.attn.in_proj_weight'] = qkv_weight
torch_weights[f'{torch_block}.attn.attn.in_proj_bias'] = qkv_bias
to_out_weight = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/out/kernel']
to_out_weight = torch.flatten(to_out_weight, start_dim=0, end_dim=1)
torch_weights[
f'{torch_block}.attn.attn.out_proj.weight'] = to_out_weight
torch_weights[f'{torch_block}.attn.attn.out_proj.bias'] = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/out/bias']
# mlp norm
torch_weights[f'{torch_block}.ln2.weight'] = jax_weights[
f'{jax_block}/LayerNorm_2/scale']
torch_weights[f'{torch_block}.ln2.bias'] = jax_weights[
f'{jax_block}/LayerNorm_2/bias']
# mlp
torch_weights[f'{torch_block}.ffn.layers.0.0.weight'] = jax_weights[
f'{jax_block}/MlpBlock_3/Dense_0/kernel']
torch_weights[f'{torch_block}.ffn.layers.0.0.bias'] = jax_weights[
f'{jax_block}/MlpBlock_3/Dense_0/bias']
torch_weights[f'{torch_block}.ffn.layers.1.weight'] = jax_weights[
f'{jax_block}/MlpBlock_3/Dense_1/kernel']
torch_weights[f'{torch_block}.ffn.layers.1.bias'] = jax_weights[
f'{jax_block}/MlpBlock_3/Dense_1/bias']
# transpose weights
for k, v in torch_weights.items():
if 'weight' in k and 'patch_embed' not in k and 'ln' not in k:
v = v.permute(1, 0)
torch_weights[k] = v
return torch_weights
def main():
# stole refactoring code from Robin Strudel, thanks
parser = argparse.ArgumentParser(
description='Convert keys from jax official pretrained vit models to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
jax_weights = np.load(args.src)
jax_weights_tensor = {}
for key in jax_weights.files:
value = torch.from_numpy(jax_weights[key])
jax_weights_tensor[key] = value
if 'L_16-i21k' in args.src:
num_layer = 24
else:
num_layer = 12
torch_weights = vit_jax_to_torch(jax_weights_tensor, num_layer)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(torch_weights, args.dst)
if __name__ == '__main__':
main()