This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

5
lib/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from .datasets import *
from .models import *
from .predictors import *
from .trainer import *
from .task import *

3
lib/datasets/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .builder import PretrainingBuilder
__all__ = ["PretrainingBuilder"]

18
lib/datasets/builder.py Normal file
View File

@@ -0,0 +1,18 @@
from antmmf.common.registry import registry
from antmmf.datasets.base_dataset_builder import BaseDatasetBuilder
from .loader.pretraining_loader import PretrainingLoader
@registry.register_builder("pretraining_loader")
class PretrainingBuilder(BaseDatasetBuilder):
def __init__(self):
super().__init__("pretraining_loader")
def _build(self, dataset_type, config, *args, **kwargs):
return None
def _load(self, dataset_type, config, *args, **kwargs):
self.dataset = PretrainingLoader(dataset_type, config)
return self.dataset
def update_registry_for_model(self, config):
pass

View File

View File

@@ -0,0 +1,289 @@
import os
import json
import datetime
import random
import itertools
import time
import numpy as np
import torch
import torch.nn.functional as F
from antmmf.structures import Sample
from antmmf.datasets.base_dataset import BaseDataset
from antmmf.common import Configuration
from lib.datasets.utils.transforms import Compose, MSNormalize
from lib.datasets.utils.formatting import ToTensor
import lib.datasets.utils.pair_trainsforms as pair_transforms
from skimage import io
from osgeo import gdal
from PIL import Image
class FewShotFloodLoader(BaseDataset):
DATASET_NAME = "few_shot_flood_loader"
def __init__(self, dataset_type, config):
super().__init__(self.__class__.DATASET_NAME, dataset_type, config)
if dataset_type == 'train':
raise ValueError('train mode not support!!!')
self.root = config.data_root_dir
self.dataset_type = dataset_type
self.img_dir = config.img_dir
self.tgt_dir = config.tgt_dir
with open(config.data_txt, 'r') as f:
test_list = f.readlines()
self.test_pairs = []
self.cls2path = {}
for i in test_list:
i = i.strip()
if i == '':
continue
img_path = i[:-3]
cls = int(i[-2:])
cls = int(cls)
self.test_pairs.append(
{'hr_path': img_path,
'class': cls,
'tgt_path': img_path.replace('_', '_lab_', 1).replace('.jpg', '.png')
})
if cls in self.cls2path.keys():
self.cls2path[cls].append({'hr_path': img_path, 'tgt_path': img_path.replace('_', '_lab_', 1).replace('.jpg', '.png'), 'class': cls})
else:
self.cls2path[cls] = [{'hr_path': img_path, 'tgt_path': img_path.replace('_', '_lab_', 1).replace('.jpg', '.png'), 'class': cls}]
self.seq_len = config.seq_len # ts
self.hr_size = config.image_size.hr
self.s2_size = config.image_size.s2
self.s1_size = config.image_size.s1
self.anno_size = config.image_size.anno
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]) # 先不管
self.config = config
self.pipeline = self._get_pipline()
# self.crop_resize = pair_transforms.RandomResizedCropComb(512, scale=(0.99, 1.0), interpolation=3)
def __len__(self) -> int:
return len(self.test_pairs)
def _combine_two_images(self, image, image2):
dst = torch.cat([image, image2], dim=-2)
return dst
def _get_pipline(self):
if self.dataset_type == 'val' or self.dataset_type == 'test':
pipeline = [
pair_transforms.ToTensor(),
pair_transforms.RandomResizedCrop(512, scale=(0.9999, 1.0), interpolation=3),
pair_transforms.Normalize(),
]
else:
raise ValueError('dataset_type not support')
return pair_transforms.Compose(pipeline)
def _load_data(self, data_path):
file_name, file_extension = os.path.splitext(data_path)
if file_extension == '.npz' or file_extension == '.npy':
npz_key = self.config.get('npz_key', 'image')
data = np.load(data_path)[npz_key]
elif file_extension == '.png' or file_extension == '.jpg':
data = io.imread(data_path)
if len(data.shape) == 3:
data = data.transpose(2, 0, 1)
elif file_extension == '.tiff' or file_extension == '.tif':
dataset = gdal.Open(data_path)
if dataset is None:
raise IOError(f'can not open file: {data_path}')
data = dataset.ReadAsArray()
dataset = None
else:
raise ValueError(f'file type {data_path} not support')
# check nan
if np.isnan(data).any():
print(f'{data_path} with nan, replace it to 0!')
data[np.isnan(data)] = 0
return data
def load_s2(self, pair):
if 'l8_path' in pair.keys():
pair['s2_path'] = pair['l8_path']
if 's2_path' in pair.keys() and not self.config.get('masking_s2', False):
with_s2 = True
if isinstance(pair['s2_path'], list):
if True: # len(pair['s2_path']) > self.seq_len:
s2_path_list = np.random.choice(pair['s2_path'], self.seq_len)
s2_path_list = sorted(s2_path_list)
else:
s2_path_list = pair['s2_path']
s2_list = []
s2_ct_1 = []
for s2_path in s2_path_list:
s2 = self._load_data(os.path.join(self.root, s2_path)) # [:10]
s2_list.append(s2)
ct = os.path.splitext(s2_path)[0].split('_')
ct = ct[3] # + ct[-3] + '01'
try:
ct = datetime.datetime.strptime(ct, '%Y%m%d')
except:
ct = datetime.datetime.strptime(ct, '%Y-%m-%d')
ct = ct.timetuple()
ct = ct.tm_yday - 1
s2_ct_1.append(ct)
s2_1 = np.stack(s2_list, axis=1)
else:
s2 = np.load(os.path.join(self.root, pair['s2_path']))['image']
date = np.load(os.path.join(self.root, pair['s2_path']))['date']
if True: # s2.shape[0] > self.seq_len:
selected_indices = np.random.choice(s2.shape[0], size=self.seq_len, replace=False)
selected_indices = sorted(selected_indices)
s2 = s2[selected_indices, :, :, :]
date = date[selected_indices]
s2_1 = s2.transpose(1, 0, 2, 3) # ts, c, h, w -> c, ts, h, w
s2_ct_1 = []
for ct in date:
try:
ct = datetime.datetime.strptime(ct, '%Y%m%d')
except:
ct = datetime.datetime.strptime(ct, '%Y-%m-%d')
ct = ct.timetuple()
ct = ct.tm_yday - 1
s2_ct_1.append(ct)
else:
with_s2 = False
s2_1 = np.zeros((10, self.seq_len, self.s2_size[0], self.s2_size[1]),
dtype=np.int16)
s2_ct_1 = [0] * self.seq_len
return with_s2, s2_1, s2_ct_1
def load_s1(self, pair):
if 's1_path' in pair.keys():
with_s1 = True
if isinstance(pair['s1_path'], list):
if True: # len(pair['s1_path']) > self.seq_len:
s1_path_list = np.random.choice(pair['s1_path'], self.seq_len)
s1_path_list = sorted(s1_path_list)
else:
s1_path_list = pair['s1_path']
s1_list = []
for s1_path in s1_path_list:
s1 = self._load_data(os.path.join(self.root, s1_path))
s1_list.append(s1)
s1_1 = np.stack(s1_list, axis=1)
else:
s1 = self._load_data(os.path.join(self.root, pair['s1_path']))
if True: # s1.shape[0] > self.seq_len:
selected_indices = np.random.choice(s1.shape[0], size=self.seq_len, replace=False)
selected_indices = sorted(selected_indices)
s1 = s1[selected_indices, :, :, :]
s1_1 = s1.transpose(1, 0, 2, 3) # ts, c, h, w -> c, ts, h, w
else:
with_s1 = False
s1_1 = np.zeros((2, self.seq_len, self.s1_size[0], self.s1_size[1]),
dtype=np.float32)
return with_s1, s1_1
def load_hr(self, pair):
if 'hr_path' in pair.keys():
with_hr = True
hr = self._load_data(os.path.join(self.root, pair['hr_path']))
else:
with_hr = False
hr = np.zeros((3, self.hr_size[0], self.hr_size[1]),
dtype=np.uint8)
return with_hr, hr
def load_tgt(self, pair):
targets = self._load_data(os.path.join(self.root, pair['target_path']))
return targets
def get_item(self, idx):
pair = self.test_pairs[idx]
test_class = pair['class']
current_dataset = 'flood3i'
with_hr = True
with_s2 = False
with_s1 = False
input_hr = io.imread(os.path.join(self.img_dir, pair['hr_path']))
input_hr = input_hr.transpose(2,0,1)
_, input_s2,_ = self.load_s2(pair)
_, input_s1 = self.load_s1(pair)
input_tgt = io.imread(os.path.join(self.tgt_dir, pair['tgt_path']))
modality_dict = {
's2': with_s2,
's1': with_s1,
'hr': with_hr
}
input_tgt[input_tgt != test_class] = 0
input_tgt[input_tgt == test_class] = 255
input_tgt = np.concatenate((input_tgt[None, :,:],)*3, axis=0)
input_hr, input_s2, input_s1, input_tgt = self.pipeline(current_dataset, input_hr, input_s2, input_s1,
input_tgt)
while True:
sel_prompt = random.choice(self.cls2path[test_class])
if sel_prompt['hr_path'] != pair['hr_path']:
break
prompt_hr = io.imread(os.path.join(self.img_dir, sel_prompt['hr_path']))
prompt_hr = prompt_hr.transpose(2,0,1)
_, prompt_s2,_ = self.load_s2(pair)
_, prompt_s1 = self.load_s1(pair)
prompt_tgt = io.imread(os.path.join(self.tgt_dir, sel_prompt['tgt_path']))
prompt_tgt[prompt_tgt != test_class] = 0
prompt_tgt[prompt_tgt == test_class] = 255
prompt_tgt = np.concatenate((prompt_tgt[None, :,:],)*3, axis=0)
prompt_hr, prompt_s2, prompt_s1, prompt_tgt = self.pipeline(current_dataset, prompt_hr, prompt_s2, prompt_s1, prompt_tgt)
targets_comb = self._combine_two_images(prompt_tgt, input_tgt)
hr_comb = self._combine_two_images(prompt_hr, input_hr)
s2_comb = self._combine_two_images(prompt_s2, input_s2)
s1_comb = self._combine_two_images(prompt_s1, input_s1)
valid = torch.ones_like(targets_comb)
thres = torch.ones(3) * 1e-5 # ignore black
thres = (thres - self.imagenet_mean) / self.imagenet_std
valid[targets_comb < thres[:, None, None]] = 0
mask_shape = (int(self.config.mim.input_size[0] / self.config.mim.patch_size),
int(self.config.mim.input_size[1] / self.config.mim.patch_size))
mask = np.zeros(mask_shape, dtype=np.int32)
mask[mask.shape[0] // 2:, :] = 1
geo_location = pair["location"] if "location" in pair.keys() else None
modality_idx = 2 ** 0 * modality_dict['s2'] + 2 ** 1 * modality_dict['s1'] + 2 ** 2 * modality_dict['hr']
modality_flag_s2 = modality_dict['s2']
modality_flag_s1 = modality_dict['s1']
modality_flag_hr = modality_dict['hr']
current_sample = Sample()
current_sample.img_name = pair["tgt_path"].split('/')[-1].split('.')[0] + '-' +str(test_class)
current_sample.hr_img = hr_comb
current_sample.dataset_name = 'flood3i'
current_sample.targets = targets_comb
current_sample.s2_img = s2_comb
current_sample.s2_ct = -1
current_sample.s2_ct2 = -1
current_sample.s1_img = s1_comb
current_sample.anno_mask = torch.from_numpy(mask)
current_sample.valid = valid
current_sample.location = geo_location
current_sample.modality_idx = modality_idx
current_sample.modality_flag_s2 = modality_flag_s2
current_sample.modality_flag_s1 = modality_flag_s1
current_sample.modality_flag_hr = modality_flag_hr
current_sample.task_type = self.dataset_type
return current_sample

View File

@@ -0,0 +1,494 @@
import os
import json
import datetime
import random
import torch
import numpy as np
from osgeo import gdal
from skimage import io
from skimage.transform import resize
from antmmf.structures import Sample
from antmmf.datasets.base_dataset import BaseDataset
import lib.datasets.utils.pair_trainsforms as pair_transforms
from lib.datasets.utils.masking_generator import MaskingGenerator
from lib.datasets.utils.dataset_colors import dataset_color_dict, get_painter_color_map_list, get_real_random_color_list
class PretrainingLoader(BaseDataset):
DATASET_NAME = "pretraining_loader"
def __init__(self, dataset_type, config):
super().__init__(self.__class__.DATASET_NAME, dataset_type, config)
self.root = config.data_root_dir
if dataset_type == 'train':
self.json_path_list = config.train_json_path_list
if dataset_type == 'val':
self.json_path_list = config.val_json_path_list
if dataset_type == 'test':
self.json_path_list = config.val_json_path_list
self.dataset_type = dataset_type
self.pairs = []
self.cls_repeat_cnt = config.cls_repeat_cnt
num_datasets = len(self.json_path_list)
for idx, json_path in enumerate(self.json_path_list):
print(os.path.join(config.data_root_dir, json_path))
cur_pairs = json.load(open(os.path.join(config.data_root_dir, json_path)))
self.pairs.extend(cur_pairs)
cur_num = len(cur_pairs)
if dataset_type == 'test' and config.prompt_json:
cur_pairs = json.load(open(config.prompt_json))
self.prompt = cur_pairs[0]
print(f'prompt:{self.prompt}')
self.use_multi_pairs = config.use_multi_pairs
if self.use_multi_pairs:
self.pair_type_dict = {}
if dataset_type == 'train' or dataset_type == 'val':
for idx, pair in enumerate(self.pairs):
if pair["type"] not in self.pair_type_dict:
new_subset = {}
classes = pair["classes"]
for cls in classes:
if cls not in new_subset.keys():
new_subset[cls] = [idx]
else:
new_subset[cls].append(idx)
self.pair_type_dict[pair["type"]] = new_subset
else:
classes = pair["classes"]
for cls in classes:
if cls not in self.pair_type_dict[pair["type"]].keys():
self.pair_type_dict[pair["type"]][cls] = [idx]
else:
self.pair_type_dict[pair["type"]][cls].append(idx)
cnt = 0
self.idx_to_cls = {}
for k, v in self.pair_type_dict.items():
for vv in v:
self.idx_to_cls[cnt] = {
'type': k,
'classes_id': vv
}
cnt = cnt + 1
print(self.idx_to_cls)
self.idx_to_cls_list = []
for i in self.idx_to_cls.keys():
self.idx_to_cls_list.append(self.idx_to_cls[i])
print(self.idx_to_cls_list)
if self.dataset_type == 'train':
self.idx_to_cls_list = self.idx_to_cls_list * self.cls_repeat_cnt
self.masked_position_generator = MaskingGenerator(
input_size=config.mim.input_size,
patch_size=config.mim.patch_size,
mask_ratio=config.mim.mask_ratio
)
if dataset_type == 'train':
self.half_mask_ratio = config.half_mask_ratio
else:
self.half_mask_ratio = 1.
self.seq_len = config.seq_len # ts
self.hr_size = config.image_size.hr
self.s2_size = config.image_size.s2
self.s1_size = config.image_size.s1
self.anno_size = config.image_size.anno
self.min_random_scale = config.min_random_scale
self.imagenet_mean=torch.tensor([0.485, 0.456, 0.406])
self.imagenet_std=torch.tensor([0.229, 0.224, 0.225])
self.pipeline = self._get_pipline()
self.crop_resize = pair_transforms.RandomResizedCropComb(512, scale=(0.3, 1.0), interpolation=3)
self.num_samples = 8
def __len__(self) -> int:
return len(self.idx_to_cls_list)
def _convert_colors_pairs(self, images, original_colors, new_colors, current_color):
if len(original_colors) != len(new_colors):
raise ValueError("The length of original_colors and new_colors must be the same.")
unique_colors_list = []
for image in images:
if len(image.shape) == 3:
image_hwc = image.transpose(1,2,0) # chw -> hwc
elif len(image.shape) == 2:
image_hwc = image[:,:,None]
else:
raise ValueError('image shape is {image_hwc.shape}, which is not support to change color!')
image_2d = image_hwc.reshape(-1, image_hwc.shape[-1])
unique_colors = np.unique(image_2d, axis=0)
unique_colors_list.append(unique_colors)
unique_colors_list.append(original_colors)
sets_of_tuples = [set(map(tuple, a)) for a in unique_colors_list]
common_tuples = set.intersection(*sets_of_tuples)
unique_old_colors = np.array(list(common_tuples), dtype=np.uint8)
if len(unique_old_colors) == 0:
unique_old_colors = [current_color]
new_colors_coverted = new_colors[:len(unique_old_colors)]
images_converted_list = []
for image in images:
image_convered = self._convert_colors(image, unique_old_colors, new_colors_coverted)
images_converted_list.append(image_convered)
return images_converted_list
def _convert_colors(self, image, original_colors, new_colors):
"""
Remap colors in an image to new colors.
Parameters:
image (numpy.ndarray): The image as a numpy array (channel x height x width).
original_colors (list of tuples): The list of original colors to be replaced.
new_colors (list of tuples): The list of new colors to replace the original colors.
Returns:
numpy.ndarray: The image with remapped colors. (channel x height x width)
"""
if len(original_colors) != len(new_colors):
raise ValueError("The length of original_colors and new_colors must be the same.")
# Convert lists of tuples to numpy arrays for faster processing
original_colors = np.array(original_colors)
new_colors = np.array(new_colors)
if len(original_colors.shape) == 1:
original_colors = original_colors[:,None]
# check image shape
if len(image.shape) == 3:
remapped_image = image.transpose(1,2,0) # chw -> hwc
elif len(image.shape) == 2:
remapped_image = image[:,:,None]
else:
raise ValueError('image shape is {image.shape}, which is not support to change color!')
# generate new image for return
new_image = np.zeros((remapped_image.shape[0], remapped_image.shape[1], 3), dtype=np.uint8)
for orig_color, new_color in zip(original_colors, new_colors):
mask = np.all(remapped_image == orig_color, axis=-1)
new_image[mask] = new_color
new_image = new_image.transpose(2,0,1) # hwc -> chw
return new_image
def _combine_images(self, images, interpolation='bicubic'):
# images 8, c, h, w -> c, 4h, 2w
group1 = images[:4]
group2 = images[4:]
stacked1 = torch.cat(group1, dim=-2)
stacked2 = torch.cat(group2, dim=-2)
result = torch.cat((stacked1, stacked2), dim=-1)
return result
def _get_pipline(self):
if self.dataset_type == 'train':
pipeline = [
pair_transforms.ToTensor(),
pair_transforms.RandomResizedCrop(512, scale=(0.8, 1.0), interpolation=3), # 3 is bicubic
pair_transforms.RandomHorizontalFlip(),
pair_transforms.Normalize(),
]
elif self.dataset_type == 'val' or self.dataset_type == 'test':
pipeline = [
pair_transforms.ToTensor(),
pair_transforms.RandomResizedCrop(512, scale=(0.9999, 1.0), interpolation=3), # 3 is bicubic
pair_transforms.Normalize(),
]
else:
raise ValueError('dataset_type not support')
return pair_transforms.Compose(pipeline)
def _load_data(self, data_path):
file_name, file_extension = os.path.splitext(data_path)
if file_extension == '.npz' or file_extension == '.npy':
data = np.load(data_path)['image']
elif file_extension == '.png' or file_extension == '.jpg':
data = io.imread(data_path)
if len(data.shape) == 3:
data = data.transpose(2,0,1)
elif file_extension == '.tiff' or file_extension == '.tif':
dataset = gdal.Open(data_path)
if dataset is None:
raise IOError(f'无法打开文件{data_path}')
data = dataset.ReadAsArray()
dataset = None
else:
raise ValueError(f'file type {data_path} not support')
if np.isnan(data).any():
print(f'{data_path} with nan, replace it to 0!')
data[np.isnan(data)] = 0
return data
def load_s2(self, pair):
if pair['type'] == 'flair-mm' and 's2_path' in pair.keys():
with_s2 =True
s2 = np.load(os.path.join(self.root, pair['s2_path']))
idx_centroid = pair['s2_cut_points']
s2_patch_size = 40
subset_sp = s2[:,:,idx_centroid[0]-int(s2_patch_size/2):idx_centroid[0] + \
int(s2_patch_size/2),idx_centroid[1] - int(s2_patch_size/2):idx_centroid[1] + \
int(s2_patch_size/2)]
ts, c, h, w = subset_sp.shape
subset_sp = subset_sp.reshape(-1, h, w).transpose(1,2,0)
s2 = resize(subset_sp, (16, 16), anti_aliasing=True).transpose(2,0,1)
s2 = s2.reshape(ts, c, 16, 16)
if True:
selected_indices = np.random.choice(s2.shape[0], size=self.seq_len, replace=False)
selected_indices = sorted(selected_indices)
s2 = s2[selected_indices, :, :, :]
s2_1 = s2.transpose(1,0,2,3) # ts, c, h, w -> c, ts, h, w
s2_ct_1 = [0] * self.seq_len
elif 's2_path' in pair.keys():
with_s2 =True
if isinstance(pair['s2_path'], list):
if True:
s2_path_list = np.random.choice(pair['s2_path'], self.seq_len)
s2_path_list = sorted(s2_path_list)
else:
s2_path_list = pair['s2_path']
s2_list = []
s2_ct_1 = []
for s2_path in s2_path_list:
s2 = self._load_data(os.path.join(self.root, s2_path))#[:10]
s2_list.append(s2)
ct = os.path.splitext(s2_path)[0].split('_')
ct = ct[-4] + ct[-3] + '01'
try:
ct = datetime.datetime.strptime(ct, '%Y%m%d')
except:
ct = datetime.datetime.strptime(ct, '%Y-%m-%d')
ct = ct.timetuple()
ct = ct.tm_yday - 1
s2_ct_1.append(ct)
s2_1 = np.stack(s2_list, axis=1)
else:
s2 = np.load(os.path.join(self.root, pair['s2_path']))['image']
date = np.load(os.path.join(self.root, pair['s2_path']))['date']
if True:
selected_indices = np.random.choice(s2.shape[0], size=self.seq_len, replace=False)
selected_indices = sorted(selected_indices)
s2 = s2[selected_indices, :, :, :]
date = date[selected_indices]
s2_1 = s2.transpose(1,0,2,3) # ts, c, h, w -> c, ts, h, w
s2_ct_1 = []
for ct in date:
try:
ct = datetime.datetime.strptime(ct, '%Y%m%d')
except:
ct = datetime.datetime.strptime(ct, '%Y-%m-%d')
ct = ct.timetuple()
ct = ct.tm_yday - 1
s2_ct_1.append(ct)
else:
with_s2 = False
s2_1 = np.zeros((10, self.seq_len, self.s2_size[0], self.s2_size[1]),
dtype=np.int16)
s2_ct_1 = [0] * self.seq_len
return with_s2, s2_1, s2_ct_1
def load_s1(self, pair):
if 's1_path' in pair.keys():
with_s1 = True
if isinstance(pair['s1_path'], list):
if True:
s1_path_list = np.random.choice(pair['s1_path'], self.seq_len)
s1_path_list = sorted(s1_path_list)
else:
s1_path_list = pair['s1_path']
s1_list = []
for s1_path in s1_path_list:
s1 = self._load_data(os.path.join(self.root, s1_path))
s1_list.append(s1)
s1_1 = np.stack(s1_list, axis=1)
else:
s1 = self._load_data(os.path.join(self.root, pair['s1_path']))
if True:
selected_indices = np.random.choice(s1.shape[0], size=self.seq_len, replace=False)
selected_indices = sorted(selected_indices)
s1 = s1[selected_indices, :, :, :]
s1_1 = s1.transpose(1,0,2,3) # ts, c, h, w -> c, ts, h, w
else:
with_s1 = False
s1_1 = np.zeros((2, self.seq_len, self.s1_size[0], self.s1_size[1]),
dtype=np.float32)
return with_s1, s1_1
def load_hr(self, pair):
if 'hr_path' in pair.keys():
if pair['type'] == 'flair-mm':
with_hr = True
hr = self._load_data(os.path.join(self.root, pair['hr_path']))[:3,:,:]
else:
with_hr = True
hr = self._load_data(os.path.join(self.root, pair['hr_path']))
else:
with_hr = False
hr = np.zeros((3, self.hr_size[0], self.hr_size[1]),
dtype=np.uint8)
return with_hr, hr
def load_tgt(self, pair):
if self.dataset_type == 'test':
targets = np.zeros((3, self.anno_size[0], self.anno_size[1]),
dtype=np.uint8)
else:
targets = self._load_data(os.path.join(self.root, pair['target_path']))
return targets
def find_random_position(self, matrix, current_color):
if matrix.ndim == 2:
matrix = matrix[None, :, :]
current_color = np.array(current_color)
C, H, W = matrix.shape
if len(current_color) != C:
raise ValueError("current_color unmatch with matrix!")
matches = np.where(np.all(matrix == current_color[:, None, None], axis=0))
if len(matches[0]) > 0:
index = np.random.choice(range(len(matches[0])))
return (matches[0][index], matches[1][index])
else:
return None
def get_item(self, idx):
dataset_cls_infos = self.idx_to_cls_list[idx]
current_dataset = dataset_cls_infos['type']
current_classes_id = dataset_cls_infos['classes_id']
pair_idx_list = self.pair_type_dict[current_dataset][current_classes_id]
old_colors = dataset_color_dict[current_dataset]
current_color = old_colors[current_classes_id]
class_num = len(old_colors)
if self.dataset_type == 'train':
new_colors = get_real_random_color_list(class_num)
else:
new_colors = get_painter_color_map_list(class_num) # fix colors mapping when testing
num_samples = self.num_samples
if len(pair_idx_list) < num_samples:
selected_samples = [random.choice(pair_idx_list) for _ in range(num_samples)]
else:
selected_samples = random.sample(pair_idx_list, num_samples)
hr_imgs = []
tgt_imgs = []
s2_imgs = []
s1_imgs = []
s2_cts = []
for sample_idx in selected_samples:
pair = self.pairs[sample_idx]
with_hr, hr = self.load_hr(pair)
with_s2, s2, s2_ct_1 = self.load_s2(pair)
with_s1, s1 = self.load_s1(pair)
tgt = self.load_tgt(pair)
modality_dict = {
's2' : with_s2,
's1' : with_s1,
'hr' : with_hr
}
if (hr.shape[-2:] != tuple(self.hr_size)) and (hr.shape[-2:] == tgt.shape[-2:]) and (self.hr_size == self.anno_size):
point_pos = self.find_random_position(tgt, current_color)
upper_left_raw = [point_pos[0] - self.hr_size[0] // 2, point_pos[1] - self.hr_size[1] // 2]
upper_left = [i - i%32 + 16 for i in upper_left_raw]
upper_left_sentinel = [i // 32 for i in upper_left_raw]
upper_left[0] = np.clip(np.array(upper_left[0]), 0, hr.shape[-2] - self.hr_size[0])
upper_left[1] = np.clip(np.array(upper_left[1]), 0, hr.shape[-1] - self.hr_size[1])
upper_left_sentinel[0] = np.clip(np.array(upper_left_sentinel[0]), 0, s1.shape[-2] - self.s1_size[0])
upper_left_sentinel[1] = np.clip(np.array(upper_left_sentinel[1]), 0, s1.shape[-1] - self.s1_size[1])
hr = hr[:, upper_left[0]:upper_left[0]+self.hr_size[0], upper_left[1]:upper_left[1]+self.hr_size[1]]
if with_s1:
s1 = s1[:, :, upper_left_sentinel[0]:upper_left_sentinel[0]+self.s1_size[0], upper_left_sentinel[1]:upper_left_sentinel[1]+self.s1_size[1]]
if with_s2:
s2 = s2[:, :, upper_left_sentinel[0]:upper_left_sentinel[0]+self.s2_size[0], upper_left_sentinel[1]:upper_left_sentinel[1]+self.s2_size[1]]
if tgt.ndim == 3:
tgt = tgt[:, upper_left[0]:upper_left[0]+self.hr_size[0], upper_left[1]:upper_left[1]+self.hr_size[1]]
elif tgt.ndim == 2:
tgt = tgt[upper_left[0]:upper_left[0]+self.hr_size[0], upper_left[1]:upper_left[1]+self.hr_size[1]]
else:
raise ValueError("tgt dim unsupport!")
hr_imgs.append(hr)
tgt_imgs.append(tgt)
s2_imgs.append(s2)
s1_imgs.append(s1)
s2_cts.append(s2_ct_1)
cvt_hr_imgs = []
cvt_tgt_imgs = []
cvt_s2_imgs = []
cvt_s1_imgs = []
tgt_imgs = self._convert_colors_pairs(tgt_imgs, old_colors, new_colors, current_color)
for i in range(len(tgt_imgs)):
hr, s2, s1, tgt = self.pipeline(current_dataset, hr_imgs[i], s2_imgs[i], s1_imgs[i], tgt_imgs[i])
cvt_hr_imgs.append(hr)
cvt_s2_imgs.append(s2)
cvt_s1_imgs.append(s1)
cvt_tgt_imgs.append(tgt)
targets_comb = self._combine_images(cvt_tgt_imgs)
hr_comb = self._combine_images(cvt_hr_imgs)
s2_comb = self._combine_images(cvt_s2_imgs)
s1_comb = self._combine_images(cvt_s1_imgs)
hr_comb, s2_comb, s1_comb, targets_comb = self.crop_resize(current_dataset, hr_comb, s2_comb, s1_comb, targets_comb)
use_half_mask = torch.rand(1)[0] < self.half_mask_ratio
valid = torch.ones_like(targets_comb)
thres = torch.ones(3) * (1e-5) # ignore black
thres = (thres - self.imagenet_mean) / self.imagenet_std
valid[targets_comb < thres[:, None, None]] = 0
if use_half_mask:
num_patches = self.masked_position_generator.num_patches
mask = np.zeros(self.masked_position_generator.get_shape(), dtype=np.int32)
mask[mask.shape[0]//2:, :] = 1
else:
mask = self.masked_position_generator()
# location
geo_location = pair["location"] if "location" in pair.keys() else None
# get modality index
modality_idx = 2**0 * modality_dict['s2'] + 2**1 * modality_dict['s1'] + 2**2 * modality_dict['hr']
modality_flag_s2 = modality_dict['s2']
modality_flag_s1 = modality_dict['s1']
modality_flag_hr = modality_dict['hr']
current_sample = Sample()
current_sample.img_name = pair["hr_path"].split('/')[-1].split('.')[0]
current_sample.hr_img = hr_comb
current_sample.dataset_name = pair["type"]
current_sample.targets = targets_comb
current_sample.s2_img = s2_comb
current_sample.s2_ct = s2_cts[0]
current_sample.s2_ct2 = s2_cts[4]
current_sample.s1_img = s1_comb
current_sample.anno_mask = torch.from_numpy(mask)
current_sample.valid = valid
current_sample.location = geo_location
current_sample.modality_idx = modality_idx
current_sample.modality_flag_s2 = modality_flag_s2
current_sample.modality_flag_s1 = modality_flag_s1
current_sample.modality_flag_hr = modality_flag_hr
current_sample.task_type = self.dataset_type
return current_sample

View File

@@ -0,0 +1,77 @@
import random
from functools import lru_cache
import numpy as np
dataset_color_dict = {
"potsdam" : [[1], [2], [3], [4], [5]],
"vaihingen" : [[255, 255, 0], [0, 255, 0], [0, 255, 255], [0, 0, 255], [255, 255, 255]],
"deepglobe" : [[255,255,255], [0,0,255], [0,255,0],[255,0,255], [255,255,0], [0,255,255]],
"fbp" : [[i+1] for i in range(24)],
"loveda" : [[i+2, i+2, i+2] for i in range(6)],
"isaid" : [[i+1] for i in range(15)],
"pastis-mm" : [[i+1] for i in range(18)],
"dynamic-mm" : [[i] for i in range(7)],
"c2seg-ab" : [[i+1] for i in range(13)],
"flood3i": [[i+1] for i in range(9)],
"jl16-mm": [[i] for i in range(16)],
"flair-mm": [[i+1] for i in range(18)],
"dfc20": [[i+1] for i in range(10)]
}
modal_norm_dict = {
'hr' : {
'div' : 255.,
'mean' : [0.485, 0.456, 0.406],
'std' : [0.229, 0.224, 0.225]
},
'anno' : {
'div' : 255.,
'mean' : [0.485, 0.456, 0.406],
'std' : [0.229, 0.224, 0.225]
},
's2' : {
'div' : 1.,
'mean' : [884.29673756, 1144.16202635, 1297.47289228, 1624.90992062, 2194.6423161, 2422.21248945, 2517.76053101, 2581.64687018, 2368.51236873, 1805.06846033],
'std' : [1155.15170768, 1183.6292542, 1368.11351514, 1370.265037, 1355.55390699, 1416.51487101, 1474.78900051, 1439.3086061, 1455.52084939, 1343.48379601]
},
's1' : {
'div' : 1.,
'mean' : [-12.54847273, -20.19237134],
'std' : [5.25697717, 5.91150917]
},
}
@lru_cache()
def get_painter_color_map_list(num_locations = 300):
num_sep_per_channel = int(num_locations ** (1 / 3)) + 1 # 19
separation_per_channel = 256 // num_sep_per_channel
color_list = []
for location in range(num_locations):
num_seq_r = location // num_sep_per_channel ** 2
num_seq_g = (location % num_sep_per_channel ** 2) // num_sep_per_channel
num_seq_b = location % num_sep_per_channel
assert (num_seq_r <= num_sep_per_channel) and (num_seq_g <= num_sep_per_channel) \
and (num_seq_b <= num_sep_per_channel)
R = 255 - num_seq_r * separation_per_channel
G = 255 - num_seq_g * separation_per_channel
B = 255 - num_seq_b * separation_per_channel
assert (R < 256) and (G < 256) and (B < 256)
assert (R >= 0) and (G >= 0) and (B >= 0)
assert (R, G, B) not in color_list
color_list.append((R, G, B))
return color_list
def get_real_random_color_list(num_locations):
random_color_list = np.random.randint(0, 256, (num_locations, 3))
while np.sum(random_color_list) == 0:
print('random_color_list is 0!')
random_color_list = np.random.randint(0, 256, (num_locations, 3))
random_color_list = random_color_list.tolist()
return random_color_list # [:num_locations]

View File

@@ -0,0 +1,65 @@
from collections.abc import Sequence
import mmcv
import numpy as np
import torch
def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
Args:
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
be converted.
"""
if isinstance(data, torch.Tensor):
return data
elif isinstance(data, np.ndarray):
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not mmcv.is_str(data):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
elif isinstance(data, float):
return torch.FloatTensor([data])
else:
raise TypeError(f'type {type(data)} cannot be converted to tensor.')
class ToTensor(object):
"""Convert some sample to :obj:`torch.Tensor` by given keys.
Args:
keys (Sequence[str]): Keys that need to be converted to Tensor.
"""
def __init__(self, keys):
self.keys = keys
def __call__(self, sample):
"""Call function to convert data in sample to :obj:`torch.Tensor`.
Args:
sample (Sample): sample data contains the data to convert.
Returns:
dict: The result dict contains the data converted
to :obj:`torch.Tensor`.
"""
for key in self.keys:
if isinstance(sample[key], list):
for i in range(len(sample[key])):
sample[key][i] = to_tensor(sample[key][i])
else:
sample[key] = to_tensor(sample[key])
return sample
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'

View File

@@ -0,0 +1,84 @@
import random
import math
import numpy as np
class MaskingGenerator:
def __init__(
self, input_size, patch_size, mask_ratio=0.5, min_num_patches=4, max_num_patches=None,
min_aspect=0.3, max_aspect=None):
if not isinstance(input_size, list):
input_size = [input_size,] * 2
self.height = input_size[0] // patch_size
self.width = input_size[1] // patch_size
self.num_patches = self.height * self.width
self.num_masking_patches = int(self.num_patches * mask_ratio)
self.min_num_patches = min_num_patches
self.max_num_patches = self.num_masking_patches if max_num_patches is None else max_num_patches
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
def __repr__(self):
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
self.height, self.width, self.min_num_patches, self.max_num_patches,
self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1])
return repr_str
def get_shape(self):
return self.height, self.width
def _mask(self, mask, max_mask_patches):
delta = 0
for attempt in range(10):
target_area = random.uniform(self.min_num_patches, max_mask_patches)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < self.width and h < self.height:
top = random.randint(0, self.height - h)
left = random.randint(0, self.width - w)
num_masked = mask[top: top + h, left: left + w].sum()
# Overlap
if 0 < h * w - num_masked <= max_mask_patches:
for i in range(top, top + h):
for j in range(left, left + w):
if mask[i, j] == 0:
mask[i, j] = 1
delta += 1
if delta > 0:
break
return delta
def __call__(self):
mask = np.zeros(shape=self.get_shape(), dtype=np.int32)
mask_count = 0
while mask_count < self.num_masking_patches:
max_mask_patches = self.num_masking_patches - mask_count
max_mask_patches = min(max_mask_patches, self.max_num_patches)
delta = self._mask(mask, max_mask_patches)
if delta == 0:
break
else:
mask_count += delta
# maintain a fix number {self.num_masking_patches}
if mask_count > self.num_masking_patches:
delta = mask_count - self.num_masking_patches
mask_x, mask_y = mask.nonzero()
to_vis = np.random.choice(mask_x.shape[0], delta, replace=False)
mask[mask_x[to_vis], mask_y[to_vis]] = 0
elif mask_count < self.num_masking_patches:
delta = self.num_masking_patches - mask_count
mask_x, mask_y = (mask == 0).nonzero()
to_mask = np.random.choice(mask_x.shape[0], delta, replace=False)
mask[mask_x[to_mask], mask_y[to_mask]] = 1
assert mask.sum() == self.num_masking_patches, f"mask: {mask}, mask count {mask.sum()}"
return mask

View File

@@ -0,0 +1,532 @@
import math
import numbers
import random
import warnings
from collections.abc import Sequence
from typing import List, Optional, Tuple
import numpy as np
import torch
from torch import Tensor
import torchvision.transforms as transforms
from skimage import io
try:
import accimage
except ImportError:
accimage = None
import torchvision.transforms.functional as F
from torchvision.transforms.functional import _interpolation_modes_from_int, InterpolationMode
from PIL import Image, ImageFilter, ImageOps
from .dataset_colors import modal_norm_dict
__all__ = [
"Compose",
"ToTensor",
"Normalize",
"RandomHorizontalFlip",
"RandomResizedCrop",
]
class Compose(transforms.Compose):
"""Composes several transforms together. This transform does not support torchscript.
Please, see the note below.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
"""
def __init__(self, transforms):
super().__init__(transforms)
def __call__(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None):
# i = 0
for t in self.transforms:
# i = i+1
# print(f'dataset_name:{dataset_name}')
# print(f'step:{i}')
# print(f'hr_img shape:{hr_img.shape}')
# print(f's2_img shape:{s2_img.shape}')
# print(f's1_img shape:{s1_img.shape}')
# print(f'tgt shape:{tgt.shape}')
hr_img, s2_img, s1_img, tgt = t(dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=interpolation1, interpolation2=interpolation2)
return hr_img, s2_img, s1_img, tgt
class ToTensor(transforms.ToTensor):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8
In the other cases, tensors are returned without scaling.
.. note::
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
transforming target image masks. See the `references`_ for implementing the transforms for image masks.
.. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
"""
def __init__(self) -> None:
super().__init__()
def __call__(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
# print(f'hr dtype:{hr_img.dtype}')
# print(f's2_img dtype:{s2_img.dtype}')
# print(f's1_img dtype:{s1_img.dtype}')
# print(f'tgt dtype:{tgt.dtype}')
if dataset_name == 'dynamic-mm' or dataset_name == 'guizhou-mm':
hr_img = hr_img.astype(np.int32)[:3,:,:]
hr_img = hr_img[::-1,:,:].copy()
else:
hr_img = hr_img.astype(np.int32)
tgt = tgt.astype(np.uint8)
s1_img = s1_img.astype(np.float32)
s2_img = s2_img.astype(np.int16)
return torch.tensor(hr_img), torch.tensor(s2_img), torch.tensor(s1_img),torch.tensor(tgt)
class Normalize(transforms.Normalize):
"""Normalize a tensor image with mean and standard deviation.
This transform does not support PIL Image.
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
channels, this transform will normalize each channel of the input
``torch.*Tensor`` i.e.,
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
.. note::
This transform acts out of place, i.e., it does not mutate the input tensor.
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation in-place.
"""
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False):
super().__init__(mean, std, inplace)
def forward(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None):
"""
Args:
tensor (Tensor): Tensor image to be normalized.
Returns:
Tensor: Normalized Tensor image.
"""
# TODO 查询对应的mean和std
# 处理一些mean和std
if dataset_name == 'dynamic-mm':
hr_std = [1008.4052, 760.9586, 631.4754]
hr_mean = [1085.2941, 944.2718, 689.2493]
hr_div = 1.
else:
hr_mean = modal_norm_dict['hr']['mean']
hr_std = modal_norm_dict['hr']['std']
hr_div = modal_norm_dict['hr']['div']
if dataset_name == 'l8activefire':
# if False:
s2_mean = modal_norm_dict['l8']['mean']
s2_std = modal_norm_dict['l8']['std']
s2_div = modal_norm_dict['l8']['div']
else:
s2_mean = modal_norm_dict['s2']['mean']
s2_std = modal_norm_dict['s2']['std']
s2_div = modal_norm_dict['s2']['div']
s1_mean = modal_norm_dict['s1']['mean']
s1_std = modal_norm_dict['s1']['std']
s1_div = modal_norm_dict['s1']['div']
anno_mean = [0.485, 0.456, 0.406]
anno_std = [0.229, 0.224, 0.225]
ann_div = 255.
# 存在问题:时间序列这样处理是否会出错
#mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
#std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
# print(s2_img.shape)
# import pdb; pdb.set_trace()
# print(s2_img)
try:
ch, ts, h, w = s2_img.shape
except:
print(f's2: {s2_img.shape}, s1: {s1_img.shape}')
s2_img = s2_img.view(ch, ts*h, w)
s2_img = self.normalize(s2_img.type(torch.float32), s2_mean, s2_std, self.inplace)
s2_img = s2_img.view(ch, ts, h, w)
ch, ts, h, w = s1_img.shape
s1_img = s1_img.view(ch, ts*h, w)
s1_img = self.normalize(s1_img.type(torch.float32), s1_mean, s1_std, self.inplace)
s1_img = s1_img.view(ch, ts, h, w)
# import pdb; pdb.set_trace()
# print(s2_img.shape, s2_img[:,0,:,:])
# print(s1_img.shape, s1_img[:,0,:,:])
# print(hr_mean, hr_std, hr_div)
return self.normalize(hr_img.type(torch.float32).div_(hr_div), hr_mean, hr_std, self.inplace), \
s2_img, \
s1_img, \
self.normalize(tgt.type(torch.float32).div_(ann_div) , anno_mean, anno_std, self.inplace)
def normalize(self, tensor, mean, std, inplace):
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any():
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
mean = mean.view(-1, 1, 1)
std = std.view(-1, 1, 1)
# print(f'tensor shape: {tensor.shape}')
# print(f'mean shape: {mean.shape}')
# print(f'std shape: {std.shape}')
return tensor.sub_(mean).div_(std)
class RandomResizedCrop(transforms.RandomResizedCrop):
"""Crop a random portion of image and resize it to a given size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
A crop of the original image is made: the crop has a random area (H * W)
and a random aspect ratio. This crop is finally resized to the given
size. This is popularly used to train the Inception networks.
Args:
size (int or sequence): expected output size of the crop, for each edge. If size is an
int instead of sequence like (h, w), a square output size ``(size, size)`` is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
.. note::
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
before resizing. The scale is defined with respect to the area of the original image.
ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before
resizing.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
"""
def __init__(
self,
size,
scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0),
interpolation=InterpolationMode.BILINEAR,
mode='small'
):
super().__init__(size, scale=scale, ratio=ratio, interpolation=interpolation)
self.cnt=0
self.mode = mode
def forward(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None, mode='small'):
"""
Args:
img (PIL Image or Tensor): Image to be cropped and resized.
Returns:
PIL Image or Tensor: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(s2_img, self.scale, self.ratio)
size_hr = hr_img.shape[-1]
size_s2 = s2_img.shape[-1]
size_anno = tgt.shape[-1]
# 映射到其他模态
ratio_s2_hr = size_s2 / size_hr
i_hr = int(i / ratio_s2_hr)
j_hr = int(j / ratio_s2_hr)
h_hr = int(h / ratio_s2_hr)
w_hr = int(w / ratio_s2_hr)
ratio_s2_anno = size_s2 / size_anno
i_anno = int(i / ratio_s2_anno)
j_anno = int(j / ratio_s2_anno)
h_anno = int(h / ratio_s2_anno)
w_anno = int(w / ratio_s2_anno)
if interpolation1 == 'nearest':
interpolation1 = InterpolationMode.NEAREST
else:
interpolation1 = InterpolationMode.BICUBIC
if interpolation2 == 'nearest':
interpolation2 = InterpolationMode.NEAREST
else:
interpolation2 = InterpolationMode.BICUBIC
# import pdb;pdb.set_trace()
if self.scale[0]>0.99 and self.scale[0]<1.0:
if self.mode=='small':
resized_s2_img = F.resize(s2_img, (16,16), interpolation=InterpolationMode.BICUBIC)
resized_hr_img = F.resize(hr_img, (512, 512), interpolation=InterpolationMode.BICUBIC)
resized_s1_img = F.resize(s1_img, (16,16), interpolation=InterpolationMode.BICUBIC)
resized_tgt = F.resize(tgt, (512,512), interpolation=InterpolationMode.NEAREST)
else:
resized_s2_img = F.resize(s2_img, (64,64), interpolation=InterpolationMode.BICUBIC)
resized_hr_img = F.resize(hr_img, (2048, 2048), interpolation=InterpolationMode.BICUBIC)
resized_s1_img = F.resize(s1_img, (64,64), interpolation=InterpolationMode.BICUBIC)
resized_tgt = F.resize(tgt, (2048,2048), interpolation=InterpolationMode.NEAREST)
return resized_hr_img, resized_s2_img, resized_s1_img, resized_tgt
if self.mode=='small':
resized_s2_img = F.resized_crop(s2_img, i, j, h, w, (16, 16), InterpolationMode.BICUBIC)
resized_hr_img = F.resized_crop(hr_img, i_hr, j_hr, h_hr, w_hr, (512, 512), InterpolationMode.BICUBIC)
resized_s1_img = F.resized_crop(s1_img, i, j, h, w, (16, 16), InterpolationMode.BICUBIC)
resized_tgt = F.resized_crop(tgt, i_anno, j_anno, h_anno, w_anno, (512, 512), InterpolationMode.NEAREST)
else:
resized_s2_img = F.resized_crop(s2_img, i, j, h, w, (512, 512), InterpolationMode.BICUBIC)
resized_hr_img = F.resized_crop(hr_img, i_hr, j_hr, h_hr, w_hr, (2048,2048), InterpolationMode.BICUBIC)
resized_s1_img = F.resized_crop(s1_img, i, j, h, w, (512, 512), InterpolationMode.BICUBIC)
resized_tgt = F.resized_crop(tgt, i_anno, j_anno, h_anno, w_anno, (2048, 2048), InterpolationMode.NEAREST)
# import pdb; pdb.set_trace()
# 将resize后的结果保存为concat的img
# self.cnt = self.cnt+1
# from torchvision.utils import save_image
# save_hr = resized_hr_img[:3, :, :] / resized_hr_img[:3, :, :].max()
# save_s2 = resized_s2_img[:3,0,:,:] / resized_s2_img[:3,0,:,:].max()
# print(f'{save_hr.shape}, {save_s2.shape}')
# save_image(save_s2, f'FoundationModel/debug/output2/resized_s2_{self.cnt}.png')
# save_image(save_hr, f'FoundationModel/debug/output2/resized_hr_{self.cnt}.png')
return resized_hr_img, resized_s2_img, resized_s1_img, resized_tgt
class RandomResizedCropComb(transforms.RandomResizedCrop):
def __init__(
self,
size,
scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0),
interpolation=InterpolationMode.BILINEAR,
):
super().__init__(size, scale=scale, ratio=ratio, interpolation=interpolation)
self.cnt=0
def forward(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None):
"""
Args:
img (PIL Image or Tensor): Image to be cropped and resized.
Returns:
PIL Image or Tensor: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(s2_img, self.scale, self.ratio)
# print(f'i, j, h, w: {i, j, h, w}')
# print(f's2_img shape: {s2_img.shape}')
size_hr = hr_img.shape[-1]
size_s2 = s2_img.shape[-1]
size_anno = tgt.shape[-1]
# 映射到其他模态
ratio_s2_hr = size_s2 / size_hr
i_hr = int(i / ratio_s2_hr)
j_hr = int(j / ratio_s2_hr)
h_hr = int(h / ratio_s2_hr)
w_hr = int(w / ratio_s2_hr)
ratio_s2_anno = size_s2 / size_anno
i_anno = int(i / ratio_s2_anno)
j_anno = int(j / ratio_s2_anno)
h_anno = int(h / ratio_s2_anno)
w_anno = int(w / ratio_s2_anno)
if interpolation1 == 'nearest':
interpolation1 = InterpolationMode.NEAREST
else:
interpolation1 = InterpolationMode.BICUBIC
if interpolation2 == 'nearest':
interpolation2 = InterpolationMode.NEAREST
else:
interpolation2 = InterpolationMode.BICUBIC
resized_s2_img = F.resized_crop(s2_img, i, j, h, w, (32, 16), InterpolationMode.BICUBIC)
resized_hr_img = F.resized_crop(hr_img, i_hr, j_hr, h_hr, w_hr, (1024, 512), InterpolationMode.BICUBIC)
resized_s1_img = F.resized_crop(s1_img, i, j, h, w, (32, 16), InterpolationMode.BICUBIC)
resized_tgt = F.resized_crop(tgt, i_anno, j_anno, h_anno, w_anno, (1024, 512), InterpolationMode.NEAREST)
return resized_hr_img, resized_s2_img, resized_s1_img, resized_tgt
class RandomHorizontalFlip(transforms.RandomHorizontalFlip):
"""Horizontally flip the given image randomly with a given probability.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__(p=p)
def forward(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None):
"""
Args:
img (PIL Image or Tensor): Image to be flipped.
Returns:
PIL Image or Tensor: Randomly flipped image.
"""
if torch.rand(1) < self.p:
return F.hflip(hr_img), F.hflip(s2_img), F.hflip(s1_img), F.hflip(tgt)
return hr_img, s2_img, s1_img, tgt
class RandomApply(transforms.RandomApply):
"""Apply randomly a list of transformations with a given probability.
.. note::
In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
transforms as shown below:
>>> transforms = transforms.RandomApply(torch.nn.ModuleList([
>>> transforms.ColorJitter(),
>>> ]), p=0.3)
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
Args:
transforms (sequence or torch.nn.Module): list of transformations
p (float): probability
"""
def __init__(self, transforms, p=0.5):
super().__init__(transforms, p=p)
def forward(self, img, tgt, interpolation1=None, interpolation2=None):
if self.p < torch.rand(1):
return img, tgt
for t in self.transforms:
img, tgt = t(img, tgt)
return img, tgt
class ColorJitter(transforms.ColorJitter):
"""Randomly change the brightness, contrast, saturation and hue of an image.
If the image is torch Tensor, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
Args:
brightness (float or tuple of float (min, max)): How much to jitter brightness.
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
or the given [min, max]. Should be non negative numbers.
contrast (float or tuple of float (min, max)): How much to jitter contrast.
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
or the given [min, max]. Should be non negative numbers.
saturation (float or tuple of float (min, max)): How much to jitter saturation.
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
or the given [min, max]. Should be non negative numbers.
hue (float or tuple of float (min, max)): How much to jitter hue.
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
thus it does not work if you normalize your image to an interval with negative values,
or use an interpolation that generates negative values before using this function.
"""
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
def forward(self, img, tgt, interpolation1=None, interpolation2=None):
"""
Args:
img (PIL Image or Tensor): Input image.
Returns:
PIL Image or Tensor: Color jittered image.
"""
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
self.brightness, self.contrast, self.saturation, self.hue
)
for fn_id in fn_idx:
if fn_id == 0 and brightness_factor is not None:
img = F.adjust_brightness(img, brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
img = F.adjust_contrast(img, contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
img = F.adjust_saturation(img, saturation_factor)
elif fn_id == 3 and hue_factor is not None:
img = F.adjust_hue(img, hue_factor)
return img, tgt
class RandomErasing(transforms.RandomErasing):
"""Randomly selects a rectangle region in a torch.Tensor image and erases its pixels.
This transform does not support PIL Image.
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
Args:
p: probability that the random erasing operation will be performed.
scale: range of proportion of erased area against input image.
ratio: range of aspect ratio of erased area.
value: erasing value. Default is 0. If a single int, it is used to
erase all pixels. If a tuple of length 3, it is used to erase
R, G, B channels respectively.
If a str of 'random', erasing each pixel with random values.
inplace: boolean to make this transform inplace. Default set to False.
Returns:
Erased Image.
Example:
>>> transform = transforms.Compose([
>>> transforms.RandomHorizontalFlip(),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> transforms.RandomErasing(),
>>> ])
"""
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace)
def forward(self, img, tgt, interpolation1=None, interpolation2=None):
"""
Args:
img (Tensor): Tensor image to be erased.
Returns:
img (Tensor): Erased Tensor image.
"""
if torch.rand(1) < self.p:
# cast self.value to script acceptable type
if isinstance(self.value, (int, float)):
value = [self.value]
elif isinstance(self.value, str):
value = None
elif isinstance(self.value, tuple):
value = list(self.value)
else:
value = self.value
if value is not None and not (len(value) in (1, img.shape[-3])):
raise ValueError(
"If value is a sequence, it should have either a single value or "
f"{img.shape[-3]} (number of input channels)"
)
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
return F.erase(img, x, y, h, w, v, self.inplace), tgt
return img, tgt
class GaussianBlur(object):
"""Gaussian blur augmentation from SimCLR: https://arxiv.org/abs/2002.05709"""
def __init__(self, sigma=[.1, 2.]):
self.sigma = sigma
def __call__(self, img, tgt, interpolation1=None, interpolation2=None):
sigma = random.uniform(self.sigma[0], self.sigma[1])
img = img.filter(ImageFilter.GaussianBlur(radius=sigma))
return img, tgt
def __repr__(self) -> str:
s = f"{self.__class__.__name__}( sigma={self.sigma})"
return s

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,208 @@
import os
import argparse
import numpy as np
import pandas as pd
from skimage.transform import resize
from skimage import io
from multiprocessing import Pool
from functools import partial
import logging
import datetime
def get_confusion_matrix(pred, gt, num_class):
assert pred.shape == gt.shape, f"pred.shape: {pred.shape} != gt.shape: {gt.shape}"
mask = (gt >= 0) & (gt < num_class) # 去掉为0的背景类别
label = num_class * gt[mask] + pred[mask]
count = np.bincount(label, minlength=num_class**2)
confusion_matrix = count.reshape(num_class, num_class)
return confusion_matrix
def get_miou(confusion_matrix):
diagonal_elements = np.diag(confusion_matrix)
column_sums = np.sum(confusion_matrix, axis=0)
row_sums = np.sum(confusion_matrix, axis=1)
ious = diagonal_elements/(column_sums + row_sums - diagonal_elements)
m_iou = np.nanmean(ious)
return m_iou
def get_mprecison(confusion_matrix):
diagonal_elements = np.diag(confusion_matrix)
column_sums = np.sum(confusion_matrix, axis=0)
precisions = diagonal_elements / (column_sums + 1e-06)
m_precision = np.nanmean(precisions)
return m_precision
def get_mrecall(confusion_matrix):
diagonal_elements = np.diag(confusion_matrix)
row_sums = np.sum(confusion_matrix, axis=1)
recalls= diagonal_elements / (row_sums + 1e-06)
m_recall = np.nanmean(recalls)
return m_recall
def get_macc(confusion_matrix):
'''
acc = tp/tp+fn 就是recall
'''
m_recall = get_mrecall(confusion_matrix)
return m_recall
def get_per_class_iou(confusion_matrix):
intersection = np.diag(confusion_matrix)
union = np.sum(confusion_matrix, axis=0) + np.sum(confusion_matrix, axis=1) - intersection
iou = intersection / (union.astype(np.float32) + 1e-6)
return iou
def get_per_class_acc(confusion_matrix):
total_acc = np.diag(confusion_matrix) / (np.sum(confusion_matrix, axis=1).astype(np.float32) + 1e-6)
return total_acc
def post_process_segm_output(segm, colors, dist_type='abs'):
"""
Post-processing to turn output segm image to class index map using NumPy
Args:
segm: (H, W, 3)
Returns:
class_map: (H, W)
"""
palette = np.array(colors)
segm = segm.astype(np.float32) # (h, w, 3)
h, w, k = segm.shape[0], segm.shape[1], palette.shape[0]
if dist_type == 'abs':
dist = np.abs(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3)) # (h, w, k)
elif dist_type == 'square':
dist = np.power(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3), 2) # (h, w, k)
elif dist_type == 'mean':
dist_abs = np.abs(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3)) # (h, w, k)
dist_square = np.power(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3), 2) # (h, w, k)
dist = (dist_abs + dist_square) / 2.
else:
raise NotImplementedError
dist = np.sum(dist, axis=-1)
pred = np.argmin(dist, axis=-1).astype(np.int)
return pred
def get_args_parser():
parser = argparse.ArgumentParser('semantic segmentation evaluation', add_help=False)
parser.add_argument('--pred_dir', type=str, help='dir to pred', required=True)
parser.add_argument('--gt_dir', type=str, help='dir to gt', required=True)
parser.add_argument('--gt_list_path', type=str, help='dir to gt_list_path', required=True)
parser.add_argument('--gt_suffix', type=str, help='suffix to gt', required=True)
parser.add_argument('--dataset_name', type=str, help='dataset name', required=True)
parser.add_argument('--model_name', type=str, help='model name', required=True)
parser.add_argument('--dist_type', type=str, help='dist type',
default='abs', choices=['abs', 'square', 'mean'])
return parser.parse_args()
def process_file(file_dict, pred_dir, gt_dir, args, num_class):
filename = file_dict['file_name']
file_cls = file_dict['file_cls']
gt = io.imread(os.path.join(gt_dir, filename))
gt_index = gt.copy()
gt_index[gt_index != file_cls] = 0
gt_index[gt_index == file_cls] = 1
try:
pred = io.imread(os.path.join(pred_dir, filename.replace('.png', f'-{file_cls}.png')))
pred = resize(pred, gt.shape[-2:], anti_aliasing=False, mode='reflect', order=0)
if len(pred.shape) == 3:
pred_index = pred[:,:,0].copy()
else:
pred_index = pred.copy()
pred_index[pred_index<=127] = 0
pred_index[pred_index>127] = 1
except:
logging.info(filename.replace('.png', f'_{file_cls}.png'), 'not found!')
pred_index = gt_index.copy()
pred_index = pred_index.flatten()
gt_index = gt_index.flatten()
confusion_matrix = get_confusion_matrix(pred_index, gt_index, num_class)
return file_cls, confusion_matrix
if __name__ == '__main__':
args = get_args_parser()
dataset_name = args.dataset_name
pred_dir = args.pred_dir
gt_dir = args.gt_dir
gt_list_path = args.gt_list_path
dist_type = args.dist_type
model_name = args.model_name
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
os.makedirs('logs/eval', exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(f'logs/eval/eval_{model_name}_{dataset_name}_{current_time}.log'),
logging.StreamHandler()
]
)
output_folder = os.path.join(pred_dir, f'eval_{dataset_name}')
os.makedirs(output_folder, exist_ok=True)
num_class = 2
with open(gt_list_path, 'r') as f:
file_cls_list = f.readlines()
file_list = []
for i in file_cls_list:
i = i.strip()
file_name = i[:-3]
file_cls = i[-2:]
file_list.append({'file_name': file_name, 'file_cls': int(file_cls)})
all_pred_labels = []
all_gt_labels = []
process_file_partial = partial(process_file, pred_dir=pred_dir,gt_dir=gt_dir, args=args, num_class=num_class)
pool = Pool()
outputs = pool.map(process_file_partial, file_list)
pool.close()
pool.join()
logging.info(f'len outputs: {len(outputs)}')
confusion_matrix_dict = {}
for cls, confusion_matrix in outputs:
if cls in confusion_matrix_dict.keys():
confusion_matrix_dict[cls] += confusion_matrix
else:
confusion_matrix_dict[cls] = confusion_matrix
class_list = []
iou_list = []
acc_list = []
for cls, confusion_matrix in confusion_matrix_dict.items():
ious = get_per_class_iou(confusion_matrix)
accs = get_per_class_acc(confusion_matrix)
logging.info(f'cls: {cls}, ious: {ious}, accs: {accs}')
class_list.append(cls)
iou_list.append(ious[1])
acc_list.append(accs[1])
miou = np.mean(iou_list)
macc = np.mean(acc_list)
df_metrics = pd.DataFrame({
'Class': class_list + ['Mean'],
'IoU': iou_list + [miou],
'Accuracy': acc_list + [macc],
})
pd.set_option('display.float_format', '{:.4f}%'.format)
logging.info(df_metrics)
pd.reset_option('display.float_format')
df_metrics.to_csv(os.path.join(output_folder, 'eval.csv'), index=False, float_format='%.4f')

7
lib/models/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
from .segmentors import SkySensePP
from .losses import (ModalityVAELoss, RecLoss)
from .metrics import (SemMetric)
__all__ = [
'SkySensePP', 'ModalityVAELoss', 'RecLoss', 'SemMetric'
]

View File

@@ -0,0 +1,14 @@
from .swin_v2 import SwinTransformerV2MSL
from .vit import VisionTransformerMSL
__all__ = [
'SwinTransformerV2MSL', 'VisionTransformerMSL'
]
type_mapping = {
'SwinTransformerV2MSL': SwinTransformerV2MSL,
'VisionTransformerMSL': VisionTransformerMSL
}
def build_backbone(type, **kwargs):
return type_mapping[type](**kwargs)

View File

@@ -0,0 +1,702 @@
from copy import deepcopy
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmcls.models.utils import (PatchMerging, ShiftWindowMSA, WindowMSAV2,
resize_pos_embed, to_2tuple)
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcv.runner import (CheckpointLoader,
load_state_dict)
from mmcv.cnn.bricks.transformer import MultiheadAttention
class SwinBlockV2(BaseModule):
"""Swin Transformer V2 block. Use post normalization.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
shift (bool): Shift the attention window or not. Defaults to False.
extra_norm (bool): Whether add extra norm at the end of main branch.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
drop_path (float): The drop path rate after attention and ffn.
Defaults to 0.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
attn_cfgs (dict): The extra config of Shift Window-MSA.
Defaults to empty dict.
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
norm_cfg (dict): The config of norm layers.
Defaults to ``dict(type='LN')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pretrained_window_size (int): Window size in pretrained.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size=8,
shift=False,
extra_norm=False,
ffn_ratio=4.,
drop_path=0.,
pad_small_map=False,
attn_cfgs=dict(),
ffn_cfgs=dict(),
norm_cfg=dict(type='LN'),
with_cp=False,
pretrained_window_size=0,
init_cfg=None):
super(SwinBlockV2, self).__init__(init_cfg)
self.with_cp = with_cp
self.extra_norm = extra_norm
_attn_cfgs = {
'embed_dims': embed_dims,
'num_heads': num_heads,
'shift_size': window_size // 2 if shift else 0,
'window_size': window_size,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'pad_small_map': pad_small_map,
**attn_cfgs
}
# use V2 attention implementation
_attn_cfgs.update(
window_msa=WindowMSAV2,
msa_cfg=dict(
pretrained_window_size=to_2tuple(pretrained_window_size)))
self.attn = ShiftWindowMSA(**_attn_cfgs)
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
_ffn_cfgs = {
'embed_dims': embed_dims,
'feedforward_channels': int(embed_dims * ffn_ratio),
'num_fcs': 2,
'ffn_drop': 0,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'act_cfg': dict(type='GELU'),
'add_identity': False,
**ffn_cfgs
}
self.ffn = FFN(**_ffn_cfgs)
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
# add extra norm for every n blocks in huge and giant model
if self.extra_norm:
self.norm3 = build_norm_layer(norm_cfg, embed_dims)[1]
def forward(self, x, hw_shape):
def _inner_forward(x):
# Use post normalization
identity = x
x = self.attn(x, hw_shape)
x = self.norm1(x)
x = x + identity
identity = x
x = self.ffn(x)
x = self.norm2(x)
x = x + identity
if self.extra_norm:
x = self.norm3(x)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class SwinBlockV2Sequence(BaseModule):
"""Module with successive Swin Transformer blocks and downsample layer.
Args:
embed_dims (int): Number of input channels.
depth (int): Number of successive swin transformer blocks.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
downsample (bool): Downsample the output of blocks by patch merging.
Defaults to False.
downsample_cfg (dict): The extra config of the patch merging layer.
Defaults to empty dict.
drop_paths (Sequence[float] | float): The drop path rate in each block.
Defaults to 0.
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
extra_norm_every_n_blocks (int): Add extra norm at the end of main
branch every n blocks. Defaults to 0, which means no needs for
extra norm layer.
pretrained_window_size (int): Window size in pretrained.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
depth,
num_heads,
window_size=8,
downsample=False,
downsample_cfg=dict(),
drop_paths=0.,
block_cfgs=dict(),
with_cp=False,
pad_small_map=False,
extra_norm_every_n_blocks=0,
pretrained_window_size=0,
init_cfg=None):
super().__init__(init_cfg)
if not isinstance(drop_paths, Sequence):
drop_paths = [drop_paths] * depth
if not isinstance(block_cfgs, Sequence):
block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)]
if downsample:
self.out_channels = 2 * embed_dims
_downsample_cfg = {
'in_channels': embed_dims,
'out_channels': self.out_channels,
'norm_cfg': dict(type='LN'),
**downsample_cfg
}
self.downsample = PatchMerging(**_downsample_cfg)
else:
self.out_channels = embed_dims
self.downsample = None
self.blocks = ModuleList()
for i in range(depth):
extra_norm = True if extra_norm_every_n_blocks and \
(i + 1) % extra_norm_every_n_blocks == 0 else False
_block_cfg = {
'embed_dims': self.out_channels,
'num_heads': num_heads,
'window_size': window_size,
'shift': False if i % 2 == 0 else True,
'extra_norm': extra_norm,
'drop_path': drop_paths[i],
'with_cp': with_cp,
'pad_small_map': pad_small_map,
'pretrained_window_size': pretrained_window_size,
**block_cfgs[i]
}
block = SwinBlockV2(**_block_cfg)
self.blocks.append(block)
def forward(self, x, in_shape):
if self.downsample:
x, out_shape = self.downsample(x, in_shape)
else:
out_shape = in_shape
for block in self.blocks:
x = block(x, out_shape)
return x, out_shape
class SwinTransformerV2(BaseBackbone):
"""Swin Transformer V2.
A PyTorch implement of : `Swin Transformer V2:
Scaling Up Capacity and Resolution
<https://arxiv.org/abs/2111.09883>`_
Inspiration from
https://github.com/microsoft/Swin-Transformer
Args:
arch (str | dict): Swin Transformer architecture. If use string, choose
from 'tiny', 'small', 'base' and 'large'. If use dict, it should
have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **num_heads** (List[int]): The number of heads in attention
modules of each stage.
- **extra_norm_every_n_blocks** (int): Add extra norm at the end
of main branch every n blocks.
Defaults to 'tiny'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 4.
in_channels (int): The num of input channels. Defaults to 3.
window_size (int | Sequence): The height and width of the window.
Defaults to 7.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
interpolate_mode (str): Select the interpolate mode for absolute
position embeding vector resize. Defaults to "bicubic".
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
stage. Defaults to an empty dict.
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to an empty dict.
pretrained_window_sizes (tuple(int)): Pretrained window sizes of
each layer.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmcls.models import SwinTransformerV2
>>> import torch
>>> extra_config = dict(
>>> arch='tiny',
>>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3,
>>> 'padding': 'same'}))
>>> self = SwinTransformerV2(**extra_config)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> output = self.forward(inputs)
>>> print(output.shape)
(1, 2592, 4)
"""
arch_zoo = {
**dict.fromkeys(['t', 'tiny'],
{'embed_dims': 96,
'depths': [2, 2, 6, 2],
'num_heads': [3, 6, 12, 24],
'extra_norm_every_n_blocks': 0}),
**dict.fromkeys(['s', 'small'],
{'embed_dims': 96,
'depths': [2, 2, 18, 2],
'num_heads': [3, 6, 12, 24],
'extra_norm_every_n_blocks': 0}),
**dict.fromkeys(['b', 'base'],
{'embed_dims': 128,
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
'extra_norm_every_n_blocks': 0}),
**dict.fromkeys(['l', 'large'],
{'embed_dims': 192,
'depths': [2, 2, 18, 2],
'num_heads': [6, 12, 24, 48],
'extra_norm_every_n_blocks': 0}),
# head count not certain for huge, and is employed for another
# parallel study about self-supervised learning.
**dict.fromkeys(['h', 'huge'],
{'embed_dims': 352,
'depths': [2, 2, 18, 2],
'num_heads': [8, 16, 32, 64],
'extra_norm_every_n_blocks': 6}),
**dict.fromkeys(['g', 'giant'],
{'embed_dims': 512,
'depths': [2, 2, 42, 4],
'num_heads': [16, 32, 64, 128],
'extra_norm_every_n_blocks': 6}),
} # yapf: disable
_version = 1
num_extra_tokens = 0
def __init__(self,
arch='tiny',
img_size=256,
patch_size=4,
in_channels=3,
vocabulary_size=128,
window_size=8,
drop_rate=0.,
drop_path_rate=0.1,
out_indices=(3, ),
use_abs_pos_embed=False,
interpolate_mode='bicubic',
with_cp=False,
frozen_stages=-1,
norm_eval=False,
pad_small_map=False,
norm_cfg=dict(type='LN'),
stage_cfgs=dict(downsample_cfg=dict(is_post_norm=True)),
patch_cfg=dict(),
pretrained_window_sizes=[0, 0, 0, 0],
init_cfg=None):
super(SwinTransformerV2, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims', 'depths', 'num_heads',
'extra_norm_every_n_blocks'
}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.vocabulary_size = vocabulary_size + 1 # 增加ignore类别
self.embed_dims = self.arch_settings['embed_dims']
self.depths = self.arch_settings['depths']
self.num_heads = self.arch_settings['num_heads']
self.extra_norm_every_n_blocks = self.arch_settings[
'extra_norm_every_n_blocks']
self.num_layers = len(self.depths)
self.out_indices = out_indices
self.use_abs_pos_embed = use_abs_pos_embed
self.interpolate_mode = interpolate_mode
self.frozen_stages = frozen_stages
if isinstance(window_size, int):
self.window_sizes = [window_size for _ in range(self.num_layers)]
elif isinstance(window_size, Sequence):
assert len(window_size) == self.num_layers, \
f'Length of window_sizes {len(window_size)} is not equal to '\
f'length of stages {self.num_layers}.'
self.window_sizes = window_size
else:
raise TypeError('window_size should be a Sequence or int.')
_patch_cfg = dict(
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
norm_cfg=dict(type='LN'),
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
self.patch_size = patch_size
if self.use_abs_pos_embed:
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, num_patches, self.embed_dims))
self._register_load_state_dict_pre_hook(
self._prepare_abs_pos_embed)
self._register_load_state_dict_pre_hook(self._delete_reinit_params)
self.drop_after_pos = nn.Dropout(p=drop_rate)
self.norm_eval = norm_eval
# stochastic depth
total_depth = sum(self.depths)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
self.stages = ModuleList()
embed_dims = [self.embed_dims]
for i, (depth, num_heads) in enumerate(zip(self.depths,
self.num_heads)):
if isinstance(stage_cfgs, Sequence):
stage_cfg = stage_cfgs[i]
else:
stage_cfg = deepcopy(stage_cfgs)
downsample = True if i > 0 else False
_stage_cfg = {
'embed_dims': embed_dims[-1],
'depth': depth,
'num_heads': num_heads,
'window_size': self.window_sizes[i],
'downsample': downsample,
'drop_paths': dpr[:depth],
'with_cp': with_cp,
'pad_small_map': pad_small_map,
'extra_norm_every_n_blocks': self.extra_norm_every_n_blocks,
'pretrained_window_size': pretrained_window_sizes[i],
**stage_cfg
}
stage = SwinBlockV2Sequence(**_stage_cfg)
self.stages.append(stage)
dpr = dpr[depth:]
embed_dims.append(stage.out_channels)
for i in out_indices:
if norm_cfg is not None:
norm_layer = build_norm_layer(norm_cfg, embed_dims[i + 1])[1]
else:
norm_layer = nn.Identity()
self.add_module(f'norm{i}', norm_layer)
def init_weights(self):
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
from mmcls.utils import get_root_logger
logger = get_root_logger()
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
# print(self.state_dict().keys())
# print('---')
# print(state_dict.keys())
# import pdb; pdb.set_trace()
load_state_dict(self, state_dict, strict=False, logger=logger)
return
else:
super(SwinTransformerV2, self).init_weights()
if self.use_abs_pos_embed:
trunc_normal_(self.absolute_pos_embed, std=0.02)
def forward(self, x):
x, hw_shape = self.patch_embed(x)
if self.use_abs_pos_embed:
x = x + resize_pos_embed(
self.absolute_pos_embed, self.patch_resolution, hw_shape,
self.interpolate_mode, self.num_extra_tokens)
x = self.drop_after_pos(x)
outs = []
for i, stage in enumerate(self.stages):
x, hw_shape = stage(x, hw_shape)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(x)
out = out.view(-1, *hw_shape,
stage.out_channels).permute(0, 3, 1,
2).contiguous()
outs.append(out)
return outs
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
for i in range(0, self.frozen_stages + 1):
m = self.stages[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
for i in self.out_indices:
if i <= self.frozen_stages:
for param in getattr(self, f'norm{i}').parameters():
param.requires_grad = False
def train(self, mode=True):
super(SwinTransformerV2, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'absolute_pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.absolute_pos_embed.shape != ckpt_pos_embed_shape:
from mmcls.utils import get_root_logger
logger = get_root_logger()
logger.info(
'Resize the absolute_pos_embed shape from '
f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.init_out_size
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
def _delete_reinit_params(self, state_dict, prefix, *args, **kwargs):
# delete relative_position_index since we always re-init it
relative_position_index_keys = [
k for k in state_dict.keys() if 'relative_position_index' in k
]
for k in relative_position_index_keys:
del state_dict[k]
# delete relative_coords_table since we always re-init it
relative_position_index_keys = [
k for k in state_dict.keys() if 'relative_coords_table' in k
]
for k in relative_position_index_keys:
del state_dict[k]
class Proj_MHSA(nn.Module):
def __init__(
self,
embed_dims,
proj_dims,
num_heads=16,
batch_first=True,
bias = True
):
super().__init__()
self.proj_in = nn.Linear(in_features=embed_dims, out_features=proj_dims)
self.attn = MultiheadAttention(
embed_dims=proj_dims,
num_heads=num_heads,
batch_first=batch_first,
bias=bias
)
self.proj_out = nn.Linear(in_features=proj_dims, out_features=embed_dims)
def forward(self, x):
x = self.proj_in(x)
x = self.attn(x, x, x)
x = self.proj_out(x)
return x
class SwinTransformerV2MSL(SwinTransformerV2):
def __init__(self, **kwargs):
if 'use_attn' in kwargs:
self.use_attn = kwargs.pop('use_attn')
else:
self.use_attn = False
if 'merge_stage' in kwargs:
self.merge_stage = kwargs.pop('merge_stage')
else:
self.merge_stage = 0
if 'with_cls_pos' in kwargs:
self.with_cls_pos = kwargs.pop('with_cls_pos')
else:
self.with_cls_pos = False
super().__init__(**kwargs)
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
#self.vocabulary_token = nn.Parameter(torch.zeros(1, 1, 1, self.vocabulary_size, self.embed_dims))
self.vocabulary_token = nn.Parameter(torch.zeros(self.vocabulary_size, self.embed_dims))
self.vocabulary_weight = nn.Parameter(torch.zeros(1, self.patch_size * self.patch_size))
trunc_normal_(self.mask_token, mean=0., std=.02)
trunc_normal_(self.vocabulary_token, mean=0., std=.02)
if self.use_attn:
self.attn1 = Proj_MHSA(embed_dims=352, proj_dims=256, num_heads=16, batch_first=True, bias = True)
self.attn2 = Proj_MHSA(embed_dims=704, proj_dims=512, num_heads=16, batch_first=True, bias = True)
self.attn3 = Proj_MHSA( embed_dims=1408, proj_dims=1024, num_heads=16, batch_first=True, bias = True)
self.attention_blocks = [self.attn1, self.attn2, self.attn3]
self.norm_attn = build_norm_layer(dict(type='LN'), 1408)[1]
def create_ann_token(self, anno_img):
B, H, W = anno_img.shape
ann_token = torch.index_select(self.vocabulary_token, 0, anno_img.reshape(-1)).reshape(B, H, W, -1)
assert H % self.patch_size == 0 and W % self.patch_size == 0
nph, npw = H // self.patch_size, W // self.patch_size
weight = F.softmax(self.vocabulary_weight, dim=1) * self.patch_size * self.patch_size
weight = weight.reshape(1, 1, self.patch_size, 1, self.patch_size).repeat(1, nph, 1, npw, 1).reshape(1, H, W, 1)
ann_token = ann_token * weight
ann_token = F.avg_pool2d(torch.einsum('BHWC->BCHW', ann_token), self.patch_size, self.patch_size)
ann_token = torch.einsum('BCHW->BHWC', ann_token).reshape(B, nph * npw, self.embed_dims) # shape B, L, C
return ann_token
def forward(self, hr_img, anno_img, mask=None):
x, hw_shape = self.patch_embed(hr_img)
y = self.create_ann_token(anno_img)
assert x.shape == y.shape
B, L, C = y.shape
if mask is not None:
mask_tokens = self.mask_token.expand(B, L, -1)
w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
y = y * (1. - w) + mask_tokens * w
if self.merge_stage == 0:
x = (x + y) * 0.5
else:
x = x.reshape(B, *hw_shape, C)
y = y.reshape(B, *hw_shape, C)
x = torch.cat((x, y), dim=2)
hw_shape = (hw_shape[0], hw_shape[1] * 2)
x = x.reshape(B, -1, C)
if self.use_abs_pos_embed:
x = x + resize_pos_embed(
self.absolute_pos_embed, self.patch_resolution, hw_shape,
self.interpolate_mode, self.num_extra_tokens)
if self.with_cls_pos:
hw_shape_half = [hw_shape[0], hw_shape[1] // 2]
x = x.reshape(B, *hw_shape, C)
x1 = x[:, :, :x.shape[2]//2, :].reshape(B, -1, C)
x2 = x[:, :, x.shape[2]//2:, :].reshape(B, -1, C)
x1 = x1 + resize_pos_embed(
self.absolute_pos_embed, self.patch_resolution, hw_shape_half,
self.interpolate_mode, self.num_extra_tokens)
x2 = x2 + resize_pos_embed(
self.absolute_pos_embed, self.patch_resolution, hw_shape_half,
self.interpolate_mode, self.num_extra_tokens)
x1 = x1.reshape(B, *hw_shape_half, C)
x2 = x2.reshape(B, *hw_shape_half, C)
x = torch.cat((x1, x2), dim=2).reshape(B, -1, C)
x = self.drop_after_pos(x)
outs = []
merge_idx = self.merge_stage - 1
for i, stage in enumerate(self.stages):
x, hw_shape = stage(x, hw_shape)
if i == merge_idx:
x = x.reshape(x.shape[0], *hw_shape, x.shape[-1]) # b,l,c -> b, h, w, c
x = (x[:, :, :x.shape[2]//2] + x[:, :, x.shape[2]//2:]) * 0.5
x = x.reshape(x.shape[0], -1, x.shape[-1])
hw_shape = (hw_shape[0], hw_shape[1] // 2)
if self.use_attn:
if i <= len(self.attention_blocks) - 1:
x = x + self.attention_blocks[i](x)
if i == len(self.attention_blocks) - 1:
x = self.norm_attn(x)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(x)
out = out.view(-1, *hw_shape, stage.out_channels).permute(0, 3, 1, 2).contiguous()
outs.append(out)
return outs

611
lib/models/backbones/vit.py Normal file
View File

@@ -0,0 +1,611 @@
# Copyright (c) Ant Group. All rights reserved.
import math
import warnings
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
trunc_normal_)
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
load_state_dict)
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.ops import resize
from mmseg.utils import get_root_logger
from mmseg.models.utils.embed import PatchEmbed
import torch.nn.functional as F
import numpy as np
class TransformerEncoderLayer(BaseModule):
"""Implements one encoder layer in Vision Transformer.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default: 0.0.
attn_drop_rate (float): The drop out rate for attention layer.
Default: 0.0.
drop_path_rate (float): stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): enable bias for qkv if True. Default: True
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default: True.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
batch_first=True,
attn_cfg=dict(),
ffn_cfg=dict(),
with_cp=False):
super(TransformerEncoderLayer, self).__init__()
self.norm1_name, norm1 = build_norm_layer(norm_cfg,
embed_dims,
postfix=1)
self.add_module(self.norm1_name, norm1)
attn_cfg.update(
dict(embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
batch_first=batch_first,
bias=qkv_bias))
self.build_attn(attn_cfg)
self.norm2_name, norm2 = build_norm_layer(norm_cfg,
embed_dims,
postfix=2)
self.add_module(self.norm2_name, norm2)
ffn_cfg.update(
dict(embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate)
if drop_path_rate > 0 else None,
act_cfg=act_cfg))
self.build_ffn(ffn_cfg)
self.with_cp = with_cp
def build_attn(self, attn_cfg):
self.attn = MultiheadAttention(**attn_cfg)
def build_ffn(self, ffn_cfg):
self.ffn = FFN(**ffn_cfg)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def forward(self, x):
def _inner_forward(x):
x = self.attn(self.norm1(x), identity=x)
x = self.ffn(self.norm2(x), identity=x)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class VisionTransformer(BaseModule):
"""Vision Transformer.
This backbone is the implementation of `An Image is Worth 16x16 Words:
Transformers for Image Recognition at
Scale <https://arxiv.org/abs/2010.11929>`_.
Args:
img_size (int | tuple): Input image size. Default: 224.
patch_size (int): The patch size. Default: 16.
in_channels (int): Number of input channels. Default: 3.
embed_dims (int): embedding dimension. Default: 768.
num_layers (int): depth of transformer. Default: 12.
num_heads (int): number of attention heads. Default: 12.
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
Default: 4.
out_indices (list | tuple | int): Output from which stages.
Default: -1.
qkv_bias (bool): enable bias for qkv if True. Default: True.
drop_rate (float): Probability of an element to be zeroed.
Default 0.0
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
drop_path_rate (float): stochastic depth rate. Default 0.0
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Default: True.
output_cls_token (bool): Whether output the cls_token. If set True,
`with_cls_token` must be True. Default: False.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
Default: False.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Default: False.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Default: bicubic.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
img_size=224,
patch_size=16,
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_indices=-1,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
with_cls_token=True,
output_cls_token=False,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
patch_norm=False,
final_norm=False,
interpolate_mode='bicubic',
num_fcs=2,
norm_eval=False,
with_cp=False,
use_ccd=False,
ccd_num=0,
pretrained=None,
init_cfg=None):
super(VisionTransformer, self).__init__(init_cfg=init_cfg)
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, tuple):
if len(img_size) == 1:
img_size = to_2tuple(img_size[0])
assert len(img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.img_size = img_size
self.patch_size = patch_size
self.interpolate_mode = interpolate_mode
self.norm_eval = norm_eval
self.with_cp = with_cp
self.pretrained = pretrained
self.embed_dims = embed_dims
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
padding='corner',
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None,
)
self.use_ccd = use_ccd
self.ccd_num = ccd_num
if self.use_ccd:
self.ccd_embed = nn.Parameter(
torch.rand(1, self.ccd_num, embed_dims))
num_patches = (img_size[0] // patch_size) * \
(img_size[1] // patch_size)
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
# self.pos_embed = nn.Parameter(
# torch.zeros(1, num_patches, embed_dims))
# 原来是
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dims))
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
if out_indices == -1:
out_indices = num_layers - 1
self.out_indices = [out_indices]
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
self.out_indices = out_indices
else:
raise TypeError('out_indices must be type of int, list or tuple')
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
] # stochastic depth decay rule
self.layers = ModuleList()
for i in range(num_layers):
self.layers.append(
TransformerEncoderLayer(embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=mlp_ratio *
embed_dims,
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
num_fcs=num_fcs,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
batch_first=True))
self.final_norm = final_norm
if final_norm:
self.norm1_name, norm1 = build_norm_layer(norm_cfg,
embed_dims,
postfix=1)
self.add_module(self.norm1_name, norm1)
@property
def norm1(self):
return getattr(self, self.norm1_name)
def init_weights(self):
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
logger = get_root_logger()
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
if 'pos_embed' in state_dict.keys():
if self.pos_embed.shape != state_dict['pos_embed'].shape:
logger.info(msg=f'Resize the pos_embed shape from '
f'{state_dict["pos_embed"].shape} to '
f'{self.pos_embed.shape}')
h, w = self.img_size
pos_size = int(
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
state_dict['pos_embed'] = self.resize_pos_embed(
state_dict['pos_embed'],
(h // self.patch_size, w // self.patch_size),
(pos_size, pos_size), self.interpolate_mode)
load_state_dict(self, state_dict, strict=False, logger=logger)
elif self.init_cfg is not None:
super(VisionTransformer, self).init_weights()
else:
# We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
if self.use_ccd:
trunc_normal_(self.ccd_embed, std=0.02)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
if 'ffn' in n:
nn.init.normal_(m.bias, mean=0., std=1e-6)
else:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
kaiming_init(m, mode='fan_in', bias=0.)
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m, val=1.0, bias=0.)
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
"""Positioning embeding method.
Resize the pos_embed, if the input image size doesn't match
the training size.
Args:
patched_img (torch.Tensor): The patched image, it should be
shape of [B, L1, C].
hw_shape (tuple): The downsampled image resolution.
pos_embed (torch.Tensor): The pos_embed weighs, it should be
shape of [B, L2, c].
Return:
torch.Tensor: The pos encoded image feature.
"""
assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
'the shapes of patched_img and pos_embed must be [B, L, C]'
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
if x_len != pos_len:
if pos_len == (self.img_size[0] // self.patch_size) * (
self.img_size[1] // self.patch_size) + 1:
pos_h = self.img_size[0] // self.patch_size
pos_w = self.img_size[1] // self.patch_size
else:
raise ValueError(
'Unexpected shape of pos_embed, got {}.'.format(
pos_embed.shape))
pos_embed = self.resize_pos_embed(pos_embed, hw_shape,
(pos_h, pos_w),
self.interpolate_mode)
return self.drop_after_pos(patched_img + pos_embed)
@staticmethod
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
"""Resize pos_embed weights.
Resize pos_embed using bicubic interpolate method.
Args:
pos_embed (torch.Tensor): Position embedding weights.
input_shpae (tuple): Tuple for (downsampled input image height,
downsampled input image width).
pos_shape (tuple): The resolution of downsampled origin training
image.
mode (str): Algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'``. Default: ``'nearest'``
Return:
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
"""
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
pos_h, pos_w = pos_shape
# keep dim for easy deployment
cls_token_weight = pos_embed[:, 0:1]
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
pos_embed_weight = pos_embed_weight.reshape(
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
pos_embed_weight = resize(pos_embed_weight,
size=input_shpae,
align_corners=False,
mode=mode)
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
return pos_embed
def forward(self, inputs, ccd_index=None):
B = inputs.shape[0]
x, hw_shape = self.patch_embed(inputs)
if self.use_ccd:
_ccd_idx = np.concatenate(ccd_index, axis=0)
_ccd_embed = self.ccd_embed[:, _ccd_idx, :].permute(1, 0, 2)
x = x + _ccd_embed
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = self._pos_embeding(x, hw_shape, self.pos_embed)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1:
if self.final_norm:
x = self.norm1(x)
if i in self.out_indices:
if self.with_cls_token:
# Remove class token and reshape token for decoder heads
out = x[:, 1:]
else:
out = x
B, _, C = out.shape
out = out.reshape(B, hw_shape[0], hw_shape[1],
C).permute(0, 3, 1, 2).contiguous()
if self.output_cls_token:
out = [out, x[:, 0]]
outs.append(out)
return tuple(outs)
def train(self, mode=True):
super(VisionTransformer, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, nn.LayerNorm):
m.eval()
class VisionTransformerMSL(VisionTransformer):
def __init__(self, **kwargs):
if 'use_attn' in kwargs:
self.use_attn = kwargs.pop('use_attn')
else:
self.use_attn = False
if 'merge_stage' in kwargs:
self.merge_stage = kwargs.pop('merge_stage')
else:
self.merge_stage = 0
if 'with_cls_pos' in kwargs:
self.with_cls_pos = kwargs.pop('with_cls_pos')
else:
self.with_cls_pos = False
self.vocabulary_size = kwargs.pop('vocabulary_size') + 1 # 增加ignore类别
super().__init__(**kwargs)
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
img_size = kwargs.pop('img_size')
patch_size = kwargs.pop('patch_size')
num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dims))
self.vocabulary_token = nn.Parameter(torch.zeros(self.vocabulary_size, self.embed_dims))
self.vocabulary_weight = nn.Parameter(torch.zeros(1, self.patch_size * self.patch_size))
trunc_normal_(self.mask_token, mean=0., std=.02)
trunc_normal_(self.vocabulary_token, mean=0., std=.02)
if self.use_attn:
self.attn1 = MultiheadAttention(embed_dims=1024, num_heads=16, batch_first=True, bias = True)
self.attn2 = MultiheadAttention(embed_dims=1024, num_heads=16, batch_first=True, bias = True)
self.attn3 = MultiheadAttention(embed_dims=1024, num_heads=16, batch_first=True, bias = True)
self.attention_blocks = [self.attn1, self.attn2, self.attn3]
self.norm_attn = build_norm_layer(dict(type='LN'), 1024)[1]
def create_ann_token(self, anno_img):
B, H, W = anno_img.shape
ann_token = torch.index_select(self.vocabulary_token, 0, anno_img.reshape(-1)).reshape(B, H, W, -1)
assert H % self.patch_size == 0 and W % self.patch_size == 0
nph, npw = H // self.patch_size, W // self.patch_size
weight = F.softmax(self.vocabulary_weight, dim=1) * self.patch_size * self.patch_size
weight = weight.reshape(1, 1, self.patch_size, 1, self.patch_size).repeat(1, nph, 1, npw, 1).reshape(1, H, W, 1)
ann_token = ann_token * weight
ann_token = F.avg_pool2d(torch.einsum('BHWC->BCHW', ann_token), self.patch_size, self.patch_size)
ann_token = torch.einsum('BCHW->BHWC', ann_token).reshape(B, nph * npw, self.embed_dims) # shape B, L, C
return ann_token
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
"""Positioning embeding method.
Resize the pos_embed, if the input image size doesn't match
the training size.
Args:
patched_img (torch.Tensor): The patched image, it should be
shape of [B, L1, C].
hw_shape (tuple): The downsampled image resolution.
pos_embed (torch.Tensor): The pos_embed weighs, it should be
shape of [B, L2, c].
Return:
torch.Tensor: The pos encoded image feature.
"""
assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
'the shapes of patched_img and pos_embed must be [B, L, C]'
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
if x_len != pos_len:
if pos_len == (self.img_size[0] // self.patch_size) * (
self.img_size[1] // self.patch_size):
pos_h = self.img_size[0] // self.patch_size
pos_w = self.img_size[1] // self.patch_size
else:
raise ValueError(
'Unexpected shape of pos_embed, got {}.'.format(
pos_embed.shape))
pos_embed = self.resize_pos_embed(pos_embed, hw_shape,
(pos_h, pos_w),
self.interpolate_mode)
return self.drop_after_pos(patched_img + pos_embed)
@staticmethod
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
"""Resize pos_embed weights.
Resize pos_embed using bicubic interpolate method.
Args:
pos_embed (torch.Tensor): Position embedding weights.
input_shpae (tuple): Tuple for (downsampled input image height,
downsampled input image width).
pos_shape (tuple): The resolution of downsampled origin training
image.
mode (str): Algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'``. Default: ``'nearest'``
Return:
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
"""
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
pos_h, pos_w = pos_shape
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
pos_embed_weight = pos_embed_weight.reshape(
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
pos_embed_weight = resize(pos_embed_weight,
size=input_shpae,
align_corners=False,
mode=mode)
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
pos_embed = pos_embed_weight # torch.cat((cls_token_weight, pos_embed_weight), dim=1)
return pos_embed
def forward(self, x, y, mask=None):
x, hw_shape = self.patch_embed(x)
y = self.create_ann_token(y)
assert x.shape == y.shape
B, L, C = y.shape
if mask is not None:
mask_tokens = self.mask_token.expand(B, L, -1)
w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
y = y * (1. - w) + mask_tokens * w
if self.merge_stage == 0:
x = (x + y) * 0.5
else:
x = x.reshape(B, *hw_shape, C)
y = y.reshape(B, *hw_shape, C)
x = torch.cat((x, y), dim=2)
hw_shape = (hw_shape[0], hw_shape[1] * 2)
x = x.reshape(B, -1, C)
x = self._pos_embeding(x, hw_shape, self.pos_embed)
merge_idx = self.merge_stage - 1
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == merge_idx:
x = x.reshape(x.shape[0], *hw_shape, x.shape[-1]) # b,l,c -> b, h, w, c
x = (x[:, :, :x.shape[2]//2] + x[:, :, x.shape[2]//2:]) * 0.5
x = x.reshape(x.shape[0], -1, x.shape[-1])
hw_shape = (hw_shape[0], hw_shape[1] // 2)
if self.use_attn:
if i <= len(self.attention_blocks) - 1:
x = x + self.attention_blocks[i](x)
if i == len(self.attention_blocks) - 1:
x = self.norm_attn(x) # 会不会有冲突
if (not self.use_attn) and (i == len(self.layers) - 1):
if self.final_norm:
x = self.norm1(x) # 会不会有冲突
if i in self.out_indices:
if self.with_cls_token:
# Remove class token and reshape token for decoder heads
out = x[:, 1:]
else:
out = x
B, _, C = out.shape
out = out.reshape(B, hw_shape[0], hw_shape[1],
C).permute(0, 3, 1, 2).contiguous()
if self.output_cls_token:
out = [out, x[:, 0]]
outs.append(out)
return tuple(outs)

View File

@@ -0,0 +1,15 @@
from .uper_head import UPerHead
from .up_head import UPHead
__all__ = [
'UPerHead', 'UPHead'
]
type_mapping = {
'UPerHead': UPerHead,
'UPHead': UPHead
}
def build_head(type, **kwargs):
return type_mapping[type](**kwargs)

View File

@@ -0,0 +1,201 @@
import warnings
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
from mmcv.runner import BaseModule, auto_fp16
from mmseg.core import build_pixel_sampler
from mmseg.ops import resize
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
"""Base class for BaseDecodeHead.
Args:
in_channels (int|Sequence[int]): Input channels.
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
out_channels (int): Output channels of conv_seg.
threshold (float): Threshold for binary segmentation in the case of
`out_channels==1`. Default: None.
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
conv_cfg (dict|None): Config of conv layers. Default: None.
norm_cfg (dict|None): Config of norm layers. Default: None.
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU')
in_index (int|Sequence[int]): Input feature index. Default: -1
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
Default: None.
loss_decode (dict | Sequence[dict]): Config of decode loss.
The `loss_name` is property of corresponding loss function which
could be shown in training log. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
e.g. dict(type='CrossEntropyLoss'),
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='DiceLoss', loss_name='loss_dice')]
Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255.
sampler (dict|None): The config of segmentation map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
in_channels,
channels,
*,
num_classes,
out_channels=None,
threshold=None,
dropout_ratio=0.1,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
in_index=-1,
input_transform=None,
sampler=None,
align_corners=False,
init_cfg=dict(type='Normal',
std=0.01,
override=dict(name='conv_seg'))):
super(BaseDecodeHead, self).__init__(init_cfg)
self._init_inputs(in_channels, in_index, input_transform)
self.channels = channels
self.dropout_ratio = dropout_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.in_index = in_index
self.align_corners = align_corners
if out_channels is None:
if num_classes == 2:
warnings.warn('For binary segmentation, we suggest using'
'`out_channels = 1` to define the output'
'channels of segmentor, and use `threshold`'
'to convert seg_logist into a prediction'
'applying a threshold')
out_channels = num_classes
if out_channels != num_classes and out_channels != 1:
raise ValueError(
'out_channels should be equal to num_classes,'
'except binary segmentation set out_channels == 1 and'
f'num_classes == 2, but got out_channels={out_channels}'
f'and num_classes={num_classes}')
if out_channels == 1 and threshold is None:
threshold = 0.3
warnings.warn('threshold is not defined for binary, and defaults'
'to 0.3')
self.num_classes = num_classes
self.out_channels = out_channels
self.threshold = threshold
if sampler is not None:
self.sampler = build_pixel_sampler(sampler, context=self)
else:
self.sampler = None
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
if dropout_ratio > 0:
self.dropout = nn.Dropout2d(dropout_ratio)
else:
self.dropout = None
self.fp16_enabled = False
def extra_repr(self):
"""Extra repr."""
s = f'input_transform={self.input_transform}, ' \
f'align_corners={self.align_corners}'
return s
def _init_inputs(self, in_channels, in_index, input_transform):
"""Check and initialize input transforms.
The in_channels, in_index and input_transform must match.
Specifically, when input_transform is None, only single feature map
will be selected. So in_channels and in_index must be of type int.
When input_transform
Args:
in_channels (int|Sequence[int]): Input channels.
in_index (int|Sequence[int]): Input feature index.
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
"""
if input_transform is not None:
assert input_transform in ['resize_concat', 'multiple_select']
self.input_transform = input_transform
self.in_index = in_index
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(in_index, (list, tuple))
assert len(in_channels) == len(in_index)
if input_transform == 'resize_concat':
self.in_channels = sum(in_channels)
else:
self.in_channels = in_channels
else:
assert isinstance(in_channels, int)
assert isinstance(in_index, int)
self.in_channels = in_channels
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
@auto_fp16()
@abstractmethod
def forward(self, inputs):
"""Placeholder of forward function."""
pass
def cls_seg(self, feat):
"""Classify each pixel."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.conv_seg(feat)
return output

View File

@@ -0,0 +1,60 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
class PPM(nn.ModuleList):
"""Pooling Pyramid Module used in PSPNet.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
align_corners (bool): align_corners argument of F.interpolate.
"""
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
act_cfg, align_corners, **kwargs):
super(PPM, self).__init__()
self.pool_scales = pool_scales
self.align_corners = align_corners
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
for pool_scale in pool_scales:
self.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
**kwargs)))
def forward(self, x):
"""Forward function."""
ppm_outs = []
for ppm in self:
ppm_out = ppm(x)
ppm_out = ppm_out.to(torch.float32)
upsampled_ppm_out = resize(
ppm_out,
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
upsampled_ppm_out = upsampled_ppm_out.to(torch.bfloat16)
ppm_outs.append(upsampled_ppm_out)
return ppm_outs

View File

@@ -0,0 +1,52 @@
import torch.nn as nn
from collections import OrderedDict
from mmcv.cnn.utils.weight_init import (kaiming_init, trunc_normal_)
from mmcv.runner import (CheckpointLoader, load_state_dict)
from mmseg.utils import get_root_logger
class UPHead(nn.Module):
def __init__(self, in_dim, out_dim, up_scale, init_cfg=None):
super().__init__()
self.decoder = nn.Sequential(
nn.Conv2d(in_channels=in_dim,
out_channels=up_scale**2 * out_dim,
kernel_size=1),
nn.PixelShuffle(up_scale),
)
self.init_cfg = init_cfg
self.apply(self._init_weights)
def _init_weights(self, m):
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
logger = get_root_logger()
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
_state_dict = checkpoint['state_dict']
else:
_state_dict = checkpoint
state_dict = OrderedDict()
for k, v in _state_dict.items():
if k.startswith('backbone.'):
state_dict[k[9:]] = v
else:
state_dict[k] = v
print(f'loading weight: {self.init_cfg["checkpoint"]}')
load_state_dict(self, state_dict, strict=False, logger=logger)
else:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
kaiming_init(m, mode='fan_in', bias=0.)
def forward(self, x):
x = self.decoder(x)
return x

View File

@@ -0,0 +1,130 @@
# coding: utf-8
# Copyright (c) Ant Group. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from .psp_head import PPM
class UPerHead(BaseDecodeHead):
"""Unified Perceptual Parsing for Scene Understanding.
This head is the implementation of `UPerNet
<https://arxiv.org/abs/1807.10221>`_.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module applied on the last feature. Default: (1, 2, 3, 6).
"""
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super(UPerHead, self).__init__(
input_transform='multiple_select', **kwargs)
# PSP Module
self.psp_modules = PPM(
pool_scales,
self.in_channels[-1],
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.bottleneck = ConvModule(
self.in_channels[-1] + len(pool_scales) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
# FPN Module
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
for in_channels in self.in_channels[:-1]: # skip the top layer
l_conv = ConvModule(
in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False)
fpn_conv = ConvModule(
self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
self.fpn_bottleneck = ConvModule(
len(self.in_channels) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def psp_forward(self, inputs):
"""Forward function of PSP module."""
x = inputs[-1]
psp_outs = [x]
# breakpoint()
psp_outs.extend(self.psp_modules(x))
psp_outs = torch.cat(psp_outs, dim=1)
output = self.bottleneck(psp_outs)
return output
def forward(self, inputs):
"""Forward function."""
# breakpoint()
inputs = self._transform_inputs(inputs)
# build laterals
laterals = [
lateral_conv(inputs[i])
for i, lateral_conv in enumerate(self.lateral_convs)
]
laterals.append(self.psp_forward(inputs))
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
prev_shape = laterals[i - 1].shape[2:]
laterals[i] = laterals[i].type(torch.float32)
laterals[i - 1] = laterals[i - 1] + resize(
laterals[i],
size=prev_shape,
mode='bilinear',
align_corners=self.align_corners)
# build outputs
fpn_outs = [
self.fpn_convs[i](laterals[i])
for i in range(used_backbone_levels - 1)
]
# append psp feature
fpn_outs.append(laterals[-1])
for i in range(used_backbone_levels - 1, 0, -1):
fpn_outs[i] = fpn_outs[i].type(torch.float32)
fpn_outs[i] = resize(
fpn_outs[i],
size=fpn_outs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners)
fpn_outs = torch.cat(fpn_outs, dim=1)
output = self.fpn_bottleneck(fpn_outs)
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,4 @@
from .modality_vae_loss import ModalityVAELoss
from .recon_anno_loss import RecLoss
__all__ = [ "ModalityVAELoss", "RecLoss" ]

View File

@@ -0,0 +1,46 @@
# Copyright (c) Ant Group and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F
from antmmf.common.registry import registry
@registry.register_loss("ModalityVAELoss")
class ModalityVAELoss(nn.Module):
def __init__(self, **params):
super().__init__()
self.weight = params.pop("weight")
def compute_rec_loss(self, x_in, x_out, modal_flag):
loss_per_pixel = F.mse_loss(x_in, x_out, reduction='none')
loss_b = torch.mean(loss_per_pixel, dim=[1, 2, 3])
return torch.sum(loss_b * modal_flag)/ (modal_flag.sum() + 1e-6)
def forward(self, sample_list, output, *args, **kwargs):
vae_out = output["vae_out"]
feat_hr = vae_out['input_hr']
feat_s2 = vae_out['input_s2']
feat_s1 = vae_out['input_s1']
g_hr = vae_out['g_hr']
g_s2 = vae_out['g_s2']
g_s1 = vae_out['g_s1']
# process modality flags
modality_info = vae_out['modality_info']
B_M, L_M = modality_info.shape
modality_hr = modality_info[:,0]
modality_s2 = modality_info[:,1]
modality_s1 = modality_info[:,2]
######## rec losses ########
loss_xent = self.compute_rec_loss(g_hr, feat_hr, modality_hr) \
+ self.compute_rec_loss(g_s2, feat_s2, modality_s2) \
+ self.compute_rec_loss(g_s1, feat_s1, modality_s1)
loss_quant = vae_out["loss_quant"]
total_loss = loss_xent / 3 + loss_quant
return total_loss * self.weight

View File

@@ -0,0 +1,89 @@
# Copyright (c) Ant Group and its affiliates.
import torch
import torch.nn as nn
from antmmf.common.registry import registry
import torch.nn.functional as F
@registry.register_loss("RecLoss")
class RecLoss(nn.Module):
def __init__(self, **params):
super().__init__()
self.weight = params.pop("weight")
self.patch_size = params.pop("patch_size")
self.eps = torch.finfo(torch.bfloat16).eps
self.pred_key = params.pop("pred_key")
self.vocabulary_size = params.pop("vocabulary_size") + 1
self.mask_key = params.pop("mask_key")
self.target_key = params.pop("target_key")
self.feature_merged = params.pop("feature_merged")
self.cnt_train = 0
self.cnt_val = 0
self.use_bg = params.pop("use_bg")
if "use_all_patch" in params:
self.use_all_patch = params.pop("use_all_patch")
else:
self.use_all_patch = False
if "balance" in params:
self.balance = params.pop("balance")
else:
self.balance = False
if "sim_regularization" in params:
self.sim_regularization = params.pop("sim_regularization")
else:
self.sim_regularization = False
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.patch_size
w = int((x.shape[1]*0.5)**.5)
h = w * 2
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p))
x = torch.einsum('nhwpq->nhpwq', x)
imgs = x.reshape(shape=(x.shape[0], h * p, w * p))
return imgs
def forward(self, sample_list, output, *args, **kwargs):
pred = output[self.pred_key] # B, C, H, W
target = output[self.target_key] # B, H, W
mask = output[self.mask_key]
b_mask, h_mask, w_mask = mask.shape
mask = mask.reshape((b_mask, h_mask*w_mask))
mask = mask[:, :, None].repeat(1, 1, self.patch_size**2)
mask = self.unpatchify(mask)
if not self.use_bg:
valid = sample_list['valid']
mask = mask * valid
loss = F.cross_entropy(pred, target, reduction="none")
if self.balance:
if self.use_all_patch:
loss_pos = loss[target > 0].sum() / ((target > 0).sum() + 1e-6)
loss_neg = loss[target == 0].sum() / ((target == 0).sum() + 1e-6)
loss = (loss_pos + loss_neg) * 0.5
else:
loss_pos = loss[(target > 0) & (mask == 1)].sum() / (((target > 0) & (mask == 1)).sum() + 1e-6)
loss_neg = loss[(target == 0) & (mask == 1)].sum() / (((target == 0) & (mask == 1)).sum() + 1e-6)
loss = (loss_pos + loss_neg) * 0.5
else:
if self.use_all_patch:
loss = loss.mean()
else:
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
if self.sim_regularization:
vocabulary_token = output['vocabulary_token']
voca_normed = F.normalize(vocabulary_token, 2, 1)
similarity_matrix = 1 + torch.einsum('nd,md->nm', voca_normed, voca_normed)
num = voca_normed.shape[0]
index = torch.triu(voca_normed.new_ones(num, num), diagonal=1).type(torch.bool)
loss_reg = similarity_matrix[index].mean()
return loss * self.weight + loss_reg * 0.05
return loss * self.weight

View File

@@ -0,0 +1,4 @@
from .sem_metrics import SemMetric
__all__ = ["SemMetric"]

View File

@@ -0,0 +1,93 @@
# coding: utf-8
# Copyright (c) Ant Group. All rights reserved.
import torch
from torch.distributed import all_reduce, ReduceOp
from antmmf.common.registry import registry
from antmmf.modules.metrics.base_metric import BaseMetric
@registry.register_metric("sem_metric")
class SemMetric(BaseMetric):
"""Segmentation metrics used in evaluation phase.
Args:
name (str): Name of the metric.
eval_type(str): 3 types are supported: 'mIoU', 'mDice', 'mFscore'
result_field(str): key of predicted results in output dict
target_field(str): key of ground truth in output dict
ignore_index(int): class value will be ignored in evaluation
num_cls(int): total number of categories in evaluation
"""
def __init__(self,
name="dummy_metric", **kwargs
):
super().__init__(name)
self.reset()
def calculate(self, sample_list, model_output, *args, **kwargs):
"""Calculate Intersection and Union for a batch.
Args:
sample_list (Sample_List): data which contains ground truth segmentation maps
model_output (dict): data which contains prediction segmentation maps
Returns:
torch.Tensor: The intersection of prediction and ground truth histogram
on all classes.
torch.Tensor: The union of prediction and ground truth histogram on all
classes.
torch.Tensor: The prediction histogram on all classes.
torch.Tensor: The ground truth histogram on all classes.
"""
return torch.tensor(0).float()
def reset(self):
""" initialized all attributes value before evaluation
"""
self.total_mask_mae = 0
self.total_num = torch.tensor(0)
def collect(self, sample_list, model_output, *args, **kwargs):
"""
Args:
sample_list(Sample_List): data which contains ground truth segmentation maps
model_output (Dict): Dict returned by model, that contains two modalities
Returns:
torch.FloatTensor: Accuracy
"""
batch_mask_mae = \
self.calculate(sample_list, model_output, *args, **kwargs)
self.total_mask_mae += batch_mask_mae
self.total_num += 1
def format(self, *args):
""" Format evaluated metrics for profile.
Returns:
dict: dict of all evaluated metrics.
"""
output_metric = dict()
# if self.eval_type == 'mae':
mae = args[0]
output_metric['mae'] = mae.item()
return output_metric
def summarize(self, *args, **kwargs):
"""This method is used to calculate the overall metric.
Returns:
dict: dict of all evaluated metrics.
"""
# if self.eval_type == 'mae':
mae = self.total_mask_mae / (self.total_num)
return self.format(mae)
def all_reduce(self):
total_number = torch.stack([
self.total_mask_mae, self.total_num
]).cuda()
all_reduce(total_number, op=ReduceOp.SUM)
self.total_mask_mae = total_number[0].cpu()
self.total_num = total_number[1].cpu()

View File

@@ -0,0 +1,13 @@
from .transformer_encoder import TransformerEncoder
from .modality_completion import ModalityCompletion
__all__ = ['TransformerEncoder', 'ModalityCompletion']
type_mapping = {
'TransformerEncoder': TransformerEncoder,
'ModalityCompletion': ModalityCompletion
}
def build_neck(type, **kwargs):
return type_mapping[type](**kwargs)

View File

@@ -0,0 +1,212 @@
# Copyright (c) AntGroup. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
class BFloat16UpsampleNearest2d(nn.Module):
def __init__(self, scale_factor, mode='bilinear'):
super().__init__()
self.scale_factor = scale_factor
self.mode = mode
def forward(self, x):
x_float = x.float()
upsampled_x = F.interpolate(x_float, scale_factor=self.scale_factor, mode=self.mode)
return upsampled_x.to(x.dtype)
class ConvVQVAEv2(nn.Module):
def __init__(self, input_shape, conv_dim, z_dim, num_tokens=8192, temp=0.9):
super().__init__()
self.z_dim = z_dim
self.conv_dim = conv_dim # 256
self.input_shape = input_shape # 256
self.temp = temp
# code book
self.codebook = nn.Embedding(num_tokens, z_dim)
# encoder
self.relu = nn.LeakyReLU()
self.pool = nn.AvgPool2d(2)
self.conv1 = nn.Conv2d(input_shape[0], conv_dim, 5, stride=1, padding=2)
self.enc_block1 = nn.Sequential(
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
nn.LeakyReLU(),
)
self.gamma_1 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
self.enc_block2 = nn.Sequential(
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
nn.LeakyReLU(),
)
self.gamma_2 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
self.logit_conv = nn.Conv2d(conv_dim, num_tokens, 1)
# decoder
self.unpool = BFloat16UpsampleNearest2d(scale_factor=2)
self.conv2 = nn.Conv2d(z_dim, conv_dim, 3, stride=1, padding=1)
self.dec_block1 = nn.Sequential(
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
nn.LeakyReLU(),
)
self.gamma_3 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
self.dec_block2 = nn.Sequential(
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1),
nn.LeakyReLU(),
)
self.gamma_4 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1)))
self.rec_conv = nn.Conv2d(conv_dim, input_shape[0], 3, stride=1, padding=1)
def forward_encoder(self, x):
x = self.relu(self.conv1(x))
x = x + self.gamma_1 * self.enc_block1(x)
x = self.pool(x)
x = x + self.gamma_2 * self.enc_block2(x)
x = self.pool(x)
logits = self.logit_conv(x)
return logits
def forward_decoder(self, logits):
soft_one_hot = F.softmax(logits * (self.temp*10), dim=1)
sampled = torch.einsum('bnhw,nd->bdhw', soft_one_hot, self.codebook.weight)
x = self.relu(self.conv2(sampled))
x = self.unpool(x)
x = x + self.gamma_3 * self.dec_block1(x)
x = self.unpool(x)
x = x + self.gamma_4 * self.dec_block2(x)
rec_feats = self.rec_conv(x)
return rec_feats, soft_one_hot
def forward(self, x):
print(x.shape)
logits = self.forward_encoder(x)
images_p, soft_one_hot = self.forward_decoder(logits)
return [logits, images_p]
class ModalityCompletion(nn.Module):
def __init__(self,
input_shape_hr=(2816, 16, 16),
input_shape_s2=(2816, 16, 16),
input_shape_s1=(2816, 16, 16),
conv_dim=256,
z_dim=256,
n_codebook=8192,
init_cfg=None
):
super(ModalityCompletion, self).__init__()
self.vae_hr = ConvVQVAEv2(input_shape=input_shape_hr, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook)
self.vae_s2 = ConvVQVAEv2(input_shape=input_shape_s2, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook)
self.vae_s1 = ConvVQVAEv2(input_shape=input_shape_s1, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook)
self.kl_div_loss = torch.nn.KLDivLoss(reduction="none", log_target=True)
self.init_cfg=init_cfg
def init_weights(self):
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
from mmcls.utils import get_root_logger
from mmcv.runner import CheckpointLoader, load_state_dict
logger = get_root_logger()
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
load_state_dict(self, state_dict, strict=False, logger=logger)
else:
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def kl_loss(self, logits_hr, logits_s2, logits_s1, modality_info):
prob_hr = F.log_softmax(logits_hr, dim=1)
prob_s2 = F.log_softmax(logits_s2, dim=1)
prob_s1 = F.log_softmax(logits_s1, dim=1)
flag_hr = modality_info[:,0][:, None, None, None]
flag_s2 = modality_info[:,1][:, None, None, None]
flag_s1 = modality_info[:,2][:, None, None, None]
loss_hr_s2 = self.kl_div_loss(prob_hr, prob_s2) + self.kl_div_loss(prob_s2, prob_hr)
loss_hr_s2 = (loss_hr_s2 * flag_hr * flag_s2).sum((1, 2, 3)).mean()
loss_hr_s1 = self.kl_div_loss(prob_hr, prob_s1) + self.kl_div_loss(prob_s1, prob_hr)
loss_hr_s1 = (loss_hr_s1 * flag_hr * flag_s1).sum((1, 2, 3)).mean()
loss_s2_s1 = self.kl_div_loss(prob_s2, prob_s1) + self.kl_div_loss(prob_s1, prob_s2)
loss_s2_s1 = (loss_s2_s1 * flag_s2 * flag_s1).sum((1, 2, 3)).mean()
loss = (loss_hr_s2 + loss_hr_s1 + loss_s2_s1) / 6.0
return loss
def forward(self, feat_hr, feat_s2, feat_s1, modality_info):
# encodersadd noise
# each modality
# 2816, 16, 16 => conv 256, 4, 4 => flatten 4096(256*4*4) => linear mu 256, log_var 256
B, C, H, W = feat_hr.shape
B_M, L_M = modality_info.shape
assert B == B_M, f'feat_hr batch: {B}, modality_info batch: {B_M}'
# quant, emb_loss, info
# hr input flow
logits_hr = self.vae_hr.forward_encoder(feat_hr)
logits_s2 = self.vae_s2.forward_encoder(feat_s2)
logits_s1 = self.vae_s1.forward_encoder(feat_s1)
modality_hr = modality_info[:,0]
modality_s2 = modality_info[:,1]
modality_s1 = modality_info[:,2]
flag_hr = modality_hr[:, None, None, None] # B => B, C, H, W
flag_s2 = modality_s2[:, None, None, None]
flag_s1 = modality_s1[:, None, None, None]
mean_logits_hr_s2 = logits_hr * flag_hr + logits_s2 * flag_s2
mean_logits_hr_s1 = logits_hr * flag_hr + logits_s1 * flag_s1
mean_logits_s1_s2 = logits_s1 * flag_s1 + logits_s2 * flag_s2
logits_hr_rec = logits_hr * flag_hr + mean_logits_s1_s2 * (~flag_hr)
logits_s2_rec = logits_s2 * flag_s2 + mean_logits_hr_s1 * (~flag_s2)
logits_s1_rec = logits_s1 * flag_s1 + mean_logits_hr_s2 * (~flag_s1)
g_hr, soft_one_hot_hr = self.vae_hr.forward_decoder(logits_hr_rec)
g_s2, soft_one_s2 = self.vae_s2.forward_decoder(logits_s2_rec)
g_s1, soft_one_s1 = self.vae_s1.forward_decoder(logits_s1_rec)
hr_out = feat_hr * flag_hr + g_hr * (~flag_hr)
s2_out = feat_s2 * flag_s2 + g_s2 * (~flag_s2)
s1_out = feat_s1 * flag_s1 + g_s1 * (~flag_s1)
output = {}
output['hr_out'] = hr_out
output['s2_out'] = s2_out
output['s1_out'] = s1_out
output['modality_info'] = modality_info
output['input_hr'] = feat_hr
output['input_s2'] = feat_s2
output['input_s1'] = feat_s1
output['logits_hr'] = logits_hr
output['logits_s2'] = logits_s2
output['logits_s1'] = logits_s1
output['soft_one_hot_hr'] = soft_one_hot_hr
output['soft_one_hot_s2'] = soft_one_s2
output['soft_one_hot_s1'] = soft_one_s1
output['g_hr'] = g_hr
output['g_s2'] = g_s2
output['g_s1'] = g_s1
output['loss_quant'] = self.kl_loss(logits_hr, logits_s2, logits_s1, modality_info)
return output

View File

@@ -0,0 +1,144 @@
# Copyright (c) Ant Group. All rights reserved.
from collections import OrderedDict
import torch
import torch.nn as nn
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
trunc_normal_)
from mmcv.runner import (CheckpointLoader, load_state_dict)
from torch.nn.modules.batchnorm import _BatchNorm
from mmseg.models.backbones.vit import TransformerEncoderLayer
from mmseg.utils import get_root_logger
class TransformerEncoder(nn.Module):
def __init__(self,
input_dims=768,
embed_dims=768,
num_layers=4,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
with_cls_token=True,
output_cls_token=False,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
num_fcs=2,
norm_eval=False,
with_cp=False,
init_cfg=None,
*args,
**kwargs):
super(TransformerEncoder, self).__init__()
self.porj_linear = nn.Linear(input_dims, embed_dims)
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
self.init_cfg = init_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self.drop_after_pos = nn.Dropout(p=drop_rate)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
] # stochastic depth decay rule
self.layers = nn.ModuleList()
for i in range(num_layers):
self.layers.append(
TransformerEncoderLayer(embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=mlp_ratio *
embed_dims,
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
num_fcs=num_fcs,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
batch_first=True))
def init_weights(self):
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
logger = get_root_logger()
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
_state_dict = checkpoint['state_dict']
else:
_state_dict = checkpoint
state_dict = OrderedDict()
for k, v in _state_dict.items():
if k.startswith('backbone.'):
state_dict[k[9:]] = v
else:
state_dict[k] = v
load_state_dict(self, state_dict, strict=False, logger=logger)
else:
# We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
trunc_normal_(self.cls_token, std=.02)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
if 'ffn' in n:
nn.init.normal_(m.bias, mean=0., std=1e-6)
else:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
kaiming_init(m, mode='fan_in', bias=0.)
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m, val=1.0, bias=0.)
def forward(self, inputs, require_feat: bool = False, require_two: bool = False):
inputs = self.porj_linear(inputs)
B, N, C = inputs.shape
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, inputs), dim=1)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
# add hidden and atten state
block_outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if require_feat:
block_outs.append(x)
if self.output_cls_token:
if require_two:
x = x[:, :2]
else:
x = x[:, 0]
elif not self.output_cls_token and self.with_cls_token:
x = x # [:, :]
if require_feat:
return x, block_outs
else:
return x
def train(self, mode=True):
super(TransformerEncoder, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, nn.LayerNorm):
m.eval()

View File

@@ -0,0 +1,4 @@
# Copyright (c) Ant Financial Service Group and its affiliates.
from .skysense_pp_pipeline import SkySensePP
__all__ = ['SkySensePP']

View File

@@ -0,0 +1,458 @@
# coding: utf-8
# Copyright (c) Ant Group. All rights reserved.
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import math
import random
from antmmf.common.registry import registry
from antmmf.models.base_model import BaseModel
from lib.models.backbones import build_backbone
from lib.models.necks import build_neck
from lib.models.heads import build_head
from lib.utils.utils import LayerDecayValueAssigner
@registry.register_model("SkySensePP")
class SkySensePP(BaseModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.sources = config.sources
assert len(self.sources) > 0, 'at least one data source is required'
if 's2' in self.sources:
self.use_ctpe = config.use_ctpe
self.use_modal_vae = config.use_modal_vae
self.use_cls_token_uper_head = config.use_cls_token_uper_head
self.target_mean=[0.485, 0.456, 0.406]
self.target_std=[0.229, 0.224, 0.225]
self.vocabulary_size = config.vocabulary_size
self.vocabulary = list(range(1, config.vocabulary_size + 1)) # 0 for ignore
def build(self):
if 'hr' in self.sources:
self.backbone_hr = self._build_backbone('hr')
if 's2' in self.sources:
self.backbone_s2 = self._build_backbone('s2')
if self.use_ctpe:
self.ctpe = nn.Parameter(
torch.zeros(1, self.config.calendar_time,
self.config.necks.input_dims))
if 'head_s2' in self.config.keys():
self.head_s2 = self._build_head('head_s2')
self.fusion = self._build_neck('necks')
if 's1' in self.sources:
self.backbone_s1 = self._build_backbone('s1')
if 'head_s1' in self.config.keys():
self.head_s1 = self._build_head('head_s1')
self.head_rec_hr = self._build_head('rec_head_hr')
self.with_aux_head = False
if self.use_modal_vae:
self.modality_vae = self._build_neck('modality_vae')
if 'auxiliary_head' in self.config.keys():
self.with_aux_head = True
self.aux_head = self._build_head('auxiliary_head')
if 'init_cfg' in self.config.keys(
) and self.config.init_cfg is not None and self.config.init_cfg.checkpoint is not None and self.config.init_cfg.key is not None:
self.load_pretrained(self.config.init_cfg.checkpoint,
self.config.init_cfg.key)
def _build_backbone(self, key):
config_dict = self.config[f'backbone_{key}'].to_dict()
backbone_type = config_dict.pop('type')
backbone = build_backbone(backbone_type, **config_dict)
backbone.init_weights()
return backbone
def _build_neck(self, key):
config_dict = self.config[key].to_dict()
neck_type = config_dict.pop('type')
neck = build_neck(neck_type, **config_dict)
neck.init_weights()
return neck
def _build_head(self, key):
head_config = self.config[key].to_dict()
head_type = head_config.pop('type')
head = build_head(head_type, **head_config)
return head
def get_optimizer_parameters(self, config):
optimizer_grouped_parameters = [
{
"params": [],
"lr": config.optimizer_attributes.params.lr,
"weight_decay": config.optimizer_attributes.params.weight_decay,
},
{
"params": [],
"lr": config.optimizer_attributes.params.lr,
"weight_decay": 0.0,
},
]
layer_decay_value_assigner_hr = LayerDecayValueAssigner(
config.lr_parameters.layer_decay, None,
config.optimizer_attributes.params.lr, 'swin',
config.model_attributes.SkySensePP.backbone_hr.arch
)
layer_decay_value_assigner_s2 = LayerDecayValueAssigner(
config.lr_parameters.layer_decay, 24,
config.optimizer_attributes.params.lr, 'vit',
)
layer_decay_value_assigner_s1 = LayerDecayValueAssigner(
config.lr_parameters.layer_decay, 24,
config.optimizer_attributes.params.lr, 'vit',
)
layer_decay_value_assigner_fusion = LayerDecayValueAssigner(
config.lr_parameters.layer_decay, 24,
config.optimizer_attributes.params.lr, 'vit',
)
num_frozen_params = 0
if 'hr' in self.sources:
print('hr'.center(60, '-'))
num_frozen_params += layer_decay_value_assigner_hr.fix_param(
self.backbone_hr,
config.lr_parameters.frozen_blocks,
)
optimizer_grouped_parameters.extend(
layer_decay_value_assigner_hr.get_parameter_groups(
self.backbone_hr, config.optimizer_attributes.params.weight_decay
)
)
if 's2' in self.sources:
print('s2'.center(60, '-'))
num_frozen_params += layer_decay_value_assigner_s2.fix_param(
self.backbone_s2,
config.lr_parameters.frozen_blocks,
)
optimizer_grouped_parameters.extend(
layer_decay_value_assigner_s2.get_parameter_groups(
self.backbone_s2, config.optimizer_attributes.params.weight_decay
)
)
no_decay = [".bn.", "bias"]
optimizer_grouped_parameters[0]["params"] += [
p for n, p in self.head_s2.named_parameters()
if not any(nd in n for nd in no_decay)
]
optimizer_grouped_parameters[1]["params"] += [
p for n, p in self.head_s2.named_parameters()
if any(nd in n for nd in no_decay)
]
if self.use_ctpe:
optimizer_grouped_parameters[1]["params"] += [self.ctpe]
if 's1' in self.sources:
print('s1'.center(60, '-'))
num_frozen_params += layer_decay_value_assigner_s1.fix_param(
self.backbone_s1,
config.lr_parameters.frozen_blocks,
)
optimizer_grouped_parameters.extend(
layer_decay_value_assigner_s1.get_parameter_groups(
self.backbone_s1, config.optimizer_attributes.params.weight_decay
)
)
no_decay = [".bn.", "bias"]
optimizer_grouped_parameters[0]["params"] += [
p for n, p in self.head_s1.named_parameters()
if not any(nd in n for nd in no_decay)
]
optimizer_grouped_parameters[1]["params"] += [
p for n, p in self.head_s1.named_parameters()
if any(nd in n for nd in no_decay)
]
if len(self.sources) > 1:
print('fusion'.center(60, '-'))
num_frozen_params += layer_decay_value_assigner_fusion.fix_param_deeper(
self.fusion,
config.lr_parameters.frozen_fusion_blocks_start, # 冻结后面所有的stage
)
optimizer_grouped_parameters.extend(
layer_decay_value_assigner_fusion.get_parameter_groups(
self.fusion, config.optimizer_attributes.params.weight_decay
)
)
if self.use_modal_vae:
no_decay = [".bn.", "bias"]
optimizer_grouped_parameters[0]["params"] += [
p for n, p in self.modality_vae.named_parameters()
if not any(nd in n for nd in no_decay)
]
optimizer_grouped_parameters[1]["params"] += [
p for n, p in self.modality_vae.named_parameters()
if any(nd in n for nd in no_decay)
]
no_decay = [".bn.", "bias"]
optimizer_grouped_parameters[0]["params"] += [
p for n, p in self.head_rec_hr.named_parameters()
if not any(nd in n for nd in no_decay)
]
optimizer_grouped_parameters[1]["params"] += [
p for n, p in self.head_rec_hr.named_parameters()
if any(nd in n for nd in no_decay)
]
if self.with_aux_head:
no_decay = [".bn.", "bias"]
optimizer_grouped_parameters[0]["params"] += [
p for n, p in self.aux_head.named_parameters()
if not any(nd in n for nd in no_decay)
]
optimizer_grouped_parameters[1]["params"] += [
p for n, p in self.aux_head.named_parameters()
if any(nd in n for nd in no_decay)
]
num_params = [len(x['params']) for x in optimizer_grouped_parameters]
print(len(list(self.parameters())), sum(num_params), num_frozen_params)
assert len(list(self.parameters())) == sum(num_params) + num_frozen_params
return optimizer_grouped_parameters
def get_custom_scheduler(self, trainer):
optimizer = trainer.optimizer
num_training_steps = trainer.config.training_parameters.max_iterations
num_warmup_steps = trainer.config.training_parameters.num_warmup_steps
if "train" in trainer.run_type:
if num_training_steps == math.inf:
epoches = trainer.config.training_parameters.max_epochs
assert epoches != math.inf
num_training_steps = trainer.config.training_parameters.max_epochs * trainer.epoch_iterations
def linear_with_wram_up(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(
1, num_warmup_steps))
return max(
0.0,
float(num_training_steps - current_step) /
float(max(1, num_training_steps - num_warmup_steps)),
)
def cos_with_wram_up(current_step):
num_cycles = 0.5
if current_step < num_warmup_steps:
return float(current_step) / float(max(
1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps))
return max(
0.0,
0.5 *
(1.0 +
math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
)
lr_lambda = cos_with_wram_up if trainer.config.training_parameters.cos_lr else linear_with_wram_up
else:
def lr_lambda(current_step):
return 0.0 # noqa
return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, -1)
def convert_target(self, target):
mean = target.new_tensor(self.target_mean).reshape(1, 3, 1, 1)
std = target.new_tensor(self.target_std).reshape(1, 3, 1, 1)
target = ((target * std + mean)*255).to(torch.long)
target[:, 0] = target[:, 0] * 256 * 256
target[:, 1] = target[:, 1] * 256
target = target.sum(1).type(torch.long)
unique_target = target.unique()
target_index = torch.searchsorted(unique_target, target)
no_bg = False
if unique_target[0].item() > 0:
target_index += 1
no_bg = True
target_index_unique = target_index.unique().tolist()
random.shuffle(self.vocabulary)
value = target.new_tensor([0] + self.vocabulary)
mapped_target = target_index.clone()
idx_2_color = {}
for v in target_index_unique:
mapped_target[target_index == v] = value[v]
idx_2_color[value[v].item()] = unique_target[v - 1 if no_bg else v].item()
return mapped_target, idx_2_color
def forward(self, sample_list):
output = dict()
modality_flag_hr = sample_list["modality_flag_hr"]
modality_flag_s2 = sample_list["modality_flag_s2"]
modality_flag_s1 = sample_list["modality_flag_s1"]
modalities = [modality_flag_hr, modality_flag_s2, modality_flag_s1]
modalities = torch.tensor(modalities).permute(1,0).contiguous() # L, B => B, L
anno_img = sample_list["targets"]
anno_img, idx_2_color = self.convert_target(anno_img)
output["mapped_targets"] = anno_img
output["idx_2_color"] = idx_2_color
anno_mask = sample_list["anno_mask"]
anno_s2 = anno_img[:, 15::32, 15::32]
anno_s1 = anno_s2
output["anno_hr"] = anno_img
output["anno_s2"] = anno_s2
### 1. backbone
if 'hr' in self.sources:
hr_img = sample_list["hr_img"]
B_MASK, H_MASK, W_MASK = anno_mask.shape
block_size = 32
anno_mask_hr = anno_mask.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, block_size, block_size)
anno_mask_hr = anno_mask_hr.permute(0, 1, 3, 2, 4).reshape(B_MASK, H_MASK*block_size, W_MASK*block_size).contiguous()
B, C_G, H_G, W_G = hr_img.shape
hr_features = self.backbone_hr(hr_img, anno_img, anno_mask_hr)
output['mask_hr'] = anno_mask_hr
output['target_hr'] = anno_img
if 's2' in self.sources:
s2_img = sample_list["s2_img"]
B, C_S2, S_S2, H_S2, W_S2 = s2_img.shape
s2_img = s2_img.permute(0, 2, 1, 3,
4).reshape(B * S_S2, C_S2, H_S2, W_S2).contiguous() # ts time to batch
anno_mask_s2 = anno_mask
s2_features = self.backbone_s2(s2_img, anno_s2, anno_mask_s2)
if 'head_s2' in self.config.keys():
s2_features = self.head_s2(s2_features[-1])
s2_features = [s2_features]
if 's1' in self.sources:
s1_img = sample_list["s1_img"]
B, C_S1, S_S1, H_S1, W_S1 = s1_img.shape
s1_img = s1_img.permute(0, 2, 1, 3,
4).reshape(B * S_S1, C_S1, H_S1, W_S1).contiguous()
anno_mask_s1 = anno_mask
s1_features = self.backbone_s1(s1_img, anno_s1, anno_mask_s1)
if 'head_s1' in self.config.keys():
s1_features = self.head_s1(s1_features[-1])
s1_features = [s1_features]
### 2. prepare features for fusion
hr_features_stage3 = hr_features[-1]
s2_features_stage3 = s2_features[-1]
s1_features_stage3 = s1_features[-1]
modalities = modalities.to(hr_features_stage3.device)
if self.use_modal_vae:
vae_out = self.modality_vae(hr_features_stage3, s2_features_stage3, s1_features_stage3, modalities)
hr_features_stage3 = vae_out['hr_out']
s2_features_stage3 = vae_out['s2_out']
s1_features_stage3 = vae_out['s1_out']
output['vae_out'] = vae_out
features_stage3 = []
if 'hr' in self.sources:
B, C3_G, H3_G, W3_G = hr_features_stage3.shape
hr_features_stage3 = hr_features_stage3.permute(
0, 2, 3, 1).reshape(B * H3_G * W3_G, C3_G).unsqueeze(1).contiguous() # B * H3_G * W3_G, 1, C3_G
features_stage3 = hr_features_stage3
if 's2' in self.sources:
# s2_features_stage3 = s2_features[-1]
_, C3_S2, H3_S2, W3_S2 = s2_features_stage3.shape
s2_features_stage3 = s2_features_stage3.reshape(
B, S_S2, C3_S2, H3_S2,
W3_S2).permute(0, 3, 4, 1, 2).reshape(B, H3_S2 * W3_S2, S_S2,
C3_S2).contiguous()
if self.use_ctpe:
ct_index = sample_list["s2_ct"]
ctpe = self.ctpe[:, ct_index, :].contiguous().permute(1, 0, 2, 3).contiguous()
ctpe = ctpe.expand(-1, 256, -1, -1)
ct_index_2 = sample_list["s2_ct2"]
ctpe2 = self.ctpe[:, ct_index_2, :].contiguous().permute(1, 0, 2, 3).contiguous()
ctpe2 = ctpe2.expand(-1, 256, -1, -1)
ctpe_comb = torch.cat([ctpe, ctpe2], 1)
# import pdb;pdb.set_trace()
s2_features_stage3 = (s2_features_stage3 + ctpe_comb).reshape(
B * H3_S2 * W3_S2, S_S2, C3_S2).contiguous()
else:
s2_features_stage3 = s2_features_stage3.reshape(
B * H3_S2 * W3_S2, S_S2, C3_S2).contiguous()
if len(features_stage3) > 0:
assert H3_G == H3_S2 and W3_G == W3_S2 and C3_G == C3_S2
features_stage3 = torch.cat((features_stage3, s2_features_stage3), dim=1)
else:
features_stage3 = s2_features_stage3
if 's1' in self.sources:
# s1_features_stage3 = s1_features[-1]
_, C3_S1, H3_S1, W3_S1 = s1_features_stage3.shape
s1_features_stage3 = s1_features_stage3.reshape(
B, S_S1, C3_S1, H3_S1,
W3_S1).permute(0, 3, 4, 1, 2).reshape(B, H3_S1 * W3_S1, S_S1,
C3_S1).contiguous()
s1_features_stage3 = s1_features_stage3.reshape(
B * H3_S1 * W3_S1, S_S1, C3_S1).contiguous()
if len(features_stage3) > 0:
assert H3_S1 == H3_S2 and W3_S1 == W3_S2 and C3_S1 == C3_S2
features_stage3 = torch.cat((features_stage3, s1_features_stage3),
dim=1)
else:
features_stage3 = s1_features_stage3
### 3. fusion
if self.config.necks.output_cls_token:
if self.config.necks.get('require_feat', False):
cls_token, block_outs = self.fusion(features_stage3 , True)
else:
cls_token = self.fusion(features_stage3)
_, C3_G = cls_token.shape
cls_token = cls_token.reshape(B, H3_G, W3_G,
C3_G).contiguous().permute(0, 3, 1, 2).contiguous() # b, c, h, w
else:
assert self.config.necks.with_cls_token is False
if self.config.necks.get('require_feat', False):
features_stage3, block_outs = self.fusion(features_stage3, True)
else:
features_stage3 = self.fusion(features_stage3)
features_stage3 = features_stage3.reshape(
B, H3_S2, W3_S2, S_S2,
C3_S2).permute(0, 3, 4, 1, 2).reshape(B * S_S2, C3_S2, H3_S2,
W3_S2).contiguous()
### 4. decoder for rec
hr_rec_inputs = hr_features
feat_stage1 = hr_rec_inputs[0]
if feat_stage1.shape[-1] == feat_stage1.shape[-2]:
feat_stage1_left, feat_stage1_right = torch.split(feat_stage1, feat_stage1.shape[-1] // 2, dim=-1)
feat_stage1 = torch.cat((feat_stage1_left, feat_stage1_right), dim=1)
hr_rec_inputs = list(hr_features)
hr_rec_inputs[0] = feat_stage1
rec_feats = [*hr_rec_inputs, cls_token]
logits_hr = self.head_rec_hr(rec_feats)
if self.config.get('upsacle_results', True):
logits_hr = logits_hr.to(torch.float32)
logits_hr = F.interpolate(logits_hr, scale_factor=4, mode='bilinear', align_corners=True)
output["logits_hr"] = logits_hr
return output
def load_pretrained(self, ckpt_path, key):
pretrained_dict = torch.load(ckpt_path, map_location={'cuda:0': 'cpu'})
pretrained_dict = pretrained_dict[key]
for k, v in pretrained_dict.items():
if k == 'backbone_s2.patch_embed.projection.weight':
pretrained_in_channels = v.shape[1]
if self.config.backbone_s2.in_channels == 4:
new_weight = v[:, [0, 1, 2, 6]]
new_weight = new_weight * (
pretrained_in_channels /
self.config.backbone_s2.in_channels)
pretrained_dict[k] = new_weight
missing_keys, unexpected_keys = self.load_state_dict(pretrained_dict,
strict=False)
print('missing_keys:', missing_keys)
print('unexpected_keys:', unexpected_keys)

View File

@@ -0,0 +1,244 @@
import os
import glob
import numpy as np
import yaml
import argparse
import oss2
import torch
from PIL import Image
from tqdm import tqdm
import concurrent.futures
from torchvision.transforms import functional as F
import random
from antmmf.common.registry import registry
from antmmf.common.report import Report, default_result_formater
from antmmf.structures import Sample, SampleList
from antmmf.predictors.base_predictor import BasePredictor
from antmmf.utils.timer import Timer
from antmmf.predictors.build import build_predictor
from antmmf.common.task_loader import build_collate_fn
from antmmf.datasets.samplers import SequentialSampler
from antmmf.common.build import build_config
from lib.utils.checkpoint import SegCheckpoint
from lib.datasets.loader.few_shot_flood3i_loader import FewShotFloodLoader
def seed_everything(seed=0):
# 为了确保CUDA卷积的确定性
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
@registry.register_predictor("OneshotPredictor")
class OneshotPredictor(BasePredictor, FewShotFloodLoader):
def __init__(self, config):
self.config = config
self.predictor_parameters = self.config.predictor_parameters
def _predict(self, sample_list):
# with torch.no_grad():
if True:
sample_list = sample_list.to(self.device)
report = self._forward_pass(sample_list)
return report, sample_list
def _forward_pass(self, samplelist):
autocast_dtype = torch.bfloat16
with torch.cuda.amp.autocast(enabled=True, dtype=autocast_dtype):
model_output = self.model(samplelist)
report = Report(samplelist, model_output)
return report
def predict(self, data=None):
if data is None:
data = self.dummy_request()
sample = self._build_sample(data)
if not isinstance(sample, Sample):
raise Exception(
f"Method _build_sample is expected to return a instance of antmmf.structures.sample.Sample,"
f"but got type {type(sample)} instead.")
result, sample_list = self._predict(SampleList([sample]))
np_result = default_result_formater(result)
result = self.format_result(np_result)
assert isinstance(
result, dict
), f"Result should be instance of Dict,but got f{type(result)} instead"
return result, sample_list
def load_checkpoint(self):
self.resume_file = self.config.predictor_parameters.model_dir
self.checkpoint = SegCheckpoint(self, load_only=True)
self.checkpoint.load_model_weights(self.resume_file, force=True)
def covert_speedup_op(self):
if self.config.predictor_parameters.replace_speedup_op:
from lib.utils.optim_utils import replace_speedup_op
self.model = replace_speedup_op(self.model)
def save_image(output_path, image_np_array):
image = Image.fromarray(image_np_array)
image.save(output_path)
def build_predictor_from_args(args, *rest, **kwargs):
config = build_config(
args.config,
config_override=args.config_override,
opts_override=args.opts,
specific_override=args,
)
predictor_obj = build_predictor(config)
setattr(predictor_obj, "args", args)
return predictor_obj
def build_online_predictor(model_dir=None, config_yaml=None):
assert model_dir or config_yaml
from antmmf.utils.flags import flags
# if config_yaml not indicated, there must be a `config.yaml` file under `model_dir`
config_path = config_yaml if config_yaml else os.path.join(model_dir, "config.yaml")
input_args = ["--config", config_path]
if model_dir is not None:
input_args += ["predictor_parameters.model_dir", model_dir]
parser = flags.get_parser()
args = parser.parse_args(input_args)
predictor = build_predictor_from_args(args)
return predictor
def profile(profiler, text):
print(f'{text}: {profiler.get_time_since_start()}')
profiler.reset()
def cvt_colors(img_2d, idx_2_color_rgb):
img_rgb = np.zeros((img_2d.shape[0], img_2d.shape[1], 3), dtype=np.uint8)
for idx, color in idx_2_color_rgb.items():
img_rgb[img_2d==idx] = color
return img_rgb
def process_results(preds, targets, input_imgs, img_names, save_dir, save_dir_vis, idx_2_color):
imagenet_std = np.array([0.229, 0.224, 0.225])
imagenet_mean = np.array([0.485, 0.456, 0.406])
idx_2_color_rgb = {}
for idx, color in idx_2_color.items():
r = color // (256 * 256)
g = (color % (256 * 256)) // 256
b = color % 256
idx_2_color_rgb[idx] = (r, g, b)
for i in range(preds.size(0)):
output1 = preds[i].argmax(0) # h, w
output1_total = output1.clone()
output1 = output1[output1.shape[0]//2:, :]
output1 = output1.numpy().astype(np.uint8)
output1 = cvt_colors(output1, idx_2_color_rgb)
output1_total = output1_total.numpy().astype(np.uint8)
output1_total = cvt_colors(output1_total, idx_2_color_rgb)
# for visualization
output2 = targets[i]
output2 = output2.numpy().astype(np.uint8)
output2 = cvt_colors(output2, idx_2_color_rgb)
input_img = torch.einsum('chw->hwc', input_imgs[i])
input_img = torch.clip((input_img * imagenet_std + imagenet_mean) * 255, 0, 255)
input_img = input_img.numpy().astype(np.uint8)
output_comb = np.concatenate((input_img, output1_total, output2), axis=1)
# save result
save_path = os.path.join(save_dir, f'{img_names[i]}.png')
save_image(save_path, output1)
save_path_vis = os.path.join(save_dir_vis, f'{img_names[i]}.png')
save_image(save_path_vis, output_comb)
def test(args):
model_path = args.model_path
config_path = args.config
global_seed = args.seed
predictor = build_online_predictor(model_path, config_path)
seed_everything(global_seed)
dataset = FewShotFloodLoader(
"test", predictor.config.task_attributes.segmentation.dataset_attributes.few_shot_flood_segmentation)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=8,
shuffle=False,
sampler=SequentialSampler(dataset),
collate_fn=build_collate_fn(dataset),
num_workers=16,
pin_memory=True,
drop_last=False,
)
print(len(loader))
predictor.load(with_ckpt=True)
predictor.covert_speedup_op()
save_dir = args.save_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_dir_vis = os.path.join(args.save_dir, 'vis_full')
if not os.path.exists(save_dir_vis):
os.makedirs(save_dir_vis)
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
profiler2 = Timer()
profiler2.reset()
for sample_batched in tqdm(loader):
profile(profiler2, "Build sample time")
result, sample_list = predictor._predict(SampleList(sample_batched))
profile(profiler2, "Infer time")
preds = result["logits_hr"].to(torch.float32).detach().cpu()
targets = result['mapped_targets'].to(torch.float32).detach().cpu()
idx_2_color = result['idx_2_color']
input_imgs = sample_list['hr_img'].to(torch.float32).detach().cpu()
img_names = sample_list["img_name"]
executor.submit(process_results, preds, targets, input_imgs, img_names, save_dir, save_dir_vis, idx_2_color)
profile(profiler2, "Save results time")
try:
del predictor.model
except Exception as e:
print('delete model error: ', e)
def parse_args():
desc = '1-shot predictor'
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--model_path',
required=True,
type=str,
help='model directory')
parser.add_argument('--seed',
default=0,
type=int,
help='seed')
parser.add_argument('--config',
required=True,
type=str,
help='config path')
parser.add_argument('--save_dir',
required=False,
type=str,
help='save directory')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
test(args)

3
lib/task/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .segmentation import SegmentationTask
__all__ = ['SegmentationTask']

18
lib/task/segmentation.py Normal file
View File

@@ -0,0 +1,18 @@
# coding: utf-8
# Copyright (c) Ant Group. All rights reserved.
from antmmf.common.registry import registry
from antmmf.tasks import BaseTask
@registry.register_task("segmentation")
class SegmentationTask(BaseTask):
def __init__(self):
super(SegmentationTask, self).__init__("segmentation")
def _get_available_datasets(self):
return ["pretraining_loader"]
def _preprocess_item(self, item):
return item

4
lib/trainer/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from .seg_trainer import SEGTrainer
__all__ = ['SEGTrainer']

399
lib/trainer/seg_trainer.py Normal file
View File

@@ -0,0 +1,399 @@
# Copyright (c) Ant Group. and its affiliates.
import gc
import math
from itertools import chain
import torch
from torch import nn
from tqdm import tqdm
from antmmf.common.registry import registry
from antmmf.common.report import Report
from antmmf.common.meter import Meter
from antmmf.modules.metrics import Metrics
from antmmf.optimizer.combine_optimizers import CombinedOptimizer
from antmmf.utils.distributed_utils import (broadcast_scalar, is_main_process)
from antmmf.utils.early_stopping import EarlyStopping
from antmmf.utils.general import clip_gradients, count_parameters, nullcontext
from antmmf.utils.timer import Timer
from antmmf.trainers.base_trainer import BaseTrainer
from lib.utils.utils import cancel_gradients_backbone, EMA
from lib.utils.checkpoint import SegCheckpoint
try:
import atorch
from atorch import amp
except ImportError:
pass
@registry.register_trainer("seg_trainer")
class SEGTrainer(BaseTrainer):
def __init__(self, config):
super().__init__(config)
self.enable_torch_amp=True
self.enable_atorch_amp=False
def load(self, has_check_point=True):
super().load(has_check_point)
torch.backends.cuda.matmul.allow_tf32 = self.config.training_parameters.get(
"enable_tf32", False)
if hasattr(
self.config.training_parameters, "freeze_backbone"
) and self.config.training_parameters.freeze_backbone is True:
for n, p in self.model.named_parameters():
if "backbone_hr." in n or 'backbone_s2.' in n or 'head_s2.' in n or 'backbone_s1.' in n or 'head_s1.' in n or 'fusion.' in n or 'ctpe' in n or 'glbank.' in n:
p.requires_grad = False
else:
print(n, '-->', p.requires_grad)
if hasattr(self.config.training_parameters,
"ema") and self.config.training_parameters.ema is True:
self.ema = EMA(self.model, 0.96)
self.ema.register()
def load_extras(self, has_check_point=True):
self.checkpoint = None if has_check_point is False else SegCheckpoint(
self)
self.meter = Meter()
self.training_parameters = self.config.training_parameters
monitored_metric = self.training_parameters.monitored_metric
metric_minimize = self.training_parameters.metric_minimize
should_early_stop = self.training_parameters.should_early_stop
patience = self.training_parameters.patience
self.log_interval = self.training_parameters.log_interval
self.snapshot_interval = self.training_parameters.snapshot_interval
self.max_iterations = self.training_parameters.max_iterations
self.should_clip_gradients = self.training_parameters.clip_gradients
self.max_epochs = self.training_parameters.max_epochs
self.gradient_accumulation_steps = int(
self.training_parameters.gradient_accumulation_steps)
assert self.gradient_accumulation_steps >= 1
for t_type in self.task_loader.task_type:
if t_type == "train":
self.dataset_train_order = self.training_parameters.get(
"dataset_train_order", self.train_task.datasets_name)
if t_type == "val":
self.dataset_val_order = self.training_parameters.get(
"dataset_val_order", self.val_task.datasets_name)
if t_type == "test":
self.dataset_test_order = self.training_parameters.get(
"dataset_test_order", self.test_task.datasets_name)
if t_type == "interpret":
self.dataset_interpret_order = self.training_parameters.get(
"dataset_interpret_order",
self.interpret_task.datasets_name)
self.early_stopping = EarlyStopping(
self.model,
self.checkpoint,
monitored_metric,
patience=patience,
minimize=metric_minimize,
should_stop=should_early_stop,
)
self.current_epoch = 1
self.current_iteration = 0
self.not_debug = self.training_parameters.logger_level != "debug"
self.lr_scheduler = None
self.setup_lr_scheduler()
if self.checkpoint is not None:
self.checkpoint.load_state_dict()
if "overall_metrics" in self.training_parameters:
self.overall_metric_evaluator = Metrics(
self.config.training_parameters.get("overall_metrics", []))
self.synchronized_loss = self.config.training_parameters.synchronized_loss
def train(self):
self.writer.write("===== Model =====")
self.writer.write(self.model)
self.writer.write(
"Model Params: Trainable {Trainable:.3f}M Total {Total:.3f}M".
format(**count_parameters(self.model)))
if "train" not in self.run_type:
self.inference()
return
should_break = False
if self.max_epochs is None:
self.max_epochs = math.inf
else:
self.max_iterations = min(self.max_iterations,
self.max_epochs * self.epoch_iterations)
self.model.train()
self.train_timer = Timer()
self.profile("Setup Time")
if self.enable_torch_amp:
self.writer.write("Using Automatic mixed precision training")
if hasattr(self.config, "amp_attributes") and hasattr(
self.config.amp_attributes, "growth_interval"):
growth_interval = self.config.amp_attributes.growth_interval
else:
growth_interval = 2000
self.scaler = torch.cuda.amp.GradScaler(
init_scale=self.config.amp_attributes.init_scale,
enabled=False,
growth_interval=growth_interval)
self.writer.write("Using Init scale:%s" %
self.config.amp_attributes.init_scale)
self.optimizer.zero_grad()
self.writer.write("Starting training...")
while self.current_iteration < self.max_iterations and not should_break:
registry.register("current_epoch", self.current_epoch)
self.task_loader.seed_sampler("train", self.current_epoch)
if self.current_epoch > self.max_epochs:
break
for batch in tqdm(
chain(*self.train_loader_list),
total=self._len_of_loader_list(self.train_loader_list),
disable=self.disable_tqdm or (not is_main_process()),
):
self.profile("Batch load time")
report, _, _ = self._forward_pass(
batch, enable_amp=self.enable_torch_amp)
if report is None:
continue
self._update_meter(report, self.meter)
loss = self._extract_loss(report)
self._backward(loss)
if hasattr(
self.config.training_parameters,
"ema") and self.config.training_parameters.ema is True:
self.ema.update()
should_break = self._logistics()
self._run_scheduler()
self.current_iteration += 1
self.writer.write(self.current_iteration, "debug")
registry.register("current_iteration", self.current_iteration)
if self.current_iteration >= self.max_iterations:
break
if should_break:
break
self.current_epoch += 1
self.finalize()
def _forward_pass(self, batch, enable_amp=False):
if not batch: # Samplelist might be empty dict
return None, None, None
prepared_batch = self.task_loader.prepare_batch(batch)
self.profile("Batch prepare time")
forward_context = torch.cuda.amp.autocast(
enabled=True,
dtype=torch.bfloat16) if enable_amp else nullcontext()
with forward_context:
# Arguments should be a dict at this point
model_output = self.model(prepared_batch)
if self.synchronized_loss:
is_parallel = isinstance(
self.model, nn.DataParallel) or isinstance(
self.model, nn.parallel.DistributedDataParallel)
if "losses" not in model_output:
loss_func = getattr(
self.model.module if is_parallel else self.model,
"losses")
model_output["losses"] = loss_func(
prepared_batch,
model_output,
iteration=self.current_iteration)
if "metrics" not in model_output:
metric_func = getattr(
self.model.module if is_parallel else self.model,
"metrics")
model_output["metrics"] = metric_func(
prepared_batch, model_output)
report = Report(prepared_batch, model_output)
self.profile("Forward time")
return report, model_output, prepared_batch
def _backward(self, loss):
loss = loss / self.gradient_accumulation_steps
if self.enable_torch_amp:
self.scaler.scale(loss).backward()
# Unscales the gradients of optimizer's assigned params in-place, this should
# be called first so that clip_gradients can take effect as usual.
self.scaler.unscale_(self.optimizer)
elif self.enable_atorch_amp:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
self.profile("Backward time")
if self.current_iteration % self.gradient_accumulation_steps != 0:
return
if self.should_clip_gradients:
if self.enable_atorch_amp:
clip_gradients(amp.master_params(self.optimizer),
self.current_iteration, self.writer,
self.config)
else:
clip_gradients(self.model, self.current_iteration, self.writer,
self.config)
if hasattr(
self.config.training_parameters, "freeze_backbone_steps"
) and self.config.training_parameters.freeze_backbone_steps is not None:
cancel_gradients_backbone(
self.current_iteration, self.model,
self.config.training_parameters.freeze_backbone_steps)
if self.enable_torch_amp:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
self.profile("Optimizer time")
def _logistics(self):
should_print = self.current_iteration and self.current_iteration % self.log_interval == 0
extra = {}
prefix = ""
if should_print is True:
if "cuda" in str(self.device):
extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
extra["max mem"] //= 1024
# display lr
if isinstance(self.optimizer, CombinedOptimizer):
extra["lr"] = self.optimizer.get_optimizers_lr_str()
else:
extra["lr"] = "|".join([
"{:.8f}".format(x["lr"]).rstrip("0")
for x in self.optimizer.param_groups
])
extra.update({
"time": self.train_timer.get_time_since_start(),
"eta": self._calculate_time_left(),
})
self.train_timer.reset()
self._summarize_meter(
self.meter,
prefix=prefix,
extra=extra,
should_print=should_print,
)
should_break = self._try_full_validation()
return should_break
def _try_full_validation(self, force=False):
should_break = False
if self.current_iteration and self.current_iteration % self.snapshot_interval == 0 or force:
self.writer.write(
"Evaluation time. Running on full validation set...")
validation_timer = Timer()
dataset_name, meter = self.evaluate_set(self.val_loader_list)
extra = {
"validation time": validation_timer.get_time_since_start()
}
overall_metric = self.overall_metric_evaluator.summarize()
stop = self.early_stopping(self.current_iteration, overall_metric,
meter)
if hasattr(self.config.training_parameters,
"ema") and self.config.training_parameters.ema is True:
self.ema.restore()
stop = bool(broadcast_scalar(stop, src=0, device=self.device))
extra.update(self.early_stopping.get_info())
prefix = "{}: full val".format(dataset_name)
self._summarize_overall(overall_metric,
meter,
prefix=prefix,
extra=extra)
gc.collect()
if "cuda" in str(self.device):
with torch.cuda.device(self.device):
torch.cuda.empty_cache()
if stop > 0: # `stop` is now `int`, NCCL does not support `boolean` type's broadcasting
self.writer.write("Early stopping activated")
should_break = True
return should_break
def evaluate_set(self, loader_list):
from antmmf.structures import SampleList
meter = Meter()
torch.cuda.empty_cache()
with torch.no_grad():
self.model.eval()
if hasattr(self.config.training_parameters,
"ema") and self.config.training_parameters.ema is True:
self.ema.apply_shadow()
if self.config.training_parameters.get('fp16', False):
self.model.half()
self.overall_metric_evaluator.reset()
for idx, batch in tqdm(
enumerate(chain(*loader_list)),
total=self._len_of_loader_list(loader_list),
disable=not is_main_process() or self.disable_tqdm,
):
# report, model_output, prepared_batch = self._forward_pass(
# batch, enable_amp=self.enable_torch_amp)
if idx >= self.config.training_parameters.get('num_eval', 1e7):
break
if self.config.training_parameters.get('fp16', False):
input_dict = SampleList()
for k, v in batch.items():
if isinstance(v, torch.cuda.FloatTensor):
input_dict[k] = v.half()
else:
input_dict[k] = v
report, model_output, prepared_batch = self._forward_pass(
input_dict, enable_amp=self.enable_torch_amp)
else:
report, model_output, prepared_batch = self._forward_pass(
batch, enable_amp=self.enable_torch_amp)
self._update_meter(report, meter)
self.overall_metric_evaluator.collect(prepared_batch,
model_output)
for _, metric_object in self.overall_metric_evaluator.metrics.items(
):
metric_object.all_reduce()
self.model.train()
return report.dataset_name, meter

135
lib/utils/checkpoint.py Normal file
View File

@@ -0,0 +1,135 @@
# Copyright (c) Ant Financial Service Group. and its affiliates.
import os
import warnings
import torch
from antmmf.common import constants
from antmmf.common.registry import registry
from antmmf.common.checkpoint import Checkpoint
from antmmf.utils.distributed_utils import is_main_process
class SegCheckpoint(Checkpoint):
def __init__(self, trainer, load_only=False):
super().__init__(trainer, load_only=False)
def load_model_weights(self, file, force=False):
self.trainer.writer.write("Loading checkpoint")
ckpt = self._torch_load(file)
if registry.get(constants.STATE) is constants.STATE_ONLINE_SERVING:
data_parallel = False
else:
data_parallel = registry.get("data_parallel") or registry.get(
"distributed")
if "model" in ckpt:
ckpt_model = ckpt["model"]
else:
ckpt_model = ckpt
ckpt = {"model": ckpt}
new_dict = {}
# TODO: Move to separate function
for attr in ckpt_model:
if "fa_history" in attr:
new_dict[attr.replace("fa_history",
"fa_context")] = ckpt_model[attr]
elif data_parallel is False and attr.startswith("module."):
new_k = attr.replace("module.", "", 1)
if '.Wqkv.' in new_k:
new_k = new_k.replace('.Wqkv.', '.in_proj_')
new_dict[new_k] = ckpt_model[attr]
elif data_parallel is not False and not attr.startswith("module."):
new_dict["module." + attr] = ckpt_model[attr]
elif data_parallel is False and not attr.startswith("module."):
print('data_parallel is False and not attr!!!')
new_k = attr
if '.Wqkv.' in new_k:
new_k = new_k.replace('.Wqkv.', '.in_proj_')
new_dict[new_k] = ckpt_model[attr]
else:
new_dict[attr] = ckpt_model[attr]
print(new_dict.keys())
self._load_state_dict(new_dict)
self._load_model_weights_with_mapping(new_dict, force=force)
print(f'load weight: {file} done!')
return ckpt
def _load(self, file, force=False, resume_state=False):
ckpt = self.load_model_weights(file, force=force)
# skip loading training state
if resume_state is False:
return
if "optimizer" in ckpt:
try:
self.trainer.optimizer.load_state_dict(ckpt["optimizer"])
# fix the bug of checkpoint in the pytorch with version higher than 1.11
if "capturable" in self.trainer.optimizer.param_groups[0]:
self.trainer.optimizer.param_groups[0]["capturable"] = True
except Exception as e:
print(e)
else:
warnings.warn(
"'optimizer' key is not present in the checkpoint asked to be loaded. Skipping."
)
if "lr_scheduler" in ckpt:
self.trainer.lr_scheduler.load_state_dict(ckpt["lr_scheduler"])
else:
warnings.warn(
"'lr_scheduler' key is not present in the checkpoint asked to be loaded. Skipping."
)
self.trainer.early_stopping.init_from_checkpoint(ckpt)
self.trainer.writer.write("Checkpoint {} loaded".format(file))
if "current_iteration" in ckpt:
self.trainer.current_iteration = ckpt["current_iteration"]
registry.register("current_iteration",
self.trainer.current_iteration)
if "current_epoch" in ckpt:
self.trainer.current_epoch = ckpt["current_epoch"]
registry.register("current_epoch", self.trainer.current_epoch)
def save(self, iteration, update_best=False):
if not is_main_process():
return
ckpt_filepath = os.path.join(self.models_foldername,
"model_%d.ckpt" % iteration)
best_ckpt_filepath = os.path.join(self.ckpt_foldername,
self.ckpt_prefix + "best.ckpt")
best_iteration = self.trainer.early_stopping.best_monitored_iteration
best_metric = self.trainer.early_stopping.best_monitored_value
current_iteration = self.trainer.current_iteration
current_epoch = self.trainer.current_epoch
model = self.trainer.model
data_parallel = registry.get("data_parallel") or registry.get(
"distributed")
if data_parallel is True:
model = model.module
ckpt = {
"model": model.state_dict(),
"optimizer": self.trainer.optimizer.state_dict(),
"lr_scheduler": self.trainer.lr_scheduler.state_dict(),
"current_iteration": current_iteration,
"current_epoch": current_epoch,
"best_iteration": best_iteration,
"best_metric_value": best_metric,
}
torch.save(ckpt, ckpt_filepath)
self.remove_redundant_ckpts()
if update_best:
torch.save(ckpt, best_ckpt_filepath)

122
lib/utils/optim_utils.py Normal file
View File

@@ -0,0 +1,122 @@
import torch
from torch.nn import LayerNorm, Linear, GELU
from torch.nn import MultiheadAttention, Sequential
import warnings
try:
from atorch.normalization import LayerNorm as FastLayerNorm
from atorch.modules.transformer.inject import replace_module
from atorch.modules.transformer.layers import MultiheadAttentionFA, BertAttentionFA
except (ImportError, ModuleNotFoundError) as e:
warnings.warn("Using replace_speedup_op but no atorch/apex installed:%s" % e)
try:
from transformers.models.bert.modeling_bert import BertAttention
replace_transformer_bert = True
except ImportError:
replace_transformer_bert = False
class DefaultStrategy:
replace_mha = True
replace_layernorm = True
replace_linear_gelu = False # TODO: numerical consistency
def replace_layer_norm(module: torch.nn.Module, cur_name: str):
for name, child in module.named_children():
child_name = cur_name + "." + name
if isinstance(child, LayerNorm):
new_module = FastLayerNorm(child.normalized_shape, eps=child.eps)
new_module.load_state_dict(child.state_dict())
setattr(module, name, new_module)
else:
replace_layer_norm(child, child_name)
def is_atorch_available(raise_error=True, log=None):
try:
import atorch # noqa: F401
return True
except ImportError as e:
if raise_error is True:
raise ImportError(e, log)
else:
return False
def _cast_if_autocast_enabled(*args):
if not torch.is_autocast_enabled():
return args
else:
return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())
def _fused_dense_gelu_dense(input, weight1, bias1, weight2, bias2):
batch, seq_length, hidden_size = input.size()
input = input.view(batch * seq_length, hidden_size)
args = _cast_if_autocast_enabled(input, weight1, bias1, weight2, bias2)
from apex.fused_dense import FusedDenseGeluDenseFunc # with cast
with torch.cuda.amp.autocast(enabled=False):
out = FusedDenseGeluDenseFunc.apply(*args)
out = out.view(batch, seq_length, -1)
return out
def linear_gelu_forward(input_, weight1, bias1, weight2, bias2):
return _fused_dense_gelu_dense(input_, weight1, bias1, weight2, bias2)
def replace_linear_gelu(module, cur_name: str):
"""
(layers): Sequential(
(0): Sequential(
(0): Linear(in_features=1536, out_features=6144, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
)
(1): Linear(in_features=6144, out_features=1536, bias=True)
(2): Dropout(p=0.0, inplace=False)
)
"""
for name, child in module.named_children():
child_name = cur_name + "." + name
if isinstance(child, Sequential):
if len(child) >= 2 and isinstance(
child[0], Sequential
) and isinstance(
child[1], Linear
) and len(child[0]
) >= 2 and isinstance(
child[0][0], Linear
) and isinstance(
child[0][1], GELU
): # Sequential+Linear
linear0 = child[0][0]
linear1 = child[1]
if getattr(child, "replace_linear_gelu", False):
continue
child.forward = lambda x: linear_gelu_forward(
x, linear0.weight, linear0.bias, linear1.weight, linear1.bias)
child.replace_linear_gelu = True
print("REPLACE linear+gelu:%s" % child_name)
# setattr(module, name, new_module)
else:
replace_linear_gelu(child, child_name)
def replace_speedup_op(model, strategy=DefaultStrategy):
if not is_atorch_available(raise_error=False):
raise ImportError("Install Atorch/apex before using speedup op")
if strategy.replace_mha:
model = replace_module(model, MultiheadAttention, MultiheadAttentionFA, need_scr_module=True)
if replace_transformer_bert:
model = replace_module(model, BertAttention, BertAttentionFA, need_scr_module=True)
root_name = model.__class__.__name__
if strategy.replace_layernorm:
replace_layer_norm(model, root_name) # inplace
if strategy.replace_linear_gelu:
replace_linear_gelu(model, root_name)
return model
# TODO:
# 1. SyncBatchNorm

243
lib/utils/utils.py Normal file
View File

@@ -0,0 +1,243 @@
import numpy as np
def cosine_scheduler(base_value,
final_value,
all_iters,
warmup_iters=0,
start_warmup_value=0):
warmup_schedule = np.array([])
if warmup_iters > 0:
warmup_schedule = np.linspace(start_warmup_value, base_value,
warmup_iters)
iters = np.arange(all_iters - warmup_iters)
schedule = final_value + 0.5 * (base_value - final_value) * (
1 + np.cos(np.pi * iters / len(iters)))
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == all_iters
return schedule
def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
if epoch >= freeze_last_layer:
return
for n, p in model.named_parameters():
if "last_layer" in n:
p.grad = None
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=float)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def cancel_gradients_backbone(iteration, model, freeze_backbone_steps):
if iteration >= freeze_backbone_steps:
return
for n, p in model.named_parameters():
if "backbon_hr" in n or 'backbon_s2' in n or 'head_s2' in n or 'fusion' in n or 'ctpe' in n:
p.grad = None
class EMA():
def __init__(self, model, decay):
self.model = model
self.decay = decay
self.shadow = {}
self.backup = {}
def register(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
new_average = (1.0 - self.decay
) * param.data + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone()
def apply_shadow(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
self.backup[name] = param.data
param.data = self.shadow[name]
def restore(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
class LayerDecayValueAssigner(object):
def __init__(self, layer_decay, num_layers, base_lr, net_type, arch='huge'):
assert net_type in ['swin', 'vit']
assert 0 < layer_decay <= 1
depths_dict = {
'tiny': [2, 2, 6, 2],
'small': [2, 2, 18, 2],
'base': [2, 2, 18, 2],
'large': [2, 2, 18, 2],
'huge': [2, 2, 18, 2],
'giant': [2, 2, 42, 4],
}
num_layers = num_layers if net_type == 'vit' else sum(depths_dict[arch])
self.layer_decay = layer_decay
self.values = list(layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))
self.depths = depths_dict[arch]
self.base_lr = base_lr
self.net_type = net_type
def get_num_layer_for_vit(self, var_name):
if var_name in ("cls_token", "mask_token", "pos_embed"):
return 0
elif var_name.startswith("patch_embed"):
return 0
elif var_name.startswith("layers"):
layer_id = int(var_name.split('.')[1])
return layer_id + 1
else:
return len(self.values) - 1
def get_num_layer_for_swin(self, var_name):
if var_name in ("mask_token", "pos_embed"):
return 0
elif var_name.startswith("patch_embed"):
return 0
elif var_name.startswith("stages"):
layer_id = int(var_name.split('.')[1])
if 'blocks' in var_name:
block_id = int(var_name.split('.')[3])
else:
block_id = self.depths[layer_id] - 1
layer_id = sum(self.depths[:layer_id]) + block_id
return layer_id + 1
else:
return len(self.values) - 1
def get_layer_id(self, var_name):
if self.net_type == 'swin':
return self.get_num_layer_for_swin(var_name)
if self.net_type == 'vit':
return self.get_num_layer_for_vit(var_name)
def fix_param(self, model, num_block=4):
if num_block < 1:
return 0
frozen_num = 0
if self.net_type == 'swin':
for name, param in model.named_parameters():
if name.startswith("patch_embed"):
param.requires_grad = False
frozen_num += 1
if name.startswith("stages") and self.get_layer_id(name) <= num_block:
param.requires_grad = False
frozen_num += 1
if self.net_type == 'vit':
for name, param in model.named_parameters():
if name.startswith("patch_embed"):
param.requires_grad = False
frozen_num += 1
if name.startswith("layers") and self.get_layer_id(name) <= num_block:
param.requires_grad = False
frozen_num += 1
return frozen_num
def fix_param_deeper(self, model, num_block=4):
if num_block < 1:
return 0
frozen_num = 0
if self.net_type == 'swin':
raise ValueError('Not Support')
if self.net_type == 'vit':
for name, param in model.named_parameters():
if name.startswith("patch_embed"):
param.requires_grad = False
frozen_num += 1
if name.startswith("layers") and self.get_layer_id(name) >= num_block:
param.requires_grad = False
frozen_num += 1
return frozen_num
def get_parameter_groups(self, model, weight_decay):
parameter_groups_with_wd, parameter_groups_without_wd = [], []
print_info_with_wd, print_info_without_wd = [], []
no_decay = [
"absolute_pos_embed", "relative_position_bias_table", "norm", "bias"
]
if self.layer_decay == 1:
parameter_groups_with_wd.append(
{"params": [], "weight_decay": weight_decay, "lr": self.base_lr}
)
print_info_with_wd.append(
{"params": [], "weight_decay": weight_decay, "lr": self.base_lr}
)
parameter_groups_without_wd.append(
{"params": [], "weight_decay": 0, "lr": self.base_lr}
)
print_info_without_wd.append(
{"params": [], "weight_decay": 0, "lr": self.base_lr}
)
else:
for scale in self.values:
parameter_groups_with_wd.append(
{"params": [], "weight_decay": weight_decay, "lr": scale * self.base_lr}
)
print_info_with_wd.append(
{"params": [], "weight_decay": weight_decay, "lr": scale * self.base_lr}
)
parameter_groups_without_wd.append(
{"params": [], "weight_decay": 0, "lr": scale * self.base_lr}
)
print_info_without_wd.append(
{"params": [], "weight_decay": 0, "lr": scale * self.base_lr}
)
for name, param in model.named_parameters():
if not param.requires_grad:
print(f'frozen param: {name}')
continue # frozen weights
layer_id = self.get_layer_id(name) if self.layer_decay < 1 else 0
if any(nd in name for nd in no_decay):
parameter_groups_without_wd[layer_id]['params'].append(param)
print_info_without_wd[layer_id]['params'].append(name)
else:
parameter_groups_with_wd[layer_id]['params'].append(param)
print_info_with_wd[layer_id]['params'].append(name)
parameter_groups_with_wd = [x for x in parameter_groups_with_wd if len(x['params']) > 0]
parameter_groups_without_wd = [x for x in parameter_groups_without_wd if len(x['params']) > 0]
print_info_with_wd = [x for x in print_info_with_wd if len(x['params']) > 0]
print_info_without_wd = [x for x in print_info_without_wd if len(x['params']) > 0]
if self.layer_decay < 1:
for wd, nwd in zip(print_info_with_wd, print_info_without_wd):
print(wd)
print(nwd)
parameter_groups = []
parameter_groups.extend(parameter_groups_with_wd)
parameter_groups.extend(parameter_groups_without_wd)
return parameter_groups