init
This commit is contained in:
56
finetune/tools/model_converters/beit2mmseg.py
Normal file
56
finetune/tools/model_converters/beit2mmseg.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_beit(ckpt):
|
||||
new_ckpt = OrderedDict()
|
||||
|
||||
for k, v in ckpt.items():
|
||||
if k.startswith('patch_embed'):
|
||||
new_key = k.replace('patch_embed.proj', 'patch_embed.projection')
|
||||
new_ckpt[new_key] = v
|
||||
if k.startswith('blocks'):
|
||||
new_key = k.replace('blocks', 'layers')
|
||||
if 'norm' in new_key:
|
||||
new_key = new_key.replace('norm', 'ln')
|
||||
elif 'mlp.fc1' in new_key:
|
||||
new_key = new_key.replace('mlp.fc1', 'ffn.layers.0.0')
|
||||
elif 'mlp.fc2' in new_key:
|
||||
new_key = new_key.replace('mlp.fc2', 'ffn.layers.1')
|
||||
new_ckpt[new_key] = v
|
||||
else:
|
||||
new_key = k
|
||||
new_ckpt[new_key] = v
|
||||
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in official pretrained beit models to'
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
weight = convert_beit(state_dict)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
163
finetune/tools/model_converters/clip2mmseg.py
Normal file
163
finetune/tools/model_converters/clip2mmseg.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_vitlayer(paras):
|
||||
new_para_name = ''
|
||||
if paras[0] == 'ln_1':
|
||||
new_para_name = '.'.join(['ln1'] + paras[1:])
|
||||
elif paras[0] == 'attn':
|
||||
new_para_name = '.'.join(['attn.attn'] + paras[1:])
|
||||
elif paras[0] == 'ln_2':
|
||||
new_para_name = '.'.join(['ln2'] + paras[1:])
|
||||
elif paras[0] == 'mlp':
|
||||
if paras[1] == 'c_fc':
|
||||
new_para_name = '.'.join(['ffn.layers.0.0'] + paras[-1:])
|
||||
else:
|
||||
new_para_name = '.'.join(['ffn.layers.1'] + paras[-1:])
|
||||
else:
|
||||
print(f'Wrong for {paras}')
|
||||
return new_para_name
|
||||
|
||||
|
||||
def convert_translayer(paras):
|
||||
new_para_name = ''
|
||||
if paras[0] == 'attn':
|
||||
new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
|
||||
elif paras[0] == 'ln_1':
|
||||
new_para_name = '.'.join(['norms.0'] + paras[1:])
|
||||
elif paras[0] == 'ln_2':
|
||||
new_para_name = '.'.join(['norms.1'] + paras[1:])
|
||||
elif paras[0] == 'mlp':
|
||||
if paras[1] == 'c_fc':
|
||||
new_para_name = '.'.join(['ffns.0.layers.0.0'] + paras[2:])
|
||||
elif paras[1] == 'c_proj':
|
||||
new_para_name = '.'.join(['ffns.0.layers.1'] + paras[2:])
|
||||
else:
|
||||
print(f'Wrong for {paras}')
|
||||
else:
|
||||
print(f'Wrong for {paras}')
|
||||
return new_para_name
|
||||
|
||||
|
||||
def convert_key_name(ckpt, visual_split):
|
||||
new_ckpt = OrderedDict()
|
||||
for k, v in ckpt.items():
|
||||
key_list = k.split('.')
|
||||
if key_list[0] == 'visual':
|
||||
new_transform_name = 'image_encoder'
|
||||
if key_list[1] == 'class_embedding':
|
||||
new_name = '.'.join([new_transform_name, 'cls_token'])
|
||||
elif key_list[1] == 'positional_embedding':
|
||||
new_name = '.'.join([new_transform_name, 'pos_embed'])
|
||||
elif key_list[1] == 'conv1':
|
||||
new_name = '.'.join([
|
||||
new_transform_name, 'patch_embed.projection', key_list[2]
|
||||
])
|
||||
elif key_list[1] == 'ln_pre':
|
||||
new_name = '.'.join(
|
||||
[new_transform_name, key_list[1], key_list[2]])
|
||||
elif key_list[1] == 'transformer':
|
||||
new_layer_name = 'layers'
|
||||
layer_index = key_list[3]
|
||||
paras = key_list[4:]
|
||||
if int(layer_index) < visual_split:
|
||||
new_para_name = convert_vitlayer(paras)
|
||||
new_name = '.'.join([
|
||||
new_transform_name, new_layer_name, layer_index,
|
||||
new_para_name
|
||||
])
|
||||
else:
|
||||
new_para_name = convert_translayer(paras)
|
||||
new_transform_name = 'decode_head.rec_with_attnbias'
|
||||
new_layer_name = 'layers'
|
||||
layer_index = str(int(layer_index) - visual_split)
|
||||
new_name = '.'.join([
|
||||
new_transform_name, new_layer_name, layer_index,
|
||||
new_para_name
|
||||
])
|
||||
elif key_list[1] == 'proj':
|
||||
new_name = 'decode_head.rec_with_attnbias.proj.weight'
|
||||
elif key_list[1] == 'ln_post':
|
||||
new_name = k.replace('visual', 'decode_head.rec_with_attnbias')
|
||||
else:
|
||||
print(f'pop parameter: {k}')
|
||||
continue
|
||||
else:
|
||||
text_encoder_name = 'text_encoder'
|
||||
if key_list[0] == 'transformer':
|
||||
layer_name = 'transformer'
|
||||
layer_index = key_list[2]
|
||||
paras = key_list[3:]
|
||||
new_para_name = convert_translayer(paras)
|
||||
new_name = '.'.join([
|
||||
text_encoder_name, layer_name, layer_index, new_para_name
|
||||
])
|
||||
elif key_list[0] in [
|
||||
'positional_embedding', 'text_projection', 'bg_embed',
|
||||
'attn_mask', 'logit_scale', 'token_embedding', 'ln_final'
|
||||
]:
|
||||
new_name = 'text_encoder.' + k
|
||||
else:
|
||||
print(f'pop parameter: {k}')
|
||||
continue
|
||||
new_ckpt[new_name] = v
|
||||
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def convert_tensor(ckpt):
|
||||
cls_token = ckpt['image_encoder.cls_token']
|
||||
new_cls_token = cls_token.unsqueeze(0).unsqueeze(0)
|
||||
ckpt['image_encoder.cls_token'] = new_cls_token
|
||||
pos_embed = ckpt['image_encoder.pos_embed']
|
||||
new_pos_embed = pos_embed.unsqueeze(0)
|
||||
ckpt['image_encoder.pos_embed'] = new_pos_embed
|
||||
proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight']
|
||||
new_proj_weight = proj_weight.transpose(1, 0)
|
||||
ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight
|
||||
return ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in timm pretrained vit models to '
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
if any([s in args.src for s in ['B-16', 'b16', 'base_patch16']]):
|
||||
visual_split = 9
|
||||
elif any([s in args.src for s in ['L-14', 'l14', 'large_patch14']]):
|
||||
visual_split = 18
|
||||
else:
|
||||
print('Make sure the clip model is ViT-B/16 or ViT-L/14!')
|
||||
visual_split = -1
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
if isinstance(checkpoint, torch.jit.RecursiveScriptModule):
|
||||
state_dict = checkpoint.state_dict()
|
||||
else:
|
||||
if 'state_dict' in checkpoint:
|
||||
# timm checkpoint
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
# deit checkpoint
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
weight = convert_key_name(state_dict, visual_split)
|
||||
weight = convert_tensor(weight)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
82
finetune/tools/model_converters/mit2mmseg.py
Normal file
82
finetune/tools/model_converters/mit2mmseg.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_mit(ckpt):
|
||||
new_ckpt = OrderedDict()
|
||||
# Process the concat between q linear weights and kv linear weights
|
||||
for k, v in ckpt.items():
|
||||
if k.startswith('head'):
|
||||
continue
|
||||
# patch embedding conversion
|
||||
elif k.startswith('patch_embed'):
|
||||
stage_i = int(k.split('.')[0].replace('patch_embed', ''))
|
||||
new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0')
|
||||
new_v = v
|
||||
if 'proj.' in new_k:
|
||||
new_k = new_k.replace('proj.', 'projection.')
|
||||
# transformer encoder layer conversion
|
||||
elif k.startswith('block'):
|
||||
stage_i = int(k.split('.')[0].replace('block', ''))
|
||||
new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1')
|
||||
new_v = v
|
||||
if 'attn.q.' in new_k:
|
||||
sub_item_k = k.replace('q.', 'kv.')
|
||||
new_k = new_k.replace('q.', 'attn.in_proj_')
|
||||
new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)
|
||||
elif 'attn.kv.' in new_k:
|
||||
continue
|
||||
elif 'attn.proj.' in new_k:
|
||||
new_k = new_k.replace('proj.', 'attn.out_proj.')
|
||||
elif 'attn.sr.' in new_k:
|
||||
new_k = new_k.replace('sr.', 'sr.')
|
||||
elif 'mlp.' in new_k:
|
||||
string = f'{new_k}-'
|
||||
new_k = new_k.replace('mlp.', 'ffn.layers.')
|
||||
if 'fc1.weight' in new_k or 'fc2.weight' in new_k:
|
||||
new_v = v.reshape((*v.shape, 1, 1))
|
||||
new_k = new_k.replace('fc1.', '0.')
|
||||
new_k = new_k.replace('dwconv.dwconv.', '1.')
|
||||
new_k = new_k.replace('fc2.', '4.')
|
||||
string += f'{new_k} {v.shape}-{new_v.shape}'
|
||||
# norm layer conversion
|
||||
elif k.startswith('norm'):
|
||||
stage_i = int(k.split('.')[0].replace('norm', ''))
|
||||
new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2')
|
||||
new_v = v
|
||||
else:
|
||||
new_k = k
|
||||
new_v = v
|
||||
new_ckpt[new_k] = new_v
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in official pretrained segformer to '
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
weight = convert_mit(state_dict)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
220
finetune/tools/model_converters/san2mmseg.py
Normal file
220
finetune/tools/model_converters/san2mmseg.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_key_name(ckpt):
|
||||
new_ckpt = OrderedDict()
|
||||
|
||||
for k, v in ckpt.items():
|
||||
key_list = k.split('.')
|
||||
if key_list[0] == 'clip_visual_extractor':
|
||||
new_transform_name = 'image_encoder'
|
||||
if key_list[1] == 'class_embedding':
|
||||
new_name = '.'.join([new_transform_name, 'cls_token'])
|
||||
elif key_list[1] == 'positional_embedding':
|
||||
new_name = '.'.join([new_transform_name, 'pos_embed'])
|
||||
elif key_list[1] == 'conv1':
|
||||
new_name = '.'.join([
|
||||
new_transform_name, 'patch_embed.projection', key_list[2]
|
||||
])
|
||||
elif key_list[1] == 'ln_pre':
|
||||
new_name = '.'.join(
|
||||
[new_transform_name, key_list[1], key_list[2]])
|
||||
elif key_list[1] == 'resblocks':
|
||||
new_layer_name = 'layers'
|
||||
layer_index = key_list[2]
|
||||
paras = key_list[3:]
|
||||
if paras[0] == 'ln_1':
|
||||
new_para_name = '.'.join(['ln1'] + key_list[4:])
|
||||
elif paras[0] == 'attn':
|
||||
new_para_name = '.'.join(['attn.attn'] + key_list[4:])
|
||||
elif paras[0] == 'ln_2':
|
||||
new_para_name = '.'.join(['ln2'] + key_list[4:])
|
||||
elif paras[0] == 'mlp':
|
||||
if paras[1] == 'c_fc':
|
||||
new_para_name = '.'.join(['ffn.layers.0.0'] +
|
||||
key_list[-1:])
|
||||
else:
|
||||
new_para_name = '.'.join(['ffn.layers.1'] +
|
||||
key_list[-1:])
|
||||
new_name = '.'.join([
|
||||
new_transform_name, new_layer_name, layer_index,
|
||||
new_para_name
|
||||
])
|
||||
elif key_list[0] == 'side_adapter_network':
|
||||
decode_head_name = 'decode_head'
|
||||
module_name = 'side_adapter_network'
|
||||
if key_list[1] == 'vit_model':
|
||||
if key_list[2] == 'blocks':
|
||||
layer_name = 'encode_layers'
|
||||
layer_index = key_list[3]
|
||||
paras = key_list[4:]
|
||||
if paras[0] == 'norm1':
|
||||
new_para_name = '.'.join(['ln1'] + key_list[5:])
|
||||
elif paras[0] == 'attn':
|
||||
new_para_name = '.'.join(key_list[4:])
|
||||
new_para_name = new_para_name.replace(
|
||||
'attn.qkv.', 'attn.attn.in_proj_')
|
||||
new_para_name = new_para_name.replace(
|
||||
'attn.proj', 'attn.attn.out_proj')
|
||||
elif paras[0] == 'norm2':
|
||||
new_para_name = '.'.join(['ln2'] + key_list[5:])
|
||||
elif paras[0] == 'mlp':
|
||||
new_para_name = '.'.join(['ffn'] + key_list[5:])
|
||||
new_para_name = new_para_name.replace(
|
||||
'fc1', 'layers.0.0')
|
||||
new_para_name = new_para_name.replace(
|
||||
'fc2', 'layers.1')
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
new_name = '.'.join([
|
||||
decode_head_name, module_name, layer_name, layer_index,
|
||||
new_para_name
|
||||
])
|
||||
elif key_list[2] == 'pos_embed':
|
||||
new_name = '.'.join(
|
||||
[decode_head_name, module_name, 'pos_embed'])
|
||||
elif key_list[2] == 'patch_embed':
|
||||
new_name = '.'.join([
|
||||
decode_head_name, module_name, 'patch_embed',
|
||||
'projection', key_list[4]
|
||||
])
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
elif key_list[1] == 'query_embed' or key_list[
|
||||
1] == 'query_pos_embed':
|
||||
new_name = '.'.join(
|
||||
[decode_head_name, module_name, key_list[1]])
|
||||
elif key_list[1] == 'fusion_layers':
|
||||
layer_name = 'conv_clips'
|
||||
layer_index = key_list[2][-1]
|
||||
paras = '.'.join(key_list[3:])
|
||||
new_para_name = paras.replace('input_proj.0', '0')
|
||||
new_para_name = new_para_name.replace('input_proj.1', '1.conv')
|
||||
new_name = '.'.join([
|
||||
decode_head_name, module_name, layer_name, layer_index,
|
||||
new_para_name
|
||||
])
|
||||
elif key_list[1] == 'mask_decoder':
|
||||
new_name = 'decode_head.' + k
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
elif key_list[0] == 'clip_rec_head':
|
||||
module_name = 'rec_with_attnbias'
|
||||
if key_list[1] == 'proj':
|
||||
new_name = '.'.join(
|
||||
[decode_head_name, module_name, 'proj.weight'])
|
||||
elif key_list[1] == 'ln_post':
|
||||
new_name = '.'.join(
|
||||
[decode_head_name, module_name, 'ln_post', key_list[2]])
|
||||
elif key_list[1] == 'resblocks':
|
||||
new_layer_name = 'layers'
|
||||
layer_index = key_list[2]
|
||||
paras = key_list[3:]
|
||||
if paras[0] == 'ln_1':
|
||||
new_para_name = '.'.join(['norms.0'] + paras[1:])
|
||||
elif paras[0] == 'attn':
|
||||
new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
|
||||
elif paras[0] == 'ln_2':
|
||||
new_para_name = '.'.join(['norms.1'] + paras[1:])
|
||||
elif paras[0] == 'mlp':
|
||||
if paras[1] == 'c_fc':
|
||||
new_para_name = '.'.join(['ffns.0.layers.0.0'] +
|
||||
paras[2:])
|
||||
elif paras[1] == 'c_proj':
|
||||
new_para_name = '.'.join(['ffns.0.layers.1'] +
|
||||
paras[2:])
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
new_name = '.'.join([
|
||||
decode_head_name, module_name, new_layer_name, layer_index,
|
||||
new_para_name
|
||||
])
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
elif key_list[0] == 'ov_classifier':
|
||||
text_encoder_name = 'text_encoder'
|
||||
if key_list[1] == 'transformer':
|
||||
layer_name = 'transformer'
|
||||
layer_index = key_list[3]
|
||||
paras = key_list[4:]
|
||||
if paras[0] == 'attn':
|
||||
new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
|
||||
elif paras[0] == 'ln_1':
|
||||
new_para_name = '.'.join(['norms.0'] + paras[1:])
|
||||
elif paras[0] == 'ln_2':
|
||||
new_para_name = '.'.join(['norms.1'] + paras[1:])
|
||||
elif paras[0] == 'mlp':
|
||||
if paras[1] == 'c_fc':
|
||||
new_para_name = '.'.join(['ffns.0.layers.0.0'] +
|
||||
paras[2:])
|
||||
elif paras[1] == 'c_proj':
|
||||
new_para_name = '.'.join(['ffns.0.layers.1'] +
|
||||
paras[2:])
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
new_name = '.'.join([
|
||||
text_encoder_name, layer_name, layer_index, new_para_name
|
||||
])
|
||||
elif key_list[1] in [
|
||||
'positional_embedding', 'text_projection', 'bg_embed',
|
||||
'attn_mask', 'logit_scale', 'token_embedding', 'ln_final'
|
||||
]:
|
||||
new_name = k.replace('ov_classifier', 'text_encoder')
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
elif key_list[0] == 'criterion':
|
||||
new_name = k
|
||||
else:
|
||||
print(f'Wrong for {k}')
|
||||
new_ckpt[new_name] = v
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def convert_tensor(ckpt):
|
||||
cls_token = ckpt['image_encoder.cls_token']
|
||||
new_cls_token = cls_token.unsqueeze(0).unsqueeze(0)
|
||||
ckpt['image_encoder.cls_token'] = new_cls_token
|
||||
pos_embed = ckpt['image_encoder.pos_embed']
|
||||
new_pos_embed = pos_embed.unsqueeze(0)
|
||||
ckpt['image_encoder.pos_embed'] = new_pos_embed
|
||||
proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight']
|
||||
new_proj_weight = proj_weight.transpose(1, 0)
|
||||
ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight
|
||||
return ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in timm pretrained vit models to '
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
# timm checkpoint
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
# deit checkpoint
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
weight = convert_key_name(state_dict)
|
||||
weight = convert_tensor(weight)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
71
finetune/tools/model_converters/stdc2mmseg.py
Normal file
71
finetune/tools/model_converters/stdc2mmseg.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_stdc(ckpt, stdc_type):
|
||||
new_state_dict = {}
|
||||
if stdc_type == 'STDC1':
|
||||
stage_lst = ['0', '1', '2.0', '2.1', '3.0', '3.1', '4.0', '4.1']
|
||||
else:
|
||||
stage_lst = [
|
||||
'0', '1', '2.0', '2.1', '2.2', '2.3', '3.0', '3.1', '3.2', '3.3',
|
||||
'3.4', '4.0', '4.1', '4.2'
|
||||
]
|
||||
for k, v in ckpt.items():
|
||||
ori_k = k
|
||||
flag = False
|
||||
if 'cp.' in k:
|
||||
k = k.replace('cp.', '')
|
||||
if 'features.' in k:
|
||||
num_layer = int(k.split('.')[1])
|
||||
feature_key_lst = 'features.' + str(num_layer) + '.'
|
||||
stages_key_lst = 'stages.' + stage_lst[num_layer] + '.'
|
||||
k = k.replace(feature_key_lst, stages_key_lst)
|
||||
flag = True
|
||||
if 'conv_list' in k:
|
||||
k = k.replace('conv_list', 'layers')
|
||||
flag = True
|
||||
if 'avd_layer.' in k:
|
||||
if 'avd_layer.0' in k:
|
||||
k = k.replace('avd_layer.0', 'downsample.conv')
|
||||
elif 'avd_layer.1' in k:
|
||||
k = k.replace('avd_layer.1', 'downsample.bn')
|
||||
flag = True
|
||||
if flag:
|
||||
new_state_dict[k] = ckpt[ori_k]
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in official pretrained STDC1/2 to '
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
parser.add_argument('type', help='model type: STDC1 or STDC2')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
assert args.type in ['STDC1',
|
||||
'STDC2'], 'STD type should be STDC1 or STDC2!'
|
||||
weight = convert_stdc(state_dict, args.type)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
87
finetune/tools/model_converters/swin2mmseg.py
Normal file
87
finetune/tools/model_converters/swin2mmseg.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_swin(ckpt):
|
||||
new_ckpt = OrderedDict()
|
||||
|
||||
def correct_unfold_reduction_order(x):
|
||||
out_channel, in_channel = x.shape
|
||||
x = x.reshape(out_channel, 4, in_channel // 4)
|
||||
x = x[:, [0, 2, 1, 3], :].transpose(1,
|
||||
2).reshape(out_channel, in_channel)
|
||||
return x
|
||||
|
||||
def correct_unfold_norm_order(x):
|
||||
in_channel = x.shape[0]
|
||||
x = x.reshape(4, in_channel // 4)
|
||||
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
|
||||
return x
|
||||
|
||||
for k, v in ckpt.items():
|
||||
if k.startswith('head'):
|
||||
continue
|
||||
elif k.startswith('layers'):
|
||||
new_v = v
|
||||
if 'attn.' in k:
|
||||
new_k = k.replace('attn.', 'attn.w_msa.')
|
||||
elif 'mlp.' in k:
|
||||
if 'mlp.fc1.' in k:
|
||||
new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
|
||||
elif 'mlp.fc2.' in k:
|
||||
new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
|
||||
else:
|
||||
new_k = k.replace('mlp.', 'ffn.')
|
||||
elif 'downsample' in k:
|
||||
new_k = k
|
||||
if 'reduction.' in k:
|
||||
new_v = correct_unfold_reduction_order(v)
|
||||
elif 'norm.' in k:
|
||||
new_v = correct_unfold_norm_order(v)
|
||||
else:
|
||||
new_k = k
|
||||
new_k = new_k.replace('layers', 'stages', 1)
|
||||
elif k.startswith('patch_embed'):
|
||||
new_v = v
|
||||
if 'proj' in k:
|
||||
new_k = k.replace('proj', 'projection')
|
||||
else:
|
||||
new_k = k
|
||||
else:
|
||||
new_v = v
|
||||
new_k = k
|
||||
|
||||
new_ckpt[new_k] = new_v
|
||||
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in official pretrained swin models to'
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
weight = convert_swin(state_dict)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
87
finetune/tools/model_converters/twins2mmseg.py
Normal file
87
finetune/tools/model_converters/twins2mmseg.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_twins(args, ckpt):
|
||||
|
||||
new_ckpt = OrderedDict()
|
||||
|
||||
for k, v in list(ckpt.items()):
|
||||
new_v = v
|
||||
if k.startswith('head'):
|
||||
continue
|
||||
elif k.startswith('patch_embeds'):
|
||||
if 'proj.' in k:
|
||||
new_k = k.replace('proj.', 'projection.')
|
||||
else:
|
||||
new_k = k
|
||||
elif k.startswith('blocks'):
|
||||
# Union
|
||||
if 'attn.q.' in k:
|
||||
new_k = k.replace('q.', 'attn.in_proj_')
|
||||
new_v = torch.cat([v, ckpt[k.replace('attn.q.', 'attn.kv.')]],
|
||||
dim=0)
|
||||
elif 'mlp.fc1' in k:
|
||||
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
|
||||
elif 'mlp.fc2' in k:
|
||||
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
|
||||
# Only pcpvt
|
||||
elif args.model == 'pcpvt':
|
||||
if 'attn.proj.' in k:
|
||||
new_k = k.replace('proj.', 'attn.out_proj.')
|
||||
else:
|
||||
new_k = k
|
||||
|
||||
# Only svt
|
||||
else:
|
||||
if 'attn.proj.' in k:
|
||||
k_lst = k.split('.')
|
||||
if int(k_lst[2]) % 2 == 1:
|
||||
new_k = k.replace('proj.', 'attn.out_proj.')
|
||||
else:
|
||||
new_k = k
|
||||
else:
|
||||
new_k = k
|
||||
new_k = new_k.replace('blocks.', 'layers.')
|
||||
elif k.startswith('pos_block'):
|
||||
new_k = k.replace('pos_block', 'position_encodings')
|
||||
if 'proj.0.' in new_k:
|
||||
new_k = new_k.replace('proj.0.', 'proj.')
|
||||
else:
|
||||
new_k = k
|
||||
if 'attn.kv.' not in k:
|
||||
new_ckpt[new_k] = new_v
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in timm pretrained vit models to '
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
parser.add_argument('model', help='model: pcpvt or svt')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
# timm checkpoint
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
weight = convert_twins(args, state_dict)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
70
finetune/tools/model_converters/vit2mmseg.py
Normal file
70
finetune/tools/model_converters/vit2mmseg.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_vit(ckpt):
|
||||
|
||||
new_ckpt = OrderedDict()
|
||||
|
||||
for k, v in ckpt.items():
|
||||
if k.startswith('head'):
|
||||
continue
|
||||
if k.startswith('norm'):
|
||||
new_k = k.replace('norm.', 'ln1.')
|
||||
elif k.startswith('patch_embed'):
|
||||
if 'proj' in k:
|
||||
new_k = k.replace('proj', 'projection')
|
||||
else:
|
||||
new_k = k
|
||||
elif k.startswith('blocks'):
|
||||
if 'norm' in k:
|
||||
new_k = k.replace('norm', 'ln')
|
||||
elif 'mlp.fc1' in k:
|
||||
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
|
||||
elif 'mlp.fc2' in k:
|
||||
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
|
||||
elif 'attn.qkv' in k:
|
||||
new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_')
|
||||
elif 'attn.proj' in k:
|
||||
new_k = k.replace('attn.proj', 'attn.attn.out_proj')
|
||||
else:
|
||||
new_k = k
|
||||
new_k = new_k.replace('blocks.', 'layers.')
|
||||
else:
|
||||
new_k = k
|
||||
new_ckpt[new_k] = v
|
||||
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in timm pretrained vit models to '
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
# timm checkpoint
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
# deit checkpoint
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
weight = convert_vit(state_dict)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
123
finetune/tools/model_converters/vitjax2mmseg.py
Normal file
123
finetune/tools/model_converters/vitjax2mmseg.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def vit_jax_to_torch(jax_weights, num_layer=12):
|
||||
torch_weights = dict()
|
||||
|
||||
# patch embedding
|
||||
conv_filters = jax_weights['embedding/kernel']
|
||||
conv_filters = conv_filters.permute(3, 2, 0, 1)
|
||||
torch_weights['patch_embed.projection.weight'] = conv_filters
|
||||
torch_weights['patch_embed.projection.bias'] = jax_weights[
|
||||
'embedding/bias']
|
||||
|
||||
# pos embedding
|
||||
torch_weights['pos_embed'] = jax_weights[
|
||||
'Transformer/posembed_input/pos_embedding']
|
||||
|
||||
# cls token
|
||||
torch_weights['cls_token'] = jax_weights['cls']
|
||||
|
||||
# head
|
||||
torch_weights['ln1.weight'] = jax_weights['Transformer/encoder_norm/scale']
|
||||
torch_weights['ln1.bias'] = jax_weights['Transformer/encoder_norm/bias']
|
||||
|
||||
# transformer blocks
|
||||
for i in range(num_layer):
|
||||
jax_block = f'Transformer/encoderblock_{i}'
|
||||
torch_block = f'layers.{i}'
|
||||
|
||||
# attention norm
|
||||
torch_weights[f'{torch_block}.ln1.weight'] = jax_weights[
|
||||
f'{jax_block}/LayerNorm_0/scale']
|
||||
torch_weights[f'{torch_block}.ln1.bias'] = jax_weights[
|
||||
f'{jax_block}/LayerNorm_0/bias']
|
||||
|
||||
# attention
|
||||
query_weight = jax_weights[
|
||||
f'{jax_block}/MultiHeadDotProductAttention_1/query/kernel']
|
||||
query_bias = jax_weights[
|
||||
f'{jax_block}/MultiHeadDotProductAttention_1/query/bias']
|
||||
key_weight = jax_weights[
|
||||
f'{jax_block}/MultiHeadDotProductAttention_1/key/kernel']
|
||||
key_bias = jax_weights[
|
||||
f'{jax_block}/MultiHeadDotProductAttention_1/key/bias']
|
||||
value_weight = jax_weights[
|
||||
f'{jax_block}/MultiHeadDotProductAttention_1/value/kernel']
|
||||
value_bias = jax_weights[
|
||||
f'{jax_block}/MultiHeadDotProductAttention_1/value/bias']
|
||||
|
||||
qkv_weight = torch.from_numpy(
|
||||
np.stack((query_weight, key_weight, value_weight), 1))
|
||||
qkv_weight = torch.flatten(qkv_weight, start_dim=1)
|
||||
qkv_bias = torch.from_numpy(
|
||||
np.stack((query_bias, key_bias, value_bias), 0))
|
||||
qkv_bias = torch.flatten(qkv_bias, start_dim=0)
|
||||
|
||||
torch_weights[f'{torch_block}.attn.attn.in_proj_weight'] = qkv_weight
|
||||
torch_weights[f'{torch_block}.attn.attn.in_proj_bias'] = qkv_bias
|
||||
to_out_weight = jax_weights[
|
||||
f'{jax_block}/MultiHeadDotProductAttention_1/out/kernel']
|
||||
to_out_weight = torch.flatten(to_out_weight, start_dim=0, end_dim=1)
|
||||
torch_weights[
|
||||
f'{torch_block}.attn.attn.out_proj.weight'] = to_out_weight
|
||||
torch_weights[f'{torch_block}.attn.attn.out_proj.bias'] = jax_weights[
|
||||
f'{jax_block}/MultiHeadDotProductAttention_1/out/bias']
|
||||
|
||||
# mlp norm
|
||||
torch_weights[f'{torch_block}.ln2.weight'] = jax_weights[
|
||||
f'{jax_block}/LayerNorm_2/scale']
|
||||
torch_weights[f'{torch_block}.ln2.bias'] = jax_weights[
|
||||
f'{jax_block}/LayerNorm_2/bias']
|
||||
|
||||
# mlp
|
||||
torch_weights[f'{torch_block}.ffn.layers.0.0.weight'] = jax_weights[
|
||||
f'{jax_block}/MlpBlock_3/Dense_0/kernel']
|
||||
torch_weights[f'{torch_block}.ffn.layers.0.0.bias'] = jax_weights[
|
||||
f'{jax_block}/MlpBlock_3/Dense_0/bias']
|
||||
torch_weights[f'{torch_block}.ffn.layers.1.weight'] = jax_weights[
|
||||
f'{jax_block}/MlpBlock_3/Dense_1/kernel']
|
||||
torch_weights[f'{torch_block}.ffn.layers.1.bias'] = jax_weights[
|
||||
f'{jax_block}/MlpBlock_3/Dense_1/bias']
|
||||
|
||||
# transpose weights
|
||||
for k, v in torch_weights.items():
|
||||
if 'weight' in k and 'patch_embed' not in k and 'ln' not in k:
|
||||
v = v.permute(1, 0)
|
||||
torch_weights[k] = v
|
||||
|
||||
return torch_weights
|
||||
|
||||
|
||||
def main():
|
||||
# stole refactoring code from Robin Strudel, thanks
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys from jax official pretrained vit models to '
|
||||
'MMSegmentation style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
jax_weights = np.load(args.src)
|
||||
jax_weights_tensor = {}
|
||||
for key in jax_weights.files:
|
||||
value = torch.from_numpy(jax_weights[key])
|
||||
jax_weights_tensor[key] = value
|
||||
if 'L_16-i21k' in args.src:
|
||||
num_layer = 24
|
||||
else:
|
||||
num_layer = 12
|
||||
torch_weights = vit_jax_to_torch(jax_weights_tensor, num_layer)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(torch_weights, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user