init
This commit is contained in:
5
lib/__init__.py
Normal file
5
lib/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .datasets import *
|
||||
from .models import *
|
||||
from .predictors import *
|
||||
from .trainer import *
|
||||
from .task import *
|
||||
3
lib/datasets/__init__.py
Normal file
3
lib/datasets/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .builder import PretrainingBuilder
|
||||
|
||||
__all__ = ["PretrainingBuilder"]
|
||||
18
lib/datasets/builder.py
Normal file
18
lib/datasets/builder.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from antmmf.common.registry import registry
|
||||
from antmmf.datasets.base_dataset_builder import BaseDatasetBuilder
|
||||
from .loader.pretraining_loader import PretrainingLoader
|
||||
|
||||
@registry.register_builder("pretraining_loader")
|
||||
class PretrainingBuilder(BaseDatasetBuilder):
|
||||
def __init__(self):
|
||||
super().__init__("pretraining_loader")
|
||||
|
||||
def _build(self, dataset_type, config, *args, **kwargs):
|
||||
return None
|
||||
|
||||
def _load(self, dataset_type, config, *args, **kwargs):
|
||||
self.dataset = PretrainingLoader(dataset_type, config)
|
||||
return self.dataset
|
||||
|
||||
def update_registry_for_model(self, config):
|
||||
pass
|
||||
0
lib/datasets/loader/__init__.py
Normal file
0
lib/datasets/loader/__init__.py
Normal file
289
lib/datasets/loader/few_shot_flood3i_loader.py
Normal file
289
lib/datasets/loader/few_shot_flood3i_loader.py
Normal file
@@ -0,0 +1,289 @@
|
||||
import os
|
||||
import json
|
||||
import datetime
|
||||
import random
|
||||
import itertools
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from antmmf.structures import Sample
|
||||
from antmmf.datasets.base_dataset import BaseDataset
|
||||
from antmmf.common import Configuration
|
||||
|
||||
from lib.datasets.utils.transforms import Compose, MSNormalize
|
||||
from lib.datasets.utils.formatting import ToTensor
|
||||
import lib.datasets.utils.pair_trainsforms as pair_transforms
|
||||
|
||||
from skimage import io
|
||||
from osgeo import gdal
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class FewShotFloodLoader(BaseDataset):
|
||||
DATASET_NAME = "few_shot_flood_loader"
|
||||
|
||||
def __init__(self, dataset_type, config):
|
||||
super().__init__(self.__class__.DATASET_NAME, dataset_type, config)
|
||||
if dataset_type == 'train':
|
||||
raise ValueError('train mode not support!!!')
|
||||
|
||||
self.root = config.data_root_dir
|
||||
self.dataset_type = dataset_type
|
||||
self.img_dir = config.img_dir
|
||||
self.tgt_dir = config.tgt_dir
|
||||
with open(config.data_txt, 'r') as f:
|
||||
test_list = f.readlines()
|
||||
self.test_pairs = []
|
||||
self.cls2path = {}
|
||||
for i in test_list:
|
||||
i = i.strip()
|
||||
if i == '':
|
||||
continue
|
||||
img_path = i[:-3]
|
||||
cls = int(i[-2:])
|
||||
cls = int(cls)
|
||||
self.test_pairs.append(
|
||||
{'hr_path': img_path,
|
||||
'class': cls,
|
||||
'tgt_path': img_path.replace('_', '_lab_', 1).replace('.jpg', '.png')
|
||||
})
|
||||
if cls in self.cls2path.keys():
|
||||
self.cls2path[cls].append({'hr_path': img_path, 'tgt_path': img_path.replace('_', '_lab_', 1).replace('.jpg', '.png'), 'class': cls})
|
||||
else:
|
||||
self.cls2path[cls] = [{'hr_path': img_path, 'tgt_path': img_path.replace('_', '_lab_', 1).replace('.jpg', '.png'), 'class': cls}]
|
||||
|
||||
self.seq_len = config.seq_len # ts
|
||||
self.hr_size = config.image_size.hr
|
||||
self.s2_size = config.image_size.s2
|
||||
self.s1_size = config.image_size.s1
|
||||
self.anno_size = config.image_size.anno
|
||||
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
|
||||
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]) # 先不管
|
||||
self.config = config
|
||||
self.pipeline = self._get_pipline()
|
||||
# self.crop_resize = pair_transforms.RandomResizedCropComb(512, scale=(0.99, 1.0), interpolation=3)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.test_pairs)
|
||||
|
||||
def _combine_two_images(self, image, image2):
|
||||
dst = torch.cat([image, image2], dim=-2)
|
||||
return dst
|
||||
|
||||
def _get_pipline(self):
|
||||
if self.dataset_type == 'val' or self.dataset_type == 'test':
|
||||
pipeline = [
|
||||
pair_transforms.ToTensor(),
|
||||
pair_transforms.RandomResizedCrop(512, scale=(0.9999, 1.0), interpolation=3),
|
||||
pair_transforms.Normalize(),
|
||||
]
|
||||
else:
|
||||
raise ValueError('dataset_type not support')
|
||||
return pair_transforms.Compose(pipeline)
|
||||
|
||||
def _load_data(self, data_path):
|
||||
file_name, file_extension = os.path.splitext(data_path)
|
||||
if file_extension == '.npz' or file_extension == '.npy':
|
||||
npz_key = self.config.get('npz_key', 'image')
|
||||
data = np.load(data_path)[npz_key]
|
||||
elif file_extension == '.png' or file_extension == '.jpg':
|
||||
data = io.imread(data_path)
|
||||
if len(data.shape) == 3:
|
||||
data = data.transpose(2, 0, 1)
|
||||
elif file_extension == '.tiff' or file_extension == '.tif':
|
||||
dataset = gdal.Open(data_path)
|
||||
if dataset is None:
|
||||
raise IOError(f'can not open file: {data_path}')
|
||||
data = dataset.ReadAsArray()
|
||||
dataset = None
|
||||
else:
|
||||
raise ValueError(f'file type {data_path} not support')
|
||||
# check nan
|
||||
if np.isnan(data).any():
|
||||
print(f'{data_path} with nan, replace it to 0!')
|
||||
data[np.isnan(data)] = 0
|
||||
return data
|
||||
|
||||
def load_s2(self, pair):
|
||||
if 'l8_path' in pair.keys():
|
||||
pair['s2_path'] = pair['l8_path']
|
||||
|
||||
if 's2_path' in pair.keys() and not self.config.get('masking_s2', False):
|
||||
with_s2 = True
|
||||
if isinstance(pair['s2_path'], list):
|
||||
if True: # len(pair['s2_path']) > self.seq_len:
|
||||
s2_path_list = np.random.choice(pair['s2_path'], self.seq_len)
|
||||
s2_path_list = sorted(s2_path_list)
|
||||
else:
|
||||
s2_path_list = pair['s2_path']
|
||||
s2_list = []
|
||||
s2_ct_1 = []
|
||||
for s2_path in s2_path_list:
|
||||
s2 = self._load_data(os.path.join(self.root, s2_path)) # [:10]
|
||||
s2_list.append(s2)
|
||||
ct = os.path.splitext(s2_path)[0].split('_')
|
||||
ct = ct[3] # + ct[-3] + '01'
|
||||
try:
|
||||
ct = datetime.datetime.strptime(ct, '%Y%m%d')
|
||||
except:
|
||||
ct = datetime.datetime.strptime(ct, '%Y-%m-%d')
|
||||
ct = ct.timetuple()
|
||||
ct = ct.tm_yday - 1
|
||||
s2_ct_1.append(ct)
|
||||
s2_1 = np.stack(s2_list, axis=1)
|
||||
|
||||
else:
|
||||
|
||||
s2 = np.load(os.path.join(self.root, pair['s2_path']))['image']
|
||||
date = np.load(os.path.join(self.root, pair['s2_path']))['date']
|
||||
if True: # s2.shape[0] > self.seq_len:
|
||||
selected_indices = np.random.choice(s2.shape[0], size=self.seq_len, replace=False)
|
||||
selected_indices = sorted(selected_indices)
|
||||
s2 = s2[selected_indices, :, :, :]
|
||||
date = date[selected_indices]
|
||||
s2_1 = s2.transpose(1, 0, 2, 3) # ts, c, h, w -> c, ts, h, w
|
||||
s2_ct_1 = []
|
||||
for ct in date:
|
||||
try:
|
||||
ct = datetime.datetime.strptime(ct, '%Y%m%d')
|
||||
except:
|
||||
ct = datetime.datetime.strptime(ct, '%Y-%m-%d')
|
||||
ct = ct.timetuple()
|
||||
ct = ct.tm_yday - 1
|
||||
s2_ct_1.append(ct)
|
||||
|
||||
else:
|
||||
with_s2 = False
|
||||
s2_1 = np.zeros((10, self.seq_len, self.s2_size[0], self.s2_size[1]),
|
||||
dtype=np.int16)
|
||||
s2_ct_1 = [0] * self.seq_len
|
||||
|
||||
return with_s2, s2_1, s2_ct_1
|
||||
|
||||
def load_s1(self, pair):
|
||||
if 's1_path' in pair.keys():
|
||||
with_s1 = True
|
||||
if isinstance(pair['s1_path'], list):
|
||||
if True: # len(pair['s1_path']) > self.seq_len:
|
||||
s1_path_list = np.random.choice(pair['s1_path'], self.seq_len)
|
||||
s1_path_list = sorted(s1_path_list)
|
||||
else:
|
||||
s1_path_list = pair['s1_path']
|
||||
s1_list = []
|
||||
for s1_path in s1_path_list:
|
||||
s1 = self._load_data(os.path.join(self.root, s1_path))
|
||||
s1_list.append(s1)
|
||||
s1_1 = np.stack(s1_list, axis=1)
|
||||
else:
|
||||
s1 = self._load_data(os.path.join(self.root, pair['s1_path']))
|
||||
if True: # s1.shape[0] > self.seq_len:
|
||||
selected_indices = np.random.choice(s1.shape[0], size=self.seq_len, replace=False)
|
||||
selected_indices = sorted(selected_indices)
|
||||
s1 = s1[selected_indices, :, :, :]
|
||||
s1_1 = s1.transpose(1, 0, 2, 3) # ts, c, h, w -> c, ts, h, w
|
||||
else:
|
||||
with_s1 = False
|
||||
s1_1 = np.zeros((2, self.seq_len, self.s1_size[0], self.s1_size[1]),
|
||||
dtype=np.float32)
|
||||
return with_s1, s1_1
|
||||
|
||||
def load_hr(self, pair):
|
||||
if 'hr_path' in pair.keys():
|
||||
with_hr = True
|
||||
hr = self._load_data(os.path.join(self.root, pair['hr_path']))
|
||||
else:
|
||||
with_hr = False
|
||||
hr = np.zeros((3, self.hr_size[0], self.hr_size[1]),
|
||||
dtype=np.uint8)
|
||||
return with_hr, hr
|
||||
|
||||
def load_tgt(self, pair):
|
||||
targets = self._load_data(os.path.join(self.root, pair['target_path']))
|
||||
return targets
|
||||
|
||||
def get_item(self, idx):
|
||||
pair = self.test_pairs[idx]
|
||||
test_class = pair['class']
|
||||
|
||||
current_dataset = 'flood3i'
|
||||
with_hr = True
|
||||
with_s2 = False
|
||||
with_s1 = False
|
||||
|
||||
input_hr = io.imread(os.path.join(self.img_dir, pair['hr_path']))
|
||||
input_hr = input_hr.transpose(2,0,1)
|
||||
_, input_s2,_ = self.load_s2(pair)
|
||||
_, input_s1 = self.load_s1(pair)
|
||||
input_tgt = io.imread(os.path.join(self.tgt_dir, pair['tgt_path']))
|
||||
modality_dict = {
|
||||
's2': with_s2,
|
||||
's1': with_s1,
|
||||
'hr': with_hr
|
||||
}
|
||||
|
||||
|
||||
input_tgt[input_tgt != test_class] = 0
|
||||
input_tgt[input_tgt == test_class] = 255
|
||||
input_tgt = np.concatenate((input_tgt[None, :,:],)*3, axis=0)
|
||||
input_hr, input_s2, input_s1, input_tgt = self.pipeline(current_dataset, input_hr, input_s2, input_s1,
|
||||
input_tgt)
|
||||
|
||||
while True:
|
||||
sel_prompt = random.choice(self.cls2path[test_class])
|
||||
if sel_prompt['hr_path'] != pair['hr_path']:
|
||||
break
|
||||
prompt_hr = io.imread(os.path.join(self.img_dir, sel_prompt['hr_path']))
|
||||
prompt_hr = prompt_hr.transpose(2,0,1)
|
||||
_, prompt_s2,_ = self.load_s2(pair)
|
||||
_, prompt_s1 = self.load_s1(pair)
|
||||
prompt_tgt = io.imread(os.path.join(self.tgt_dir, sel_prompt['tgt_path']))
|
||||
|
||||
prompt_tgt[prompt_tgt != test_class] = 0
|
||||
prompt_tgt[prompt_tgt == test_class] = 255
|
||||
prompt_tgt = np.concatenate((prompt_tgt[None, :,:],)*3, axis=0)
|
||||
|
||||
prompt_hr, prompt_s2, prompt_s1, prompt_tgt = self.pipeline(current_dataset, prompt_hr, prompt_s2, prompt_s1, prompt_tgt)
|
||||
|
||||
targets_comb = self._combine_two_images(prompt_tgt, input_tgt)
|
||||
hr_comb = self._combine_two_images(prompt_hr, input_hr)
|
||||
s2_comb = self._combine_two_images(prompt_s2, input_s2)
|
||||
s1_comb = self._combine_two_images(prompt_s1, input_s1)
|
||||
|
||||
valid = torch.ones_like(targets_comb)
|
||||
thres = torch.ones(3) * 1e-5 # ignore black
|
||||
thres = (thres - self.imagenet_mean) / self.imagenet_std
|
||||
valid[targets_comb < thres[:, None, None]] = 0
|
||||
|
||||
mask_shape = (int(self.config.mim.input_size[0] / self.config.mim.patch_size),
|
||||
int(self.config.mim.input_size[1] / self.config.mim.patch_size))
|
||||
mask = np.zeros(mask_shape, dtype=np.int32)
|
||||
mask[mask.shape[0] // 2:, :] = 1
|
||||
|
||||
geo_location = pair["location"] if "location" in pair.keys() else None
|
||||
|
||||
modality_idx = 2 ** 0 * modality_dict['s2'] + 2 ** 1 * modality_dict['s1'] + 2 ** 2 * modality_dict['hr']
|
||||
modality_flag_s2 = modality_dict['s2']
|
||||
modality_flag_s1 = modality_dict['s1']
|
||||
modality_flag_hr = modality_dict['hr']
|
||||
|
||||
current_sample = Sample()
|
||||
current_sample.img_name = pair["tgt_path"].split('/')[-1].split('.')[0] + '-' +str(test_class)
|
||||
current_sample.hr_img = hr_comb
|
||||
current_sample.dataset_name = 'flood3i'
|
||||
current_sample.targets = targets_comb
|
||||
current_sample.s2_img = s2_comb
|
||||
current_sample.s2_ct = -1
|
||||
current_sample.s2_ct2 = -1
|
||||
current_sample.s1_img = s1_comb
|
||||
current_sample.anno_mask = torch.from_numpy(mask)
|
||||
current_sample.valid = valid
|
||||
current_sample.location = geo_location
|
||||
current_sample.modality_idx = modality_idx
|
||||
current_sample.modality_flag_s2 = modality_flag_s2
|
||||
current_sample.modality_flag_s1 = modality_flag_s1
|
||||
current_sample.modality_flag_hr = modality_flag_hr
|
||||
current_sample.task_type = self.dataset_type
|
||||
return current_sample
|
||||
494
lib/datasets/loader/pretraining_loader.py
Normal file
494
lib/datasets/loader/pretraining_loader.py
Normal file
@@ -0,0 +1,494 @@
|
||||
import os
|
||||
import json
|
||||
import datetime
|
||||
import random
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from osgeo import gdal
|
||||
from skimage import io
|
||||
from skimage.transform import resize
|
||||
|
||||
from antmmf.structures import Sample
|
||||
from antmmf.datasets.base_dataset import BaseDataset
|
||||
|
||||
import lib.datasets.utils.pair_trainsforms as pair_transforms
|
||||
from lib.datasets.utils.masking_generator import MaskingGenerator
|
||||
from lib.datasets.utils.dataset_colors import dataset_color_dict, get_painter_color_map_list, get_real_random_color_list
|
||||
|
||||
|
||||
class PretrainingLoader(BaseDataset):
|
||||
DATASET_NAME = "pretraining_loader"
|
||||
|
||||
def __init__(self, dataset_type, config):
|
||||
super().__init__(self.__class__.DATASET_NAME, dataset_type, config)
|
||||
self.root = config.data_root_dir
|
||||
if dataset_type == 'train':
|
||||
self.json_path_list = config.train_json_path_list
|
||||
if dataset_type == 'val':
|
||||
self.json_path_list = config.val_json_path_list
|
||||
if dataset_type == 'test':
|
||||
self.json_path_list = config.val_json_path_list
|
||||
self.dataset_type = dataset_type
|
||||
self.pairs = []
|
||||
self.cls_repeat_cnt = config.cls_repeat_cnt
|
||||
num_datasets = len(self.json_path_list)
|
||||
for idx, json_path in enumerate(self.json_path_list):
|
||||
print(os.path.join(config.data_root_dir, json_path))
|
||||
cur_pairs = json.load(open(os.path.join(config.data_root_dir, json_path)))
|
||||
self.pairs.extend(cur_pairs)
|
||||
cur_num = len(cur_pairs)
|
||||
|
||||
if dataset_type == 'test' and config.prompt_json:
|
||||
cur_pairs = json.load(open(config.prompt_json))
|
||||
self.prompt = cur_pairs[0]
|
||||
print(f'prompt:{self.prompt}')
|
||||
|
||||
self.use_multi_pairs = config.use_multi_pairs
|
||||
|
||||
if self.use_multi_pairs:
|
||||
self.pair_type_dict = {}
|
||||
if dataset_type == 'train' or dataset_type == 'val':
|
||||
for idx, pair in enumerate(self.pairs):
|
||||
if pair["type"] not in self.pair_type_dict:
|
||||
new_subset = {}
|
||||
classes = pair["classes"]
|
||||
for cls in classes:
|
||||
if cls not in new_subset.keys():
|
||||
new_subset[cls] = [idx]
|
||||
else:
|
||||
new_subset[cls].append(idx)
|
||||
self.pair_type_dict[pair["type"]] = new_subset
|
||||
else:
|
||||
classes = pair["classes"]
|
||||
for cls in classes:
|
||||
if cls not in self.pair_type_dict[pair["type"]].keys():
|
||||
self.pair_type_dict[pair["type"]][cls] = [idx]
|
||||
else:
|
||||
self.pair_type_dict[pair["type"]][cls].append(idx)
|
||||
|
||||
cnt = 0
|
||||
self.idx_to_cls = {}
|
||||
for k, v in self.pair_type_dict.items():
|
||||
for vv in v:
|
||||
self.idx_to_cls[cnt] = {
|
||||
'type': k,
|
||||
'classes_id': vv
|
||||
}
|
||||
cnt = cnt + 1
|
||||
|
||||
print(self.idx_to_cls)
|
||||
self.idx_to_cls_list = []
|
||||
for i in self.idx_to_cls.keys():
|
||||
self.idx_to_cls_list.append(self.idx_to_cls[i])
|
||||
print(self.idx_to_cls_list)
|
||||
if self.dataset_type == 'train':
|
||||
self.idx_to_cls_list = self.idx_to_cls_list * self.cls_repeat_cnt
|
||||
self.masked_position_generator = MaskingGenerator(
|
||||
input_size=config.mim.input_size,
|
||||
patch_size=config.mim.patch_size,
|
||||
mask_ratio=config.mim.mask_ratio
|
||||
)
|
||||
if dataset_type == 'train':
|
||||
self.half_mask_ratio = config.half_mask_ratio
|
||||
else:
|
||||
self.half_mask_ratio = 1.
|
||||
|
||||
self.seq_len = config.seq_len # ts
|
||||
self.hr_size = config.image_size.hr
|
||||
self.s2_size = config.image_size.s2
|
||||
self.s1_size = config.image_size.s1
|
||||
self.anno_size = config.image_size.anno
|
||||
self.min_random_scale = config.min_random_scale
|
||||
self.imagenet_mean=torch.tensor([0.485, 0.456, 0.406])
|
||||
self.imagenet_std=torch.tensor([0.229, 0.224, 0.225])
|
||||
|
||||
self.pipeline = self._get_pipline()
|
||||
self.crop_resize = pair_transforms.RandomResizedCropComb(512, scale=(0.3, 1.0), interpolation=3)
|
||||
self.num_samples = 8
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.idx_to_cls_list)
|
||||
|
||||
def _convert_colors_pairs(self, images, original_colors, new_colors, current_color):
|
||||
if len(original_colors) != len(new_colors):
|
||||
raise ValueError("The length of original_colors and new_colors must be the same.")
|
||||
unique_colors_list = []
|
||||
for image in images:
|
||||
if len(image.shape) == 3:
|
||||
image_hwc = image.transpose(1,2,0) # chw -> hwc
|
||||
elif len(image.shape) == 2:
|
||||
image_hwc = image[:,:,None]
|
||||
else:
|
||||
raise ValueError('image shape is {image_hwc.shape}, which is not support to change color!')
|
||||
|
||||
image_2d = image_hwc.reshape(-1, image_hwc.shape[-1])
|
||||
unique_colors = np.unique(image_2d, axis=0)
|
||||
unique_colors_list.append(unique_colors)
|
||||
unique_colors_list.append(original_colors)
|
||||
|
||||
sets_of_tuples = [set(map(tuple, a)) for a in unique_colors_list]
|
||||
common_tuples = set.intersection(*sets_of_tuples)
|
||||
unique_old_colors = np.array(list(common_tuples), dtype=np.uint8)
|
||||
if len(unique_old_colors) == 0:
|
||||
unique_old_colors = [current_color]
|
||||
new_colors_coverted = new_colors[:len(unique_old_colors)]
|
||||
images_converted_list = []
|
||||
|
||||
for image in images:
|
||||
image_convered = self._convert_colors(image, unique_old_colors, new_colors_coverted)
|
||||
images_converted_list.append(image_convered)
|
||||
|
||||
return images_converted_list
|
||||
|
||||
def _convert_colors(self, image, original_colors, new_colors):
|
||||
"""
|
||||
Remap colors in an image to new colors.
|
||||
|
||||
Parameters:
|
||||
image (numpy.ndarray): The image as a numpy array (channel x height x width).
|
||||
original_colors (list of tuples): The list of original colors to be replaced.
|
||||
new_colors (list of tuples): The list of new colors to replace the original colors.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: The image with remapped colors. (channel x height x width)
|
||||
"""
|
||||
|
||||
if len(original_colors) != len(new_colors):
|
||||
raise ValueError("The length of original_colors and new_colors must be the same.")
|
||||
|
||||
# Convert lists of tuples to numpy arrays for faster processing
|
||||
original_colors = np.array(original_colors)
|
||||
new_colors = np.array(new_colors)
|
||||
if len(original_colors.shape) == 1:
|
||||
original_colors = original_colors[:,None]
|
||||
|
||||
# check image shape
|
||||
if len(image.shape) == 3:
|
||||
remapped_image = image.transpose(1,2,0) # chw -> hwc
|
||||
elif len(image.shape) == 2:
|
||||
remapped_image = image[:,:,None]
|
||||
else:
|
||||
raise ValueError('image shape is {image.shape}, which is not support to change color!')
|
||||
|
||||
# generate new image for return
|
||||
new_image = np.zeros((remapped_image.shape[0], remapped_image.shape[1], 3), dtype=np.uint8)
|
||||
|
||||
for orig_color, new_color in zip(original_colors, new_colors):
|
||||
mask = np.all(remapped_image == orig_color, axis=-1)
|
||||
new_image[mask] = new_color
|
||||
|
||||
new_image = new_image.transpose(2,0,1) # hwc -> chw
|
||||
return new_image
|
||||
|
||||
def _combine_images(self, images, interpolation='bicubic'):
|
||||
# images 8, c, h, w -> c, 4h, 2w
|
||||
group1 = images[:4]
|
||||
group2 = images[4:]
|
||||
stacked1 = torch.cat(group1, dim=-2)
|
||||
stacked2 = torch.cat(group2, dim=-2)
|
||||
result = torch.cat((stacked1, stacked2), dim=-1)
|
||||
|
||||
return result
|
||||
|
||||
def _get_pipline(self):
|
||||
if self.dataset_type == 'train':
|
||||
pipeline = [
|
||||
pair_transforms.ToTensor(),
|
||||
pair_transforms.RandomResizedCrop(512, scale=(0.8, 1.0), interpolation=3), # 3 is bicubic
|
||||
pair_transforms.RandomHorizontalFlip(),
|
||||
pair_transforms.Normalize(),
|
||||
]
|
||||
elif self.dataset_type == 'val' or self.dataset_type == 'test':
|
||||
pipeline = [
|
||||
pair_transforms.ToTensor(),
|
||||
pair_transforms.RandomResizedCrop(512, scale=(0.9999, 1.0), interpolation=3), # 3 is bicubic
|
||||
pair_transforms.Normalize(),
|
||||
]
|
||||
else:
|
||||
raise ValueError('dataset_type not support')
|
||||
return pair_transforms.Compose(pipeline)
|
||||
|
||||
def _load_data(self, data_path):
|
||||
file_name, file_extension = os.path.splitext(data_path)
|
||||
if file_extension == '.npz' or file_extension == '.npy':
|
||||
data = np.load(data_path)['image']
|
||||
elif file_extension == '.png' or file_extension == '.jpg':
|
||||
data = io.imread(data_path)
|
||||
if len(data.shape) == 3:
|
||||
data = data.transpose(2,0,1)
|
||||
elif file_extension == '.tiff' or file_extension == '.tif':
|
||||
dataset = gdal.Open(data_path)
|
||||
if dataset is None:
|
||||
raise IOError(f'无法打开文件{data_path}')
|
||||
data = dataset.ReadAsArray()
|
||||
dataset = None
|
||||
else:
|
||||
raise ValueError(f'file type {data_path} not support')
|
||||
if np.isnan(data).any():
|
||||
print(f'{data_path} with nan, replace it to 0!')
|
||||
data[np.isnan(data)] = 0
|
||||
return data
|
||||
|
||||
def load_s2(self, pair):
|
||||
if pair['type'] == 'flair-mm' and 's2_path' in pair.keys():
|
||||
with_s2 =True
|
||||
s2 = np.load(os.path.join(self.root, pair['s2_path']))
|
||||
idx_centroid = pair['s2_cut_points']
|
||||
s2_patch_size = 40
|
||||
subset_sp = s2[:,:,idx_centroid[0]-int(s2_patch_size/2):idx_centroid[0] + \
|
||||
int(s2_patch_size/2),idx_centroid[1] - int(s2_patch_size/2):idx_centroid[1] + \
|
||||
int(s2_patch_size/2)]
|
||||
ts, c, h, w = subset_sp.shape
|
||||
subset_sp = subset_sp.reshape(-1, h, w).transpose(1,2,0)
|
||||
s2 = resize(subset_sp, (16, 16), anti_aliasing=True).transpose(2,0,1)
|
||||
s2 = s2.reshape(ts, c, 16, 16)
|
||||
if True:
|
||||
selected_indices = np.random.choice(s2.shape[0], size=self.seq_len, replace=False)
|
||||
selected_indices = sorted(selected_indices)
|
||||
s2 = s2[selected_indices, :, :, :]
|
||||
|
||||
s2_1 = s2.transpose(1,0,2,3) # ts, c, h, w -> c, ts, h, w
|
||||
s2_ct_1 = [0] * self.seq_len
|
||||
|
||||
elif 's2_path' in pair.keys():
|
||||
with_s2 =True
|
||||
if isinstance(pair['s2_path'], list):
|
||||
if True:
|
||||
s2_path_list = np.random.choice(pair['s2_path'], self.seq_len)
|
||||
s2_path_list = sorted(s2_path_list)
|
||||
else:
|
||||
s2_path_list = pair['s2_path']
|
||||
s2_list = []
|
||||
s2_ct_1 = []
|
||||
for s2_path in s2_path_list:
|
||||
s2 = self._load_data(os.path.join(self.root, s2_path))#[:10]
|
||||
s2_list.append(s2)
|
||||
ct = os.path.splitext(s2_path)[0].split('_')
|
||||
ct = ct[-4] + ct[-3] + '01'
|
||||
try:
|
||||
ct = datetime.datetime.strptime(ct, '%Y%m%d')
|
||||
except:
|
||||
ct = datetime.datetime.strptime(ct, '%Y-%m-%d')
|
||||
ct = ct.timetuple()
|
||||
ct = ct.tm_yday - 1
|
||||
s2_ct_1.append(ct)
|
||||
s2_1 = np.stack(s2_list, axis=1)
|
||||
|
||||
else:
|
||||
s2 = np.load(os.path.join(self.root, pair['s2_path']))['image']
|
||||
date = np.load(os.path.join(self.root, pair['s2_path']))['date']
|
||||
if True:
|
||||
selected_indices = np.random.choice(s2.shape[0], size=self.seq_len, replace=False)
|
||||
selected_indices = sorted(selected_indices)
|
||||
s2 = s2[selected_indices, :, :, :]
|
||||
date = date[selected_indices]
|
||||
s2_1 = s2.transpose(1,0,2,3) # ts, c, h, w -> c, ts, h, w
|
||||
s2_ct_1 = []
|
||||
for ct in date:
|
||||
try:
|
||||
ct = datetime.datetime.strptime(ct, '%Y%m%d')
|
||||
except:
|
||||
ct = datetime.datetime.strptime(ct, '%Y-%m-%d')
|
||||
ct = ct.timetuple()
|
||||
ct = ct.tm_yday - 1
|
||||
s2_ct_1.append(ct)
|
||||
else:
|
||||
with_s2 = False
|
||||
s2_1 = np.zeros((10, self.seq_len, self.s2_size[0], self.s2_size[1]),
|
||||
dtype=np.int16)
|
||||
s2_ct_1 = [0] * self.seq_len
|
||||
|
||||
return with_s2, s2_1, s2_ct_1
|
||||
|
||||
def load_s1(self, pair):
|
||||
if 's1_path' in pair.keys():
|
||||
with_s1 = True
|
||||
if isinstance(pair['s1_path'], list):
|
||||
if True:
|
||||
s1_path_list = np.random.choice(pair['s1_path'], self.seq_len)
|
||||
s1_path_list = sorted(s1_path_list)
|
||||
else:
|
||||
s1_path_list = pair['s1_path']
|
||||
s1_list = []
|
||||
for s1_path in s1_path_list:
|
||||
s1 = self._load_data(os.path.join(self.root, s1_path))
|
||||
s1_list.append(s1)
|
||||
s1_1 = np.stack(s1_list, axis=1)
|
||||
else:
|
||||
s1 = self._load_data(os.path.join(self.root, pair['s1_path']))
|
||||
if True:
|
||||
selected_indices = np.random.choice(s1.shape[0], size=self.seq_len, replace=False)
|
||||
selected_indices = sorted(selected_indices)
|
||||
s1 = s1[selected_indices, :, :, :]
|
||||
s1_1 = s1.transpose(1,0,2,3) # ts, c, h, w -> c, ts, h, w
|
||||
else:
|
||||
with_s1 = False
|
||||
s1_1 = np.zeros((2, self.seq_len, self.s1_size[0], self.s1_size[1]),
|
||||
dtype=np.float32)
|
||||
return with_s1, s1_1
|
||||
|
||||
def load_hr(self, pair):
|
||||
if 'hr_path' in pair.keys():
|
||||
if pair['type'] == 'flair-mm':
|
||||
with_hr = True
|
||||
hr = self._load_data(os.path.join(self.root, pair['hr_path']))[:3,:,:]
|
||||
else:
|
||||
with_hr = True
|
||||
hr = self._load_data(os.path.join(self.root, pair['hr_path']))
|
||||
else:
|
||||
with_hr = False
|
||||
hr = np.zeros((3, self.hr_size[0], self.hr_size[1]),
|
||||
dtype=np.uint8)
|
||||
return with_hr, hr
|
||||
|
||||
def load_tgt(self, pair):
|
||||
if self.dataset_type == 'test':
|
||||
targets = np.zeros((3, self.anno_size[0], self.anno_size[1]),
|
||||
dtype=np.uint8)
|
||||
else:
|
||||
targets = self._load_data(os.path.join(self.root, pair['target_path']))
|
||||
return targets
|
||||
|
||||
def find_random_position(self, matrix, current_color):
|
||||
if matrix.ndim == 2:
|
||||
matrix = matrix[None, :, :]
|
||||
current_color = np.array(current_color)
|
||||
C, H, W = matrix.shape
|
||||
|
||||
if len(current_color) != C:
|
||||
raise ValueError("current_color unmatch with matrix!")
|
||||
|
||||
matches = np.where(np.all(matrix == current_color[:, None, None], axis=0))
|
||||
|
||||
if len(matches[0]) > 0:
|
||||
index = np.random.choice(range(len(matches[0])))
|
||||
return (matches[0][index], matches[1][index])
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_item(self, idx):
|
||||
dataset_cls_infos = self.idx_to_cls_list[idx]
|
||||
current_dataset = dataset_cls_infos['type']
|
||||
current_classes_id = dataset_cls_infos['classes_id']
|
||||
pair_idx_list = self.pair_type_dict[current_dataset][current_classes_id]
|
||||
|
||||
old_colors = dataset_color_dict[current_dataset]
|
||||
current_color = old_colors[current_classes_id]
|
||||
class_num = len(old_colors)
|
||||
if self.dataset_type == 'train':
|
||||
new_colors = get_real_random_color_list(class_num)
|
||||
else:
|
||||
new_colors = get_painter_color_map_list(class_num) # fix colors mapping when testing
|
||||
|
||||
num_samples = self.num_samples
|
||||
if len(pair_idx_list) < num_samples:
|
||||
selected_samples = [random.choice(pair_idx_list) for _ in range(num_samples)]
|
||||
else:
|
||||
selected_samples = random.sample(pair_idx_list, num_samples)
|
||||
hr_imgs = []
|
||||
tgt_imgs = []
|
||||
s2_imgs = []
|
||||
s1_imgs = []
|
||||
s2_cts = []
|
||||
for sample_idx in selected_samples:
|
||||
pair = self.pairs[sample_idx]
|
||||
with_hr, hr = self.load_hr(pair)
|
||||
with_s2, s2, s2_ct_1 = self.load_s2(pair)
|
||||
with_s1, s1 = self.load_s1(pair)
|
||||
tgt = self.load_tgt(pair)
|
||||
modality_dict = {
|
||||
's2' : with_s2,
|
||||
's1' : with_s1,
|
||||
'hr' : with_hr
|
||||
}
|
||||
|
||||
if (hr.shape[-2:] != tuple(self.hr_size)) and (hr.shape[-2:] == tgt.shape[-2:]) and (self.hr_size == self.anno_size):
|
||||
point_pos = self.find_random_position(tgt, current_color)
|
||||
upper_left_raw = [point_pos[0] - self.hr_size[0] // 2, point_pos[1] - self.hr_size[1] // 2]
|
||||
upper_left = [i - i%32 + 16 for i in upper_left_raw]
|
||||
upper_left_sentinel = [i // 32 for i in upper_left_raw]
|
||||
upper_left[0] = np.clip(np.array(upper_left[0]), 0, hr.shape[-2] - self.hr_size[0])
|
||||
upper_left[1] = np.clip(np.array(upper_left[1]), 0, hr.shape[-1] - self.hr_size[1])
|
||||
|
||||
upper_left_sentinel[0] = np.clip(np.array(upper_left_sentinel[0]), 0, s1.shape[-2] - self.s1_size[0])
|
||||
upper_left_sentinel[1] = np.clip(np.array(upper_left_sentinel[1]), 0, s1.shape[-1] - self.s1_size[1])
|
||||
hr = hr[:, upper_left[0]:upper_left[0]+self.hr_size[0], upper_left[1]:upper_left[1]+self.hr_size[1]]
|
||||
if with_s1:
|
||||
s1 = s1[:, :, upper_left_sentinel[0]:upper_left_sentinel[0]+self.s1_size[0], upper_left_sentinel[1]:upper_left_sentinel[1]+self.s1_size[1]]
|
||||
if with_s2:
|
||||
s2 = s2[:, :, upper_left_sentinel[0]:upper_left_sentinel[0]+self.s2_size[0], upper_left_sentinel[1]:upper_left_sentinel[1]+self.s2_size[1]]
|
||||
if tgt.ndim == 3:
|
||||
tgt = tgt[:, upper_left[0]:upper_left[0]+self.hr_size[0], upper_left[1]:upper_left[1]+self.hr_size[1]]
|
||||
elif tgt.ndim == 2:
|
||||
tgt = tgt[upper_left[0]:upper_left[0]+self.hr_size[0], upper_left[1]:upper_left[1]+self.hr_size[1]]
|
||||
else:
|
||||
raise ValueError("tgt dim unsupport!")
|
||||
hr_imgs.append(hr)
|
||||
tgt_imgs.append(tgt)
|
||||
s2_imgs.append(s2)
|
||||
s1_imgs.append(s1)
|
||||
s2_cts.append(s2_ct_1)
|
||||
|
||||
|
||||
cvt_hr_imgs = []
|
||||
cvt_tgt_imgs = []
|
||||
cvt_s2_imgs = []
|
||||
cvt_s1_imgs = []
|
||||
|
||||
tgt_imgs = self._convert_colors_pairs(tgt_imgs, old_colors, new_colors, current_color)
|
||||
for i in range(len(tgt_imgs)):
|
||||
hr, s2, s1, tgt = self.pipeline(current_dataset, hr_imgs[i], s2_imgs[i], s1_imgs[i], tgt_imgs[i])
|
||||
cvt_hr_imgs.append(hr)
|
||||
cvt_s2_imgs.append(s2)
|
||||
cvt_s1_imgs.append(s1)
|
||||
cvt_tgt_imgs.append(tgt)
|
||||
|
||||
targets_comb = self._combine_images(cvt_tgt_imgs)
|
||||
hr_comb = self._combine_images(cvt_hr_imgs)
|
||||
s2_comb = self._combine_images(cvt_s2_imgs)
|
||||
s1_comb = self._combine_images(cvt_s1_imgs)
|
||||
hr_comb, s2_comb, s1_comb, targets_comb = self.crop_resize(current_dataset, hr_comb, s2_comb, s1_comb, targets_comb)
|
||||
use_half_mask = torch.rand(1)[0] < self.half_mask_ratio
|
||||
valid = torch.ones_like(targets_comb)
|
||||
|
||||
thres = torch.ones(3) * (1e-5) # ignore black
|
||||
thres = (thres - self.imagenet_mean) / self.imagenet_std
|
||||
valid[targets_comb < thres[:, None, None]] = 0
|
||||
|
||||
if use_half_mask:
|
||||
num_patches = self.masked_position_generator.num_patches
|
||||
mask = np.zeros(self.masked_position_generator.get_shape(), dtype=np.int32)
|
||||
mask[mask.shape[0]//2:, :] = 1
|
||||
else:
|
||||
mask = self.masked_position_generator()
|
||||
|
||||
# location
|
||||
geo_location = pair["location"] if "location" in pair.keys() else None
|
||||
|
||||
# get modality index
|
||||
modality_idx = 2**0 * modality_dict['s2'] + 2**1 * modality_dict['s1'] + 2**2 * modality_dict['hr']
|
||||
modality_flag_s2 = modality_dict['s2']
|
||||
modality_flag_s1 = modality_dict['s1']
|
||||
modality_flag_hr = modality_dict['hr']
|
||||
|
||||
|
||||
current_sample = Sample()
|
||||
current_sample.img_name = pair["hr_path"].split('/')[-1].split('.')[0]
|
||||
current_sample.hr_img = hr_comb
|
||||
current_sample.dataset_name = pair["type"]
|
||||
current_sample.targets = targets_comb
|
||||
current_sample.s2_img = s2_comb
|
||||
current_sample.s2_ct = s2_cts[0]
|
||||
current_sample.s2_ct2 = s2_cts[4]
|
||||
current_sample.s1_img = s1_comb
|
||||
current_sample.anno_mask = torch.from_numpy(mask)
|
||||
current_sample.valid = valid
|
||||
current_sample.location = geo_location
|
||||
current_sample.modality_idx = modality_idx
|
||||
current_sample.modality_flag_s2 = modality_flag_s2
|
||||
current_sample.modality_flag_s1 = modality_flag_s1
|
||||
current_sample.modality_flag_hr = modality_flag_hr
|
||||
current_sample.task_type = self.dataset_type
|
||||
|
||||
return current_sample
|
||||
77
lib/datasets/utils/dataset_colors.py
Normal file
77
lib/datasets/utils/dataset_colors.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import random
|
||||
from functools import lru_cache
|
||||
import numpy as np
|
||||
|
||||
dataset_color_dict = {
|
||||
"potsdam" : [[1], [2], [3], [4], [5]],
|
||||
"vaihingen" : [[255, 255, 0], [0, 255, 0], [0, 255, 255], [0, 0, 255], [255, 255, 255]],
|
||||
"deepglobe" : [[255,255,255], [0,0,255], [0,255,0],[255,0,255], [255,255,0], [0,255,255]],
|
||||
"fbp" : [[i+1] for i in range(24)],
|
||||
"loveda" : [[i+2, i+2, i+2] for i in range(6)],
|
||||
"isaid" : [[i+1] for i in range(15)],
|
||||
"pastis-mm" : [[i+1] for i in range(18)],
|
||||
"dynamic-mm" : [[i] for i in range(7)],
|
||||
"c2seg-ab" : [[i+1] for i in range(13)],
|
||||
"flood3i": [[i+1] for i in range(9)],
|
||||
"jl16-mm": [[i] for i in range(16)],
|
||||
"flair-mm": [[i+1] for i in range(18)],
|
||||
"dfc20": [[i+1] for i in range(10)]
|
||||
}
|
||||
|
||||
|
||||
modal_norm_dict = {
|
||||
'hr' : {
|
||||
'div' : 255.,
|
||||
'mean' : [0.485, 0.456, 0.406],
|
||||
'std' : [0.229, 0.224, 0.225]
|
||||
},
|
||||
'anno' : {
|
||||
'div' : 255.,
|
||||
'mean' : [0.485, 0.456, 0.406],
|
||||
'std' : [0.229, 0.224, 0.225]
|
||||
},
|
||||
's2' : {
|
||||
'div' : 1.,
|
||||
'mean' : [884.29673756, 1144.16202635, 1297.47289228, 1624.90992062, 2194.6423161, 2422.21248945, 2517.76053101, 2581.64687018, 2368.51236873, 1805.06846033],
|
||||
'std' : [1155.15170768, 1183.6292542, 1368.11351514, 1370.265037, 1355.55390699, 1416.51487101, 1474.78900051, 1439.3086061, 1455.52084939, 1343.48379601]
|
||||
},
|
||||
's1' : {
|
||||
'div' : 1.,
|
||||
'mean' : [-12.54847273, -20.19237134],
|
||||
'std' : [5.25697717, 5.91150917]
|
||||
},
|
||||
}
|
||||
|
||||
@lru_cache()
|
||||
def get_painter_color_map_list(num_locations = 300):
|
||||
|
||||
num_sep_per_channel = int(num_locations ** (1 / 3)) + 1 # 19
|
||||
separation_per_channel = 256 // num_sep_per_channel
|
||||
|
||||
color_list = []
|
||||
for location in range(num_locations):
|
||||
num_seq_r = location // num_sep_per_channel ** 2
|
||||
num_seq_g = (location % num_sep_per_channel ** 2) // num_sep_per_channel
|
||||
num_seq_b = location % num_sep_per_channel
|
||||
assert (num_seq_r <= num_sep_per_channel) and (num_seq_g <= num_sep_per_channel) \
|
||||
and (num_seq_b <= num_sep_per_channel)
|
||||
|
||||
R = 255 - num_seq_r * separation_per_channel
|
||||
G = 255 - num_seq_g * separation_per_channel
|
||||
B = 255 - num_seq_b * separation_per_channel
|
||||
assert (R < 256) and (G < 256) and (B < 256)
|
||||
assert (R >= 0) and (G >= 0) and (B >= 0)
|
||||
assert (R, G, B) not in color_list
|
||||
|
||||
color_list.append((R, G, B))
|
||||
|
||||
return color_list
|
||||
|
||||
|
||||
def get_real_random_color_list(num_locations):
|
||||
random_color_list = np.random.randint(0, 256, (num_locations, 3))
|
||||
while np.sum(random_color_list) == 0:
|
||||
print('random_color_list is 0!')
|
||||
random_color_list = np.random.randint(0, 256, (num_locations, 3))
|
||||
random_color_list = random_color_list.tolist()
|
||||
return random_color_list # [:num_locations]
|
||||
65
lib/datasets/utils/formatting.py
Normal file
65
lib/datasets/utils/formatting.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from collections.abc import Sequence
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def to_tensor(data):
|
||||
"""Convert objects of various python types to :obj:`torch.Tensor`.
|
||||
|
||||
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
||||
:class:`Sequence`, :class:`int` and :class:`float`.
|
||||
|
||||
Args:
|
||||
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
|
||||
be converted.
|
||||
"""
|
||||
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data
|
||||
elif isinstance(data, np.ndarray):
|
||||
return torch.from_numpy(data)
|
||||
elif isinstance(data, Sequence) and not mmcv.is_str(data):
|
||||
return torch.tensor(data)
|
||||
elif isinstance(data, int):
|
||||
return torch.LongTensor([data])
|
||||
elif isinstance(data, float):
|
||||
return torch.FloatTensor([data])
|
||||
else:
|
||||
raise TypeError(f'type {type(data)} cannot be converted to tensor.')
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
"""Convert some sample to :obj:`torch.Tensor` by given keys.
|
||||
|
||||
Args:
|
||||
keys (Sequence[str]): Keys that need to be converted to Tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, keys):
|
||||
self.keys = keys
|
||||
|
||||
def __call__(self, sample):
|
||||
"""Call function to convert data in sample to :obj:`torch.Tensor`.
|
||||
|
||||
Args:
|
||||
sample (Sample): sample data contains the data to convert.
|
||||
|
||||
Returns:
|
||||
dict: The result dict contains the data converted
|
||||
to :obj:`torch.Tensor`.
|
||||
"""
|
||||
|
||||
for key in self.keys:
|
||||
if isinstance(sample[key], list):
|
||||
for i in range(len(sample[key])):
|
||||
sample[key][i] = to_tensor(sample[key][i])
|
||||
else:
|
||||
sample[key] = to_tensor(sample[key])
|
||||
return sample
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(keys={self.keys})'
|
||||
|
||||
|
||||
|
||||
84
lib/datasets/utils/masking_generator.py
Normal file
84
lib/datasets/utils/masking_generator.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
class MaskingGenerator:
|
||||
def __init__(
|
||||
self, input_size, patch_size, mask_ratio=0.5, min_num_patches=4, max_num_patches=None,
|
||||
min_aspect=0.3, max_aspect=None):
|
||||
if not isinstance(input_size, list):
|
||||
input_size = [input_size,] * 2
|
||||
self.height = input_size[0] // patch_size
|
||||
self.width = input_size[1] // patch_size
|
||||
|
||||
self.num_patches = self.height * self.width
|
||||
self.num_masking_patches = int(self.num_patches * mask_ratio)
|
||||
|
||||
self.min_num_patches = min_num_patches
|
||||
self.max_num_patches = self.num_masking_patches if max_num_patches is None else max_num_patches
|
||||
|
||||
max_aspect = max_aspect or 1 / min_aspect
|
||||
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
|
||||
self.height, self.width, self.min_num_patches, self.max_num_patches,
|
||||
self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1])
|
||||
return repr_str
|
||||
|
||||
def get_shape(self):
|
||||
return self.height, self.width
|
||||
|
||||
def _mask(self, mask, max_mask_patches):
|
||||
delta = 0
|
||||
for attempt in range(10):
|
||||
target_area = random.uniform(self.min_num_patches, max_mask_patches)
|
||||
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
if w < self.width and h < self.height:
|
||||
top = random.randint(0, self.height - h)
|
||||
left = random.randint(0, self.width - w)
|
||||
|
||||
num_masked = mask[top: top + h, left: left + w].sum()
|
||||
# Overlap
|
||||
if 0 < h * w - num_masked <= max_mask_patches:
|
||||
for i in range(top, top + h):
|
||||
for j in range(left, left + w):
|
||||
if mask[i, j] == 0:
|
||||
mask[i, j] = 1
|
||||
delta += 1
|
||||
|
||||
if delta > 0:
|
||||
break
|
||||
return delta
|
||||
|
||||
def __call__(self):
|
||||
mask = np.zeros(shape=self.get_shape(), dtype=np.int32)
|
||||
mask_count = 0
|
||||
while mask_count < self.num_masking_patches:
|
||||
max_mask_patches = self.num_masking_patches - mask_count
|
||||
max_mask_patches = min(max_mask_patches, self.max_num_patches)
|
||||
|
||||
delta = self._mask(mask, max_mask_patches)
|
||||
if delta == 0:
|
||||
break
|
||||
else:
|
||||
mask_count += delta
|
||||
|
||||
# maintain a fix number {self.num_masking_patches}
|
||||
if mask_count > self.num_masking_patches:
|
||||
delta = mask_count - self.num_masking_patches
|
||||
mask_x, mask_y = mask.nonzero()
|
||||
to_vis = np.random.choice(mask_x.shape[0], delta, replace=False)
|
||||
mask[mask_x[to_vis], mask_y[to_vis]] = 0
|
||||
|
||||
elif mask_count < self.num_masking_patches:
|
||||
delta = self.num_masking_patches - mask_count
|
||||
mask_x, mask_y = (mask == 0).nonzero()
|
||||
to_mask = np.random.choice(mask_x.shape[0], delta, replace=False)
|
||||
mask[mask_x[to_mask], mask_y[to_mask]] = 1
|
||||
|
||||
assert mask.sum() == self.num_masking_patches, f"mask: {mask}, mask count {mask.sum()}"
|
||||
|
||||
return mask
|
||||
532
lib/datasets/utils/pair_trainsforms.py
Normal file
532
lib/datasets/utils/pair_trainsforms.py
Normal file
@@ -0,0 +1,532 @@
|
||||
import math
|
||||
import numbers
|
||||
import random
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torchvision.transforms as transforms
|
||||
from skimage import io
|
||||
|
||||
|
||||
try:
|
||||
import accimage
|
||||
except ImportError:
|
||||
accimage = None
|
||||
|
||||
import torchvision.transforms.functional as F
|
||||
from torchvision.transforms.functional import _interpolation_modes_from_int, InterpolationMode
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
|
||||
from .dataset_colors import modal_norm_dict
|
||||
|
||||
__all__ = [
|
||||
"Compose",
|
||||
"ToTensor",
|
||||
"Normalize",
|
||||
"RandomHorizontalFlip",
|
||||
"RandomResizedCrop",
|
||||
]
|
||||
|
||||
|
||||
|
||||
class Compose(transforms.Compose):
|
||||
"""Composes several transforms together. This transform does not support torchscript.
|
||||
Please, see the note below.
|
||||
Args:
|
||||
transforms (list of ``Transform`` objects): list of transforms to compose.
|
||||
"""
|
||||
|
||||
def __init__(self, transforms):
|
||||
super().__init__(transforms)
|
||||
|
||||
def __call__(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None):
|
||||
# i = 0
|
||||
for t in self.transforms:
|
||||
# i = i+1
|
||||
# print(f'dataset_name:{dataset_name}')
|
||||
# print(f'step:{i}')
|
||||
# print(f'hr_img shape:{hr_img.shape}')
|
||||
# print(f's2_img shape:{s2_img.shape}')
|
||||
# print(f's1_img shape:{s1_img.shape}')
|
||||
# print(f'tgt shape:{tgt.shape}')
|
||||
|
||||
hr_img, s2_img, s1_img, tgt = t(dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=interpolation1, interpolation2=interpolation2)
|
||||
return hr_img, s2_img, s1_img, tgt
|
||||
|
||||
|
||||
class ToTensor(transforms.ToTensor):
|
||||
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
|
||||
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
|
||||
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
|
||||
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
|
||||
or if the numpy.ndarray has dtype = np.uint8
|
||||
In the other cases, tensors are returned without scaling.
|
||||
.. note::
|
||||
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
|
||||
transforming target image masks. See the `references`_ for implementing the transforms for image masks.
|
||||
.. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None):
|
||||
"""
|
||||
Args:
|
||||
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
|
||||
Returns:
|
||||
Tensor: Converted image.
|
||||
"""
|
||||
|
||||
|
||||
# print(f'hr dtype:{hr_img.dtype}')
|
||||
# print(f's2_img dtype:{s2_img.dtype}')
|
||||
# print(f's1_img dtype:{s1_img.dtype}')
|
||||
# print(f'tgt dtype:{tgt.dtype}')
|
||||
if dataset_name == 'dynamic-mm' or dataset_name == 'guizhou-mm':
|
||||
hr_img = hr_img.astype(np.int32)[:3,:,:]
|
||||
hr_img = hr_img[::-1,:,:].copy()
|
||||
else:
|
||||
hr_img = hr_img.astype(np.int32)
|
||||
tgt = tgt.astype(np.uint8)
|
||||
s1_img = s1_img.astype(np.float32)
|
||||
s2_img = s2_img.astype(np.int16)
|
||||
|
||||
return torch.tensor(hr_img), torch.tensor(s2_img), torch.tensor(s1_img),torch.tensor(tgt)
|
||||
|
||||
|
||||
class Normalize(transforms.Normalize):
|
||||
"""Normalize a tensor image with mean and standard deviation.
|
||||
This transform does not support PIL Image.
|
||||
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
|
||||
channels, this transform will normalize each channel of the input
|
||||
``torch.*Tensor`` i.e.,
|
||||
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
|
||||
.. note::
|
||||
This transform acts out of place, i.e., it does not mutate the input tensor.
|
||||
Args:
|
||||
mean (sequence): Sequence of means for each channel.
|
||||
std (sequence): Sequence of standard deviations for each channel.
|
||||
inplace(bool,optional): Bool to make this operation in-place.
|
||||
"""
|
||||
|
||||
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False):
|
||||
super().__init__(mean, std, inplace)
|
||||
|
||||
def forward(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None):
|
||||
"""
|
||||
Args:
|
||||
tensor (Tensor): Tensor image to be normalized.
|
||||
Returns:
|
||||
Tensor: Normalized Tensor image.
|
||||
"""
|
||||
# TODO 查询对应的mean和std
|
||||
|
||||
# 处理一些mean和std
|
||||
if dataset_name == 'dynamic-mm':
|
||||
hr_std = [1008.4052, 760.9586, 631.4754]
|
||||
hr_mean = [1085.2941, 944.2718, 689.2493]
|
||||
hr_div = 1.
|
||||
else:
|
||||
hr_mean = modal_norm_dict['hr']['mean']
|
||||
hr_std = modal_norm_dict['hr']['std']
|
||||
hr_div = modal_norm_dict['hr']['div']
|
||||
|
||||
if dataset_name == 'l8activefire':
|
||||
# if False:
|
||||
s2_mean = modal_norm_dict['l8']['mean']
|
||||
s2_std = modal_norm_dict['l8']['std']
|
||||
s2_div = modal_norm_dict['l8']['div']
|
||||
else:
|
||||
s2_mean = modal_norm_dict['s2']['mean']
|
||||
s2_std = modal_norm_dict['s2']['std']
|
||||
s2_div = modal_norm_dict['s2']['div']
|
||||
|
||||
s1_mean = modal_norm_dict['s1']['mean']
|
||||
s1_std = modal_norm_dict['s1']['std']
|
||||
s1_div = modal_norm_dict['s1']['div']
|
||||
|
||||
anno_mean = [0.485, 0.456, 0.406]
|
||||
anno_std = [0.229, 0.224, 0.225]
|
||||
ann_div = 255.
|
||||
|
||||
# 存在问题:时间序列这样处理是否会出错
|
||||
|
||||
#mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
|
||||
#std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
|
||||
# print(s2_img.shape)
|
||||
# import pdb; pdb.set_trace()
|
||||
# print(s2_img)
|
||||
try:
|
||||
ch, ts, h, w = s2_img.shape
|
||||
except:
|
||||
print(f's2: {s2_img.shape}, s1: {s1_img.shape}')
|
||||
s2_img = s2_img.view(ch, ts*h, w)
|
||||
s2_img = self.normalize(s2_img.type(torch.float32), s2_mean, s2_std, self.inplace)
|
||||
s2_img = s2_img.view(ch, ts, h, w)
|
||||
|
||||
ch, ts, h, w = s1_img.shape
|
||||
s1_img = s1_img.view(ch, ts*h, w)
|
||||
s1_img = self.normalize(s1_img.type(torch.float32), s1_mean, s1_std, self.inplace)
|
||||
s1_img = s1_img.view(ch, ts, h, w)
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
# print(s2_img.shape, s2_img[:,0,:,:])
|
||||
# print(s1_img.shape, s1_img[:,0,:,:])
|
||||
# print(hr_mean, hr_std, hr_div)
|
||||
return self.normalize(hr_img.type(torch.float32).div_(hr_div), hr_mean, hr_std, self.inplace), \
|
||||
s2_img, \
|
||||
s1_img, \
|
||||
self.normalize(tgt.type(torch.float32).div_(ann_div) , anno_mean, anno_std, self.inplace)
|
||||
|
||||
def normalize(self, tensor, mean, std, inplace):
|
||||
dtype = tensor.dtype
|
||||
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
|
||||
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
|
||||
if (std == 0).any():
|
||||
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
|
||||
mean = mean.view(-1, 1, 1)
|
||||
std = std.view(-1, 1, 1)
|
||||
# print(f'tensor shape: {tensor.shape}')
|
||||
# print(f'mean shape: {mean.shape}')
|
||||
# print(f'std shape: {std.shape}')
|
||||
return tensor.sub_(mean).div_(std)
|
||||
|
||||
|
||||
class RandomResizedCrop(transforms.RandomResizedCrop):
|
||||
"""Crop a random portion of image and resize it to a given size.
|
||||
If the image is torch Tensor, it is expected
|
||||
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
|
||||
A crop of the original image is made: the crop has a random area (H * W)
|
||||
and a random aspect ratio. This crop is finally resized to the given
|
||||
size. This is popularly used to train the Inception networks.
|
||||
Args:
|
||||
size (int or sequence): expected output size of the crop, for each edge. If size is an
|
||||
int instead of sequence like (h, w), a square output size ``(size, size)`` is
|
||||
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
||||
.. note::
|
||||
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
|
||||
scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
|
||||
before resizing. The scale is defined with respect to the area of the original image.
|
||||
ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before
|
||||
resizing.
|
||||
interpolation (InterpolationMode): Desired interpolation enum defined by
|
||||
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
||||
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
|
||||
``InterpolationMode.BICUBIC`` are supported.
|
||||
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
|
||||
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
scale=(0.08, 1.0),
|
||||
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
||||
interpolation=InterpolationMode.BILINEAR,
|
||||
mode='small'
|
||||
):
|
||||
super().__init__(size, scale=scale, ratio=ratio, interpolation=interpolation)
|
||||
self.cnt=0
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None, mode='small'):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image or Tensor): Image to be cropped and resized.
|
||||
Returns:
|
||||
PIL Image or Tensor: Randomly cropped and resized image.
|
||||
"""
|
||||
i, j, h, w = self.get_params(s2_img, self.scale, self.ratio)
|
||||
size_hr = hr_img.shape[-1]
|
||||
size_s2 = s2_img.shape[-1]
|
||||
size_anno = tgt.shape[-1]
|
||||
# 映射到其他模态
|
||||
ratio_s2_hr = size_s2 / size_hr
|
||||
i_hr = int(i / ratio_s2_hr)
|
||||
j_hr = int(j / ratio_s2_hr)
|
||||
h_hr = int(h / ratio_s2_hr)
|
||||
w_hr = int(w / ratio_s2_hr)
|
||||
|
||||
ratio_s2_anno = size_s2 / size_anno
|
||||
i_anno = int(i / ratio_s2_anno)
|
||||
j_anno = int(j / ratio_s2_anno)
|
||||
h_anno = int(h / ratio_s2_anno)
|
||||
w_anno = int(w / ratio_s2_anno)
|
||||
|
||||
if interpolation1 == 'nearest':
|
||||
interpolation1 = InterpolationMode.NEAREST
|
||||
else:
|
||||
interpolation1 = InterpolationMode.BICUBIC
|
||||
if interpolation2 == 'nearest':
|
||||
interpolation2 = InterpolationMode.NEAREST
|
||||
else:
|
||||
interpolation2 = InterpolationMode.BICUBIC
|
||||
# import pdb;pdb.set_trace()
|
||||
if self.scale[0]>0.99 and self.scale[0]<1.0:
|
||||
if self.mode=='small':
|
||||
resized_s2_img = F.resize(s2_img, (16,16), interpolation=InterpolationMode.BICUBIC)
|
||||
resized_hr_img = F.resize(hr_img, (512, 512), interpolation=InterpolationMode.BICUBIC)
|
||||
resized_s1_img = F.resize(s1_img, (16,16), interpolation=InterpolationMode.BICUBIC)
|
||||
resized_tgt = F.resize(tgt, (512,512), interpolation=InterpolationMode.NEAREST)
|
||||
else:
|
||||
resized_s2_img = F.resize(s2_img, (64,64), interpolation=InterpolationMode.BICUBIC)
|
||||
resized_hr_img = F.resize(hr_img, (2048, 2048), interpolation=InterpolationMode.BICUBIC)
|
||||
resized_s1_img = F.resize(s1_img, (64,64), interpolation=InterpolationMode.BICUBIC)
|
||||
resized_tgt = F.resize(tgt, (2048,2048), interpolation=InterpolationMode.NEAREST)
|
||||
return resized_hr_img, resized_s2_img, resized_s1_img, resized_tgt
|
||||
|
||||
if self.mode=='small':
|
||||
resized_s2_img = F.resized_crop(s2_img, i, j, h, w, (16, 16), InterpolationMode.BICUBIC)
|
||||
resized_hr_img = F.resized_crop(hr_img, i_hr, j_hr, h_hr, w_hr, (512, 512), InterpolationMode.BICUBIC)
|
||||
resized_s1_img = F.resized_crop(s1_img, i, j, h, w, (16, 16), InterpolationMode.BICUBIC)
|
||||
resized_tgt = F.resized_crop(tgt, i_anno, j_anno, h_anno, w_anno, (512, 512), InterpolationMode.NEAREST)
|
||||
else:
|
||||
resized_s2_img = F.resized_crop(s2_img, i, j, h, w, (512, 512), InterpolationMode.BICUBIC)
|
||||
resized_hr_img = F.resized_crop(hr_img, i_hr, j_hr, h_hr, w_hr, (2048,2048), InterpolationMode.BICUBIC)
|
||||
resized_s1_img = F.resized_crop(s1_img, i, j, h, w, (512, 512), InterpolationMode.BICUBIC)
|
||||
resized_tgt = F.resized_crop(tgt, i_anno, j_anno, h_anno, w_anno, (2048, 2048), InterpolationMode.NEAREST)
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
# 将resize后的结果保存为concat的img
|
||||
# self.cnt = self.cnt+1
|
||||
# from torchvision.utils import save_image
|
||||
# save_hr = resized_hr_img[:3, :, :] / resized_hr_img[:3, :, :].max()
|
||||
# save_s2 = resized_s2_img[:3,0,:,:] / resized_s2_img[:3,0,:,:].max()
|
||||
# print(f'{save_hr.shape}, {save_s2.shape}')
|
||||
# save_image(save_s2, f'FoundationModel/debug/output2/resized_s2_{self.cnt}.png')
|
||||
# save_image(save_hr, f'FoundationModel/debug/output2/resized_hr_{self.cnt}.png')
|
||||
|
||||
return resized_hr_img, resized_s2_img, resized_s1_img, resized_tgt
|
||||
|
||||
class RandomResizedCropComb(transforms.RandomResizedCrop):
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
scale=(0.08, 1.0),
|
||||
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
||||
interpolation=InterpolationMode.BILINEAR,
|
||||
):
|
||||
super().__init__(size, scale=scale, ratio=ratio, interpolation=interpolation)
|
||||
self.cnt=0
|
||||
|
||||
def forward(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image or Tensor): Image to be cropped and resized.
|
||||
Returns:
|
||||
PIL Image or Tensor: Randomly cropped and resized image.
|
||||
"""
|
||||
i, j, h, w = self.get_params(s2_img, self.scale, self.ratio)
|
||||
# print(f'i, j, h, w: {i, j, h, w}')
|
||||
# print(f's2_img shape: {s2_img.shape}')
|
||||
size_hr = hr_img.shape[-1]
|
||||
size_s2 = s2_img.shape[-1]
|
||||
size_anno = tgt.shape[-1]
|
||||
# 映射到其他模态
|
||||
ratio_s2_hr = size_s2 / size_hr
|
||||
i_hr = int(i / ratio_s2_hr)
|
||||
j_hr = int(j / ratio_s2_hr)
|
||||
h_hr = int(h / ratio_s2_hr)
|
||||
w_hr = int(w / ratio_s2_hr)
|
||||
|
||||
ratio_s2_anno = size_s2 / size_anno
|
||||
i_anno = int(i / ratio_s2_anno)
|
||||
j_anno = int(j / ratio_s2_anno)
|
||||
h_anno = int(h / ratio_s2_anno)
|
||||
w_anno = int(w / ratio_s2_anno)
|
||||
|
||||
if interpolation1 == 'nearest':
|
||||
interpolation1 = InterpolationMode.NEAREST
|
||||
else:
|
||||
interpolation1 = InterpolationMode.BICUBIC
|
||||
if interpolation2 == 'nearest':
|
||||
interpolation2 = InterpolationMode.NEAREST
|
||||
else:
|
||||
interpolation2 = InterpolationMode.BICUBIC
|
||||
|
||||
resized_s2_img = F.resized_crop(s2_img, i, j, h, w, (32, 16), InterpolationMode.BICUBIC)
|
||||
resized_hr_img = F.resized_crop(hr_img, i_hr, j_hr, h_hr, w_hr, (1024, 512), InterpolationMode.BICUBIC)
|
||||
resized_s1_img = F.resized_crop(s1_img, i, j, h, w, (32, 16), InterpolationMode.BICUBIC)
|
||||
resized_tgt = F.resized_crop(tgt, i_anno, j_anno, h_anno, w_anno, (1024, 512), InterpolationMode.NEAREST)
|
||||
|
||||
return resized_hr_img, resized_s2_img, resized_s1_img, resized_tgt
|
||||
|
||||
|
||||
class RandomHorizontalFlip(transforms.RandomHorizontalFlip):
|
||||
"""Horizontally flip the given image randomly with a given probability.
|
||||
If the image is torch Tensor, it is expected
|
||||
to have [..., H, W] shape, where ... means an arbitrary number of leading
|
||||
dimensions
|
||||
Args:
|
||||
p (float): probability of the image being flipped. Default value is 0.5
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5):
|
||||
super().__init__(p=p)
|
||||
|
||||
def forward(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image or Tensor): Image to be flipped.
|
||||
Returns:
|
||||
PIL Image or Tensor: Randomly flipped image.
|
||||
"""
|
||||
if torch.rand(1) < self.p:
|
||||
return F.hflip(hr_img), F.hflip(s2_img), F.hflip(s1_img), F.hflip(tgt)
|
||||
return hr_img, s2_img, s1_img, tgt
|
||||
|
||||
|
||||
class RandomApply(transforms.RandomApply):
|
||||
"""Apply randomly a list of transformations with a given probability.
|
||||
.. note::
|
||||
In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
|
||||
transforms as shown below:
|
||||
>>> transforms = transforms.RandomApply(torch.nn.ModuleList([
|
||||
>>> transforms.ColorJitter(),
|
||||
>>> ]), p=0.3)
|
||||
>>> scripted_transforms = torch.jit.script(transforms)
|
||||
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
|
||||
`lambda` functions or ``PIL.Image``.
|
||||
Args:
|
||||
transforms (sequence or torch.nn.Module): list of transformations
|
||||
p (float): probability
|
||||
"""
|
||||
|
||||
def __init__(self, transforms, p=0.5):
|
||||
super().__init__(transforms, p=p)
|
||||
|
||||
def forward(self, img, tgt, interpolation1=None, interpolation2=None):
|
||||
if self.p < torch.rand(1):
|
||||
return img, tgt
|
||||
for t in self.transforms:
|
||||
img, tgt = t(img, tgt)
|
||||
return img, tgt
|
||||
|
||||
class ColorJitter(transforms.ColorJitter):
|
||||
"""Randomly change the brightness, contrast, saturation and hue of an image.
|
||||
If the image is torch Tensor, it is expected
|
||||
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||
If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
|
||||
Args:
|
||||
brightness (float or tuple of float (min, max)): How much to jitter brightness.
|
||||
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
|
||||
or the given [min, max]. Should be non negative numbers.
|
||||
contrast (float or tuple of float (min, max)): How much to jitter contrast.
|
||||
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
|
||||
or the given [min, max]. Should be non negative numbers.
|
||||
saturation (float or tuple of float (min, max)): How much to jitter saturation.
|
||||
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
|
||||
or the given [min, max]. Should be non negative numbers.
|
||||
hue (float or tuple of float (min, max)): How much to jitter hue.
|
||||
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
|
||||
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
|
||||
To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
|
||||
thus it does not work if you normalize your image to an interval with negative values,
|
||||
or use an interpolation that generates negative values before using this function.
|
||||
"""
|
||||
|
||||
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
||||
super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
|
||||
|
||||
def forward(self, img, tgt, interpolation1=None, interpolation2=None):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image or Tensor): Input image.
|
||||
Returns:
|
||||
PIL Image or Tensor: Color jittered image.
|
||||
"""
|
||||
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
|
||||
self.brightness, self.contrast, self.saturation, self.hue
|
||||
)
|
||||
|
||||
for fn_id in fn_idx:
|
||||
if fn_id == 0 and brightness_factor is not None:
|
||||
img = F.adjust_brightness(img, brightness_factor)
|
||||
elif fn_id == 1 and contrast_factor is not None:
|
||||
img = F.adjust_contrast(img, contrast_factor)
|
||||
elif fn_id == 2 and saturation_factor is not None:
|
||||
img = F.adjust_saturation(img, saturation_factor)
|
||||
elif fn_id == 3 and hue_factor is not None:
|
||||
img = F.adjust_hue(img, hue_factor)
|
||||
return img, tgt
|
||||
|
||||
|
||||
class RandomErasing(transforms.RandomErasing):
|
||||
"""Randomly selects a rectangle region in a torch.Tensor image and erases its pixels.
|
||||
This transform does not support PIL Image.
|
||||
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
|
||||
Args:
|
||||
p: probability that the random erasing operation will be performed.
|
||||
scale: range of proportion of erased area against input image.
|
||||
ratio: range of aspect ratio of erased area.
|
||||
value: erasing value. Default is 0. If a single int, it is used to
|
||||
erase all pixels. If a tuple of length 3, it is used to erase
|
||||
R, G, B channels respectively.
|
||||
If a str of 'random', erasing each pixel with random values.
|
||||
inplace: boolean to make this transform inplace. Default set to False.
|
||||
Returns:
|
||||
Erased Image.
|
||||
Example:
|
||||
>>> transform = transforms.Compose([
|
||||
>>> transforms.RandomHorizontalFlip(),
|
||||
>>> transforms.PILToTensor(),
|
||||
>>> transforms.ConvertImageDtype(torch.float),
|
||||
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
>>> transforms.RandomErasing(),
|
||||
>>> ])
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
|
||||
super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace)
|
||||
|
||||
def forward(self, img, tgt, interpolation1=None, interpolation2=None):
|
||||
"""
|
||||
Args:
|
||||
img (Tensor): Tensor image to be erased.
|
||||
Returns:
|
||||
img (Tensor): Erased Tensor image.
|
||||
"""
|
||||
if torch.rand(1) < self.p:
|
||||
|
||||
# cast self.value to script acceptable type
|
||||
if isinstance(self.value, (int, float)):
|
||||
value = [self.value]
|
||||
elif isinstance(self.value, str):
|
||||
value = None
|
||||
elif isinstance(self.value, tuple):
|
||||
value = list(self.value)
|
||||
else:
|
||||
value = self.value
|
||||
|
||||
if value is not None and not (len(value) in (1, img.shape[-3])):
|
||||
raise ValueError(
|
||||
"If value is a sequence, it should have either a single value or "
|
||||
f"{img.shape[-3]} (number of input channels)"
|
||||
)
|
||||
|
||||
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
|
||||
return F.erase(img, x, y, h, w, v, self.inplace), tgt
|
||||
return img, tgt
|
||||
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
"""Gaussian blur augmentation from SimCLR: https://arxiv.org/abs/2002.05709"""
|
||||
|
||||
def __init__(self, sigma=[.1, 2.]):
|
||||
self.sigma = sigma
|
||||
|
||||
def __call__(self, img, tgt, interpolation1=None, interpolation2=None):
|
||||
sigma = random.uniform(self.sigma[0], self.sigma[1])
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=sigma))
|
||||
return img, tgt
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = f"{self.__class__.__name__}( sigma={self.sigma})"
|
||||
return s
|
||||
|
||||
1558
lib/datasets/utils/transforms.py
Normal file
1558
lib/datasets/utils/transforms.py
Normal file
File diff suppressed because it is too large
Load Diff
208
lib/evaluation/segm_eval_base.py
Normal file
208
lib/evaluation/segm_eval_base.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from skimage.transform import resize
|
||||
|
||||
from skimage import io
|
||||
from multiprocessing import Pool
|
||||
from functools import partial
|
||||
import logging
|
||||
import datetime
|
||||
|
||||
def get_confusion_matrix(pred, gt, num_class):
|
||||
assert pred.shape == gt.shape, f"pred.shape: {pred.shape} != gt.shape: {gt.shape}"
|
||||
mask = (gt >= 0) & (gt < num_class) # 去掉为0的背景类别
|
||||
label = num_class * gt[mask] + pred[mask]
|
||||
count = np.bincount(label, minlength=num_class**2)
|
||||
confusion_matrix = count.reshape(num_class, num_class)
|
||||
return confusion_matrix
|
||||
|
||||
def get_miou(confusion_matrix):
|
||||
diagonal_elements = np.diag(confusion_matrix)
|
||||
column_sums = np.sum(confusion_matrix, axis=0)
|
||||
row_sums = np.sum(confusion_matrix, axis=1)
|
||||
ious = diagonal_elements/(column_sums + row_sums - diagonal_elements)
|
||||
m_iou = np.nanmean(ious)
|
||||
return m_iou
|
||||
|
||||
def get_mprecison(confusion_matrix):
|
||||
diagonal_elements = np.diag(confusion_matrix)
|
||||
column_sums = np.sum(confusion_matrix, axis=0)
|
||||
precisions = diagonal_elements / (column_sums + 1e-06)
|
||||
m_precision = np.nanmean(precisions)
|
||||
return m_precision
|
||||
|
||||
def get_mrecall(confusion_matrix):
|
||||
diagonal_elements = np.diag(confusion_matrix)
|
||||
row_sums = np.sum(confusion_matrix, axis=1)
|
||||
recalls= diagonal_elements / (row_sums + 1e-06)
|
||||
m_recall = np.nanmean(recalls)
|
||||
return m_recall
|
||||
|
||||
def get_macc(confusion_matrix):
|
||||
'''
|
||||
acc = tp/tp+fn 就是recall
|
||||
'''
|
||||
m_recall = get_mrecall(confusion_matrix)
|
||||
return m_recall
|
||||
|
||||
def get_per_class_iou(confusion_matrix):
|
||||
intersection = np.diag(confusion_matrix)
|
||||
union = np.sum(confusion_matrix, axis=0) + np.sum(confusion_matrix, axis=1) - intersection
|
||||
iou = intersection / (union.astype(np.float32) + 1e-6)
|
||||
return iou
|
||||
|
||||
def get_per_class_acc(confusion_matrix):
|
||||
total_acc = np.diag(confusion_matrix) / (np.sum(confusion_matrix, axis=1).astype(np.float32) + 1e-6)
|
||||
return total_acc
|
||||
|
||||
|
||||
def post_process_segm_output(segm, colors, dist_type='abs'):
|
||||
"""
|
||||
Post-processing to turn output segm image to class index map using NumPy
|
||||
Args:
|
||||
segm: (H, W, 3)
|
||||
Returns:
|
||||
class_map: (H, W)
|
||||
"""
|
||||
palette = np.array(colors)
|
||||
segm = segm.astype(np.float32) # (h, w, 3)
|
||||
h, w, k = segm.shape[0], segm.shape[1], palette.shape[0]
|
||||
if dist_type == 'abs':
|
||||
dist = np.abs(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3)) # (h, w, k)
|
||||
elif dist_type == 'square':
|
||||
dist = np.power(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3), 2) # (h, w, k)
|
||||
elif dist_type == 'mean':
|
||||
dist_abs = np.abs(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3)) # (h, w, k)
|
||||
dist_square = np.power(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3), 2) # (h, w, k)
|
||||
dist = (dist_abs + dist_square) / 2.
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
dist = np.sum(dist, axis=-1)
|
||||
pred = np.argmin(dist, axis=-1).astype(np.int)
|
||||
return pred
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser('semantic segmentation evaluation', add_help=False)
|
||||
parser.add_argument('--pred_dir', type=str, help='dir to pred', required=True)
|
||||
parser.add_argument('--gt_dir', type=str, help='dir to gt', required=True)
|
||||
parser.add_argument('--gt_list_path', type=str, help='dir to gt_list_path', required=True)
|
||||
parser.add_argument('--gt_suffix', type=str, help='suffix to gt', required=True)
|
||||
parser.add_argument('--dataset_name', type=str, help='dataset name', required=True)
|
||||
parser.add_argument('--model_name', type=str, help='model name', required=True)
|
||||
parser.add_argument('--dist_type', type=str, help='dist type',
|
||||
default='abs', choices=['abs', 'square', 'mean'])
|
||||
return parser.parse_args()
|
||||
|
||||
def process_file(file_dict, pred_dir, gt_dir, args, num_class):
|
||||
filename = file_dict['file_name']
|
||||
file_cls = file_dict['file_cls']
|
||||
|
||||
gt = io.imread(os.path.join(gt_dir, filename))
|
||||
gt_index = gt.copy()
|
||||
gt_index[gt_index != file_cls] = 0
|
||||
gt_index[gt_index == file_cls] = 1
|
||||
|
||||
try:
|
||||
pred = io.imread(os.path.join(pred_dir, filename.replace('.png', f'-{file_cls}.png')))
|
||||
pred = resize(pred, gt.shape[-2:], anti_aliasing=False, mode='reflect', order=0)
|
||||
|
||||
if len(pred.shape) == 3:
|
||||
pred_index = pred[:,:,0].copy()
|
||||
else:
|
||||
pred_index = pred.copy()
|
||||
pred_index[pred_index<=127] = 0
|
||||
pred_index[pred_index>127] = 1
|
||||
except:
|
||||
logging.info(filename.replace('.png', f'_{file_cls}.png'), 'not found!')
|
||||
pred_index = gt_index.copy()
|
||||
|
||||
pred_index = pred_index.flatten()
|
||||
gt_index = gt_index.flatten()
|
||||
confusion_matrix = get_confusion_matrix(pred_index, gt_index, num_class)
|
||||
return file_cls, confusion_matrix
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args_parser()
|
||||
dataset_name = args.dataset_name
|
||||
pred_dir = args.pred_dir
|
||||
gt_dir = args.gt_dir
|
||||
gt_list_path = args.gt_list_path
|
||||
dist_type = args.dist_type
|
||||
model_name = args.model_name
|
||||
|
||||
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
os.makedirs('logs/eval', exist_ok=True)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(f'logs/eval/eval_{model_name}_{dataset_name}_{current_time}.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
|
||||
output_folder = os.path.join(pred_dir, f'eval_{dataset_name}')
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
|
||||
num_class = 2
|
||||
|
||||
with open(gt_list_path, 'r') as f:
|
||||
file_cls_list = f.readlines()
|
||||
|
||||
file_list = []
|
||||
for i in file_cls_list:
|
||||
i = i.strip()
|
||||
file_name = i[:-3]
|
||||
file_cls = i[-2:]
|
||||
file_list.append({'file_name': file_name, 'file_cls': int(file_cls)})
|
||||
|
||||
all_pred_labels = []
|
||||
all_gt_labels = []
|
||||
|
||||
process_file_partial = partial(process_file, pred_dir=pred_dir,gt_dir=gt_dir, args=args, num_class=num_class)
|
||||
|
||||
pool = Pool()
|
||||
|
||||
outputs = pool.map(process_file_partial, file_list)
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
logging.info(f'len outputs: {len(outputs)}')
|
||||
confusion_matrix_dict = {}
|
||||
for cls, confusion_matrix in outputs:
|
||||
if cls in confusion_matrix_dict.keys():
|
||||
confusion_matrix_dict[cls] += confusion_matrix
|
||||
else:
|
||||
confusion_matrix_dict[cls] = confusion_matrix
|
||||
|
||||
class_list = []
|
||||
iou_list = []
|
||||
acc_list = []
|
||||
for cls, confusion_matrix in confusion_matrix_dict.items():
|
||||
ious = get_per_class_iou(confusion_matrix)
|
||||
accs = get_per_class_acc(confusion_matrix)
|
||||
logging.info(f'cls: {cls}, ious: {ious}, accs: {accs}')
|
||||
class_list.append(cls)
|
||||
iou_list.append(ious[1])
|
||||
acc_list.append(accs[1])
|
||||
|
||||
miou = np.mean(iou_list)
|
||||
macc = np.mean(acc_list)
|
||||
|
||||
df_metrics = pd.DataFrame({
|
||||
'Class': class_list + ['Mean'],
|
||||
'IoU': iou_list + [miou],
|
||||
'Accuracy': acc_list + [macc],
|
||||
})
|
||||
pd.set_option('display.float_format', '{:.4f}%'.format)
|
||||
logging.info(df_metrics)
|
||||
pd.reset_option('display.float_format')
|
||||
df_metrics.to_csv(os.path.join(output_folder, 'eval.csv'), index=False, float_format='%.4f')
|
||||
|
||||
|
||||
7
lib/models/__init__.py
Normal file
7
lib/models/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .segmentors import SkySensePP
|
||||
from .losses import (ModalityVAELoss, RecLoss)
|
||||
from .metrics import (SemMetric)
|
||||
|
||||
__all__ = [
|
||||
'SkySensePP', 'ModalityVAELoss', 'RecLoss', 'SemMetric'
|
||||
]
|
||||
14
lib/models/backbones/__init__.py
Normal file
14
lib/models/backbones/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from .swin_v2 import SwinTransformerV2MSL
|
||||
from .vit import VisionTransformerMSL
|
||||
|
||||
__all__ = [
|
||||
'SwinTransformerV2MSL', 'VisionTransformerMSL'
|
||||
]
|
||||
|
||||
type_mapping = {
|
||||
'SwinTransformerV2MSL': SwinTransformerV2MSL,
|
||||
'VisionTransformerMSL': VisionTransformerMSL
|
||||
}
|
||||
|
||||
def build_backbone(type, **kwargs):
|
||||
return type_mapping[type](**kwargs)
|
||||
702
lib/models/backbones/swin_v2.py
Normal file
702
lib/models/backbones/swin_v2.py
Normal file
@@ -0,0 +1,702 @@
|
||||
from copy import deepcopy
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.utils import (PatchMerging, ShiftWindowMSA, WindowMSAV2,
|
||||
resize_pos_embed, to_2tuple)
|
||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||
from mmcv.runner import (CheckpointLoader,
|
||||
load_state_dict)
|
||||
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
||||
|
||||
class SwinBlockV2(BaseModule):
|
||||
"""Swin Transformer V2 block. Use post normalization.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): The height and width of the window. Defaults to 7.
|
||||
shift (bool): Shift the attention window or not. Defaults to False.
|
||||
extra_norm (bool): Whether add extra norm at the end of main branch.
|
||||
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
||||
layer channels. Defaults to 4.
|
||||
drop_path (float): The drop path rate after attention and ffn.
|
||||
Defaults to 0.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
attn_cfgs (dict): The extra config of Shift Window-MSA.
|
||||
Defaults to empty dict.
|
||||
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
|
||||
norm_cfg (dict): The config of norm layers.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
pretrained_window_size (int): Window size in pretrained.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
window_size=8,
|
||||
shift=False,
|
||||
extra_norm=False,
|
||||
ffn_ratio=4.,
|
||||
drop_path=0.,
|
||||
pad_small_map=False,
|
||||
attn_cfgs=dict(),
|
||||
ffn_cfgs=dict(),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
pretrained_window_size=0,
|
||||
init_cfg=None):
|
||||
|
||||
super(SwinBlockV2, self).__init__(init_cfg)
|
||||
self.with_cp = with_cp
|
||||
self.extra_norm = extra_norm
|
||||
|
||||
_attn_cfgs = {
|
||||
'embed_dims': embed_dims,
|
||||
'num_heads': num_heads,
|
||||
'shift_size': window_size // 2 if shift else 0,
|
||||
'window_size': window_size,
|
||||
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
|
||||
'pad_small_map': pad_small_map,
|
||||
**attn_cfgs
|
||||
}
|
||||
# use V2 attention implementation
|
||||
_attn_cfgs.update(
|
||||
window_msa=WindowMSAV2,
|
||||
msa_cfg=dict(
|
||||
pretrained_window_size=to_2tuple(pretrained_window_size)))
|
||||
self.attn = ShiftWindowMSA(**_attn_cfgs)
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
_ffn_cfgs = {
|
||||
'embed_dims': embed_dims,
|
||||
'feedforward_channels': int(embed_dims * ffn_ratio),
|
||||
'num_fcs': 2,
|
||||
'ffn_drop': 0,
|
||||
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
|
||||
'act_cfg': dict(type='GELU'),
|
||||
'add_identity': False,
|
||||
**ffn_cfgs
|
||||
}
|
||||
self.ffn = FFN(**_ffn_cfgs)
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
# add extra norm for every n blocks in huge and giant model
|
||||
if self.extra_norm:
|
||||
self.norm3 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
|
||||
def _inner_forward(x):
|
||||
# Use post normalization
|
||||
identity = x
|
||||
x = self.attn(x, hw_shape)
|
||||
x = self.norm1(x)
|
||||
x = x + identity
|
||||
|
||||
identity = x
|
||||
x = self.ffn(x)
|
||||
x = self.norm2(x)
|
||||
x = x + identity
|
||||
|
||||
if self.extra_norm:
|
||||
x = self.norm3(x)
|
||||
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SwinBlockV2Sequence(BaseModule):
|
||||
"""Module with successive Swin Transformer blocks and downsample layer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
depth (int): Number of successive swin transformer blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): The height and width of the window. Defaults to 7.
|
||||
downsample (bool): Downsample the output of blocks by patch merging.
|
||||
Defaults to False.
|
||||
downsample_cfg (dict): The extra config of the patch merging layer.
|
||||
Defaults to empty dict.
|
||||
drop_paths (Sequence[float] | float): The drop path rate in each block.
|
||||
Defaults to 0.
|
||||
block_cfgs (Sequence[dict] | dict): The extra config of each block.
|
||||
Defaults to empty dicts.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
extra_norm_every_n_blocks (int): Add extra norm at the end of main
|
||||
branch every n blocks. Defaults to 0, which means no needs for
|
||||
extra norm layer.
|
||||
pretrained_window_size (int): Window size in pretrained.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size=8,
|
||||
downsample=False,
|
||||
downsample_cfg=dict(),
|
||||
drop_paths=0.,
|
||||
block_cfgs=dict(),
|
||||
with_cp=False,
|
||||
pad_small_map=False,
|
||||
extra_norm_every_n_blocks=0,
|
||||
pretrained_window_size=0,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
if not isinstance(drop_paths, Sequence):
|
||||
drop_paths = [drop_paths] * depth
|
||||
|
||||
if not isinstance(block_cfgs, Sequence):
|
||||
block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)]
|
||||
|
||||
if downsample:
|
||||
self.out_channels = 2 * embed_dims
|
||||
_downsample_cfg = {
|
||||
'in_channels': embed_dims,
|
||||
'out_channels': self.out_channels,
|
||||
'norm_cfg': dict(type='LN'),
|
||||
**downsample_cfg
|
||||
}
|
||||
self.downsample = PatchMerging(**_downsample_cfg)
|
||||
else:
|
||||
self.out_channels = embed_dims
|
||||
self.downsample = None
|
||||
|
||||
self.blocks = ModuleList()
|
||||
for i in range(depth):
|
||||
extra_norm = True if extra_norm_every_n_blocks and \
|
||||
(i + 1) % extra_norm_every_n_blocks == 0 else False
|
||||
_block_cfg = {
|
||||
'embed_dims': self.out_channels,
|
||||
'num_heads': num_heads,
|
||||
'window_size': window_size,
|
||||
'shift': False if i % 2 == 0 else True,
|
||||
'extra_norm': extra_norm,
|
||||
'drop_path': drop_paths[i],
|
||||
'with_cp': with_cp,
|
||||
'pad_small_map': pad_small_map,
|
||||
'pretrained_window_size': pretrained_window_size,
|
||||
**block_cfgs[i]
|
||||
}
|
||||
block = SwinBlockV2(**_block_cfg)
|
||||
self.blocks.append(block)
|
||||
|
||||
def forward(self, x, in_shape):
|
||||
if self.downsample:
|
||||
x, out_shape = self.downsample(x, in_shape)
|
||||
else:
|
||||
out_shape = in_shape
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, out_shape)
|
||||
|
||||
return x, out_shape
|
||||
|
||||
|
||||
class SwinTransformerV2(BaseBackbone):
|
||||
"""Swin Transformer V2.
|
||||
|
||||
A PyTorch implement of : `Swin Transformer V2:
|
||||
Scaling Up Capacity and Resolution
|
||||
<https://arxiv.org/abs/2111.09883>`_
|
||||
|
||||
Inspiration from
|
||||
https://github.com/microsoft/Swin-Transformer
|
||||
|
||||
Args:
|
||||
arch (str | dict): Swin Transformer architecture. If use string, choose
|
||||
from 'tiny', 'small', 'base' and 'large'. If use dict, it should
|
||||
have below keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **depths** (List[int]): The number of blocks in each stage.
|
||||
- **num_heads** (List[int]): The number of heads in attention
|
||||
modules of each stage.
|
||||
- **extra_norm_every_n_blocks** (int): Add extra norm at the end
|
||||
of main branch every n blocks.
|
||||
|
||||
Defaults to 'tiny'.
|
||||
img_size (int | tuple): The expected input image shape. Because we
|
||||
support dynamic input shape, just set the argument to the most
|
||||
common input image shape. Defaults to 224.
|
||||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 4.
|
||||
in_channels (int): The num of input channels. Defaults to 3.
|
||||
window_size (int | Sequence): The height and width of the window.
|
||||
Defaults to 7.
|
||||
drop_rate (float): Dropout rate after embedding. Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
|
||||
use_abs_pos_embed (bool): If True, add absolute position embedding to
|
||||
the patch embedding. Defaults to False.
|
||||
interpolate_mode (str): Select the interpolate mode for absolute
|
||||
position embeding vector resize. Defaults to "bicubic".
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Defaults to False.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
norm_cfg (dict): Config dict for normalization layer for all output
|
||||
features. Defaults to ``dict(type='LN')``
|
||||
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
|
||||
stage. Defaults to an empty dict.
|
||||
patch_cfg (dict): Extra config dict for patch embedding.
|
||||
Defaults to an empty dict.
|
||||
pretrained_window_sizes (tuple(int)): Pretrained window sizes of
|
||||
each layer.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> from mmcls.models import SwinTransformerV2
|
||||
>>> import torch
|
||||
>>> extra_config = dict(
|
||||
>>> arch='tiny',
|
||||
>>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3,
|
||||
>>> 'padding': 'same'}))
|
||||
>>> self = SwinTransformerV2(**extra_config)
|
||||
>>> inputs = torch.rand(1, 3, 224, 224)
|
||||
>>> output = self.forward(inputs)
|
||||
>>> print(output.shape)
|
||||
(1, 2592, 4)
|
||||
"""
|
||||
arch_zoo = {
|
||||
**dict.fromkeys(['t', 'tiny'],
|
||||
{'embed_dims': 96,
|
||||
'depths': [2, 2, 6, 2],
|
||||
'num_heads': [3, 6, 12, 24],
|
||||
'extra_norm_every_n_blocks': 0}),
|
||||
**dict.fromkeys(['s', 'small'],
|
||||
{'embed_dims': 96,
|
||||
'depths': [2, 2, 18, 2],
|
||||
'num_heads': [3, 6, 12, 24],
|
||||
'extra_norm_every_n_blocks': 0}),
|
||||
**dict.fromkeys(['b', 'base'],
|
||||
{'embed_dims': 128,
|
||||
'depths': [2, 2, 18, 2],
|
||||
'num_heads': [4, 8, 16, 32],
|
||||
'extra_norm_every_n_blocks': 0}),
|
||||
**dict.fromkeys(['l', 'large'],
|
||||
{'embed_dims': 192,
|
||||
'depths': [2, 2, 18, 2],
|
||||
'num_heads': [6, 12, 24, 48],
|
||||
'extra_norm_every_n_blocks': 0}),
|
||||
# head count not certain for huge, and is employed for another
|
||||
# parallel study about self-supervised learning.
|
||||
**dict.fromkeys(['h', 'huge'],
|
||||
{'embed_dims': 352,
|
||||
'depths': [2, 2, 18, 2],
|
||||
'num_heads': [8, 16, 32, 64],
|
||||
'extra_norm_every_n_blocks': 6}),
|
||||
**dict.fromkeys(['g', 'giant'],
|
||||
{'embed_dims': 512,
|
||||
'depths': [2, 2, 42, 4],
|
||||
'num_heads': [16, 32, 64, 128],
|
||||
'extra_norm_every_n_blocks': 6}),
|
||||
} # yapf: disable
|
||||
|
||||
_version = 1
|
||||
num_extra_tokens = 0
|
||||
|
||||
def __init__(self,
|
||||
arch='tiny',
|
||||
img_size=256,
|
||||
patch_size=4,
|
||||
in_channels=3,
|
||||
vocabulary_size=128,
|
||||
window_size=8,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.1,
|
||||
out_indices=(3, ),
|
||||
use_abs_pos_embed=False,
|
||||
interpolate_mode='bicubic',
|
||||
with_cp=False,
|
||||
frozen_stages=-1,
|
||||
norm_eval=False,
|
||||
pad_small_map=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
stage_cfgs=dict(downsample_cfg=dict(is_post_norm=True)),
|
||||
patch_cfg=dict(),
|
||||
pretrained_window_sizes=[0, 0, 0, 0],
|
||||
init_cfg=None):
|
||||
super(SwinTransformerV2, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(arch, str):
|
||||
arch = arch.lower()
|
||||
assert arch in set(self.arch_zoo), \
|
||||
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
|
||||
self.arch_settings = self.arch_zoo[arch]
|
||||
else:
|
||||
essential_keys = {
|
||||
'embed_dims', 'depths', 'num_heads',
|
||||
'extra_norm_every_n_blocks'
|
||||
}
|
||||
assert isinstance(arch, dict) and set(arch) == essential_keys, \
|
||||
f'Custom arch needs a dict with keys {essential_keys}'
|
||||
self.arch_settings = arch
|
||||
|
||||
self.vocabulary_size = vocabulary_size + 1 # 增加ignore类别
|
||||
self.embed_dims = self.arch_settings['embed_dims']
|
||||
self.depths = self.arch_settings['depths']
|
||||
self.num_heads = self.arch_settings['num_heads']
|
||||
self.extra_norm_every_n_blocks = self.arch_settings[
|
||||
'extra_norm_every_n_blocks']
|
||||
self.num_layers = len(self.depths)
|
||||
self.out_indices = out_indices
|
||||
self.use_abs_pos_embed = use_abs_pos_embed
|
||||
self.interpolate_mode = interpolate_mode
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
if isinstance(window_size, int):
|
||||
self.window_sizes = [window_size for _ in range(self.num_layers)]
|
||||
elif isinstance(window_size, Sequence):
|
||||
assert len(window_size) == self.num_layers, \
|
||||
f'Length of window_sizes {len(window_size)} is not equal to '\
|
||||
f'length of stages {self.num_layers}.'
|
||||
self.window_sizes = window_size
|
||||
else:
|
||||
raise TypeError('window_size should be a Sequence or int.')
|
||||
|
||||
_patch_cfg = dict(
|
||||
in_channels=in_channels,
|
||||
input_size=img_size,
|
||||
embed_dims=self.embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
norm_cfg=dict(type='LN'),
|
||||
)
|
||||
_patch_cfg.update(patch_cfg)
|
||||
self.patch_embed = PatchEmbed(**_patch_cfg)
|
||||
self.patch_resolution = self.patch_embed.init_out_size
|
||||
self.patch_size = patch_size
|
||||
|
||||
if self.use_abs_pos_embed:
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
self.absolute_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches, self.embed_dims))
|
||||
self._register_load_state_dict_pre_hook(
|
||||
self._prepare_abs_pos_embed)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self._delete_reinit_params)
|
||||
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
self.norm_eval = norm_eval
|
||||
|
||||
# stochastic depth
|
||||
total_depth = sum(self.depths)
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
self.stages = ModuleList()
|
||||
embed_dims = [self.embed_dims]
|
||||
for i, (depth, num_heads) in enumerate(zip(self.depths,
|
||||
self.num_heads)):
|
||||
if isinstance(stage_cfgs, Sequence):
|
||||
stage_cfg = stage_cfgs[i]
|
||||
else:
|
||||
stage_cfg = deepcopy(stage_cfgs)
|
||||
downsample = True if i > 0 else False
|
||||
_stage_cfg = {
|
||||
'embed_dims': embed_dims[-1],
|
||||
'depth': depth,
|
||||
'num_heads': num_heads,
|
||||
'window_size': self.window_sizes[i],
|
||||
'downsample': downsample,
|
||||
'drop_paths': dpr[:depth],
|
||||
'with_cp': with_cp,
|
||||
'pad_small_map': pad_small_map,
|
||||
'extra_norm_every_n_blocks': self.extra_norm_every_n_blocks,
|
||||
'pretrained_window_size': pretrained_window_sizes[i],
|
||||
**stage_cfg
|
||||
}
|
||||
|
||||
stage = SwinBlockV2Sequence(**_stage_cfg)
|
||||
self.stages.append(stage)
|
||||
|
||||
dpr = dpr[depth:]
|
||||
embed_dims.append(stage.out_channels)
|
||||
|
||||
for i in out_indices:
|
||||
if norm_cfg is not None:
|
||||
norm_layer = build_norm_layer(norm_cfg, embed_dims[i + 1])[1]
|
||||
else:
|
||||
norm_layer = nn.Identity()
|
||||
|
||||
self.add_module(f'norm{i}', norm_layer)
|
||||
|
||||
def init_weights(self):
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
# Suppress default init if use pretrained model.
|
||||
from mmcls.utils import get_root_logger
|
||||
logger = get_root_logger()
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
# print(self.state_dict().keys())
|
||||
# print('---')
|
||||
# print(state_dict.keys())
|
||||
# import pdb; pdb.set_trace()
|
||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||
return
|
||||
else:
|
||||
super(SwinTransformerV2, self).init_weights()
|
||||
if self.use_abs_pos_embed:
|
||||
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
x, hw_shape = self.patch_embed(x)
|
||||
|
||||
if self.use_abs_pos_embed:
|
||||
x = x + resize_pos_embed(
|
||||
self.absolute_pos_embed, self.patch_resolution, hw_shape,
|
||||
self.interpolate_mode, self.num_extra_tokens)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
outs = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
x, hw_shape = stage(x, hw_shape)
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
out = norm_layer(x)
|
||||
out = out.view(-1, *hw_shape,
|
||||
stage.out_channels).permute(0, 3, 1,
|
||||
2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return outs
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
for i in range(0, self.frozen_stages + 1):
|
||||
m = self.stages[i]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
for i in self.out_indices:
|
||||
if i <= self.frozen_stages:
|
||||
for param in getattr(self, f'norm{i}').parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super(SwinTransformerV2, self).train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs):
|
||||
name = prefix + 'absolute_pos_embed'
|
||||
if name not in state_dict.keys():
|
||||
return
|
||||
|
||||
ckpt_pos_embed_shape = state_dict[name].shape
|
||||
if self.absolute_pos_embed.shape != ckpt_pos_embed_shape:
|
||||
from mmcls.utils import get_root_logger
|
||||
logger = get_root_logger()
|
||||
logger.info(
|
||||
'Resize the absolute_pos_embed shape from '
|
||||
f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.')
|
||||
|
||||
ckpt_pos_embed_shape = to_2tuple(
|
||||
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
|
||||
pos_embed_shape = self.patch_embed.init_out_size
|
||||
|
||||
state_dict[name] = resize_pos_embed(state_dict[name],
|
||||
ckpt_pos_embed_shape,
|
||||
pos_embed_shape,
|
||||
self.interpolate_mode,
|
||||
self.num_extra_tokens)
|
||||
|
||||
def _delete_reinit_params(self, state_dict, prefix, *args, **kwargs):
|
||||
# delete relative_position_index since we always re-init it
|
||||
relative_position_index_keys = [
|
||||
k for k in state_dict.keys() if 'relative_position_index' in k
|
||||
]
|
||||
for k in relative_position_index_keys:
|
||||
del state_dict[k]
|
||||
|
||||
# delete relative_coords_table since we always re-init it
|
||||
relative_position_index_keys = [
|
||||
k for k in state_dict.keys() if 'relative_coords_table' in k
|
||||
]
|
||||
for k in relative_position_index_keys:
|
||||
del state_dict[k]
|
||||
|
||||
class Proj_MHSA(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims,
|
||||
proj_dims,
|
||||
num_heads=16,
|
||||
batch_first=True,
|
||||
bias = True
|
||||
):
|
||||
super().__init__()
|
||||
self.proj_in = nn.Linear(in_features=embed_dims, out_features=proj_dims)
|
||||
self.attn = MultiheadAttention(
|
||||
embed_dims=proj_dims,
|
||||
num_heads=num_heads,
|
||||
batch_first=batch_first,
|
||||
bias=bias
|
||||
)
|
||||
self.proj_out = nn.Linear(in_features=proj_dims, out_features=embed_dims)
|
||||
def forward(self, x):
|
||||
x = self.proj_in(x)
|
||||
x = self.attn(x, x, x)
|
||||
x = self.proj_out(x)
|
||||
return x
|
||||
|
||||
class SwinTransformerV2MSL(SwinTransformerV2):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if 'use_attn' in kwargs:
|
||||
self.use_attn = kwargs.pop('use_attn')
|
||||
else:
|
||||
self.use_attn = False
|
||||
if 'merge_stage' in kwargs:
|
||||
self.merge_stage = kwargs.pop('merge_stage')
|
||||
else:
|
||||
self.merge_stage = 0
|
||||
if 'with_cls_pos' in kwargs:
|
||||
self.with_cls_pos = kwargs.pop('with_cls_pos')
|
||||
else:
|
||||
self.with_cls_pos = False
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
#self.vocabulary_token = nn.Parameter(torch.zeros(1, 1, 1, self.vocabulary_size, self.embed_dims))
|
||||
self.vocabulary_token = nn.Parameter(torch.zeros(self.vocabulary_size, self.embed_dims))
|
||||
self.vocabulary_weight = nn.Parameter(torch.zeros(1, self.patch_size * self.patch_size))
|
||||
trunc_normal_(self.mask_token, mean=0., std=.02)
|
||||
trunc_normal_(self.vocabulary_token, mean=0., std=.02)
|
||||
|
||||
if self.use_attn:
|
||||
self.attn1 = Proj_MHSA(embed_dims=352, proj_dims=256, num_heads=16, batch_first=True, bias = True)
|
||||
self.attn2 = Proj_MHSA(embed_dims=704, proj_dims=512, num_heads=16, batch_first=True, bias = True)
|
||||
self.attn3 = Proj_MHSA( embed_dims=1408, proj_dims=1024, num_heads=16, batch_first=True, bias = True)
|
||||
self.attention_blocks = [self.attn1, self.attn2, self.attn3]
|
||||
self.norm_attn = build_norm_layer(dict(type='LN'), 1408)[1]
|
||||
|
||||
def create_ann_token(self, anno_img):
|
||||
B, H, W = anno_img.shape
|
||||
ann_token = torch.index_select(self.vocabulary_token, 0, anno_img.reshape(-1)).reshape(B, H, W, -1)
|
||||
assert H % self.patch_size == 0 and W % self.patch_size == 0
|
||||
nph, npw = H // self.patch_size, W // self.patch_size
|
||||
weight = F.softmax(self.vocabulary_weight, dim=1) * self.patch_size * self.patch_size
|
||||
weight = weight.reshape(1, 1, self.patch_size, 1, self.patch_size).repeat(1, nph, 1, npw, 1).reshape(1, H, W, 1)
|
||||
ann_token = ann_token * weight
|
||||
ann_token = F.avg_pool2d(torch.einsum('BHWC->BCHW', ann_token), self.patch_size, self.patch_size)
|
||||
ann_token = torch.einsum('BCHW->BHWC', ann_token).reshape(B, nph * npw, self.embed_dims) # shape B, L, C
|
||||
return ann_token
|
||||
|
||||
def forward(self, hr_img, anno_img, mask=None):
|
||||
x, hw_shape = self.patch_embed(hr_img)
|
||||
y = self.create_ann_token(anno_img)
|
||||
assert x.shape == y.shape
|
||||
B, L, C = y.shape
|
||||
if mask is not None:
|
||||
mask_tokens = self.mask_token.expand(B, L, -1)
|
||||
w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
|
||||
y = y * (1. - w) + mask_tokens * w
|
||||
|
||||
if self.merge_stage == 0:
|
||||
x = (x + y) * 0.5
|
||||
else:
|
||||
x = x.reshape(B, *hw_shape, C)
|
||||
y = y.reshape(B, *hw_shape, C)
|
||||
x = torch.cat((x, y), dim=2)
|
||||
hw_shape = (hw_shape[0], hw_shape[1] * 2)
|
||||
x = x.reshape(B, -1, C)
|
||||
if self.use_abs_pos_embed:
|
||||
x = x + resize_pos_embed(
|
||||
self.absolute_pos_embed, self.patch_resolution, hw_shape,
|
||||
self.interpolate_mode, self.num_extra_tokens)
|
||||
if self.with_cls_pos:
|
||||
hw_shape_half = [hw_shape[0], hw_shape[1] // 2]
|
||||
x = x.reshape(B, *hw_shape, C)
|
||||
x1 = x[:, :, :x.shape[2]//2, :].reshape(B, -1, C)
|
||||
x2 = x[:, :, x.shape[2]//2:, :].reshape(B, -1, C)
|
||||
x1 = x1 + resize_pos_embed(
|
||||
self.absolute_pos_embed, self.patch_resolution, hw_shape_half,
|
||||
self.interpolate_mode, self.num_extra_tokens)
|
||||
x2 = x2 + resize_pos_embed(
|
||||
self.absolute_pos_embed, self.patch_resolution, hw_shape_half,
|
||||
self.interpolate_mode, self.num_extra_tokens)
|
||||
x1 = x1.reshape(B, *hw_shape_half, C)
|
||||
x2 = x2.reshape(B, *hw_shape_half, C)
|
||||
x = torch.cat((x1, x2), dim=2).reshape(B, -1, C)
|
||||
x = self.drop_after_pos(x)
|
||||
outs = []
|
||||
merge_idx = self.merge_stage - 1
|
||||
for i, stage in enumerate(self.stages):
|
||||
x, hw_shape = stage(x, hw_shape)
|
||||
if i == merge_idx:
|
||||
x = x.reshape(x.shape[0], *hw_shape, x.shape[-1]) # b,l,c -> b, h, w, c
|
||||
x = (x[:, :, :x.shape[2]//2] + x[:, :, x.shape[2]//2:]) * 0.5
|
||||
x = x.reshape(x.shape[0], -1, x.shape[-1])
|
||||
hw_shape = (hw_shape[0], hw_shape[1] // 2)
|
||||
if self.use_attn:
|
||||
if i <= len(self.attention_blocks) - 1:
|
||||
x = x + self.attention_blocks[i](x)
|
||||
if i == len(self.attention_blocks) - 1:
|
||||
x = self.norm_attn(x)
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
out = norm_layer(x)
|
||||
out = out.view(-1, *hw_shape, stage.out_channels).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
return outs
|
||||
611
lib/models/backbones/vit.py
Normal file
611
lib/models/backbones/vit.py
Normal file
@@ -0,0 +1,611 @@
|
||||
# Copyright (c) Ant Group. All rights reserved.
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
|
||||
load_state_dict)
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.utils import get_root_logger
|
||||
from mmseg.models.utils.embed import PatchEmbed
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
class TransformerEncoderLayer(BaseModule):
|
||||
"""Implements one encoder layer in Vision Transformer.
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Default: 0.0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default: 0.0.
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default: True
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default: True.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
qkv_bias=True,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
batch_first=True,
|
||||
attn_cfg=dict(),
|
||||
ffn_cfg=dict(),
|
||||
with_cp=False):
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(norm_cfg,
|
||||
embed_dims,
|
||||
postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
attn_cfg.update(
|
||||
dict(embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
batch_first=batch_first,
|
||||
bias=qkv_bias))
|
||||
|
||||
self.build_attn(attn_cfg)
|
||||
|
||||
self.norm2_name, norm2 = build_norm_layer(norm_cfg,
|
||||
embed_dims,
|
||||
postfix=2)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
|
||||
ffn_cfg.update(
|
||||
dict(embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
if drop_path_rate > 0 else None,
|
||||
act_cfg=act_cfg))
|
||||
self.build_ffn(ffn_cfg)
|
||||
self.with_cp = with_cp
|
||||
|
||||
def build_attn(self, attn_cfg):
|
||||
self.attn = MultiheadAttention(**attn_cfg)
|
||||
|
||||
def build_ffn(self, ffn_cfg):
|
||||
self.ffn = FFN(**ffn_cfg)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
x = self.attn(self.norm1(x), identity=x)
|
||||
x = self.ffn(self.norm2(x), identity=x)
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(BaseModule):
|
||||
"""Vision Transformer.
|
||||
This backbone is the implementation of `An Image is Worth 16x16 Words:
|
||||
Transformers for Image Recognition at
|
||||
Scale <https://arxiv.org/abs/2010.11929>`_.
|
||||
Args:
|
||||
img_size (int | tuple): Input image size. Default: 224.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): embedding dimension. Default: 768.
|
||||
num_layers (int): depth of transformer. Default: 12.
|
||||
num_heads (int): number of attention heads. Default: 12.
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
out_indices (list | tuple | int): Output from which stages.
|
||||
Default: -1.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default: True.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.0
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Default: True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
`with_cls_token` must be True. Default: False.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
|
||||
Default: False.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Default: False.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Default: bicubic.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dims=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
out_indices=-1,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
with_cls_token=True,
|
||||
output_cls_token=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
patch_norm=False,
|
||||
final_norm=False,
|
||||
interpolate_mode='bicubic',
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
use_ccd=False,
|
||||
ccd_num=0,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super(VisionTransformer, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(img_size, int):
|
||||
img_size = to_2tuple(img_size)
|
||||
elif isinstance(img_size, tuple):
|
||||
if len(img_size) == 1:
|
||||
img_size = to_2tuple(img_size[0])
|
||||
assert len(img_size) == 2, \
|
||||
f'The size of image should have length 1 or 2, ' \
|
||||
f'but got {len(img_size)}'
|
||||
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.interpolate_mode = interpolate_mode
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.pretrained = pretrained
|
||||
self.embed_dims = embed_dims
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
padding='corner',
|
||||
norm_cfg=norm_cfg if patch_norm else None,
|
||||
init_cfg=None,
|
||||
)
|
||||
|
||||
self.use_ccd = use_ccd
|
||||
self.ccd_num = ccd_num
|
||||
if self.use_ccd:
|
||||
self.ccd_embed = nn.Parameter(
|
||||
torch.rand(1, self.ccd_num, embed_dims))
|
||||
|
||||
num_patches = (img_size[0] // patch_size) * \
|
||||
(img_size[1] // patch_size)
|
||||
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
# self.pos_embed = nn.Parameter(
|
||||
# torch.zeros(1, num_patches, embed_dims))
|
||||
# 原来是
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dims))
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
if out_indices == -1:
|
||||
out_indices = num_layers - 1
|
||||
self.out_indices = [out_indices]
|
||||
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
|
||||
self.out_indices = out_indices
|
||||
else:
|
||||
raise TypeError('out_indices must be type of int, list or tuple')
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
self.layers = ModuleList()
|
||||
for i in range(num_layers):
|
||||
self.layers.append(
|
||||
TransformerEncoderLayer(embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=mlp_ratio *
|
||||
embed_dims,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
batch_first=True))
|
||||
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm1_name, norm1 = build_norm_layer(norm_cfg,
|
||||
embed_dims,
|
||||
postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def init_weights(self):
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
logger = get_root_logger()
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
if 'pos_embed' in state_dict.keys():
|
||||
if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
||||
logger.info(msg=f'Resize the pos_embed shape from '
|
||||
f'{state_dict["pos_embed"].shape} to '
|
||||
f'{self.pos_embed.shape}')
|
||||
h, w = self.img_size
|
||||
pos_size = int(
|
||||
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
|
||||
state_dict['pos_embed'] = self.resize_pos_embed(
|
||||
state_dict['pos_embed'],
|
||||
(h // self.patch_size, w // self.patch_size),
|
||||
(pos_size, pos_size), self.interpolate_mode)
|
||||
|
||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||
elif self.init_cfg is not None:
|
||||
super(VisionTransformer, self).init_weights()
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
if self.use_ccd:
|
||||
trunc_normal_(self.ccd_embed, std=0.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
|
||||
"""Positioning embeding method.
|
||||
Resize the pos_embed, if the input image size doesn't match
|
||||
the training size.
|
||||
Args:
|
||||
patched_img (torch.Tensor): The patched image, it should be
|
||||
shape of [B, L1, C].
|
||||
hw_shape (tuple): The downsampled image resolution.
|
||||
pos_embed (torch.Tensor): The pos_embed weighs, it should be
|
||||
shape of [B, L2, c].
|
||||
Return:
|
||||
torch.Tensor: The pos encoded image feature.
|
||||
"""
|
||||
assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
|
||||
'the shapes of patched_img and pos_embed must be [B, L, C]'
|
||||
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
|
||||
if x_len != pos_len:
|
||||
if pos_len == (self.img_size[0] // self.patch_size) * (
|
||||
self.img_size[1] // self.patch_size) + 1:
|
||||
pos_h = self.img_size[0] // self.patch_size
|
||||
pos_w = self.img_size[1] // self.patch_size
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unexpected shape of pos_embed, got {}.'.format(
|
||||
pos_embed.shape))
|
||||
pos_embed = self.resize_pos_embed(pos_embed, hw_shape,
|
||||
(pos_h, pos_w),
|
||||
self.interpolate_mode)
|
||||
return self.drop_after_pos(patched_img + pos_embed)
|
||||
|
||||
@staticmethod
|
||||
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
|
||||
"""Resize pos_embed weights.
|
||||
Resize pos_embed using bicubic interpolate method.
|
||||
Args:
|
||||
pos_embed (torch.Tensor): Position embedding weights.
|
||||
input_shpae (tuple): Tuple for (downsampled input image height,
|
||||
downsampled input image width).
|
||||
pos_shape (tuple): The resolution of downsampled origin training
|
||||
image.
|
||||
mode (str): Algorithm used for upsampling:
|
||||
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||
``'trilinear'``. Default: ``'nearest'``
|
||||
Return:
|
||||
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
|
||||
"""
|
||||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
||||
pos_h, pos_w = pos_shape
|
||||
# keep dim for easy deployment
|
||||
cls_token_weight = pos_embed[:, 0:1]
|
||||
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
||||
pos_embed_weight = pos_embed_weight.reshape(
|
||||
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
||||
pos_embed_weight = resize(pos_embed_weight,
|
||||
size=input_shpae,
|
||||
align_corners=False,
|
||||
mode=mode)
|
||||
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
||||
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
||||
return pos_embed
|
||||
|
||||
def forward(self, inputs, ccd_index=None):
|
||||
B = inputs.shape[0]
|
||||
|
||||
x, hw_shape = self.patch_embed(inputs)
|
||||
|
||||
if self.use_ccd:
|
||||
_ccd_idx = np.concatenate(ccd_index, axis=0)
|
||||
_ccd_embed = self.ccd_embed[:, _ccd_idx, :].permute(1, 0, 2)
|
||||
x = x + _ccd_embed
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = self._pos_embeding(x, hw_shape, self.pos_embed)
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i == len(self.layers) - 1:
|
||||
if self.final_norm:
|
||||
x = self.norm1(x)
|
||||
if i in self.out_indices:
|
||||
if self.with_cls_token:
|
||||
# Remove class token and reshape token for decoder heads
|
||||
out = x[:, 1:]
|
||||
else:
|
||||
out = x
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
if self.output_cls_token:
|
||||
out = [out, x[:, 0]]
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
super(VisionTransformer, self).train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.LayerNorm):
|
||||
m.eval()
|
||||
|
||||
|
||||
class VisionTransformerMSL(VisionTransformer):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if 'use_attn' in kwargs:
|
||||
self.use_attn = kwargs.pop('use_attn')
|
||||
else:
|
||||
self.use_attn = False
|
||||
if 'merge_stage' in kwargs:
|
||||
self.merge_stage = kwargs.pop('merge_stage')
|
||||
else:
|
||||
self.merge_stage = 0
|
||||
if 'with_cls_pos' in kwargs:
|
||||
self.with_cls_pos = kwargs.pop('with_cls_pos')
|
||||
else:
|
||||
self.with_cls_pos = False
|
||||
|
||||
self.vocabulary_size = kwargs.pop('vocabulary_size') + 1 # 增加ignore类别
|
||||
super().__init__(**kwargs)
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
img_size = kwargs.pop('img_size')
|
||||
patch_size = kwargs.pop('patch_size')
|
||||
num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dims))
|
||||
|
||||
self.vocabulary_token = nn.Parameter(torch.zeros(self.vocabulary_size, self.embed_dims))
|
||||
self.vocabulary_weight = nn.Parameter(torch.zeros(1, self.patch_size * self.patch_size))
|
||||
trunc_normal_(self.mask_token, mean=0., std=.02)
|
||||
trunc_normal_(self.vocabulary_token, mean=0., std=.02)
|
||||
|
||||
if self.use_attn:
|
||||
self.attn1 = MultiheadAttention(embed_dims=1024, num_heads=16, batch_first=True, bias = True)
|
||||
self.attn2 = MultiheadAttention(embed_dims=1024, num_heads=16, batch_first=True, bias = True)
|
||||
self.attn3 = MultiheadAttention(embed_dims=1024, num_heads=16, batch_first=True, bias = True)
|
||||
self.attention_blocks = [self.attn1, self.attn2, self.attn3]
|
||||
self.norm_attn = build_norm_layer(dict(type='LN'), 1024)[1]
|
||||
|
||||
def create_ann_token(self, anno_img):
|
||||
B, H, W = anno_img.shape
|
||||
ann_token = torch.index_select(self.vocabulary_token, 0, anno_img.reshape(-1)).reshape(B, H, W, -1)
|
||||
assert H % self.patch_size == 0 and W % self.patch_size == 0
|
||||
nph, npw = H // self.patch_size, W // self.patch_size
|
||||
weight = F.softmax(self.vocabulary_weight, dim=1) * self.patch_size * self.patch_size
|
||||
weight = weight.reshape(1, 1, self.patch_size, 1, self.patch_size).repeat(1, nph, 1, npw, 1).reshape(1, H, W, 1)
|
||||
ann_token = ann_token * weight
|
||||
ann_token = F.avg_pool2d(torch.einsum('BHWC->BCHW', ann_token), self.patch_size, self.patch_size)
|
||||
ann_token = torch.einsum('BCHW->BHWC', ann_token).reshape(B, nph * npw, self.embed_dims) # shape B, L, C
|
||||
return ann_token
|
||||
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
|
||||
"""Positioning embeding method.
|
||||
Resize the pos_embed, if the input image size doesn't match
|
||||
the training size.
|
||||
Args:
|
||||
patched_img (torch.Tensor): The patched image, it should be
|
||||
shape of [B, L1, C].
|
||||
hw_shape (tuple): The downsampled image resolution.
|
||||
pos_embed (torch.Tensor): The pos_embed weighs, it should be
|
||||
shape of [B, L2, c].
|
||||
Return:
|
||||
torch.Tensor: The pos encoded image feature.
|
||||
"""
|
||||
assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
|
||||
'the shapes of patched_img and pos_embed must be [B, L, C]'
|
||||
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
|
||||
if x_len != pos_len:
|
||||
if pos_len == (self.img_size[0] // self.patch_size) * (
|
||||
self.img_size[1] // self.patch_size):
|
||||
pos_h = self.img_size[0] // self.patch_size
|
||||
pos_w = self.img_size[1] // self.patch_size
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unexpected shape of pos_embed, got {}.'.format(
|
||||
pos_embed.shape))
|
||||
pos_embed = self.resize_pos_embed(pos_embed, hw_shape,
|
||||
(pos_h, pos_w),
|
||||
self.interpolate_mode)
|
||||
return self.drop_after_pos(patched_img + pos_embed)
|
||||
|
||||
@staticmethod
|
||||
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
|
||||
"""Resize pos_embed weights.
|
||||
Resize pos_embed using bicubic interpolate method.
|
||||
Args:
|
||||
pos_embed (torch.Tensor): Position embedding weights.
|
||||
input_shpae (tuple): Tuple for (downsampled input image height,
|
||||
downsampled input image width).
|
||||
pos_shape (tuple): The resolution of downsampled origin training
|
||||
image.
|
||||
mode (str): Algorithm used for upsampling:
|
||||
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||
``'trilinear'``. Default: ``'nearest'``
|
||||
Return:
|
||||
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
|
||||
"""
|
||||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
||||
pos_h, pos_w = pos_shape
|
||||
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
||||
pos_embed_weight = pos_embed_weight.reshape(
|
||||
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
||||
pos_embed_weight = resize(pos_embed_weight,
|
||||
size=input_shpae,
|
||||
align_corners=False,
|
||||
mode=mode)
|
||||
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
||||
pos_embed = pos_embed_weight # torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
||||
return pos_embed
|
||||
|
||||
def forward(self, x, y, mask=None):
|
||||
x, hw_shape = self.patch_embed(x)
|
||||
y = self.create_ann_token(y)
|
||||
assert x.shape == y.shape
|
||||
B, L, C = y.shape
|
||||
if mask is not None:
|
||||
mask_tokens = self.mask_token.expand(B, L, -1)
|
||||
w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
|
||||
y = y * (1. - w) + mask_tokens * w
|
||||
if self.merge_stage == 0:
|
||||
x = (x + y) * 0.5
|
||||
else:
|
||||
x = x.reshape(B, *hw_shape, C)
|
||||
y = y.reshape(B, *hw_shape, C)
|
||||
x = torch.cat((x, y), dim=2)
|
||||
hw_shape = (hw_shape[0], hw_shape[1] * 2)
|
||||
x = x.reshape(B, -1, C)
|
||||
x = self._pos_embeding(x, hw_shape, self.pos_embed)
|
||||
merge_idx = self.merge_stage - 1
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i == merge_idx:
|
||||
x = x.reshape(x.shape[0], *hw_shape, x.shape[-1]) # b,l,c -> b, h, w, c
|
||||
x = (x[:, :, :x.shape[2]//2] + x[:, :, x.shape[2]//2:]) * 0.5
|
||||
x = x.reshape(x.shape[0], -1, x.shape[-1])
|
||||
hw_shape = (hw_shape[0], hw_shape[1] // 2)
|
||||
if self.use_attn:
|
||||
if i <= len(self.attention_blocks) - 1:
|
||||
x = x + self.attention_blocks[i](x)
|
||||
if i == len(self.attention_blocks) - 1:
|
||||
x = self.norm_attn(x) # 会不会有冲突
|
||||
if (not self.use_attn) and (i == len(self.layers) - 1):
|
||||
if self.final_norm:
|
||||
x = self.norm1(x) # 会不会有冲突
|
||||
if i in self.out_indices:
|
||||
if self.with_cls_token:
|
||||
# Remove class token and reshape token for decoder heads
|
||||
out = x[:, 1:]
|
||||
else:
|
||||
out = x
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
if self.output_cls_token:
|
||||
out = [out, x[:, 0]]
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
15
lib/models/heads/__init__.py
Normal file
15
lib/models/heads/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .uper_head import UPerHead
|
||||
from .up_head import UPHead
|
||||
|
||||
__all__ = [
|
||||
'UPerHead', 'UPHead'
|
||||
]
|
||||
|
||||
type_mapping = {
|
||||
'UPerHead': UPerHead,
|
||||
'UPHead': UPHead
|
||||
}
|
||||
|
||||
|
||||
def build_head(type, **kwargs):
|
||||
return type_mapping[type](**kwargs)
|
||||
201
lib/models/heads/decode_head.py
Normal file
201
lib/models/heads/decode_head.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule, auto_fp16
|
||||
|
||||
from mmseg.core import build_pixel_sampler
|
||||
from mmseg.ops import resize
|
||||
|
||||
|
||||
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
"""Base class for BaseDecodeHead.
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
num_classes (int): Number of classes.
|
||||
out_channels (int): Output channels of conv_seg.
|
||||
threshold (float): Threshold for binary segmentation in the case of
|
||||
`out_channels==1`. Default: None.
|
||||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
||||
conv_cfg (dict|None): Config of conv layers. Default: None.
|
||||
norm_cfg (dict|None): Config of norm layers. Default: None.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU')
|
||||
in_index (int|Sequence[int]): Input feature index. Default: -1
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
Default: None.
|
||||
loss_decode (dict | Sequence[dict]): Config of decode loss.
|
||||
The `loss_name` is property of corresponding loss function which
|
||||
could be shown in training log. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_ce'.
|
||||
e.g. dict(type='CrossEntropyLoss'),
|
||||
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
|
||||
dict(type='DiceLoss', loss_name='loss_dice')]
|
||||
Default: dict(type='CrossEntropyLoss').
|
||||
ignore_index (int | None): The label index to be ignored. When using
|
||||
masked BCE loss, ignore_index should be set to None. Default: 255.
|
||||
sampler (dict|None): The config of segmentation map sampler.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channels,
|
||||
*,
|
||||
num_classes,
|
||||
out_channels=None,
|
||||
threshold=None,
|
||||
dropout_ratio=0.1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
in_index=-1,
|
||||
input_transform=None,
|
||||
sampler=None,
|
||||
align_corners=False,
|
||||
init_cfg=dict(type='Normal',
|
||||
std=0.01,
|
||||
override=dict(name='conv_seg'))):
|
||||
super(BaseDecodeHead, self).__init__(init_cfg)
|
||||
self._init_inputs(in_channels, in_index, input_transform)
|
||||
self.channels = channels
|
||||
self.dropout_ratio = dropout_ratio
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.in_index = in_index
|
||||
|
||||
self.align_corners = align_corners
|
||||
|
||||
if out_channels is None:
|
||||
if num_classes == 2:
|
||||
warnings.warn('For binary segmentation, we suggest using'
|
||||
'`out_channels = 1` to define the output'
|
||||
'channels of segmentor, and use `threshold`'
|
||||
'to convert seg_logist into a prediction'
|
||||
'applying a threshold')
|
||||
out_channels = num_classes
|
||||
|
||||
if out_channels != num_classes and out_channels != 1:
|
||||
raise ValueError(
|
||||
'out_channels should be equal to num_classes,'
|
||||
'except binary segmentation set out_channels == 1 and'
|
||||
f'num_classes == 2, but got out_channels={out_channels}'
|
||||
f'and num_classes={num_classes}')
|
||||
|
||||
if out_channels == 1 and threshold is None:
|
||||
threshold = 0.3
|
||||
warnings.warn('threshold is not defined for binary, and defaults'
|
||||
'to 0.3')
|
||||
self.num_classes = num_classes
|
||||
self.out_channels = out_channels
|
||||
self.threshold = threshold
|
||||
|
||||
if sampler is not None:
|
||||
self.sampler = build_pixel_sampler(sampler, context=self)
|
||||
else:
|
||||
self.sampler = None
|
||||
|
||||
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
|
||||
if dropout_ratio > 0:
|
||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
||||
else:
|
||||
self.dropout = None
|
||||
self.fp16_enabled = False
|
||||
|
||||
def extra_repr(self):
|
||||
"""Extra repr."""
|
||||
s = f'input_transform={self.input_transform}, ' \
|
||||
f'align_corners={self.align_corners}'
|
||||
return s
|
||||
|
||||
def _init_inputs(self, in_channels, in_index, input_transform):
|
||||
"""Check and initialize input transforms.
|
||||
|
||||
The in_channels, in_index and input_transform must match.
|
||||
Specifically, when input_transform is None, only single feature map
|
||||
will be selected. So in_channels and in_index must be of type int.
|
||||
When input_transform
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
in_index (int|Sequence[int]): Input feature index.
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
"""
|
||||
|
||||
if input_transform is not None:
|
||||
assert input_transform in ['resize_concat', 'multiple_select']
|
||||
self.input_transform = input_transform
|
||||
self.in_index = in_index
|
||||
if input_transform is not None:
|
||||
assert isinstance(in_channels, (list, tuple))
|
||||
assert isinstance(in_index, (list, tuple))
|
||||
assert len(in_channels) == len(in_index)
|
||||
if input_transform == 'resize_concat':
|
||||
self.in_channels = sum(in_channels)
|
||||
else:
|
||||
self.in_channels = in_channels
|
||||
else:
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(in_index, int)
|
||||
self.in_channels = in_channels
|
||||
|
||||
def _transform_inputs(self, inputs):
|
||||
"""Transform inputs for decoder.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
Tensor: The transformed inputs
|
||||
"""
|
||||
|
||||
if self.input_transform == 'resize_concat':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
upsampled_inputs = [
|
||||
resize(input=x,
|
||||
size=inputs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners) for x in inputs
|
||||
]
|
||||
inputs = torch.cat(upsampled_inputs, dim=1)
|
||||
elif self.input_transform == 'multiple_select':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
else:
|
||||
inputs = inputs[self.in_index]
|
||||
|
||||
return inputs
|
||||
|
||||
@auto_fp16()
|
||||
@abstractmethod
|
||||
def forward(self, inputs):
|
||||
"""Placeholder of forward function."""
|
||||
pass
|
||||
|
||||
def cls_seg(self, feat):
|
||||
"""Classify each pixel."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.conv_seg(feat)
|
||||
return output
|
||||
60
lib/models/heads/psp_head.py
Normal file
60
lib/models/heads/psp_head.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class PPM(nn.ModuleList):
|
||||
"""Pooling Pyramid Module used in PSPNet.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
|
||||
act_cfg, align_corners, **kwargs):
|
||||
super(PPM, self).__init__()
|
||||
self.pool_scales = pool_scales
|
||||
self.align_corners = align_corners
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
for pool_scale in pool_scales:
|
||||
self.append(
|
||||
nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(pool_scale),
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
**kwargs)))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
ppm_outs = []
|
||||
for ppm in self:
|
||||
ppm_out = ppm(x)
|
||||
ppm_out = ppm_out.to(torch.float32)
|
||||
upsampled_ppm_out = resize(
|
||||
ppm_out,
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
upsampled_ppm_out = upsampled_ppm_out.to(torch.bfloat16)
|
||||
ppm_outs.append(upsampled_ppm_out)
|
||||
return ppm_outs
|
||||
52
lib/models/heads/up_head.py
Normal file
52
lib/models/heads/up_head.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
from mmcv.cnn.utils.weight_init import (kaiming_init, trunc_normal_)
|
||||
from mmcv.runner import (CheckpointLoader, load_state_dict)
|
||||
from mmseg.utils import get_root_logger
|
||||
|
||||
|
||||
class UPHead(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, up_scale, init_cfg=None):
|
||||
super().__init__()
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Conv2d(in_channels=in_dim,
|
||||
out_channels=up_scale**2 * out_dim,
|
||||
kernel_size=1),
|
||||
nn.PixelShuffle(up_scale),
|
||||
)
|
||||
self.init_cfg = init_cfg
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
logger = get_root_logger()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
_state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
_state_dict = checkpoint
|
||||
|
||||
state_dict = OrderedDict()
|
||||
for k, v in _state_dict.items():
|
||||
if k.startswith('backbone.'):
|
||||
state_dict[k[9:]] = v
|
||||
else:
|
||||
state_dict[k] = v
|
||||
print(f'loading weight: {self.init_cfg["checkpoint"]}')
|
||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||
else:
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
130
lib/models/heads/uper_head.py
Normal file
130
lib/models/heads/uper_head.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# coding: utf-8
|
||||
# Copyright (c) Ant Group. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from .psp_head import PPM
|
||||
|
||||
|
||||
class UPerHead(BaseDecodeHead):
|
||||
"""Unified Perceptual Parsing for Scene Understanding.
|
||||
|
||||
This head is the implementation of `UPerNet
|
||||
<https://arxiv.org/abs/1807.10221>`_.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module applied on the last feature. Default: (1, 2, 3, 6).
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
||||
super(UPerHead, self).__init__(
|
||||
input_transform='multiple_select', **kwargs)
|
||||
# PSP Module
|
||||
self.psp_modules = PPM(
|
||||
pool_scales,
|
||||
self.in_channels[-1],
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels[-1] + len(pool_scales) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
# FPN Module
|
||||
self.lateral_convs = nn.ModuleList()
|
||||
self.fpn_convs = nn.ModuleList()
|
||||
for in_channels in self.in_channels[:-1]: # skip the top layer
|
||||
l_conv = ConvModule(
|
||||
in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
inplace=False)
|
||||
fpn_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
inplace=False)
|
||||
self.lateral_convs.append(l_conv)
|
||||
self.fpn_convs.append(fpn_conv)
|
||||
|
||||
self.fpn_bottleneck = ConvModule(
|
||||
len(self.in_channels) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def psp_forward(self, inputs):
|
||||
"""Forward function of PSP module."""
|
||||
x = inputs[-1]
|
||||
psp_outs = [x]
|
||||
# breakpoint()
|
||||
psp_outs.extend(self.psp_modules(x))
|
||||
psp_outs = torch.cat(psp_outs, dim=1)
|
||||
output = self.bottleneck(psp_outs)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
# breakpoint()
|
||||
inputs = self._transform_inputs(inputs)
|
||||
|
||||
# build laterals
|
||||
laterals = [
|
||||
lateral_conv(inputs[i])
|
||||
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||
]
|
||||
|
||||
laterals.append(self.psp_forward(inputs))
|
||||
|
||||
# build top-down path
|
||||
used_backbone_levels = len(laterals)
|
||||
for i in range(used_backbone_levels - 1, 0, -1):
|
||||
prev_shape = laterals[i - 1].shape[2:]
|
||||
laterals[i] = laterals[i].type(torch.float32)
|
||||
laterals[i - 1] = laterals[i - 1] + resize(
|
||||
laterals[i],
|
||||
size=prev_shape,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
# build outputs
|
||||
fpn_outs = [
|
||||
self.fpn_convs[i](laterals[i])
|
||||
for i in range(used_backbone_levels - 1)
|
||||
]
|
||||
# append psp feature
|
||||
fpn_outs.append(laterals[-1])
|
||||
|
||||
for i in range(used_backbone_levels - 1, 0, -1):
|
||||
fpn_outs[i] = fpn_outs[i].type(torch.float32)
|
||||
fpn_outs[i] = resize(
|
||||
fpn_outs[i],
|
||||
size=fpn_outs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
fpn_outs = torch.cat(fpn_outs, dim=1)
|
||||
output = self.fpn_bottleneck(fpn_outs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
4
lib/models/losses/__init__.py
Normal file
4
lib/models/losses/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .modality_vae_loss import ModalityVAELoss
|
||||
from .recon_anno_loss import RecLoss
|
||||
|
||||
__all__ = [ "ModalityVAELoss", "RecLoss" ]
|
||||
46
lib/models/losses/modality_vae_loss.py
Normal file
46
lib/models/losses/modality_vae_loss.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) Ant Group and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from antmmf.common.registry import registry
|
||||
|
||||
|
||||
@registry.register_loss("ModalityVAELoss")
|
||||
class ModalityVAELoss(nn.Module):
|
||||
def __init__(self, **params):
|
||||
super().__init__()
|
||||
self.weight = params.pop("weight")
|
||||
|
||||
def compute_rec_loss(self, x_in, x_out, modal_flag):
|
||||
loss_per_pixel = F.mse_loss(x_in, x_out, reduction='none')
|
||||
loss_b = torch.mean(loss_per_pixel, dim=[1, 2, 3])
|
||||
return torch.sum(loss_b * modal_flag)/ (modal_flag.sum() + 1e-6)
|
||||
|
||||
def forward(self, sample_list, output, *args, **kwargs):
|
||||
vae_out = output["vae_out"]
|
||||
feat_hr = vae_out['input_hr']
|
||||
feat_s2 = vae_out['input_s2']
|
||||
feat_s1 = vae_out['input_s1']
|
||||
|
||||
g_hr = vae_out['g_hr']
|
||||
g_s2 = vae_out['g_s2']
|
||||
g_s1 = vae_out['g_s1']
|
||||
|
||||
# process modality flags
|
||||
modality_info = vae_out['modality_info']
|
||||
B_M, L_M = modality_info.shape
|
||||
|
||||
modality_hr = modality_info[:,0]
|
||||
modality_s2 = modality_info[:,1]
|
||||
modality_s1 = modality_info[:,2]
|
||||
|
||||
######## rec losses ########
|
||||
loss_xent = self.compute_rec_loss(g_hr, feat_hr, modality_hr) \
|
||||
+ self.compute_rec_loss(g_s2, feat_s2, modality_s2) \
|
||||
+ self.compute_rec_loss(g_s1, feat_s1, modality_s1)
|
||||
|
||||
|
||||
loss_quant = vae_out["loss_quant"]
|
||||
total_loss = loss_xent / 3 + loss_quant
|
||||
return total_loss * self.weight
|
||||
89
lib/models/losses/recon_anno_loss.py
Normal file
89
lib/models/losses/recon_anno_loss.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# Copyright (c) Ant Group and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from antmmf.common.registry import registry
|
||||
import torch.nn.functional as F
|
||||
|
||||
@registry.register_loss("RecLoss")
|
||||
class RecLoss(nn.Module):
|
||||
def __init__(self, **params):
|
||||
super().__init__()
|
||||
self.weight = params.pop("weight")
|
||||
self.patch_size = params.pop("patch_size")
|
||||
self.eps = torch.finfo(torch.bfloat16).eps
|
||||
self.pred_key = params.pop("pred_key")
|
||||
self.vocabulary_size = params.pop("vocabulary_size") + 1
|
||||
self.mask_key = params.pop("mask_key")
|
||||
self.target_key = params.pop("target_key")
|
||||
self.feature_merged = params.pop("feature_merged")
|
||||
self.cnt_train = 0
|
||||
self.cnt_val = 0
|
||||
self.use_bg = params.pop("use_bg")
|
||||
if "use_all_patch" in params:
|
||||
self.use_all_patch = params.pop("use_all_patch")
|
||||
else:
|
||||
self.use_all_patch = False
|
||||
if "balance" in params:
|
||||
self.balance = params.pop("balance")
|
||||
else:
|
||||
self.balance = False
|
||||
if "sim_regularization" in params:
|
||||
self.sim_regularization = params.pop("sim_regularization")
|
||||
else:
|
||||
self.sim_regularization = False
|
||||
|
||||
def unpatchify(self, x):
|
||||
"""
|
||||
x: (N, L, patch_size**2 *3)
|
||||
imgs: (N, 3, H, W)
|
||||
"""
|
||||
p = self.patch_size
|
||||
w = int((x.shape[1]*0.5)**.5)
|
||||
h = w * 2
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p))
|
||||
x = torch.einsum('nhwpq->nhpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], h * p, w * p))
|
||||
return imgs
|
||||
|
||||
|
||||
def forward(self, sample_list, output, *args, **kwargs):
|
||||
pred = output[self.pred_key] # B, C, H, W
|
||||
target = output[self.target_key] # B, H, W
|
||||
mask = output[self.mask_key]
|
||||
b_mask, h_mask, w_mask = mask.shape
|
||||
mask = mask.reshape((b_mask, h_mask*w_mask))
|
||||
mask = mask[:, :, None].repeat(1, 1, self.patch_size**2)
|
||||
mask = self.unpatchify(mask)
|
||||
|
||||
if not self.use_bg:
|
||||
valid = sample_list['valid']
|
||||
mask = mask * valid
|
||||
|
||||
loss = F.cross_entropy(pred, target, reduction="none")
|
||||
|
||||
if self.balance:
|
||||
if self.use_all_patch:
|
||||
loss_pos = loss[target > 0].sum() / ((target > 0).sum() + 1e-6)
|
||||
loss_neg = loss[target == 0].sum() / ((target == 0).sum() + 1e-6)
|
||||
loss = (loss_pos + loss_neg) * 0.5
|
||||
else:
|
||||
loss_pos = loss[(target > 0) & (mask == 1)].sum() / (((target > 0) & (mask == 1)).sum() + 1e-6)
|
||||
loss_neg = loss[(target == 0) & (mask == 1)].sum() / (((target == 0) & (mask == 1)).sum() + 1e-6)
|
||||
loss = (loss_pos + loss_neg) * 0.5
|
||||
else:
|
||||
if self.use_all_patch:
|
||||
loss = loss.mean()
|
||||
else:
|
||||
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
||||
if self.sim_regularization:
|
||||
vocabulary_token = output['vocabulary_token']
|
||||
voca_normed = F.normalize(vocabulary_token, 2, 1)
|
||||
similarity_matrix = 1 + torch.einsum('nd,md->nm', voca_normed, voca_normed)
|
||||
num = voca_normed.shape[0]
|
||||
index = torch.triu(voca_normed.new_ones(num, num), diagonal=1).type(torch.bool)
|
||||
loss_reg = similarity_matrix[index].mean()
|
||||
return loss * self.weight + loss_reg * 0.05
|
||||
return loss * self.weight
|
||||
|
||||
4
lib/models/metrics/__init__.py
Normal file
4
lib/models/metrics/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .sem_metrics import SemMetric
|
||||
|
||||
__all__ = ["SemMetric"]
|
||||
|
||||
93
lib/models/metrics/sem_metrics.py
Normal file
93
lib/models/metrics/sem_metrics.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# coding: utf-8
|
||||
# Copyright (c) Ant Group. All rights reserved.
|
||||
import torch
|
||||
from torch.distributed import all_reduce, ReduceOp
|
||||
from antmmf.common.registry import registry
|
||||
from antmmf.modules.metrics.base_metric import BaseMetric
|
||||
|
||||
@registry.register_metric("sem_metric")
|
||||
class SemMetric(BaseMetric):
|
||||
"""Segmentation metrics used in evaluation phase.
|
||||
|
||||
Args:
|
||||
name (str): Name of the metric.
|
||||
eval_type(str): 3 types are supported: 'mIoU', 'mDice', 'mFscore'
|
||||
result_field(str): key of predicted results in output dict
|
||||
target_field(str): key of ground truth in output dict
|
||||
ignore_index(int): class value will be ignored in evaluation
|
||||
num_cls(int): total number of categories in evaluation
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name="dummy_metric", **kwargs
|
||||
):
|
||||
super().__init__(name)
|
||||
self.reset()
|
||||
|
||||
def calculate(self, sample_list, model_output, *args, **kwargs):
|
||||
"""Calculate Intersection and Union for a batch.
|
||||
|
||||
Args:
|
||||
sample_list (Sample_List): data which contains ground truth segmentation maps
|
||||
model_output (dict): data which contains prediction segmentation maps
|
||||
Returns:
|
||||
torch.Tensor: The intersection of prediction and ground truth histogram
|
||||
on all classes.
|
||||
torch.Tensor: The union of prediction and ground truth histogram on all
|
||||
classes.
|
||||
torch.Tensor: The prediction histogram on all classes.
|
||||
torch.Tensor: The ground truth histogram on all classes.
|
||||
"""
|
||||
|
||||
return torch.tensor(0).float()
|
||||
|
||||
def reset(self):
|
||||
""" initialized all attributes value before evaluation
|
||||
|
||||
"""
|
||||
self.total_mask_mae = 0
|
||||
self.total_num = torch.tensor(0)
|
||||
|
||||
def collect(self, sample_list, model_output, *args, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
sample_list(Sample_List): data which contains ground truth segmentation maps
|
||||
model_output (Dict): Dict returned by model, that contains two modalities
|
||||
Returns:
|
||||
torch.FloatTensor: Accuracy
|
||||
"""
|
||||
batch_mask_mae = \
|
||||
self.calculate(sample_list, model_output, *args, **kwargs)
|
||||
self.total_mask_mae += batch_mask_mae
|
||||
self.total_num += 1
|
||||
|
||||
def format(self, *args):
|
||||
""" Format evaluated metrics for profile.
|
||||
|
||||
Returns:
|
||||
dict: dict of all evaluated metrics.
|
||||
"""
|
||||
output_metric = dict()
|
||||
# if self.eval_type == 'mae':
|
||||
mae = args[0]
|
||||
output_metric['mae'] = mae.item()
|
||||
return output_metric
|
||||
|
||||
def summarize(self, *args, **kwargs):
|
||||
"""This method is used to calculate the overall metric.
|
||||
|
||||
Returns:
|
||||
dict: dict of all evaluated metrics.
|
||||
|
||||
"""
|
||||
# if self.eval_type == 'mae':
|
||||
mae = self.total_mask_mae / (self.total_num)
|
||||
return self.format(mae)
|
||||
|
||||
def all_reduce(self):
|
||||
total_number = torch.stack([
|
||||
self.total_mask_mae, self.total_num
|
||||
]).cuda()
|
||||
all_reduce(total_number, op=ReduceOp.SUM)
|
||||
self.total_mask_mae = total_number[0].cpu()
|
||||
self.total_num = total_number[1].cpu()
|
||||
13
lib/models/necks/__init__.py
Normal file
13
lib/models/necks/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .transformer_encoder import TransformerEncoder
|
||||
from .modality_completion import ModalityCompletion
|
||||
|
||||
__all__ = ['TransformerEncoder', 'ModalityCompletion']
|
||||
|
||||
type_mapping = {
|
||||
'TransformerEncoder': TransformerEncoder,
|
||||
'ModalityCompletion': ModalityCompletion
|
||||
}
|
||||
|
||||
|
||||
def build_neck(type, **kwargs):
|
||||
return type_mapping[type](**kwargs)
|
||||
212
lib/models/necks/modality_completion.py
Normal file
212
lib/models/necks/modality_completion.py
Normal file
@@ -0,0 +1,212 @@
|
||||
# Copyright (c) AntGroup. All rights reserved.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class BFloat16UpsampleNearest2d(nn.Module):
|
||||
def __init__(self, scale_factor, mode='bilinear'):
|
||||
super().__init__()
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, x):
|
||||
x_float = x.float()
|
||||
upsampled_x = F.interpolate(x_float, scale_factor=self.scale_factor, mode=self.mode)
|
||||
return upsampled_x.to(x.dtype)
|
||||
|
||||
class ConvVQVAEv2(nn.Module):
|
||||
def __init__(self, input_shape, conv_dim, z_dim, num_tokens=8192, temp=0.9):
|
||||
super().__init__()
|
||||
self.z_dim = z_dim
|
||||
self.conv_dim = conv_dim # 256
|
||||
self.input_shape = input_shape # 256
|
||||
self.temp = temp
|
||||
# code book
|
||||
self.codebook = nn.Embedding(num_tokens, z_dim)
|
||||
# encoder
|
||||
self.relu = nn.LeakyReLU()
|
||||
self.pool = nn.AvgPool2d(2)
|
||||
self.conv1 = nn.Conv2d(input_shape[0], conv_dim, 5, stride=1, padding=2)
|
||||
self.enc_block1 = nn.Sequential(
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
self.gamma_1 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
|
||||
self.enc_block2 = nn.Sequential(
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
self.gamma_2 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
|
||||
self.logit_conv = nn.Conv2d(conv_dim, num_tokens, 1)
|
||||
# decoder
|
||||
self.unpool = BFloat16UpsampleNearest2d(scale_factor=2)
|
||||
self.conv2 = nn.Conv2d(z_dim, conv_dim, 3, stride=1, padding=1)
|
||||
self.dec_block1 = nn.Sequential(
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
self.gamma_3 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
|
||||
self.dec_block2 = nn.Sequential(
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
self.gamma_4 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
|
||||
self.rec_conv = nn.Conv2d(conv_dim, input_shape[0], 3, stride=1, padding=1)
|
||||
|
||||
def forward_encoder(self, x):
|
||||
x = self.relu(self.conv1(x))
|
||||
x = x + self.gamma_1 * self.enc_block1(x)
|
||||
x = self.pool(x)
|
||||
x = x + self.gamma_2 * self.enc_block2(x)
|
||||
x = self.pool(x)
|
||||
logits = self.logit_conv(x)
|
||||
return logits
|
||||
|
||||
def forward_decoder(self, logits):
|
||||
soft_one_hot = F.softmax(logits * (self.temp*10), dim=1)
|
||||
sampled = torch.einsum('bnhw,nd->bdhw', soft_one_hot, self.codebook.weight)
|
||||
x = self.relu(self.conv2(sampled))
|
||||
x = self.unpool(x)
|
||||
x = x + self.gamma_3 * self.dec_block1(x)
|
||||
x = self.unpool(x)
|
||||
x = x + self.gamma_4 * self.dec_block2(x)
|
||||
rec_feats = self.rec_conv(x)
|
||||
return rec_feats, soft_one_hot
|
||||
|
||||
def forward(self, x):
|
||||
print(x.shape)
|
||||
logits = self.forward_encoder(x)
|
||||
images_p, soft_one_hot = self.forward_decoder(logits)
|
||||
return [logits, images_p]
|
||||
|
||||
class ModalityCompletion(nn.Module):
|
||||
def __init__(self,
|
||||
input_shape_hr=(2816, 16, 16),
|
||||
input_shape_s2=(2816, 16, 16),
|
||||
input_shape_s1=(2816, 16, 16),
|
||||
conv_dim=256,
|
||||
z_dim=256,
|
||||
n_codebook=8192,
|
||||
init_cfg=None
|
||||
):
|
||||
super(ModalityCompletion, self).__init__()
|
||||
self.vae_hr = ConvVQVAEv2(input_shape=input_shape_hr, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook)
|
||||
self.vae_s2 = ConvVQVAEv2(input_shape=input_shape_s2, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook)
|
||||
self.vae_s1 = ConvVQVAEv2(input_shape=input_shape_s1, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook)
|
||||
self.kl_div_loss = torch.nn.KLDivLoss(reduction="none", log_target=True)
|
||||
self.init_cfg=init_cfg
|
||||
|
||||
def init_weights(self):
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
# Suppress default init if use pretrained model.
|
||||
from mmcls.utils import get_root_logger
|
||||
from mmcv.runner import CheckpointLoader, load_state_dict
|
||||
logger = get_root_logger()
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||
else:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def kl_loss(self, logits_hr, logits_s2, logits_s1, modality_info):
|
||||
prob_hr = F.log_softmax(logits_hr, dim=1)
|
||||
prob_s2 = F.log_softmax(logits_s2, dim=1)
|
||||
prob_s1 = F.log_softmax(logits_s1, dim=1)
|
||||
flag_hr = modality_info[:,0][:, None, None, None]
|
||||
flag_s2 = modality_info[:,1][:, None, None, None]
|
||||
flag_s1 = modality_info[:,2][:, None, None, None]
|
||||
loss_hr_s2 = self.kl_div_loss(prob_hr, prob_s2) + self.kl_div_loss(prob_s2, prob_hr)
|
||||
loss_hr_s2 = (loss_hr_s2 * flag_hr * flag_s2).sum((1, 2, 3)).mean()
|
||||
loss_hr_s1 = self.kl_div_loss(prob_hr, prob_s1) + self.kl_div_loss(prob_s1, prob_hr)
|
||||
loss_hr_s1 = (loss_hr_s1 * flag_hr * flag_s1).sum((1, 2, 3)).mean()
|
||||
loss_s2_s1 = self.kl_div_loss(prob_s2, prob_s1) + self.kl_div_loss(prob_s1, prob_s2)
|
||||
loss_s2_s1 = (loss_s2_s1 * flag_s2 * flag_s1).sum((1, 2, 3)).mean()
|
||||
loss = (loss_hr_s2 + loss_hr_s1 + loss_s2_s1) / 6.0
|
||||
|
||||
return loss
|
||||
|
||||
def forward(self, feat_hr, feat_s2, feat_s1, modality_info):
|
||||
# encoders,add noise
|
||||
# each modality
|
||||
# 2816, 16, 16 => conv 256, 4, 4 => flatten 4096(256*4*4) => linear mu 256, log_var 256
|
||||
B, C, H, W = feat_hr.shape
|
||||
B_M, L_M = modality_info.shape
|
||||
assert B == B_M, f'feat_hr batch: {B}, modality_info batch: {B_M}'
|
||||
|
||||
# quant, emb_loss, info
|
||||
# hr input flow
|
||||
logits_hr = self.vae_hr.forward_encoder(feat_hr)
|
||||
logits_s2 = self.vae_s2.forward_encoder(feat_s2)
|
||||
logits_s1 = self.vae_s1.forward_encoder(feat_s1)
|
||||
modality_hr = modality_info[:,0]
|
||||
modality_s2 = modality_info[:,1]
|
||||
modality_s1 = modality_info[:,2]
|
||||
flag_hr = modality_hr[:, None, None, None] # B => B, C, H, W
|
||||
flag_s2 = modality_s2[:, None, None, None]
|
||||
flag_s1 = modality_s1[:, None, None, None]
|
||||
|
||||
mean_logits_hr_s2 = logits_hr * flag_hr + logits_s2 * flag_s2
|
||||
mean_logits_hr_s1 = logits_hr * flag_hr + logits_s1 * flag_s1
|
||||
mean_logits_s1_s2 = logits_s1 * flag_s1 + logits_s2 * flag_s2
|
||||
|
||||
logits_hr_rec = logits_hr * flag_hr + mean_logits_s1_s2 * (~flag_hr)
|
||||
logits_s2_rec = logits_s2 * flag_s2 + mean_logits_hr_s1 * (~flag_s2)
|
||||
logits_s1_rec = logits_s1 * flag_s1 + mean_logits_hr_s2 * (~flag_s1)
|
||||
g_hr, soft_one_hot_hr = self.vae_hr.forward_decoder(logits_hr_rec)
|
||||
g_s2, soft_one_s2 = self.vae_s2.forward_decoder(logits_s2_rec)
|
||||
g_s1, soft_one_s1 = self.vae_s1.forward_decoder(logits_s1_rec)
|
||||
|
||||
hr_out = feat_hr * flag_hr + g_hr * (~flag_hr)
|
||||
s2_out = feat_s2 * flag_s2 + g_s2 * (~flag_s2)
|
||||
s1_out = feat_s1 * flag_s1 + g_s1 * (~flag_s1)
|
||||
|
||||
output = {}
|
||||
|
||||
output['hr_out'] = hr_out
|
||||
output['s2_out'] = s2_out
|
||||
output['s1_out'] = s1_out
|
||||
|
||||
output['modality_info'] = modality_info
|
||||
|
||||
output['input_hr'] = feat_hr
|
||||
output['input_s2'] = feat_s2
|
||||
output['input_s1'] = feat_s1
|
||||
|
||||
output['logits_hr'] = logits_hr
|
||||
output['logits_s2'] = logits_s2
|
||||
output['logits_s1'] = logits_s1
|
||||
|
||||
output['soft_one_hot_hr'] = soft_one_hot_hr
|
||||
output['soft_one_hot_s2'] = soft_one_s2
|
||||
output['soft_one_hot_s1'] = soft_one_s1
|
||||
|
||||
output['g_hr'] = g_hr
|
||||
output['g_s2'] = g_s2
|
||||
output['g_s1'] = g_s1
|
||||
output['loss_quant'] = self.kl_loss(logits_hr, logits_s2, logits_s1, modality_info)
|
||||
|
||||
return output
|
||||
|
||||
144
lib/models/necks/transformer_encoder.py
Normal file
144
lib/models/necks/transformer_encoder.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright (c) Ant Group. All rights reserved.
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmcv.runner import (CheckpointLoader, load_state_dict)
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from mmseg.models.backbones.vit import TransformerEncoderLayer
|
||||
|
||||
from mmseg.utils import get_root_logger
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_dims=768,
|
||||
embed_dims=768,
|
||||
num_layers=4,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
with_cls_token=True,
|
||||
output_cls_token=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
init_cfg=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super(TransformerEncoder, self).__init__()
|
||||
|
||||
self.porj_linear = nn.Linear(input_dims, embed_dims)
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
|
||||
self.init_cfg = init_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for i in range(num_layers):
|
||||
self.layers.append(
|
||||
TransformerEncoderLayer(embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=mlp_ratio *
|
||||
embed_dims,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
batch_first=True))
|
||||
|
||||
def init_weights(self):
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
logger = get_root_logger()
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
_state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
_state_dict = checkpoint
|
||||
|
||||
state_dict = OrderedDict()
|
||||
for k, v in _state_dict.items():
|
||||
if k.startswith('backbone.'):
|
||||
state_dict[k[9:]] = v
|
||||
else:
|
||||
state_dict[k] = v
|
||||
|
||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def forward(self, inputs, require_feat: bool = False, require_two: bool = False):
|
||||
inputs = self.porj_linear(inputs)
|
||||
B, N, C = inputs.shape
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, inputs), dim=1)
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
# add hidden and atten state
|
||||
block_outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if require_feat:
|
||||
block_outs.append(x)
|
||||
|
||||
if self.output_cls_token:
|
||||
if require_two:
|
||||
x = x[:, :2]
|
||||
else:
|
||||
x = x[:, 0]
|
||||
elif not self.output_cls_token and self.with_cls_token:
|
||||
x = x # [:, :]
|
||||
|
||||
if require_feat:
|
||||
return x, block_outs
|
||||
else:
|
||||
return x
|
||||
|
||||
def train(self, mode=True):
|
||||
super(TransformerEncoder, self).train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.LayerNorm):
|
||||
m.eval()
|
||||
4
lib/models/segmentors/__init__.py
Normal file
4
lib/models/segmentors/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Ant Financial Service Group and its affiliates.
|
||||
from .skysense_pp_pipeline import SkySensePP
|
||||
|
||||
__all__ = ['SkySensePP']
|
||||
458
lib/models/segmentors/skysense_pp_pipeline.py
Normal file
458
lib/models/segmentors/skysense_pp_pipeline.py
Normal file
@@ -0,0 +1,458 @@
|
||||
# coding: utf-8
|
||||
# Copyright (c) Ant Group. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
import math
|
||||
import random
|
||||
from antmmf.common.registry import registry
|
||||
from antmmf.models.base_model import BaseModel
|
||||
from lib.models.backbones import build_backbone
|
||||
from lib.models.necks import build_neck
|
||||
from lib.models.heads import build_head
|
||||
from lib.utils.utils import LayerDecayValueAssigner
|
||||
|
||||
|
||||
@registry.register_model("SkySensePP")
|
||||
class SkySensePP(BaseModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.sources = config.sources
|
||||
assert len(self.sources) > 0, 'at least one data source is required'
|
||||
if 's2' in self.sources:
|
||||
self.use_ctpe = config.use_ctpe
|
||||
self.use_modal_vae = config.use_modal_vae
|
||||
self.use_cls_token_uper_head = config.use_cls_token_uper_head
|
||||
self.target_mean=[0.485, 0.456, 0.406]
|
||||
self.target_std=[0.229, 0.224, 0.225]
|
||||
self.vocabulary_size = config.vocabulary_size
|
||||
self.vocabulary = list(range(1, config.vocabulary_size + 1)) # 0 for ignore
|
||||
|
||||
|
||||
def build(self):
|
||||
if 'hr' in self.sources:
|
||||
self.backbone_hr = self._build_backbone('hr')
|
||||
if 's2' in self.sources:
|
||||
self.backbone_s2 = self._build_backbone('s2')
|
||||
if self.use_ctpe:
|
||||
self.ctpe = nn.Parameter(
|
||||
torch.zeros(1, self.config.calendar_time,
|
||||
self.config.necks.input_dims))
|
||||
if 'head_s2' in self.config.keys():
|
||||
self.head_s2 = self._build_head('head_s2')
|
||||
self.fusion = self._build_neck('necks')
|
||||
if 's1' in self.sources:
|
||||
self.backbone_s1 = self._build_backbone('s1')
|
||||
if 'head_s1' in self.config.keys():
|
||||
self.head_s1 = self._build_head('head_s1')
|
||||
self.head_rec_hr = self._build_head('rec_head_hr')
|
||||
|
||||
self.with_aux_head = False
|
||||
|
||||
if self.use_modal_vae:
|
||||
self.modality_vae = self._build_neck('modality_vae')
|
||||
if 'auxiliary_head' in self.config.keys():
|
||||
self.with_aux_head = True
|
||||
self.aux_head = self._build_head('auxiliary_head')
|
||||
if 'init_cfg' in self.config.keys(
|
||||
) and self.config.init_cfg is not None and self.config.init_cfg.checkpoint is not None and self.config.init_cfg.key is not None:
|
||||
self.load_pretrained(self.config.init_cfg.checkpoint,
|
||||
self.config.init_cfg.key)
|
||||
|
||||
def _build_backbone(self, key):
|
||||
config_dict = self.config[f'backbone_{key}'].to_dict()
|
||||
backbone_type = config_dict.pop('type')
|
||||
backbone = build_backbone(backbone_type, **config_dict)
|
||||
backbone.init_weights()
|
||||
return backbone
|
||||
|
||||
def _build_neck(self, key):
|
||||
config_dict = self.config[key].to_dict()
|
||||
neck_type = config_dict.pop('type')
|
||||
neck = build_neck(neck_type, **config_dict)
|
||||
neck.init_weights()
|
||||
return neck
|
||||
|
||||
def _build_head(self, key):
|
||||
head_config = self.config[key].to_dict()
|
||||
head_type = head_config.pop('type')
|
||||
head = build_head(head_type, **head_config)
|
||||
return head
|
||||
|
||||
def get_optimizer_parameters(self, config):
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [],
|
||||
"lr": config.optimizer_attributes.params.lr,
|
||||
"weight_decay": config.optimizer_attributes.params.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [],
|
||||
"lr": config.optimizer_attributes.params.lr,
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
layer_decay_value_assigner_hr = LayerDecayValueAssigner(
|
||||
config.lr_parameters.layer_decay, None,
|
||||
config.optimizer_attributes.params.lr, 'swin',
|
||||
config.model_attributes.SkySensePP.backbone_hr.arch
|
||||
)
|
||||
layer_decay_value_assigner_s2 = LayerDecayValueAssigner(
|
||||
config.lr_parameters.layer_decay, 24,
|
||||
config.optimizer_attributes.params.lr, 'vit',
|
||||
)
|
||||
layer_decay_value_assigner_s1 = LayerDecayValueAssigner(
|
||||
config.lr_parameters.layer_decay, 24,
|
||||
config.optimizer_attributes.params.lr, 'vit',
|
||||
)
|
||||
layer_decay_value_assigner_fusion = LayerDecayValueAssigner(
|
||||
config.lr_parameters.layer_decay, 24,
|
||||
config.optimizer_attributes.params.lr, 'vit',
|
||||
)
|
||||
num_frozen_params = 0
|
||||
if 'hr' in self.sources:
|
||||
print('hr'.center(60, '-'))
|
||||
num_frozen_params += layer_decay_value_assigner_hr.fix_param(
|
||||
self.backbone_hr,
|
||||
config.lr_parameters.frozen_blocks,
|
||||
)
|
||||
optimizer_grouped_parameters.extend(
|
||||
layer_decay_value_assigner_hr.get_parameter_groups(
|
||||
self.backbone_hr, config.optimizer_attributes.params.weight_decay
|
||||
)
|
||||
)
|
||||
if 's2' in self.sources:
|
||||
print('s2'.center(60, '-'))
|
||||
num_frozen_params += layer_decay_value_assigner_s2.fix_param(
|
||||
self.backbone_s2,
|
||||
config.lr_parameters.frozen_blocks,
|
||||
)
|
||||
optimizer_grouped_parameters.extend(
|
||||
layer_decay_value_assigner_s2.get_parameter_groups(
|
||||
self.backbone_s2, config.optimizer_attributes.params.weight_decay
|
||||
)
|
||||
)
|
||||
no_decay = [".bn.", "bias"]
|
||||
optimizer_grouped_parameters[0]["params"] += [
|
||||
p for n, p in self.head_s2.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
]
|
||||
optimizer_grouped_parameters[1]["params"] += [
|
||||
p for n, p in self.head_s2.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
]
|
||||
if self.use_ctpe:
|
||||
optimizer_grouped_parameters[1]["params"] += [self.ctpe]
|
||||
|
||||
if 's1' in self.sources:
|
||||
print('s1'.center(60, '-'))
|
||||
num_frozen_params += layer_decay_value_assigner_s1.fix_param(
|
||||
self.backbone_s1,
|
||||
config.lr_parameters.frozen_blocks,
|
||||
)
|
||||
optimizer_grouped_parameters.extend(
|
||||
layer_decay_value_assigner_s1.get_parameter_groups(
|
||||
self.backbone_s1, config.optimizer_attributes.params.weight_decay
|
||||
)
|
||||
)
|
||||
no_decay = [".bn.", "bias"]
|
||||
optimizer_grouped_parameters[0]["params"] += [
|
||||
p for n, p in self.head_s1.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
]
|
||||
optimizer_grouped_parameters[1]["params"] += [
|
||||
p for n, p in self.head_s1.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
]
|
||||
|
||||
if len(self.sources) > 1:
|
||||
print('fusion'.center(60, '-'))
|
||||
num_frozen_params += layer_decay_value_assigner_fusion.fix_param_deeper(
|
||||
self.fusion,
|
||||
config.lr_parameters.frozen_fusion_blocks_start, # 冻结后面所有的stage
|
||||
)
|
||||
optimizer_grouped_parameters.extend(
|
||||
layer_decay_value_assigner_fusion.get_parameter_groups(
|
||||
self.fusion, config.optimizer_attributes.params.weight_decay
|
||||
)
|
||||
)
|
||||
|
||||
if self.use_modal_vae:
|
||||
no_decay = [".bn.", "bias"]
|
||||
optimizer_grouped_parameters[0]["params"] += [
|
||||
p for n, p in self.modality_vae.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
]
|
||||
optimizer_grouped_parameters[1]["params"] += [
|
||||
p for n, p in self.modality_vae.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
]
|
||||
|
||||
no_decay = [".bn.", "bias"]
|
||||
optimizer_grouped_parameters[0]["params"] += [
|
||||
p for n, p in self.head_rec_hr.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
]
|
||||
optimizer_grouped_parameters[1]["params"] += [
|
||||
p for n, p in self.head_rec_hr.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
]
|
||||
|
||||
if self.with_aux_head:
|
||||
no_decay = [".bn.", "bias"]
|
||||
optimizer_grouped_parameters[0]["params"] += [
|
||||
p for n, p in self.aux_head.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
]
|
||||
optimizer_grouped_parameters[1]["params"] += [
|
||||
p for n, p in self.aux_head.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
]
|
||||
num_params = [len(x['params']) for x in optimizer_grouped_parameters]
|
||||
print(len(list(self.parameters())), sum(num_params), num_frozen_params)
|
||||
assert len(list(self.parameters())) == sum(num_params) + num_frozen_params
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
def get_custom_scheduler(self, trainer):
|
||||
optimizer = trainer.optimizer
|
||||
num_training_steps = trainer.config.training_parameters.max_iterations
|
||||
num_warmup_steps = trainer.config.training_parameters.num_warmup_steps
|
||||
|
||||
if "train" in trainer.run_type:
|
||||
if num_training_steps == math.inf:
|
||||
epoches = trainer.config.training_parameters.max_epochs
|
||||
assert epoches != math.inf
|
||||
num_training_steps = trainer.config.training_parameters.max_epochs * trainer.epoch_iterations
|
||||
|
||||
def linear_with_wram_up(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(
|
||||
1, num_warmup_steps))
|
||||
return max(
|
||||
0.0,
|
||||
float(num_training_steps - current_step) /
|
||||
float(max(1, num_training_steps - num_warmup_steps)),
|
||||
)
|
||||
|
||||
def cos_with_wram_up(current_step):
|
||||
num_cycles = 0.5
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(
|
||||
1, num_warmup_steps))
|
||||
progress = float(current_step - num_warmup_steps) / float(
|
||||
max(1, num_training_steps - num_warmup_steps))
|
||||
return max(
|
||||
0.0,
|
||||
0.5 *
|
||||
(1.0 +
|
||||
math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
|
||||
)
|
||||
|
||||
lr_lambda = cos_with_wram_up if trainer.config.training_parameters.cos_lr else linear_with_wram_up
|
||||
|
||||
else:
|
||||
def lr_lambda(current_step):
|
||||
return 0.0 # noqa
|
||||
|
||||
return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
def convert_target(self, target):
|
||||
mean = target.new_tensor(self.target_mean).reshape(1, 3, 1, 1)
|
||||
std = target.new_tensor(self.target_std).reshape(1, 3, 1, 1)
|
||||
target = ((target * std + mean)*255).to(torch.long)
|
||||
target[:, 0] = target[:, 0] * 256 * 256
|
||||
target[:, 1] = target[:, 1] * 256
|
||||
target = target.sum(1).type(torch.long)
|
||||
unique_target = target.unique()
|
||||
target_index = torch.searchsorted(unique_target, target)
|
||||
no_bg = False
|
||||
if unique_target[0].item() > 0:
|
||||
target_index += 1
|
||||
no_bg = True
|
||||
target_index_unique = target_index.unique().tolist()
|
||||
random.shuffle(self.vocabulary)
|
||||
value = target.new_tensor([0] + self.vocabulary)
|
||||
mapped_target = target_index.clone()
|
||||
idx_2_color = {}
|
||||
for v in target_index_unique:
|
||||
mapped_target[target_index == v] = value[v]
|
||||
idx_2_color[value[v].item()] = unique_target[v - 1 if no_bg else v].item()
|
||||
return mapped_target, idx_2_color
|
||||
|
||||
def forward(self, sample_list):
|
||||
output = dict()
|
||||
modality_flag_hr = sample_list["modality_flag_hr"]
|
||||
modality_flag_s2 = sample_list["modality_flag_s2"]
|
||||
modality_flag_s1 = sample_list["modality_flag_s1"]
|
||||
modalities = [modality_flag_hr, modality_flag_s2, modality_flag_s1]
|
||||
modalities = torch.tensor(modalities).permute(1,0).contiguous() # L, B => B, L
|
||||
|
||||
anno_img = sample_list["targets"]
|
||||
anno_img, idx_2_color = self.convert_target(anno_img)
|
||||
output["mapped_targets"] = anno_img
|
||||
output["idx_2_color"] = idx_2_color
|
||||
anno_mask = sample_list["anno_mask"]
|
||||
anno_s2 = anno_img[:, 15::32, 15::32]
|
||||
anno_s1 = anno_s2
|
||||
|
||||
output["anno_hr"] = anno_img
|
||||
output["anno_s2"] = anno_s2
|
||||
|
||||
### 1. backbone
|
||||
if 'hr' in self.sources:
|
||||
hr_img = sample_list["hr_img"]
|
||||
B_MASK, H_MASK, W_MASK = anno_mask.shape
|
||||
block_size = 32
|
||||
anno_mask_hr = anno_mask.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, block_size, block_size)
|
||||
anno_mask_hr = anno_mask_hr.permute(0, 1, 3, 2, 4).reshape(B_MASK, H_MASK*block_size, W_MASK*block_size).contiguous()
|
||||
B, C_G, H_G, W_G = hr_img.shape
|
||||
hr_features = self.backbone_hr(hr_img, anno_img, anno_mask_hr)
|
||||
output['mask_hr'] = anno_mask_hr
|
||||
output['target_hr'] = anno_img
|
||||
|
||||
if 's2' in self.sources:
|
||||
s2_img = sample_list["s2_img"]
|
||||
B, C_S2, S_S2, H_S2, W_S2 = s2_img.shape
|
||||
s2_img = s2_img.permute(0, 2, 1, 3,
|
||||
4).reshape(B * S_S2, C_S2, H_S2, W_S2).contiguous() # ts time to batch
|
||||
anno_mask_s2 = anno_mask
|
||||
s2_features = self.backbone_s2(s2_img, anno_s2, anno_mask_s2)
|
||||
if 'head_s2' in self.config.keys():
|
||||
s2_features = self.head_s2(s2_features[-1])
|
||||
s2_features = [s2_features]
|
||||
|
||||
if 's1' in self.sources:
|
||||
s1_img = sample_list["s1_img"]
|
||||
B, C_S1, S_S1, H_S1, W_S1 = s1_img.shape
|
||||
s1_img = s1_img.permute(0, 2, 1, 3,
|
||||
4).reshape(B * S_S1, C_S1, H_S1, W_S1).contiguous()
|
||||
|
||||
anno_mask_s1 = anno_mask
|
||||
s1_features = self.backbone_s1(s1_img, anno_s1, anno_mask_s1)
|
||||
if 'head_s1' in self.config.keys():
|
||||
s1_features = self.head_s1(s1_features[-1])
|
||||
s1_features = [s1_features]
|
||||
|
||||
### 2. prepare features for fusion
|
||||
hr_features_stage3 = hr_features[-1]
|
||||
s2_features_stage3 = s2_features[-1]
|
||||
s1_features_stage3 = s1_features[-1]
|
||||
modalities = modalities.to(hr_features_stage3.device)
|
||||
if self.use_modal_vae:
|
||||
vae_out = self.modality_vae(hr_features_stage3, s2_features_stage3, s1_features_stage3, modalities)
|
||||
hr_features_stage3 = vae_out['hr_out']
|
||||
s2_features_stage3 = vae_out['s2_out']
|
||||
s1_features_stage3 = vae_out['s1_out']
|
||||
output['vae_out'] = vae_out
|
||||
|
||||
features_stage3 = []
|
||||
if 'hr' in self.sources:
|
||||
B, C3_G, H3_G, W3_G = hr_features_stage3.shape
|
||||
hr_features_stage3 = hr_features_stage3.permute(
|
||||
0, 2, 3, 1).reshape(B * H3_G * W3_G, C3_G).unsqueeze(1).contiguous() # B * H3_G * W3_G, 1, C3_G
|
||||
features_stage3 = hr_features_stage3
|
||||
|
||||
if 's2' in self.sources:
|
||||
# s2_features_stage3 = s2_features[-1]
|
||||
_, C3_S2, H3_S2, W3_S2 = s2_features_stage3.shape
|
||||
s2_features_stage3 = s2_features_stage3.reshape(
|
||||
B, S_S2, C3_S2, H3_S2,
|
||||
W3_S2).permute(0, 3, 4, 1, 2).reshape(B, H3_S2 * W3_S2, S_S2,
|
||||
C3_S2).contiguous()
|
||||
if self.use_ctpe:
|
||||
ct_index = sample_list["s2_ct"]
|
||||
ctpe = self.ctpe[:, ct_index, :].contiguous().permute(1, 0, 2, 3).contiguous()
|
||||
ctpe = ctpe.expand(-1, 256, -1, -1)
|
||||
|
||||
ct_index_2 = sample_list["s2_ct2"]
|
||||
ctpe2 = self.ctpe[:, ct_index_2, :].contiguous().permute(1, 0, 2, 3).contiguous()
|
||||
ctpe2 = ctpe2.expand(-1, 256, -1, -1)
|
||||
|
||||
ctpe_comb = torch.cat([ctpe, ctpe2], 1)
|
||||
# import pdb;pdb.set_trace()
|
||||
s2_features_stage3 = (s2_features_stage3 + ctpe_comb).reshape(
|
||||
B * H3_S2 * W3_S2, S_S2, C3_S2).contiguous()
|
||||
else:
|
||||
s2_features_stage3 = s2_features_stage3.reshape(
|
||||
B * H3_S2 * W3_S2, S_S2, C3_S2).contiguous()
|
||||
|
||||
if len(features_stage3) > 0:
|
||||
assert H3_G == H3_S2 and W3_G == W3_S2 and C3_G == C3_S2
|
||||
features_stage3 = torch.cat((features_stage3, s2_features_stage3), dim=1)
|
||||
else:
|
||||
features_stage3 = s2_features_stage3
|
||||
|
||||
if 's1' in self.sources:
|
||||
# s1_features_stage3 = s1_features[-1]
|
||||
_, C3_S1, H3_S1, W3_S1 = s1_features_stage3.shape
|
||||
s1_features_stage3 = s1_features_stage3.reshape(
|
||||
B, S_S1, C3_S1, H3_S1,
|
||||
W3_S1).permute(0, 3, 4, 1, 2).reshape(B, H3_S1 * W3_S1, S_S1,
|
||||
C3_S1).contiguous()
|
||||
s1_features_stage3 = s1_features_stage3.reshape(
|
||||
B * H3_S1 * W3_S1, S_S1, C3_S1).contiguous()
|
||||
|
||||
if len(features_stage3) > 0:
|
||||
assert H3_S1 == H3_S2 and W3_S1 == W3_S2 and C3_S1 == C3_S2
|
||||
features_stage3 = torch.cat((features_stage3, s1_features_stage3),
|
||||
dim=1)
|
||||
else:
|
||||
features_stage3 = s1_features_stage3
|
||||
|
||||
### 3. fusion
|
||||
if self.config.necks.output_cls_token:
|
||||
if self.config.necks.get('require_feat', False):
|
||||
cls_token, block_outs = self.fusion(features_stage3 , True)
|
||||
else:
|
||||
cls_token = self.fusion(features_stage3)
|
||||
_, C3_G = cls_token.shape
|
||||
cls_token = cls_token.reshape(B, H3_G, W3_G,
|
||||
C3_G).contiguous().permute(0, 3, 1, 2).contiguous() # b, c, h, w
|
||||
else:
|
||||
assert self.config.necks.with_cls_token is False
|
||||
if self.config.necks.get('require_feat', False):
|
||||
features_stage3, block_outs = self.fusion(features_stage3, True)
|
||||
else:
|
||||
features_stage3 = self.fusion(features_stage3)
|
||||
features_stage3 = features_stage3.reshape(
|
||||
B, H3_S2, W3_S2, S_S2,
|
||||
C3_S2).permute(0, 3, 4, 1, 2).reshape(B * S_S2, C3_S2, H3_S2,
|
||||
W3_S2).contiguous()
|
||||
### 4. decoder for rec
|
||||
hr_rec_inputs = hr_features
|
||||
feat_stage1 = hr_rec_inputs[0]
|
||||
|
||||
if feat_stage1.shape[-1] == feat_stage1.shape[-2]:
|
||||
feat_stage1_left, feat_stage1_right = torch.split(feat_stage1, feat_stage1.shape[-1] // 2, dim=-1)
|
||||
feat_stage1 = torch.cat((feat_stage1_left, feat_stage1_right), dim=1)
|
||||
hr_rec_inputs = list(hr_features)
|
||||
hr_rec_inputs[0] = feat_stage1
|
||||
|
||||
rec_feats = [*hr_rec_inputs, cls_token]
|
||||
logits_hr = self.head_rec_hr(rec_feats)
|
||||
if self.config.get('upsacle_results', True):
|
||||
logits_hr = logits_hr.to(torch.float32)
|
||||
logits_hr = F.interpolate(logits_hr, scale_factor=4, mode='bilinear', align_corners=True)
|
||||
output["logits_hr"] = logits_hr
|
||||
return output
|
||||
|
||||
def load_pretrained(self, ckpt_path, key):
|
||||
pretrained_dict = torch.load(ckpt_path, map_location={'cuda:0': 'cpu'})
|
||||
pretrained_dict = pretrained_dict[key]
|
||||
for k, v in pretrained_dict.items():
|
||||
if k == 'backbone_s2.patch_embed.projection.weight':
|
||||
pretrained_in_channels = v.shape[1]
|
||||
if self.config.backbone_s2.in_channels == 4:
|
||||
new_weight = v[:, [0, 1, 2, 6]]
|
||||
new_weight = new_weight * (
|
||||
pretrained_in_channels /
|
||||
self.config.backbone_s2.in_channels)
|
||||
pretrained_dict[k] = new_weight
|
||||
missing_keys, unexpected_keys = self.load_state_dict(pretrained_dict,
|
||||
strict=False)
|
||||
print('missing_keys:', missing_keys)
|
||||
print('unexpected_keys:', unexpected_keys)
|
||||
|
||||
244
lib/predictors/flood3i_1shot.py
Normal file
244
lib/predictors/flood3i_1shot.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
import yaml
|
||||
import argparse
|
||||
import oss2
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import concurrent.futures
|
||||
from torchvision.transforms import functional as F
|
||||
import random
|
||||
|
||||
|
||||
from antmmf.common.registry import registry
|
||||
from antmmf.common.report import Report, default_result_formater
|
||||
from antmmf.structures import Sample, SampleList
|
||||
from antmmf.predictors.base_predictor import BasePredictor
|
||||
from antmmf.utils.timer import Timer
|
||||
from antmmf.predictors.build import build_predictor
|
||||
from antmmf.common.task_loader import build_collate_fn
|
||||
from antmmf.datasets.samplers import SequentialSampler
|
||||
from antmmf.common.build import build_config
|
||||
|
||||
from lib.utils.checkpoint import SegCheckpoint
|
||||
from lib.datasets.loader.few_shot_flood3i_loader import FewShotFloodLoader
|
||||
|
||||
|
||||
def seed_everything(seed=0):
|
||||
# 为了确保CUDA卷积的确定性
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
|
||||
@registry.register_predictor("OneshotPredictor")
|
||||
class OneshotPredictor(BasePredictor, FewShotFloodLoader):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.predictor_parameters = self.config.predictor_parameters
|
||||
|
||||
def _predict(self, sample_list):
|
||||
# with torch.no_grad():
|
||||
if True:
|
||||
sample_list = sample_list.to(self.device)
|
||||
report = self._forward_pass(sample_list)
|
||||
return report, sample_list
|
||||
|
||||
def _forward_pass(self, samplelist):
|
||||
autocast_dtype = torch.bfloat16
|
||||
with torch.cuda.amp.autocast(enabled=True, dtype=autocast_dtype):
|
||||
model_output = self.model(samplelist)
|
||||
report = Report(samplelist, model_output)
|
||||
return report
|
||||
|
||||
def predict(self, data=None):
|
||||
if data is None:
|
||||
data = self.dummy_request()
|
||||
sample = self._build_sample(data)
|
||||
if not isinstance(sample, Sample):
|
||||
raise Exception(
|
||||
f"Method _build_sample is expected to return a instance of antmmf.structures.sample.Sample,"
|
||||
f"but got type {type(sample)} instead.")
|
||||
result, sample_list = self._predict(SampleList([sample]))
|
||||
np_result = default_result_formater(result)
|
||||
result = self.format_result(np_result)
|
||||
assert isinstance(
|
||||
result, dict
|
||||
), f"Result should be instance of Dict,but got f{type(result)} instead"
|
||||
return result, sample_list
|
||||
|
||||
def load_checkpoint(self):
|
||||
self.resume_file = self.config.predictor_parameters.model_dir
|
||||
self.checkpoint = SegCheckpoint(self, load_only=True)
|
||||
self.checkpoint.load_model_weights(self.resume_file, force=True)
|
||||
|
||||
def covert_speedup_op(self):
|
||||
if self.config.predictor_parameters.replace_speedup_op:
|
||||
from lib.utils.optim_utils import replace_speedup_op
|
||||
self.model = replace_speedup_op(self.model)
|
||||
|
||||
def save_image(output_path, image_np_array):
|
||||
image = Image.fromarray(image_np_array)
|
||||
image.save(output_path)
|
||||
|
||||
def build_predictor_from_args(args, *rest, **kwargs):
|
||||
config = build_config(
|
||||
args.config,
|
||||
config_override=args.config_override,
|
||||
opts_override=args.opts,
|
||||
specific_override=args,
|
||||
)
|
||||
predictor_obj = build_predictor(config)
|
||||
setattr(predictor_obj, "args", args)
|
||||
return predictor_obj
|
||||
|
||||
|
||||
def build_online_predictor(model_dir=None, config_yaml=None):
|
||||
assert model_dir or config_yaml
|
||||
from antmmf.utils.flags import flags
|
||||
|
||||
# if config_yaml not indicated, there must be a `config.yaml` file under `model_dir`
|
||||
config_path = config_yaml if config_yaml else os.path.join(model_dir, "config.yaml")
|
||||
input_args = ["--config", config_path]
|
||||
if model_dir is not None:
|
||||
input_args += ["predictor_parameters.model_dir", model_dir]
|
||||
parser = flags.get_parser()
|
||||
args = parser.parse_args(input_args)
|
||||
predictor = build_predictor_from_args(args)
|
||||
return predictor
|
||||
|
||||
def profile(profiler, text):
|
||||
print(f'{text}: {profiler.get_time_since_start()}')
|
||||
profiler.reset()
|
||||
|
||||
def cvt_colors(img_2d, idx_2_color_rgb):
|
||||
img_rgb = np.zeros((img_2d.shape[0], img_2d.shape[1], 3), dtype=np.uint8)
|
||||
for idx, color in idx_2_color_rgb.items():
|
||||
img_rgb[img_2d==idx] = color
|
||||
return img_rgb
|
||||
|
||||
def process_results(preds, targets, input_imgs, img_names, save_dir, save_dir_vis, idx_2_color):
|
||||
imagenet_std = np.array([0.229, 0.224, 0.225])
|
||||
imagenet_mean = np.array([0.485, 0.456, 0.406])
|
||||
idx_2_color_rgb = {}
|
||||
|
||||
for idx, color in idx_2_color.items():
|
||||
r = color // (256 * 256)
|
||||
g = (color % (256 * 256)) // 256
|
||||
b = color % 256
|
||||
idx_2_color_rgb[idx] = (r, g, b)
|
||||
|
||||
for i in range(preds.size(0)):
|
||||
output1 = preds[i].argmax(0) # h, w
|
||||
output1_total = output1.clone()
|
||||
output1 = output1[output1.shape[0]//2:, :]
|
||||
output1 = output1.numpy().astype(np.uint8)
|
||||
output1 = cvt_colors(output1, idx_2_color_rgb)
|
||||
|
||||
output1_total = output1_total.numpy().astype(np.uint8)
|
||||
output1_total = cvt_colors(output1_total, idx_2_color_rgb)
|
||||
|
||||
# for visualization
|
||||
output2 = targets[i]
|
||||
output2 = output2.numpy().astype(np.uint8)
|
||||
output2 = cvt_colors(output2, idx_2_color_rgb)
|
||||
|
||||
input_img = torch.einsum('chw->hwc', input_imgs[i])
|
||||
input_img = torch.clip((input_img * imagenet_std + imagenet_mean) * 255, 0, 255)
|
||||
input_img = input_img.numpy().astype(np.uint8)
|
||||
|
||||
output_comb = np.concatenate((input_img, output1_total, output2), axis=1)
|
||||
|
||||
# save result
|
||||
save_path = os.path.join(save_dir, f'{img_names[i]}.png')
|
||||
save_image(save_path, output1)
|
||||
save_path_vis = os.path.join(save_dir_vis, f'{img_names[i]}.png')
|
||||
save_image(save_path_vis, output_comb)
|
||||
|
||||
def test(args):
|
||||
model_path = args.model_path
|
||||
config_path = args.config
|
||||
global_seed = args.seed
|
||||
predictor = build_online_predictor(model_path, config_path)
|
||||
seed_everything(global_seed)
|
||||
|
||||
dataset = FewShotFloodLoader(
|
||||
"test", predictor.config.task_attributes.segmentation.dataset_attributes.few_shot_flood_segmentation)
|
||||
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=8,
|
||||
shuffle=False,
|
||||
sampler=SequentialSampler(dataset),
|
||||
collate_fn=build_collate_fn(dataset),
|
||||
num_workers=16,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
print(len(loader))
|
||||
|
||||
predictor.load(with_ckpt=True)
|
||||
|
||||
predictor.covert_speedup_op()
|
||||
|
||||
save_dir = args.save_dir
|
||||
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
save_dir_vis = os.path.join(args.save_dir, 'vis_full')
|
||||
if not os.path.exists(save_dir_vis):
|
||||
os.makedirs(save_dir_vis)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
|
||||
profiler2 = Timer()
|
||||
profiler2.reset()
|
||||
for sample_batched in tqdm(loader):
|
||||
profile(profiler2, "Build sample time")
|
||||
result, sample_list = predictor._predict(SampleList(sample_batched))
|
||||
profile(profiler2, "Infer time")
|
||||
preds = result["logits_hr"].to(torch.float32).detach().cpu()
|
||||
targets = result['mapped_targets'].to(torch.float32).detach().cpu()
|
||||
idx_2_color = result['idx_2_color']
|
||||
input_imgs = sample_list['hr_img'].to(torch.float32).detach().cpu()
|
||||
img_names = sample_list["img_name"]
|
||||
|
||||
executor.submit(process_results, preds, targets, input_imgs, img_names, save_dir, save_dir_vis, idx_2_color)
|
||||
profile(profiler2, "Save results time")
|
||||
try:
|
||||
del predictor.model
|
||||
except Exception as e:
|
||||
print('delete model error: ', e)
|
||||
|
||||
def parse_args():
|
||||
desc = '1-shot predictor'
|
||||
parser = argparse.ArgumentParser(description=desc)
|
||||
parser.add_argument('--model_path',
|
||||
required=True,
|
||||
type=str,
|
||||
help='model directory')
|
||||
parser.add_argument('--seed',
|
||||
default=0,
|
||||
type=int,
|
||||
help='seed')
|
||||
parser.add_argument('--config',
|
||||
required=True,
|
||||
type=str,
|
||||
help='config path')
|
||||
parser.add_argument('--save_dir',
|
||||
required=False,
|
||||
type=str,
|
||||
help='save directory')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
test(args)
|
||||
3
lib/task/__init__.py
Normal file
3
lib/task/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .segmentation import SegmentationTask
|
||||
|
||||
__all__ = ['SegmentationTask']
|
||||
18
lib/task/segmentation.py
Normal file
18
lib/task/segmentation.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# coding: utf-8
|
||||
# Copyright (c) Ant Group. All rights reserved.
|
||||
|
||||
from antmmf.common.registry import registry
|
||||
from antmmf.tasks import BaseTask
|
||||
|
||||
|
||||
@registry.register_task("segmentation")
|
||||
class SegmentationTask(BaseTask):
|
||||
|
||||
def __init__(self):
|
||||
super(SegmentationTask, self).__init__("segmentation")
|
||||
|
||||
def _get_available_datasets(self):
|
||||
return ["pretraining_loader"]
|
||||
|
||||
def _preprocess_item(self, item):
|
||||
return item
|
||||
4
lib/trainer/__init__.py
Normal file
4
lib/trainer/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .seg_trainer import SEGTrainer
|
||||
|
||||
__all__ = ['SEGTrainer']
|
||||
|
||||
399
lib/trainer/seg_trainer.py
Normal file
399
lib/trainer/seg_trainer.py
Normal file
@@ -0,0 +1,399 @@
|
||||
# Copyright (c) Ant Group. and its affiliates.
|
||||
import gc
|
||||
import math
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from antmmf.common.registry import registry
|
||||
from antmmf.common.report import Report
|
||||
from antmmf.common.meter import Meter
|
||||
from antmmf.modules.metrics import Metrics
|
||||
from antmmf.optimizer.combine_optimizers import CombinedOptimizer
|
||||
from antmmf.utils.distributed_utils import (broadcast_scalar, is_main_process)
|
||||
from antmmf.utils.early_stopping import EarlyStopping
|
||||
from antmmf.utils.general import clip_gradients, count_parameters, nullcontext
|
||||
from antmmf.utils.timer import Timer
|
||||
from antmmf.trainers.base_trainer import BaseTrainer
|
||||
|
||||
from lib.utils.utils import cancel_gradients_backbone, EMA
|
||||
from lib.utils.checkpoint import SegCheckpoint
|
||||
|
||||
try:
|
||||
import atorch
|
||||
from atorch import amp
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@registry.register_trainer("seg_trainer")
|
||||
class SEGTrainer(BaseTrainer):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.enable_torch_amp=True
|
||||
self.enable_atorch_amp=False
|
||||
|
||||
def load(self, has_check_point=True):
|
||||
super().load(has_check_point)
|
||||
torch.backends.cuda.matmul.allow_tf32 = self.config.training_parameters.get(
|
||||
"enable_tf32", False)
|
||||
if hasattr(
|
||||
self.config.training_parameters, "freeze_backbone"
|
||||
) and self.config.training_parameters.freeze_backbone is True:
|
||||
for n, p in self.model.named_parameters():
|
||||
if "backbone_hr." in n or 'backbone_s2.' in n or 'head_s2.' in n or 'backbone_s1.' in n or 'head_s1.' in n or 'fusion.' in n or 'ctpe' in n or 'glbank.' in n:
|
||||
p.requires_grad = False
|
||||
else:
|
||||
print(n, '-->', p.requires_grad)
|
||||
if hasattr(self.config.training_parameters,
|
||||
"ema") and self.config.training_parameters.ema is True:
|
||||
self.ema = EMA(self.model, 0.96)
|
||||
self.ema.register()
|
||||
|
||||
def load_extras(self, has_check_point=True):
|
||||
self.checkpoint = None if has_check_point is False else SegCheckpoint(
|
||||
self)
|
||||
self.meter = Meter()
|
||||
|
||||
self.training_parameters = self.config.training_parameters
|
||||
|
||||
monitored_metric = self.training_parameters.monitored_metric
|
||||
metric_minimize = self.training_parameters.metric_minimize
|
||||
should_early_stop = self.training_parameters.should_early_stop
|
||||
patience = self.training_parameters.patience
|
||||
|
||||
self.log_interval = self.training_parameters.log_interval
|
||||
self.snapshot_interval = self.training_parameters.snapshot_interval
|
||||
self.max_iterations = self.training_parameters.max_iterations
|
||||
self.should_clip_gradients = self.training_parameters.clip_gradients
|
||||
self.max_epochs = self.training_parameters.max_epochs
|
||||
self.gradient_accumulation_steps = int(
|
||||
self.training_parameters.gradient_accumulation_steps)
|
||||
assert self.gradient_accumulation_steps >= 1
|
||||
for t_type in self.task_loader.task_type:
|
||||
if t_type == "train":
|
||||
self.dataset_train_order = self.training_parameters.get(
|
||||
"dataset_train_order", self.train_task.datasets_name)
|
||||
if t_type == "val":
|
||||
self.dataset_val_order = self.training_parameters.get(
|
||||
"dataset_val_order", self.val_task.datasets_name)
|
||||
if t_type == "test":
|
||||
self.dataset_test_order = self.training_parameters.get(
|
||||
"dataset_test_order", self.test_task.datasets_name)
|
||||
if t_type == "interpret":
|
||||
self.dataset_interpret_order = self.training_parameters.get(
|
||||
"dataset_interpret_order",
|
||||
self.interpret_task.datasets_name)
|
||||
|
||||
self.early_stopping = EarlyStopping(
|
||||
self.model,
|
||||
self.checkpoint,
|
||||
monitored_metric,
|
||||
patience=patience,
|
||||
minimize=metric_minimize,
|
||||
should_stop=should_early_stop,
|
||||
)
|
||||
self.current_epoch = 1
|
||||
self.current_iteration = 0
|
||||
|
||||
self.not_debug = self.training_parameters.logger_level != "debug"
|
||||
|
||||
self.lr_scheduler = None
|
||||
self.setup_lr_scheduler()
|
||||
|
||||
if self.checkpoint is not None:
|
||||
self.checkpoint.load_state_dict()
|
||||
|
||||
if "overall_metrics" in self.training_parameters:
|
||||
self.overall_metric_evaluator = Metrics(
|
||||
self.config.training_parameters.get("overall_metrics", []))
|
||||
self.synchronized_loss = self.config.training_parameters.synchronized_loss
|
||||
|
||||
def train(self):
|
||||
self.writer.write("===== Model =====")
|
||||
self.writer.write(self.model)
|
||||
self.writer.write(
|
||||
"Model Params: Trainable {Trainable:.3f}M Total {Total:.3f}M".
|
||||
format(**count_parameters(self.model)))
|
||||
|
||||
if "train" not in self.run_type:
|
||||
self.inference()
|
||||
return
|
||||
|
||||
should_break = False
|
||||
|
||||
if self.max_epochs is None:
|
||||
self.max_epochs = math.inf
|
||||
else:
|
||||
self.max_iterations = min(self.max_iterations,
|
||||
self.max_epochs * self.epoch_iterations)
|
||||
|
||||
self.model.train()
|
||||
self.train_timer = Timer()
|
||||
|
||||
self.profile("Setup Time")
|
||||
|
||||
if self.enable_torch_amp:
|
||||
self.writer.write("Using Automatic mixed precision training")
|
||||
if hasattr(self.config, "amp_attributes") and hasattr(
|
||||
self.config.amp_attributes, "growth_interval"):
|
||||
growth_interval = self.config.amp_attributes.growth_interval
|
||||
else:
|
||||
growth_interval = 2000
|
||||
self.scaler = torch.cuda.amp.GradScaler(
|
||||
init_scale=self.config.amp_attributes.init_scale,
|
||||
enabled=False,
|
||||
growth_interval=growth_interval)
|
||||
self.writer.write("Using Init scale:%s" %
|
||||
self.config.amp_attributes.init_scale)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
self.writer.write("Starting training...")
|
||||
while self.current_iteration < self.max_iterations and not should_break:
|
||||
registry.register("current_epoch", self.current_epoch)
|
||||
self.task_loader.seed_sampler("train", self.current_epoch)
|
||||
|
||||
if self.current_epoch > self.max_epochs:
|
||||
break
|
||||
|
||||
for batch in tqdm(
|
||||
chain(*self.train_loader_list),
|
||||
total=self._len_of_loader_list(self.train_loader_list),
|
||||
disable=self.disable_tqdm or (not is_main_process()),
|
||||
):
|
||||
self.profile("Batch load time")
|
||||
report, _, _ = self._forward_pass(
|
||||
batch, enable_amp=self.enable_torch_amp)
|
||||
if report is None:
|
||||
continue
|
||||
|
||||
self._update_meter(report, self.meter)
|
||||
|
||||
loss = self._extract_loss(report)
|
||||
self._backward(loss)
|
||||
if hasattr(
|
||||
self.config.training_parameters,
|
||||
"ema") and self.config.training_parameters.ema is True:
|
||||
self.ema.update()
|
||||
should_break = self._logistics()
|
||||
|
||||
self._run_scheduler()
|
||||
|
||||
self.current_iteration += 1
|
||||
self.writer.write(self.current_iteration, "debug")
|
||||
registry.register("current_iteration", self.current_iteration)
|
||||
if self.current_iteration >= self.max_iterations:
|
||||
break
|
||||
if should_break:
|
||||
break
|
||||
|
||||
self.current_epoch += 1
|
||||
|
||||
self.finalize()
|
||||
|
||||
def _forward_pass(self, batch, enable_amp=False):
|
||||
if not batch: # Samplelist might be empty dict
|
||||
return None, None, None
|
||||
prepared_batch = self.task_loader.prepare_batch(batch)
|
||||
|
||||
self.profile("Batch prepare time")
|
||||
forward_context = torch.cuda.amp.autocast(
|
||||
enabled=True,
|
||||
dtype=torch.bfloat16) if enable_amp else nullcontext()
|
||||
|
||||
with forward_context:
|
||||
# Arguments should be a dict at this point
|
||||
model_output = self.model(prepared_batch)
|
||||
|
||||
if self.synchronized_loss:
|
||||
is_parallel = isinstance(
|
||||
self.model, nn.DataParallel) or isinstance(
|
||||
self.model, nn.parallel.DistributedDataParallel)
|
||||
if "losses" not in model_output:
|
||||
loss_func = getattr(
|
||||
self.model.module if is_parallel else self.model,
|
||||
"losses")
|
||||
model_output["losses"] = loss_func(
|
||||
prepared_batch,
|
||||
model_output,
|
||||
iteration=self.current_iteration)
|
||||
if "metrics" not in model_output:
|
||||
metric_func = getattr(
|
||||
self.model.module if is_parallel else self.model,
|
||||
"metrics")
|
||||
model_output["metrics"] = metric_func(
|
||||
prepared_batch, model_output)
|
||||
|
||||
report = Report(prepared_batch, model_output)
|
||||
self.profile("Forward time")
|
||||
|
||||
return report, model_output, prepared_batch
|
||||
|
||||
def _backward(self, loss):
|
||||
loss = loss / self.gradient_accumulation_steps
|
||||
|
||||
if self.enable_torch_amp:
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
# Unscales the gradients of optimizer's assigned params in-place, this should
|
||||
# be called first so that clip_gradients can take effect as usual.
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
elif self.enable_atorch_amp:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
self.profile("Backward time")
|
||||
|
||||
if self.current_iteration % self.gradient_accumulation_steps != 0:
|
||||
return
|
||||
|
||||
if self.should_clip_gradients:
|
||||
if self.enable_atorch_amp:
|
||||
clip_gradients(amp.master_params(self.optimizer),
|
||||
self.current_iteration, self.writer,
|
||||
self.config)
|
||||
else:
|
||||
clip_gradients(self.model, self.current_iteration, self.writer,
|
||||
self.config)
|
||||
|
||||
if hasattr(
|
||||
self.config.training_parameters, "freeze_backbone_steps"
|
||||
) and self.config.training_parameters.freeze_backbone_steps is not None:
|
||||
cancel_gradients_backbone(
|
||||
self.current_iteration, self.model,
|
||||
self.config.training_parameters.freeze_backbone_steps)
|
||||
|
||||
if self.enable_torch_amp:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
self.profile("Optimizer time")
|
||||
|
||||
def _logistics(self):
|
||||
should_print = self.current_iteration and self.current_iteration % self.log_interval == 0
|
||||
extra = {}
|
||||
prefix = ""
|
||||
|
||||
if should_print is True:
|
||||
if "cuda" in str(self.device):
|
||||
extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
|
||||
extra["max mem"] //= 1024
|
||||
|
||||
# display lr
|
||||
if isinstance(self.optimizer, CombinedOptimizer):
|
||||
extra["lr"] = self.optimizer.get_optimizers_lr_str()
|
||||
else:
|
||||
extra["lr"] = "|".join([
|
||||
"{:.8f}".format(x["lr"]).rstrip("0")
|
||||
for x in self.optimizer.param_groups
|
||||
])
|
||||
|
||||
extra.update({
|
||||
"time": self.train_timer.get_time_since_start(),
|
||||
"eta": self._calculate_time_left(),
|
||||
})
|
||||
|
||||
self.train_timer.reset()
|
||||
|
||||
self._summarize_meter(
|
||||
self.meter,
|
||||
prefix=prefix,
|
||||
extra=extra,
|
||||
should_print=should_print,
|
||||
)
|
||||
|
||||
should_break = self._try_full_validation()
|
||||
|
||||
return should_break
|
||||
|
||||
def _try_full_validation(self, force=False):
|
||||
should_break = False
|
||||
|
||||
if self.current_iteration and self.current_iteration % self.snapshot_interval == 0 or force:
|
||||
self.writer.write(
|
||||
"Evaluation time. Running on full validation set...")
|
||||
|
||||
validation_timer = Timer()
|
||||
dataset_name, meter = self.evaluate_set(self.val_loader_list)
|
||||
extra = {
|
||||
"validation time": validation_timer.get_time_since_start()
|
||||
}
|
||||
|
||||
overall_metric = self.overall_metric_evaluator.summarize()
|
||||
stop = self.early_stopping(self.current_iteration, overall_metric,
|
||||
meter)
|
||||
if hasattr(self.config.training_parameters,
|
||||
"ema") and self.config.training_parameters.ema is True:
|
||||
self.ema.restore()
|
||||
stop = bool(broadcast_scalar(stop, src=0, device=self.device))
|
||||
|
||||
extra.update(self.early_stopping.get_info())
|
||||
|
||||
prefix = "{}: full val".format(dataset_name)
|
||||
self._summarize_overall(overall_metric,
|
||||
meter,
|
||||
prefix=prefix,
|
||||
extra=extra)
|
||||
gc.collect()
|
||||
|
||||
if "cuda" in str(self.device):
|
||||
with torch.cuda.device(self.device):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if stop > 0: # `stop` is now `int`, NCCL does not support `boolean` type's broadcasting
|
||||
self.writer.write("Early stopping activated")
|
||||
should_break = True
|
||||
|
||||
return should_break
|
||||
|
||||
def evaluate_set(self, loader_list):
|
||||
from antmmf.structures import SampleList
|
||||
|
||||
meter = Meter()
|
||||
torch.cuda.empty_cache()
|
||||
with torch.no_grad():
|
||||
self.model.eval()
|
||||
if hasattr(self.config.training_parameters,
|
||||
"ema") and self.config.training_parameters.ema is True:
|
||||
self.ema.apply_shadow()
|
||||
if self.config.training_parameters.get('fp16', False):
|
||||
self.model.half()
|
||||
self.overall_metric_evaluator.reset()
|
||||
for idx, batch in tqdm(
|
||||
enumerate(chain(*loader_list)),
|
||||
total=self._len_of_loader_list(loader_list),
|
||||
disable=not is_main_process() or self.disable_tqdm,
|
||||
):
|
||||
# report, model_output, prepared_batch = self._forward_pass(
|
||||
# batch, enable_amp=self.enable_torch_amp)
|
||||
if idx >= self.config.training_parameters.get('num_eval', 1e7):
|
||||
break
|
||||
if self.config.training_parameters.get('fp16', False):
|
||||
input_dict = SampleList()
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.cuda.FloatTensor):
|
||||
input_dict[k] = v.half()
|
||||
else:
|
||||
input_dict[k] = v
|
||||
report, model_output, prepared_batch = self._forward_pass(
|
||||
input_dict, enable_amp=self.enable_torch_amp)
|
||||
else:
|
||||
report, model_output, prepared_batch = self._forward_pass(
|
||||
batch, enable_amp=self.enable_torch_amp)
|
||||
self._update_meter(report, meter)
|
||||
self.overall_metric_evaluator.collect(prepared_batch,
|
||||
model_output)
|
||||
for _, metric_object in self.overall_metric_evaluator.metrics.items(
|
||||
):
|
||||
metric_object.all_reduce()
|
||||
self.model.train()
|
||||
|
||||
return report.dataset_name, meter
|
||||
135
lib/utils/checkpoint.py
Normal file
135
lib/utils/checkpoint.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# Copyright (c) Ant Financial Service Group. and its affiliates.
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from antmmf.common import constants
|
||||
from antmmf.common.registry import registry
|
||||
from antmmf.common.checkpoint import Checkpoint
|
||||
from antmmf.utils.distributed_utils import is_main_process
|
||||
|
||||
class SegCheckpoint(Checkpoint):
|
||||
def __init__(self, trainer, load_only=False):
|
||||
super().__init__(trainer, load_only=False)
|
||||
|
||||
def load_model_weights(self, file, force=False):
|
||||
self.trainer.writer.write("Loading checkpoint")
|
||||
ckpt = self._torch_load(file)
|
||||
if registry.get(constants.STATE) is constants.STATE_ONLINE_SERVING:
|
||||
data_parallel = False
|
||||
else:
|
||||
data_parallel = registry.get("data_parallel") or registry.get(
|
||||
"distributed")
|
||||
|
||||
if "model" in ckpt:
|
||||
ckpt_model = ckpt["model"]
|
||||
else:
|
||||
ckpt_model = ckpt
|
||||
ckpt = {"model": ckpt}
|
||||
|
||||
new_dict = {}
|
||||
|
||||
# TODO: Move to separate function
|
||||
for attr in ckpt_model:
|
||||
if "fa_history" in attr:
|
||||
new_dict[attr.replace("fa_history",
|
||||
"fa_context")] = ckpt_model[attr]
|
||||
elif data_parallel is False and attr.startswith("module."):
|
||||
new_k = attr.replace("module.", "", 1)
|
||||
if '.Wqkv.' in new_k:
|
||||
new_k = new_k.replace('.Wqkv.', '.in_proj_')
|
||||
|
||||
new_dict[new_k] = ckpt_model[attr]
|
||||
elif data_parallel is not False and not attr.startswith("module."):
|
||||
new_dict["module." + attr] = ckpt_model[attr]
|
||||
elif data_parallel is False and not attr.startswith("module."):
|
||||
print('data_parallel is False and not attr!!!')
|
||||
new_k = attr
|
||||
if '.Wqkv.' in new_k:
|
||||
new_k = new_k.replace('.Wqkv.', '.in_proj_')
|
||||
new_dict[new_k] = ckpt_model[attr]
|
||||
else:
|
||||
new_dict[attr] = ckpt_model[attr]
|
||||
print(new_dict.keys())
|
||||
self._load_state_dict(new_dict)
|
||||
self._load_model_weights_with_mapping(new_dict, force=force)
|
||||
print(f'load weight: {file} done!')
|
||||
return ckpt
|
||||
|
||||
def _load(self, file, force=False, resume_state=False):
|
||||
ckpt = self.load_model_weights(file, force=force)
|
||||
|
||||
# skip loading training state
|
||||
if resume_state is False:
|
||||
return
|
||||
|
||||
if "optimizer" in ckpt:
|
||||
try:
|
||||
self.trainer.optimizer.load_state_dict(ckpt["optimizer"])
|
||||
# fix the bug of checkpoint in the pytorch with version higher than 1.11
|
||||
if "capturable" in self.trainer.optimizer.param_groups[0]:
|
||||
self.trainer.optimizer.param_groups[0]["capturable"] = True
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
else:
|
||||
warnings.warn(
|
||||
"'optimizer' key is not present in the checkpoint asked to be loaded. Skipping."
|
||||
)
|
||||
|
||||
if "lr_scheduler" in ckpt:
|
||||
self.trainer.lr_scheduler.load_state_dict(ckpt["lr_scheduler"])
|
||||
else:
|
||||
warnings.warn(
|
||||
"'lr_scheduler' key is not present in the checkpoint asked to be loaded. Skipping."
|
||||
)
|
||||
|
||||
self.trainer.early_stopping.init_from_checkpoint(ckpt)
|
||||
|
||||
self.trainer.writer.write("Checkpoint {} loaded".format(file))
|
||||
|
||||
if "current_iteration" in ckpt:
|
||||
self.trainer.current_iteration = ckpt["current_iteration"]
|
||||
registry.register("current_iteration",
|
||||
self.trainer.current_iteration)
|
||||
|
||||
if "current_epoch" in ckpt:
|
||||
self.trainer.current_epoch = ckpt["current_epoch"]
|
||||
registry.register("current_epoch", self.trainer.current_epoch)
|
||||
|
||||
def save(self, iteration, update_best=False):
|
||||
if not is_main_process():
|
||||
return
|
||||
|
||||
ckpt_filepath = os.path.join(self.models_foldername,
|
||||
"model_%d.ckpt" % iteration)
|
||||
best_ckpt_filepath = os.path.join(self.ckpt_foldername,
|
||||
self.ckpt_prefix + "best.ckpt")
|
||||
|
||||
best_iteration = self.trainer.early_stopping.best_monitored_iteration
|
||||
best_metric = self.trainer.early_stopping.best_monitored_value
|
||||
current_iteration = self.trainer.current_iteration
|
||||
current_epoch = self.trainer.current_epoch
|
||||
model = self.trainer.model
|
||||
data_parallel = registry.get("data_parallel") or registry.get(
|
||||
"distributed")
|
||||
|
||||
if data_parallel is True:
|
||||
model = model.module
|
||||
|
||||
ckpt = {
|
||||
"model": model.state_dict(),
|
||||
"optimizer": self.trainer.optimizer.state_dict(),
|
||||
"lr_scheduler": self.trainer.lr_scheduler.state_dict(),
|
||||
"current_iteration": current_iteration,
|
||||
"current_epoch": current_epoch,
|
||||
"best_iteration": best_iteration,
|
||||
"best_metric_value": best_metric,
|
||||
}
|
||||
|
||||
torch.save(ckpt, ckpt_filepath)
|
||||
self.remove_redundant_ckpts()
|
||||
|
||||
if update_best:
|
||||
torch.save(ckpt, best_ckpt_filepath)
|
||||
122
lib/utils/optim_utils.py
Normal file
122
lib/utils/optim_utils.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import torch
|
||||
from torch.nn import LayerNorm, Linear, GELU
|
||||
from torch.nn import MultiheadAttention, Sequential
|
||||
import warnings
|
||||
try:
|
||||
from atorch.normalization import LayerNorm as FastLayerNorm
|
||||
from atorch.modules.transformer.inject import replace_module
|
||||
from atorch.modules.transformer.layers import MultiheadAttentionFA, BertAttentionFA
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
warnings.warn("Using replace_speedup_op but no atorch/apex installed:%s" % e)
|
||||
try:
|
||||
from transformers.models.bert.modeling_bert import BertAttention
|
||||
replace_transformer_bert = True
|
||||
|
||||
except ImportError:
|
||||
replace_transformer_bert = False
|
||||
|
||||
|
||||
class DefaultStrategy:
|
||||
replace_mha = True
|
||||
replace_layernorm = True
|
||||
replace_linear_gelu = False # TODO: numerical consistency
|
||||
|
||||
|
||||
def replace_layer_norm(module: torch.nn.Module, cur_name: str):
|
||||
|
||||
for name, child in module.named_children():
|
||||
child_name = cur_name + "." + name
|
||||
if isinstance(child, LayerNorm):
|
||||
new_module = FastLayerNorm(child.normalized_shape, eps=child.eps)
|
||||
new_module.load_state_dict(child.state_dict())
|
||||
setattr(module, name, new_module)
|
||||
else:
|
||||
replace_layer_norm(child, child_name)
|
||||
|
||||
def is_atorch_available(raise_error=True, log=None):
|
||||
try:
|
||||
import atorch # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError as e:
|
||||
if raise_error is True:
|
||||
raise ImportError(e, log)
|
||||
else:
|
||||
return False
|
||||
|
||||
def _cast_if_autocast_enabled(*args):
|
||||
if not torch.is_autocast_enabled():
|
||||
return args
|
||||
else:
|
||||
return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())
|
||||
|
||||
|
||||
def _fused_dense_gelu_dense(input, weight1, bias1, weight2, bias2):
|
||||
batch, seq_length, hidden_size = input.size()
|
||||
input = input.view(batch * seq_length, hidden_size)
|
||||
args = _cast_if_autocast_enabled(input, weight1, bias1, weight2, bias2)
|
||||
from apex.fused_dense import FusedDenseGeluDenseFunc # with cast
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
out = FusedDenseGeluDenseFunc.apply(*args)
|
||||
out = out.view(batch, seq_length, -1)
|
||||
return out
|
||||
|
||||
|
||||
def linear_gelu_forward(input_, weight1, bias1, weight2, bias2):
|
||||
return _fused_dense_gelu_dense(input_, weight1, bias1, weight2, bias2)
|
||||
|
||||
|
||||
def replace_linear_gelu(module, cur_name: str):
|
||||
"""
|
||||
(layers): Sequential(
|
||||
(0): Sequential(
|
||||
(0): Linear(in_features=1536, out_features=6144, bias=True)
|
||||
(1): GELU()
|
||||
(2): Dropout(p=0.0, inplace=False)
|
||||
)
|
||||
(1): Linear(in_features=6144, out_features=1536, bias=True)
|
||||
(2): Dropout(p=0.0, inplace=False)
|
||||
)
|
||||
"""
|
||||
for name, child in module.named_children():
|
||||
child_name = cur_name + "." + name
|
||||
if isinstance(child, Sequential):
|
||||
if len(child) >= 2 and isinstance(
|
||||
child[0], Sequential
|
||||
) and isinstance(
|
||||
child[1], Linear
|
||||
) and len(child[0]
|
||||
) >= 2 and isinstance(
|
||||
child[0][0], Linear
|
||||
) and isinstance(
|
||||
child[0][1], GELU
|
||||
): # Sequential+Linear
|
||||
linear0 = child[0][0]
|
||||
linear1 = child[1]
|
||||
if getattr(child, "replace_linear_gelu", False):
|
||||
continue
|
||||
child.forward = lambda x: linear_gelu_forward(
|
||||
x, linear0.weight, linear0.bias, linear1.weight, linear1.bias)
|
||||
child.replace_linear_gelu = True
|
||||
print("REPLACE linear+gelu:%s" % child_name)
|
||||
# setattr(module, name, new_module)
|
||||
else:
|
||||
replace_linear_gelu(child, child_name)
|
||||
|
||||
|
||||
def replace_speedup_op(model, strategy=DefaultStrategy):
|
||||
if not is_atorch_available(raise_error=False):
|
||||
raise ImportError("Install Atorch/apex before using speedup op")
|
||||
if strategy.replace_mha:
|
||||
model = replace_module(model, MultiheadAttention, MultiheadAttentionFA, need_scr_module=True)
|
||||
if replace_transformer_bert:
|
||||
model = replace_module(model, BertAttention, BertAttentionFA, need_scr_module=True)
|
||||
root_name = model.__class__.__name__
|
||||
if strategy.replace_layernorm:
|
||||
replace_layer_norm(model, root_name) # inplace
|
||||
if strategy.replace_linear_gelu:
|
||||
replace_linear_gelu(model, root_name)
|
||||
return model
|
||||
|
||||
# TODO:
|
||||
# 1. SyncBatchNorm
|
||||
243
lib/utils/utils.py
Normal file
243
lib/utils/utils.py
Normal file
@@ -0,0 +1,243 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def cosine_scheduler(base_value,
|
||||
final_value,
|
||||
all_iters,
|
||||
warmup_iters=0,
|
||||
start_warmup_value=0):
|
||||
warmup_schedule = np.array([])
|
||||
if warmup_iters > 0:
|
||||
warmup_schedule = np.linspace(start_warmup_value, base_value,
|
||||
warmup_iters)
|
||||
|
||||
iters = np.arange(all_iters - warmup_iters)
|
||||
schedule = final_value + 0.5 * (base_value - final_value) * (
|
||||
1 + np.cos(np.pi * iters / len(iters)))
|
||||
|
||||
schedule = np.concatenate((warmup_schedule, schedule))
|
||||
assert len(schedule) == all_iters
|
||||
return schedule
|
||||
|
||||
|
||||
def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
|
||||
if epoch >= freeze_last_layer:
|
||||
return
|
||||
for n, p in model.named_parameters():
|
||||
if "last_layer" in n:
|
||||
p.grad = None
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=float)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
def cancel_gradients_backbone(iteration, model, freeze_backbone_steps):
|
||||
if iteration >= freeze_backbone_steps:
|
||||
return
|
||||
for n, p in model.named_parameters():
|
||||
if "backbon_hr" in n or 'backbon_s2' in n or 'head_s2' in n or 'fusion' in n or 'ctpe' in n:
|
||||
p.grad = None
|
||||
|
||||
|
||||
class EMA():
|
||||
|
||||
def __init__(self, model, decay):
|
||||
self.model = model
|
||||
self.decay = decay
|
||||
self.shadow = {}
|
||||
self.backup = {}
|
||||
|
||||
def register(self):
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.shadow[name] = param.data.clone()
|
||||
|
||||
def update(self):
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
assert name in self.shadow
|
||||
new_average = (1.0 - self.decay
|
||||
) * param.data + self.decay * self.shadow[name]
|
||||
self.shadow[name] = new_average.clone()
|
||||
|
||||
def apply_shadow(self):
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
assert name in self.shadow
|
||||
self.backup[name] = param.data
|
||||
param.data = self.shadow[name]
|
||||
|
||||
def restore(self):
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
assert name in self.backup
|
||||
param.data = self.backup[name]
|
||||
self.backup = {}
|
||||
|
||||
|
||||
class LayerDecayValueAssigner(object):
|
||||
|
||||
def __init__(self, layer_decay, num_layers, base_lr, net_type, arch='huge'):
|
||||
assert net_type in ['swin', 'vit']
|
||||
assert 0 < layer_decay <= 1
|
||||
depths_dict = {
|
||||
'tiny': [2, 2, 6, 2],
|
||||
'small': [2, 2, 18, 2],
|
||||
'base': [2, 2, 18, 2],
|
||||
'large': [2, 2, 18, 2],
|
||||
'huge': [2, 2, 18, 2],
|
||||
'giant': [2, 2, 42, 4],
|
||||
}
|
||||
num_layers = num_layers if net_type == 'vit' else sum(depths_dict[arch])
|
||||
self.layer_decay = layer_decay
|
||||
self.values = list(layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))
|
||||
self.depths = depths_dict[arch]
|
||||
self.base_lr = base_lr
|
||||
self.net_type = net_type
|
||||
|
||||
def get_num_layer_for_vit(self, var_name):
|
||||
if var_name in ("cls_token", "mask_token", "pos_embed"):
|
||||
return 0
|
||||
elif var_name.startswith("patch_embed"):
|
||||
return 0
|
||||
elif var_name.startswith("layers"):
|
||||
layer_id = int(var_name.split('.')[1])
|
||||
return layer_id + 1
|
||||
else:
|
||||
return len(self.values) - 1
|
||||
|
||||
def get_num_layer_for_swin(self, var_name):
|
||||
if var_name in ("mask_token", "pos_embed"):
|
||||
return 0
|
||||
elif var_name.startswith("patch_embed"):
|
||||
return 0
|
||||
elif var_name.startswith("stages"):
|
||||
layer_id = int(var_name.split('.')[1])
|
||||
if 'blocks' in var_name:
|
||||
block_id = int(var_name.split('.')[3])
|
||||
else:
|
||||
block_id = self.depths[layer_id] - 1
|
||||
layer_id = sum(self.depths[:layer_id]) + block_id
|
||||
return layer_id + 1
|
||||
else:
|
||||
return len(self.values) - 1
|
||||
|
||||
def get_layer_id(self, var_name):
|
||||
if self.net_type == 'swin':
|
||||
return self.get_num_layer_for_swin(var_name)
|
||||
if self.net_type == 'vit':
|
||||
return self.get_num_layer_for_vit(var_name)
|
||||
|
||||
def fix_param(self, model, num_block=4):
|
||||
if num_block < 1:
|
||||
return 0
|
||||
frozen_num = 0
|
||||
if self.net_type == 'swin':
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith("patch_embed"):
|
||||
param.requires_grad = False
|
||||
frozen_num += 1
|
||||
if name.startswith("stages") and self.get_layer_id(name) <= num_block:
|
||||
param.requires_grad = False
|
||||
frozen_num += 1
|
||||
if self.net_type == 'vit':
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith("patch_embed"):
|
||||
param.requires_grad = False
|
||||
frozen_num += 1
|
||||
if name.startswith("layers") and self.get_layer_id(name) <= num_block:
|
||||
param.requires_grad = False
|
||||
frozen_num += 1
|
||||
return frozen_num
|
||||
|
||||
def fix_param_deeper(self, model, num_block=4):
|
||||
if num_block < 1:
|
||||
return 0
|
||||
frozen_num = 0
|
||||
if self.net_type == 'swin':
|
||||
raise ValueError('Not Support')
|
||||
if self.net_type == 'vit':
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith("patch_embed"):
|
||||
param.requires_grad = False
|
||||
frozen_num += 1
|
||||
if name.startswith("layers") and self.get_layer_id(name) >= num_block:
|
||||
param.requires_grad = False
|
||||
frozen_num += 1
|
||||
return frozen_num
|
||||
|
||||
def get_parameter_groups(self, model, weight_decay):
|
||||
parameter_groups_with_wd, parameter_groups_without_wd = [], []
|
||||
print_info_with_wd, print_info_without_wd = [], []
|
||||
no_decay = [
|
||||
"absolute_pos_embed", "relative_position_bias_table", "norm", "bias"
|
||||
]
|
||||
if self.layer_decay == 1:
|
||||
parameter_groups_with_wd.append(
|
||||
{"params": [], "weight_decay": weight_decay, "lr": self.base_lr}
|
||||
)
|
||||
print_info_with_wd.append(
|
||||
{"params": [], "weight_decay": weight_decay, "lr": self.base_lr}
|
||||
)
|
||||
parameter_groups_without_wd.append(
|
||||
{"params": [], "weight_decay": 0, "lr": self.base_lr}
|
||||
)
|
||||
print_info_without_wd.append(
|
||||
{"params": [], "weight_decay": 0, "lr": self.base_lr}
|
||||
)
|
||||
else:
|
||||
for scale in self.values:
|
||||
parameter_groups_with_wd.append(
|
||||
{"params": [], "weight_decay": weight_decay, "lr": scale * self.base_lr}
|
||||
)
|
||||
print_info_with_wd.append(
|
||||
{"params": [], "weight_decay": weight_decay, "lr": scale * self.base_lr}
|
||||
)
|
||||
parameter_groups_without_wd.append(
|
||||
{"params": [], "weight_decay": 0, "lr": scale * self.base_lr}
|
||||
)
|
||||
print_info_without_wd.append(
|
||||
{"params": [], "weight_decay": 0, "lr": scale * self.base_lr}
|
||||
)
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
print(f'frozen param: {name}')
|
||||
continue # frozen weights
|
||||
layer_id = self.get_layer_id(name) if self.layer_decay < 1 else 0
|
||||
if any(nd in name for nd in no_decay):
|
||||
parameter_groups_without_wd[layer_id]['params'].append(param)
|
||||
print_info_without_wd[layer_id]['params'].append(name)
|
||||
else:
|
||||
parameter_groups_with_wd[layer_id]['params'].append(param)
|
||||
print_info_with_wd[layer_id]['params'].append(name)
|
||||
parameter_groups_with_wd = [x for x in parameter_groups_with_wd if len(x['params']) > 0]
|
||||
parameter_groups_without_wd = [x for x in parameter_groups_without_wd if len(x['params']) > 0]
|
||||
print_info_with_wd = [x for x in print_info_with_wd if len(x['params']) > 0]
|
||||
print_info_without_wd = [x for x in print_info_without_wd if len(x['params']) > 0]
|
||||
if self.layer_decay < 1:
|
||||
for wd, nwd in zip(print_info_with_wd, print_info_without_wd):
|
||||
print(wd)
|
||||
print(nwd)
|
||||
parameter_groups = []
|
||||
parameter_groups.extend(parameter_groups_with_wd)
|
||||
parameter_groups.extend(parameter_groups_without_wd)
|
||||
return parameter_groups
|
||||
|
||||
Reference in New Issue
Block a user