init
This commit is contained in:
3
lib/datasets/__init__.py
Normal file
3
lib/datasets/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .builder import PretrainingBuilder
|
||||
|
||||
__all__ = ["PretrainingBuilder"]
|
||||
18
lib/datasets/builder.py
Normal file
18
lib/datasets/builder.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from antmmf.common.registry import registry
|
||||
from antmmf.datasets.base_dataset_builder import BaseDatasetBuilder
|
||||
from .loader.pretraining_loader import PretrainingLoader
|
||||
|
||||
@registry.register_builder("pretraining_loader")
|
||||
class PretrainingBuilder(BaseDatasetBuilder):
|
||||
def __init__(self):
|
||||
super().__init__("pretraining_loader")
|
||||
|
||||
def _build(self, dataset_type, config, *args, **kwargs):
|
||||
return None
|
||||
|
||||
def _load(self, dataset_type, config, *args, **kwargs):
|
||||
self.dataset = PretrainingLoader(dataset_type, config)
|
||||
return self.dataset
|
||||
|
||||
def update_registry_for_model(self, config):
|
||||
pass
|
||||
0
lib/datasets/loader/__init__.py
Normal file
0
lib/datasets/loader/__init__.py
Normal file
289
lib/datasets/loader/few_shot_flood3i_loader.py
Normal file
289
lib/datasets/loader/few_shot_flood3i_loader.py
Normal file
@@ -0,0 +1,289 @@
|
||||
import os
|
||||
import json
|
||||
import datetime
|
||||
import random
|
||||
import itertools
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from antmmf.structures import Sample
|
||||
from antmmf.datasets.base_dataset import BaseDataset
|
||||
from antmmf.common import Configuration
|
||||
|
||||
from lib.datasets.utils.transforms import Compose, MSNormalize
|
||||
from lib.datasets.utils.formatting import ToTensor
|
||||
import lib.datasets.utils.pair_trainsforms as pair_transforms
|
||||
|
||||
from skimage import io
|
||||
from osgeo import gdal
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class FewShotFloodLoader(BaseDataset):
|
||||
DATASET_NAME = "few_shot_flood_loader"
|
||||
|
||||
def __init__(self, dataset_type, config):
|
||||
super().__init__(self.__class__.DATASET_NAME, dataset_type, config)
|
||||
if dataset_type == 'train':
|
||||
raise ValueError('train mode not support!!!')
|
||||
|
||||
self.root = config.data_root_dir
|
||||
self.dataset_type = dataset_type
|
||||
self.img_dir = config.img_dir
|
||||
self.tgt_dir = config.tgt_dir
|
||||
with open(config.data_txt, 'r') as f:
|
||||
test_list = f.readlines()
|
||||
self.test_pairs = []
|
||||
self.cls2path = {}
|
||||
for i in test_list:
|
||||
i = i.strip()
|
||||
if i == '':
|
||||
continue
|
||||
img_path = i[:-3]
|
||||
cls = int(i[-2:])
|
||||
cls = int(cls)
|
||||
self.test_pairs.append(
|
||||
{'hr_path': img_path,
|
||||
'class': cls,
|
||||
'tgt_path': img_path.replace('_', '_lab_', 1).replace('.jpg', '.png')
|
||||
})
|
||||
if cls in self.cls2path.keys():
|
||||
self.cls2path[cls].append({'hr_path': img_path, 'tgt_path': img_path.replace('_', '_lab_', 1).replace('.jpg', '.png'), 'class': cls})
|
||||
else:
|
||||
self.cls2path[cls] = [{'hr_path': img_path, 'tgt_path': img_path.replace('_', '_lab_', 1).replace('.jpg', '.png'), 'class': cls}]
|
||||
|
||||
self.seq_len = config.seq_len # ts
|
||||
self.hr_size = config.image_size.hr
|
||||
self.s2_size = config.image_size.s2
|
||||
self.s1_size = config.image_size.s1
|
||||
self.anno_size = config.image_size.anno
|
||||
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
|
||||
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]) # 先不管
|
||||
self.config = config
|
||||
self.pipeline = self._get_pipline()
|
||||
# self.crop_resize = pair_transforms.RandomResizedCropComb(512, scale=(0.99, 1.0), interpolation=3)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.test_pairs)
|
||||
|
||||
def _combine_two_images(self, image, image2):
|
||||
dst = torch.cat([image, image2], dim=-2)
|
||||
return dst
|
||||
|
||||
def _get_pipline(self):
|
||||
if self.dataset_type == 'val' or self.dataset_type == 'test':
|
||||
pipeline = [
|
||||
pair_transforms.ToTensor(),
|
||||
pair_transforms.RandomResizedCrop(512, scale=(0.9999, 1.0), interpolation=3),
|
||||
pair_transforms.Normalize(),
|
||||
]
|
||||
else:
|
||||
raise ValueError('dataset_type not support')
|
||||
return pair_transforms.Compose(pipeline)
|
||||
|
||||
def _load_data(self, data_path):
|
||||
file_name, file_extension = os.path.splitext(data_path)
|
||||
if file_extension == '.npz' or file_extension == '.npy':
|
||||
npz_key = self.config.get('npz_key', 'image')
|
||||
data = np.load(data_path)[npz_key]
|
||||
elif file_extension == '.png' or file_extension == '.jpg':
|
||||
data = io.imread(data_path)
|
||||
if len(data.shape) == 3:
|
||||
data = data.transpose(2, 0, 1)
|
||||
elif file_extension == '.tiff' or file_extension == '.tif':
|
||||
dataset = gdal.Open(data_path)
|
||||
if dataset is None:
|
||||
raise IOError(f'can not open file: {data_path}')
|
||||
data = dataset.ReadAsArray()
|
||||
dataset = None
|
||||
else:
|
||||
raise ValueError(f'file type {data_path} not support')
|
||||
# check nan
|
||||
if np.isnan(data).any():
|
||||
print(f'{data_path} with nan, replace it to 0!')
|
||||
data[np.isnan(data)] = 0
|
||||
return data
|
||||
|
||||
def load_s2(self, pair):
|
||||
if 'l8_path' in pair.keys():
|
||||
pair['s2_path'] = pair['l8_path']
|
||||
|
||||
if 's2_path' in pair.keys() and not self.config.get('masking_s2', False):
|
||||
with_s2 = True
|
||||
if isinstance(pair['s2_path'], list):
|
||||
if True: # len(pair['s2_path']) > self.seq_len:
|
||||
s2_path_list = np.random.choice(pair['s2_path'], self.seq_len)
|
||||
s2_path_list = sorted(s2_path_list)
|
||||
else:
|
||||
s2_path_list = pair['s2_path']
|
||||
s2_list = []
|
||||
s2_ct_1 = []
|
||||
for s2_path in s2_path_list:
|
||||
s2 = self._load_data(os.path.join(self.root, s2_path)) # [:10]
|
||||
s2_list.append(s2)
|
||||
ct = os.path.splitext(s2_path)[0].split('_')
|
||||
ct = ct[3] # + ct[-3] + '01'
|
||||
try:
|
||||
ct = datetime.datetime.strptime(ct, '%Y%m%d')
|
||||
except:
|
||||
ct = datetime.datetime.strptime(ct, '%Y-%m-%d')
|
||||
ct = ct.timetuple()
|
||||
ct = ct.tm_yday - 1
|
||||
s2_ct_1.append(ct)
|
||||
s2_1 = np.stack(s2_list, axis=1)
|
||||
|
||||
else:
|
||||
|
||||
s2 = np.load(os.path.join(self.root, pair['s2_path']))['image']
|
||||
date = np.load(os.path.join(self.root, pair['s2_path']))['date']
|
||||
if True: # s2.shape[0] > self.seq_len:
|
||||
selected_indices = np.random.choice(s2.shape[0], size=self.seq_len, replace=False)
|
||||
selected_indices = sorted(selected_indices)
|
||||
s2 = s2[selected_indices, :, :, :]
|
||||
date = date[selected_indices]
|
||||
s2_1 = s2.transpose(1, 0, 2, 3) # ts, c, h, w -> c, ts, h, w
|
||||
s2_ct_1 = []
|
||||
for ct in date:
|
||||
try:
|
||||
ct = datetime.datetime.strptime(ct, '%Y%m%d')
|
||||
except:
|
||||
ct = datetime.datetime.strptime(ct, '%Y-%m-%d')
|
||||
ct = ct.timetuple()
|
||||
ct = ct.tm_yday - 1
|
||||
s2_ct_1.append(ct)
|
||||
|
||||
else:
|
||||
with_s2 = False
|
||||
s2_1 = np.zeros((10, self.seq_len, self.s2_size[0], self.s2_size[1]),
|
||||
dtype=np.int16)
|
||||
s2_ct_1 = [0] * self.seq_len
|
||||
|
||||
return with_s2, s2_1, s2_ct_1
|
||||
|
||||
def load_s1(self, pair):
|
||||
if 's1_path' in pair.keys():
|
||||
with_s1 = True
|
||||
if isinstance(pair['s1_path'], list):
|
||||
if True: # len(pair['s1_path']) > self.seq_len:
|
||||
s1_path_list = np.random.choice(pair['s1_path'], self.seq_len)
|
||||
s1_path_list = sorted(s1_path_list)
|
||||
else:
|
||||
s1_path_list = pair['s1_path']
|
||||
s1_list = []
|
||||
for s1_path in s1_path_list:
|
||||
s1 = self._load_data(os.path.join(self.root, s1_path))
|
||||
s1_list.append(s1)
|
||||
s1_1 = np.stack(s1_list, axis=1)
|
||||
else:
|
||||
s1 = self._load_data(os.path.join(self.root, pair['s1_path']))
|
||||
if True: # s1.shape[0] > self.seq_len:
|
||||
selected_indices = np.random.choice(s1.shape[0], size=self.seq_len, replace=False)
|
||||
selected_indices = sorted(selected_indices)
|
||||
s1 = s1[selected_indices, :, :, :]
|
||||
s1_1 = s1.transpose(1, 0, 2, 3) # ts, c, h, w -> c, ts, h, w
|
||||
else:
|
||||
with_s1 = False
|
||||
s1_1 = np.zeros((2, self.seq_len, self.s1_size[0], self.s1_size[1]),
|
||||
dtype=np.float32)
|
||||
return with_s1, s1_1
|
||||
|
||||
def load_hr(self, pair):
|
||||
if 'hr_path' in pair.keys():
|
||||
with_hr = True
|
||||
hr = self._load_data(os.path.join(self.root, pair['hr_path']))
|
||||
else:
|
||||
with_hr = False
|
||||
hr = np.zeros((3, self.hr_size[0], self.hr_size[1]),
|
||||
dtype=np.uint8)
|
||||
return with_hr, hr
|
||||
|
||||
def load_tgt(self, pair):
|
||||
targets = self._load_data(os.path.join(self.root, pair['target_path']))
|
||||
return targets
|
||||
|
||||
def get_item(self, idx):
|
||||
pair = self.test_pairs[idx]
|
||||
test_class = pair['class']
|
||||
|
||||
current_dataset = 'flood3i'
|
||||
with_hr = True
|
||||
with_s2 = False
|
||||
with_s1 = False
|
||||
|
||||
input_hr = io.imread(os.path.join(self.img_dir, pair['hr_path']))
|
||||
input_hr = input_hr.transpose(2,0,1)
|
||||
_, input_s2,_ = self.load_s2(pair)
|
||||
_, input_s1 = self.load_s1(pair)
|
||||
input_tgt = io.imread(os.path.join(self.tgt_dir, pair['tgt_path']))
|
||||
modality_dict = {
|
||||
's2': with_s2,
|
||||
's1': with_s1,
|
||||
'hr': with_hr
|
||||
}
|
||||
|
||||
|
||||
input_tgt[input_tgt != test_class] = 0
|
||||
input_tgt[input_tgt == test_class] = 255
|
||||
input_tgt = np.concatenate((input_tgt[None, :,:],)*3, axis=0)
|
||||
input_hr, input_s2, input_s1, input_tgt = self.pipeline(current_dataset, input_hr, input_s2, input_s1,
|
||||
input_tgt)
|
||||
|
||||
while True:
|
||||
sel_prompt = random.choice(self.cls2path[test_class])
|
||||
if sel_prompt['hr_path'] != pair['hr_path']:
|
||||
break
|
||||
prompt_hr = io.imread(os.path.join(self.img_dir, sel_prompt['hr_path']))
|
||||
prompt_hr = prompt_hr.transpose(2,0,1)
|
||||
_, prompt_s2,_ = self.load_s2(pair)
|
||||
_, prompt_s1 = self.load_s1(pair)
|
||||
prompt_tgt = io.imread(os.path.join(self.tgt_dir, sel_prompt['tgt_path']))
|
||||
|
||||
prompt_tgt[prompt_tgt != test_class] = 0
|
||||
prompt_tgt[prompt_tgt == test_class] = 255
|
||||
prompt_tgt = np.concatenate((prompt_tgt[None, :,:],)*3, axis=0)
|
||||
|
||||
prompt_hr, prompt_s2, prompt_s1, prompt_tgt = self.pipeline(current_dataset, prompt_hr, prompt_s2, prompt_s1, prompt_tgt)
|
||||
|
||||
targets_comb = self._combine_two_images(prompt_tgt, input_tgt)
|
||||
hr_comb = self._combine_two_images(prompt_hr, input_hr)
|
||||
s2_comb = self._combine_two_images(prompt_s2, input_s2)
|
||||
s1_comb = self._combine_two_images(prompt_s1, input_s1)
|
||||
|
||||
valid = torch.ones_like(targets_comb)
|
||||
thres = torch.ones(3) * 1e-5 # ignore black
|
||||
thres = (thres - self.imagenet_mean) / self.imagenet_std
|
||||
valid[targets_comb < thres[:, None, None]] = 0
|
||||
|
||||
mask_shape = (int(self.config.mim.input_size[0] / self.config.mim.patch_size),
|
||||
int(self.config.mim.input_size[1] / self.config.mim.patch_size))
|
||||
mask = np.zeros(mask_shape, dtype=np.int32)
|
||||
mask[mask.shape[0] // 2:, :] = 1
|
||||
|
||||
geo_location = pair["location"] if "location" in pair.keys() else None
|
||||
|
||||
modality_idx = 2 ** 0 * modality_dict['s2'] + 2 ** 1 * modality_dict['s1'] + 2 ** 2 * modality_dict['hr']
|
||||
modality_flag_s2 = modality_dict['s2']
|
||||
modality_flag_s1 = modality_dict['s1']
|
||||
modality_flag_hr = modality_dict['hr']
|
||||
|
||||
current_sample = Sample()
|
||||
current_sample.img_name = pair["tgt_path"].split('/')[-1].split('.')[0] + '-' +str(test_class)
|
||||
current_sample.hr_img = hr_comb
|
||||
current_sample.dataset_name = 'flood3i'
|
||||
current_sample.targets = targets_comb
|
||||
current_sample.s2_img = s2_comb
|
||||
current_sample.s2_ct = -1
|
||||
current_sample.s2_ct2 = -1
|
||||
current_sample.s1_img = s1_comb
|
||||
current_sample.anno_mask = torch.from_numpy(mask)
|
||||
current_sample.valid = valid
|
||||
current_sample.location = geo_location
|
||||
current_sample.modality_idx = modality_idx
|
||||
current_sample.modality_flag_s2 = modality_flag_s2
|
||||
current_sample.modality_flag_s1 = modality_flag_s1
|
||||
current_sample.modality_flag_hr = modality_flag_hr
|
||||
current_sample.task_type = self.dataset_type
|
||||
return current_sample
|
||||
494
lib/datasets/loader/pretraining_loader.py
Normal file
494
lib/datasets/loader/pretraining_loader.py
Normal file
@@ -0,0 +1,494 @@
|
||||
import os
|
||||
import json
|
||||
import datetime
|
||||
import random
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from osgeo import gdal
|
||||
from skimage import io
|
||||
from skimage.transform import resize
|
||||
|
||||
from antmmf.structures import Sample
|
||||
from antmmf.datasets.base_dataset import BaseDataset
|
||||
|
||||
import lib.datasets.utils.pair_trainsforms as pair_transforms
|
||||
from lib.datasets.utils.masking_generator import MaskingGenerator
|
||||
from lib.datasets.utils.dataset_colors import dataset_color_dict, get_painter_color_map_list, get_real_random_color_list
|
||||
|
||||
|
||||
class PretrainingLoader(BaseDataset):
|
||||
DATASET_NAME = "pretraining_loader"
|
||||
|
||||
def __init__(self, dataset_type, config):
|
||||
super().__init__(self.__class__.DATASET_NAME, dataset_type, config)
|
||||
self.root = config.data_root_dir
|
||||
if dataset_type == 'train':
|
||||
self.json_path_list = config.train_json_path_list
|
||||
if dataset_type == 'val':
|
||||
self.json_path_list = config.val_json_path_list
|
||||
if dataset_type == 'test':
|
||||
self.json_path_list = config.val_json_path_list
|
||||
self.dataset_type = dataset_type
|
||||
self.pairs = []
|
||||
self.cls_repeat_cnt = config.cls_repeat_cnt
|
||||
num_datasets = len(self.json_path_list)
|
||||
for idx, json_path in enumerate(self.json_path_list):
|
||||
print(os.path.join(config.data_root_dir, json_path))
|
||||
cur_pairs = json.load(open(os.path.join(config.data_root_dir, json_path)))
|
||||
self.pairs.extend(cur_pairs)
|
||||
cur_num = len(cur_pairs)
|
||||
|
||||
if dataset_type == 'test' and config.prompt_json:
|
||||
cur_pairs = json.load(open(config.prompt_json))
|
||||
self.prompt = cur_pairs[0]
|
||||
print(f'prompt:{self.prompt}')
|
||||
|
||||
self.use_multi_pairs = config.use_multi_pairs
|
||||
|
||||
if self.use_multi_pairs:
|
||||
self.pair_type_dict = {}
|
||||
if dataset_type == 'train' or dataset_type == 'val':
|
||||
for idx, pair in enumerate(self.pairs):
|
||||
if pair["type"] not in self.pair_type_dict:
|
||||
new_subset = {}
|
||||
classes = pair["classes"]
|
||||
for cls in classes:
|
||||
if cls not in new_subset.keys():
|
||||
new_subset[cls] = [idx]
|
||||
else:
|
||||
new_subset[cls].append(idx)
|
||||
self.pair_type_dict[pair["type"]] = new_subset
|
||||
else:
|
||||
classes = pair["classes"]
|
||||
for cls in classes:
|
||||
if cls not in self.pair_type_dict[pair["type"]].keys():
|
||||
self.pair_type_dict[pair["type"]][cls] = [idx]
|
||||
else:
|
||||
self.pair_type_dict[pair["type"]][cls].append(idx)
|
||||
|
||||
cnt = 0
|
||||
self.idx_to_cls = {}
|
||||
for k, v in self.pair_type_dict.items():
|
||||
for vv in v:
|
||||
self.idx_to_cls[cnt] = {
|
||||
'type': k,
|
||||
'classes_id': vv
|
||||
}
|
||||
cnt = cnt + 1
|
||||
|
||||
print(self.idx_to_cls)
|
||||
self.idx_to_cls_list = []
|
||||
for i in self.idx_to_cls.keys():
|
||||
self.idx_to_cls_list.append(self.idx_to_cls[i])
|
||||
print(self.idx_to_cls_list)
|
||||
if self.dataset_type == 'train':
|
||||
self.idx_to_cls_list = self.idx_to_cls_list * self.cls_repeat_cnt
|
||||
self.masked_position_generator = MaskingGenerator(
|
||||
input_size=config.mim.input_size,
|
||||
patch_size=config.mim.patch_size,
|
||||
mask_ratio=config.mim.mask_ratio
|
||||
)
|
||||
if dataset_type == 'train':
|
||||
self.half_mask_ratio = config.half_mask_ratio
|
||||
else:
|
||||
self.half_mask_ratio = 1.
|
||||
|
||||
self.seq_len = config.seq_len # ts
|
||||
self.hr_size = config.image_size.hr
|
||||
self.s2_size = config.image_size.s2
|
||||
self.s1_size = config.image_size.s1
|
||||
self.anno_size = config.image_size.anno
|
||||
self.min_random_scale = config.min_random_scale
|
||||
self.imagenet_mean=torch.tensor([0.485, 0.456, 0.406])
|
||||
self.imagenet_std=torch.tensor([0.229, 0.224, 0.225])
|
||||
|
||||
self.pipeline = self._get_pipline()
|
||||
self.crop_resize = pair_transforms.RandomResizedCropComb(512, scale=(0.3, 1.0), interpolation=3)
|
||||
self.num_samples = 8
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.idx_to_cls_list)
|
||||
|
||||
def _convert_colors_pairs(self, images, original_colors, new_colors, current_color):
|
||||
if len(original_colors) != len(new_colors):
|
||||
raise ValueError("The length of original_colors and new_colors must be the same.")
|
||||
unique_colors_list = []
|
||||
for image in images:
|
||||
if len(image.shape) == 3:
|
||||
image_hwc = image.transpose(1,2,0) # chw -> hwc
|
||||
elif len(image.shape) == 2:
|
||||
image_hwc = image[:,:,None]
|
||||
else:
|
||||
raise ValueError('image shape is {image_hwc.shape}, which is not support to change color!')
|
||||
|
||||
image_2d = image_hwc.reshape(-1, image_hwc.shape[-1])
|
||||
unique_colors = np.unique(image_2d, axis=0)
|
||||
unique_colors_list.append(unique_colors)
|
||||
unique_colors_list.append(original_colors)
|
||||
|
||||
sets_of_tuples = [set(map(tuple, a)) for a in unique_colors_list]
|
||||
common_tuples = set.intersection(*sets_of_tuples)
|
||||
unique_old_colors = np.array(list(common_tuples), dtype=np.uint8)
|
||||
if len(unique_old_colors) == 0:
|
||||
unique_old_colors = [current_color]
|
||||
new_colors_coverted = new_colors[:len(unique_old_colors)]
|
||||
images_converted_list = []
|
||||
|
||||
for image in images:
|
||||
image_convered = self._convert_colors(image, unique_old_colors, new_colors_coverted)
|
||||
images_converted_list.append(image_convered)
|
||||
|
||||
return images_converted_list
|
||||
|
||||
def _convert_colors(self, image, original_colors, new_colors):
|
||||
"""
|
||||
Remap colors in an image to new colors.
|
||||
|
||||
Parameters:
|
||||
image (numpy.ndarray): The image as a numpy array (channel x height x width).
|
||||
original_colors (list of tuples): The list of original colors to be replaced.
|
||||
new_colors (list of tuples): The list of new colors to replace the original colors.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: The image with remapped colors. (channel x height x width)
|
||||
"""
|
||||
|
||||
if len(original_colors) != len(new_colors):
|
||||
raise ValueError("The length of original_colors and new_colors must be the same.")
|
||||
|
||||
# Convert lists of tuples to numpy arrays for faster processing
|
||||
original_colors = np.array(original_colors)
|
||||
new_colors = np.array(new_colors)
|
||||
if len(original_colors.shape) == 1:
|
||||
original_colors = original_colors[:,None]
|
||||
|
||||
# check image shape
|
||||
if len(image.shape) == 3:
|
||||
remapped_image = image.transpose(1,2,0) # chw -> hwc
|
||||
elif len(image.shape) == 2:
|
||||
remapped_image = image[:,:,None]
|
||||
else:
|
||||
raise ValueError('image shape is {image.shape}, which is not support to change color!')
|
||||
|
||||
# generate new image for return
|
||||
new_image = np.zeros((remapped_image.shape[0], remapped_image.shape[1], 3), dtype=np.uint8)
|
||||
|
||||
for orig_color, new_color in zip(original_colors, new_colors):
|
||||
mask = np.all(remapped_image == orig_color, axis=-1)
|
||||
new_image[mask] = new_color
|
||||
|
||||
new_image = new_image.transpose(2,0,1) # hwc -> chw
|
||||
return new_image
|
||||
|
||||
def _combine_images(self, images, interpolation='bicubic'):
|
||||
# images 8, c, h, w -> c, 4h, 2w
|
||||
group1 = images[:4]
|
||||
group2 = images[4:]
|
||||
stacked1 = torch.cat(group1, dim=-2)
|
||||
stacked2 = torch.cat(group2, dim=-2)
|
||||
result = torch.cat((stacked1, stacked2), dim=-1)
|
||||
|
||||
return result
|
||||
|
||||
def _get_pipline(self):
|
||||
if self.dataset_type == 'train':
|
||||
pipeline = [
|
||||
pair_transforms.ToTensor(),
|
||||
pair_transforms.RandomResizedCrop(512, scale=(0.8, 1.0), interpolation=3), # 3 is bicubic
|
||||
pair_transforms.RandomHorizontalFlip(),
|
||||
pair_transforms.Normalize(),
|
||||
]
|
||||
elif self.dataset_type == 'val' or self.dataset_type == 'test':
|
||||
pipeline = [
|
||||
pair_transforms.ToTensor(),
|
||||
pair_transforms.RandomResizedCrop(512, scale=(0.9999, 1.0), interpolation=3), # 3 is bicubic
|
||||
pair_transforms.Normalize(),
|
||||
]
|
||||
else:
|
||||
raise ValueError('dataset_type not support')
|
||||
return pair_transforms.Compose(pipeline)
|
||||
|
||||
def _load_data(self, data_path):
|
||||
file_name, file_extension = os.path.splitext(data_path)
|
||||
if file_extension == '.npz' or file_extension == '.npy':
|
||||
data = np.load(data_path)['image']
|
||||
elif file_extension == '.png' or file_extension == '.jpg':
|
||||
data = io.imread(data_path)
|
||||
if len(data.shape) == 3:
|
||||
data = data.transpose(2,0,1)
|
||||
elif file_extension == '.tiff' or file_extension == '.tif':
|
||||
dataset = gdal.Open(data_path)
|
||||
if dataset is None:
|
||||
raise IOError(f'无法打开文件{data_path}')
|
||||
data = dataset.ReadAsArray()
|
||||
dataset = None
|
||||
else:
|
||||
raise ValueError(f'file type {data_path} not support')
|
||||
if np.isnan(data).any():
|
||||
print(f'{data_path} with nan, replace it to 0!')
|
||||
data[np.isnan(data)] = 0
|
||||
return data
|
||||
|
||||
def load_s2(self, pair):
|
||||
if pair['type'] == 'flair-mm' and 's2_path' in pair.keys():
|
||||
with_s2 =True
|
||||
s2 = np.load(os.path.join(self.root, pair['s2_path']))
|
||||
idx_centroid = pair['s2_cut_points']
|
||||
s2_patch_size = 40
|
||||
subset_sp = s2[:,:,idx_centroid[0]-int(s2_patch_size/2):idx_centroid[0] + \
|
||||
int(s2_patch_size/2),idx_centroid[1] - int(s2_patch_size/2):idx_centroid[1] + \
|
||||
int(s2_patch_size/2)]
|
||||
ts, c, h, w = subset_sp.shape
|
||||
subset_sp = subset_sp.reshape(-1, h, w).transpose(1,2,0)
|
||||
s2 = resize(subset_sp, (16, 16), anti_aliasing=True).transpose(2,0,1)
|
||||
s2 = s2.reshape(ts, c, 16, 16)
|
||||
if True:
|
||||
selected_indices = np.random.choice(s2.shape[0], size=self.seq_len, replace=False)
|
||||
selected_indices = sorted(selected_indices)
|
||||
s2 = s2[selected_indices, :, :, :]
|
||||
|
||||
s2_1 = s2.transpose(1,0,2,3) # ts, c, h, w -> c, ts, h, w
|
||||
s2_ct_1 = [0] * self.seq_len
|
||||
|
||||
elif 's2_path' in pair.keys():
|
||||
with_s2 =True
|
||||
if isinstance(pair['s2_path'], list):
|
||||
if True:
|
||||
s2_path_list = np.random.choice(pair['s2_path'], self.seq_len)
|
||||
s2_path_list = sorted(s2_path_list)
|
||||
else:
|
||||
s2_path_list = pair['s2_path']
|
||||
s2_list = []
|
||||
s2_ct_1 = []
|
||||
for s2_path in s2_path_list:
|
||||
s2 = self._load_data(os.path.join(self.root, s2_path))#[:10]
|
||||
s2_list.append(s2)
|
||||
ct = os.path.splitext(s2_path)[0].split('_')
|
||||
ct = ct[-4] + ct[-3] + '01'
|
||||
try:
|
||||
ct = datetime.datetime.strptime(ct, '%Y%m%d')
|
||||
except:
|
||||
ct = datetime.datetime.strptime(ct, '%Y-%m-%d')
|
||||
ct = ct.timetuple()
|
||||
ct = ct.tm_yday - 1
|
||||
s2_ct_1.append(ct)
|
||||
s2_1 = np.stack(s2_list, axis=1)
|
||||
|
||||
else:
|
||||
s2 = np.load(os.path.join(self.root, pair['s2_path']))['image']
|
||||
date = np.load(os.path.join(self.root, pair['s2_path']))['date']
|
||||
if True:
|
||||
selected_indices = np.random.choice(s2.shape[0], size=self.seq_len, replace=False)
|
||||
selected_indices = sorted(selected_indices)
|
||||
s2 = s2[selected_indices, :, :, :]
|
||||
date = date[selected_indices]
|
||||
s2_1 = s2.transpose(1,0,2,3) # ts, c, h, w -> c, ts, h, w
|
||||
s2_ct_1 = []
|
||||
for ct in date:
|
||||
try:
|
||||
ct = datetime.datetime.strptime(ct, '%Y%m%d')
|
||||
except:
|
||||
ct = datetime.datetime.strptime(ct, '%Y-%m-%d')
|
||||
ct = ct.timetuple()
|
||||
ct = ct.tm_yday - 1
|
||||
s2_ct_1.append(ct)
|
||||
else:
|
||||
with_s2 = False
|
||||
s2_1 = np.zeros((10, self.seq_len, self.s2_size[0], self.s2_size[1]),
|
||||
dtype=np.int16)
|
||||
s2_ct_1 = [0] * self.seq_len
|
||||
|
||||
return with_s2, s2_1, s2_ct_1
|
||||
|
||||
def load_s1(self, pair):
|
||||
if 's1_path' in pair.keys():
|
||||
with_s1 = True
|
||||
if isinstance(pair['s1_path'], list):
|
||||
if True:
|
||||
s1_path_list = np.random.choice(pair['s1_path'], self.seq_len)
|
||||
s1_path_list = sorted(s1_path_list)
|
||||
else:
|
||||
s1_path_list = pair['s1_path']
|
||||
s1_list = []
|
||||
for s1_path in s1_path_list:
|
||||
s1 = self._load_data(os.path.join(self.root, s1_path))
|
||||
s1_list.append(s1)
|
||||
s1_1 = np.stack(s1_list, axis=1)
|
||||
else:
|
||||
s1 = self._load_data(os.path.join(self.root, pair['s1_path']))
|
||||
if True:
|
||||
selected_indices = np.random.choice(s1.shape[0], size=self.seq_len, replace=False)
|
||||
selected_indices = sorted(selected_indices)
|
||||
s1 = s1[selected_indices, :, :, :]
|
||||
s1_1 = s1.transpose(1,0,2,3) # ts, c, h, w -> c, ts, h, w
|
||||
else:
|
||||
with_s1 = False
|
||||
s1_1 = np.zeros((2, self.seq_len, self.s1_size[0], self.s1_size[1]),
|
||||
dtype=np.float32)
|
||||
return with_s1, s1_1
|
||||
|
||||
def load_hr(self, pair):
|
||||
if 'hr_path' in pair.keys():
|
||||
if pair['type'] == 'flair-mm':
|
||||
with_hr = True
|
||||
hr = self._load_data(os.path.join(self.root, pair['hr_path']))[:3,:,:]
|
||||
else:
|
||||
with_hr = True
|
||||
hr = self._load_data(os.path.join(self.root, pair['hr_path']))
|
||||
else:
|
||||
with_hr = False
|
||||
hr = np.zeros((3, self.hr_size[0], self.hr_size[1]),
|
||||
dtype=np.uint8)
|
||||
return with_hr, hr
|
||||
|
||||
def load_tgt(self, pair):
|
||||
if self.dataset_type == 'test':
|
||||
targets = np.zeros((3, self.anno_size[0], self.anno_size[1]),
|
||||
dtype=np.uint8)
|
||||
else:
|
||||
targets = self._load_data(os.path.join(self.root, pair['target_path']))
|
||||
return targets
|
||||
|
||||
def find_random_position(self, matrix, current_color):
|
||||
if matrix.ndim == 2:
|
||||
matrix = matrix[None, :, :]
|
||||
current_color = np.array(current_color)
|
||||
C, H, W = matrix.shape
|
||||
|
||||
if len(current_color) != C:
|
||||
raise ValueError("current_color unmatch with matrix!")
|
||||
|
||||
matches = np.where(np.all(matrix == current_color[:, None, None], axis=0))
|
||||
|
||||
if len(matches[0]) > 0:
|
||||
index = np.random.choice(range(len(matches[0])))
|
||||
return (matches[0][index], matches[1][index])
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_item(self, idx):
|
||||
dataset_cls_infos = self.idx_to_cls_list[idx]
|
||||
current_dataset = dataset_cls_infos['type']
|
||||
current_classes_id = dataset_cls_infos['classes_id']
|
||||
pair_idx_list = self.pair_type_dict[current_dataset][current_classes_id]
|
||||
|
||||
old_colors = dataset_color_dict[current_dataset]
|
||||
current_color = old_colors[current_classes_id]
|
||||
class_num = len(old_colors)
|
||||
if self.dataset_type == 'train':
|
||||
new_colors = get_real_random_color_list(class_num)
|
||||
else:
|
||||
new_colors = get_painter_color_map_list(class_num) # fix colors mapping when testing
|
||||
|
||||
num_samples = self.num_samples
|
||||
if len(pair_idx_list) < num_samples:
|
||||
selected_samples = [random.choice(pair_idx_list) for _ in range(num_samples)]
|
||||
else:
|
||||
selected_samples = random.sample(pair_idx_list, num_samples)
|
||||
hr_imgs = []
|
||||
tgt_imgs = []
|
||||
s2_imgs = []
|
||||
s1_imgs = []
|
||||
s2_cts = []
|
||||
for sample_idx in selected_samples:
|
||||
pair = self.pairs[sample_idx]
|
||||
with_hr, hr = self.load_hr(pair)
|
||||
with_s2, s2, s2_ct_1 = self.load_s2(pair)
|
||||
with_s1, s1 = self.load_s1(pair)
|
||||
tgt = self.load_tgt(pair)
|
||||
modality_dict = {
|
||||
's2' : with_s2,
|
||||
's1' : with_s1,
|
||||
'hr' : with_hr
|
||||
}
|
||||
|
||||
if (hr.shape[-2:] != tuple(self.hr_size)) and (hr.shape[-2:] == tgt.shape[-2:]) and (self.hr_size == self.anno_size):
|
||||
point_pos = self.find_random_position(tgt, current_color)
|
||||
upper_left_raw = [point_pos[0] - self.hr_size[0] // 2, point_pos[1] - self.hr_size[1] // 2]
|
||||
upper_left = [i - i%32 + 16 for i in upper_left_raw]
|
||||
upper_left_sentinel = [i // 32 for i in upper_left_raw]
|
||||
upper_left[0] = np.clip(np.array(upper_left[0]), 0, hr.shape[-2] - self.hr_size[0])
|
||||
upper_left[1] = np.clip(np.array(upper_left[1]), 0, hr.shape[-1] - self.hr_size[1])
|
||||
|
||||
upper_left_sentinel[0] = np.clip(np.array(upper_left_sentinel[0]), 0, s1.shape[-2] - self.s1_size[0])
|
||||
upper_left_sentinel[1] = np.clip(np.array(upper_left_sentinel[1]), 0, s1.shape[-1] - self.s1_size[1])
|
||||
hr = hr[:, upper_left[0]:upper_left[0]+self.hr_size[0], upper_left[1]:upper_left[1]+self.hr_size[1]]
|
||||
if with_s1:
|
||||
s1 = s1[:, :, upper_left_sentinel[0]:upper_left_sentinel[0]+self.s1_size[0], upper_left_sentinel[1]:upper_left_sentinel[1]+self.s1_size[1]]
|
||||
if with_s2:
|
||||
s2 = s2[:, :, upper_left_sentinel[0]:upper_left_sentinel[0]+self.s2_size[0], upper_left_sentinel[1]:upper_left_sentinel[1]+self.s2_size[1]]
|
||||
if tgt.ndim == 3:
|
||||
tgt = tgt[:, upper_left[0]:upper_left[0]+self.hr_size[0], upper_left[1]:upper_left[1]+self.hr_size[1]]
|
||||
elif tgt.ndim == 2:
|
||||
tgt = tgt[upper_left[0]:upper_left[0]+self.hr_size[0], upper_left[1]:upper_left[1]+self.hr_size[1]]
|
||||
else:
|
||||
raise ValueError("tgt dim unsupport!")
|
||||
hr_imgs.append(hr)
|
||||
tgt_imgs.append(tgt)
|
||||
s2_imgs.append(s2)
|
||||
s1_imgs.append(s1)
|
||||
s2_cts.append(s2_ct_1)
|
||||
|
||||
|
||||
cvt_hr_imgs = []
|
||||
cvt_tgt_imgs = []
|
||||
cvt_s2_imgs = []
|
||||
cvt_s1_imgs = []
|
||||
|
||||
tgt_imgs = self._convert_colors_pairs(tgt_imgs, old_colors, new_colors, current_color)
|
||||
for i in range(len(tgt_imgs)):
|
||||
hr, s2, s1, tgt = self.pipeline(current_dataset, hr_imgs[i], s2_imgs[i], s1_imgs[i], tgt_imgs[i])
|
||||
cvt_hr_imgs.append(hr)
|
||||
cvt_s2_imgs.append(s2)
|
||||
cvt_s1_imgs.append(s1)
|
||||
cvt_tgt_imgs.append(tgt)
|
||||
|
||||
targets_comb = self._combine_images(cvt_tgt_imgs)
|
||||
hr_comb = self._combine_images(cvt_hr_imgs)
|
||||
s2_comb = self._combine_images(cvt_s2_imgs)
|
||||
s1_comb = self._combine_images(cvt_s1_imgs)
|
||||
hr_comb, s2_comb, s1_comb, targets_comb = self.crop_resize(current_dataset, hr_comb, s2_comb, s1_comb, targets_comb)
|
||||
use_half_mask = torch.rand(1)[0] < self.half_mask_ratio
|
||||
valid = torch.ones_like(targets_comb)
|
||||
|
||||
thres = torch.ones(3) * (1e-5) # ignore black
|
||||
thres = (thres - self.imagenet_mean) / self.imagenet_std
|
||||
valid[targets_comb < thres[:, None, None]] = 0
|
||||
|
||||
if use_half_mask:
|
||||
num_patches = self.masked_position_generator.num_patches
|
||||
mask = np.zeros(self.masked_position_generator.get_shape(), dtype=np.int32)
|
||||
mask[mask.shape[0]//2:, :] = 1
|
||||
else:
|
||||
mask = self.masked_position_generator()
|
||||
|
||||
# location
|
||||
geo_location = pair["location"] if "location" in pair.keys() else None
|
||||
|
||||
# get modality index
|
||||
modality_idx = 2**0 * modality_dict['s2'] + 2**1 * modality_dict['s1'] + 2**2 * modality_dict['hr']
|
||||
modality_flag_s2 = modality_dict['s2']
|
||||
modality_flag_s1 = modality_dict['s1']
|
||||
modality_flag_hr = modality_dict['hr']
|
||||
|
||||
|
||||
current_sample = Sample()
|
||||
current_sample.img_name = pair["hr_path"].split('/')[-1].split('.')[0]
|
||||
current_sample.hr_img = hr_comb
|
||||
current_sample.dataset_name = pair["type"]
|
||||
current_sample.targets = targets_comb
|
||||
current_sample.s2_img = s2_comb
|
||||
current_sample.s2_ct = s2_cts[0]
|
||||
current_sample.s2_ct2 = s2_cts[4]
|
||||
current_sample.s1_img = s1_comb
|
||||
current_sample.anno_mask = torch.from_numpy(mask)
|
||||
current_sample.valid = valid
|
||||
current_sample.location = geo_location
|
||||
current_sample.modality_idx = modality_idx
|
||||
current_sample.modality_flag_s2 = modality_flag_s2
|
||||
current_sample.modality_flag_s1 = modality_flag_s1
|
||||
current_sample.modality_flag_hr = modality_flag_hr
|
||||
current_sample.task_type = self.dataset_type
|
||||
|
||||
return current_sample
|
||||
77
lib/datasets/utils/dataset_colors.py
Normal file
77
lib/datasets/utils/dataset_colors.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import random
|
||||
from functools import lru_cache
|
||||
import numpy as np
|
||||
|
||||
dataset_color_dict = {
|
||||
"potsdam" : [[1], [2], [3], [4], [5]],
|
||||
"vaihingen" : [[255, 255, 0], [0, 255, 0], [0, 255, 255], [0, 0, 255], [255, 255, 255]],
|
||||
"deepglobe" : [[255,255,255], [0,0,255], [0,255,0],[255,0,255], [255,255,0], [0,255,255]],
|
||||
"fbp" : [[i+1] for i in range(24)],
|
||||
"loveda" : [[i+2, i+2, i+2] for i in range(6)],
|
||||
"isaid" : [[i+1] for i in range(15)],
|
||||
"pastis-mm" : [[i+1] for i in range(18)],
|
||||
"dynamic-mm" : [[i] for i in range(7)],
|
||||
"c2seg-ab" : [[i+1] for i in range(13)],
|
||||
"flood3i": [[i+1] for i in range(9)],
|
||||
"jl16-mm": [[i] for i in range(16)],
|
||||
"flair-mm": [[i+1] for i in range(18)],
|
||||
"dfc20": [[i+1] for i in range(10)]
|
||||
}
|
||||
|
||||
|
||||
modal_norm_dict = {
|
||||
'hr' : {
|
||||
'div' : 255.,
|
||||
'mean' : [0.485, 0.456, 0.406],
|
||||
'std' : [0.229, 0.224, 0.225]
|
||||
},
|
||||
'anno' : {
|
||||
'div' : 255.,
|
||||
'mean' : [0.485, 0.456, 0.406],
|
||||
'std' : [0.229, 0.224, 0.225]
|
||||
},
|
||||
's2' : {
|
||||
'div' : 1.,
|
||||
'mean' : [884.29673756, 1144.16202635, 1297.47289228, 1624.90992062, 2194.6423161, 2422.21248945, 2517.76053101, 2581.64687018, 2368.51236873, 1805.06846033],
|
||||
'std' : [1155.15170768, 1183.6292542, 1368.11351514, 1370.265037, 1355.55390699, 1416.51487101, 1474.78900051, 1439.3086061, 1455.52084939, 1343.48379601]
|
||||
},
|
||||
's1' : {
|
||||
'div' : 1.,
|
||||
'mean' : [-12.54847273, -20.19237134],
|
||||
'std' : [5.25697717, 5.91150917]
|
||||
},
|
||||
}
|
||||
|
||||
@lru_cache()
|
||||
def get_painter_color_map_list(num_locations = 300):
|
||||
|
||||
num_sep_per_channel = int(num_locations ** (1 / 3)) + 1 # 19
|
||||
separation_per_channel = 256 // num_sep_per_channel
|
||||
|
||||
color_list = []
|
||||
for location in range(num_locations):
|
||||
num_seq_r = location // num_sep_per_channel ** 2
|
||||
num_seq_g = (location % num_sep_per_channel ** 2) // num_sep_per_channel
|
||||
num_seq_b = location % num_sep_per_channel
|
||||
assert (num_seq_r <= num_sep_per_channel) and (num_seq_g <= num_sep_per_channel) \
|
||||
and (num_seq_b <= num_sep_per_channel)
|
||||
|
||||
R = 255 - num_seq_r * separation_per_channel
|
||||
G = 255 - num_seq_g * separation_per_channel
|
||||
B = 255 - num_seq_b * separation_per_channel
|
||||
assert (R < 256) and (G < 256) and (B < 256)
|
||||
assert (R >= 0) and (G >= 0) and (B >= 0)
|
||||
assert (R, G, B) not in color_list
|
||||
|
||||
color_list.append((R, G, B))
|
||||
|
||||
return color_list
|
||||
|
||||
|
||||
def get_real_random_color_list(num_locations):
|
||||
random_color_list = np.random.randint(0, 256, (num_locations, 3))
|
||||
while np.sum(random_color_list) == 0:
|
||||
print('random_color_list is 0!')
|
||||
random_color_list = np.random.randint(0, 256, (num_locations, 3))
|
||||
random_color_list = random_color_list.tolist()
|
||||
return random_color_list # [:num_locations]
|
||||
65
lib/datasets/utils/formatting.py
Normal file
65
lib/datasets/utils/formatting.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from collections.abc import Sequence
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def to_tensor(data):
|
||||
"""Convert objects of various python types to :obj:`torch.Tensor`.
|
||||
|
||||
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
||||
:class:`Sequence`, :class:`int` and :class:`float`.
|
||||
|
||||
Args:
|
||||
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
|
||||
be converted.
|
||||
"""
|
||||
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data
|
||||
elif isinstance(data, np.ndarray):
|
||||
return torch.from_numpy(data)
|
||||
elif isinstance(data, Sequence) and not mmcv.is_str(data):
|
||||
return torch.tensor(data)
|
||||
elif isinstance(data, int):
|
||||
return torch.LongTensor([data])
|
||||
elif isinstance(data, float):
|
||||
return torch.FloatTensor([data])
|
||||
else:
|
||||
raise TypeError(f'type {type(data)} cannot be converted to tensor.')
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
"""Convert some sample to :obj:`torch.Tensor` by given keys.
|
||||
|
||||
Args:
|
||||
keys (Sequence[str]): Keys that need to be converted to Tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, keys):
|
||||
self.keys = keys
|
||||
|
||||
def __call__(self, sample):
|
||||
"""Call function to convert data in sample to :obj:`torch.Tensor`.
|
||||
|
||||
Args:
|
||||
sample (Sample): sample data contains the data to convert.
|
||||
|
||||
Returns:
|
||||
dict: The result dict contains the data converted
|
||||
to :obj:`torch.Tensor`.
|
||||
"""
|
||||
|
||||
for key in self.keys:
|
||||
if isinstance(sample[key], list):
|
||||
for i in range(len(sample[key])):
|
||||
sample[key][i] = to_tensor(sample[key][i])
|
||||
else:
|
||||
sample[key] = to_tensor(sample[key])
|
||||
return sample
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(keys={self.keys})'
|
||||
|
||||
|
||||
|
||||
84
lib/datasets/utils/masking_generator.py
Normal file
84
lib/datasets/utils/masking_generator.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
class MaskingGenerator:
|
||||
def __init__(
|
||||
self, input_size, patch_size, mask_ratio=0.5, min_num_patches=4, max_num_patches=None,
|
||||
min_aspect=0.3, max_aspect=None):
|
||||
if not isinstance(input_size, list):
|
||||
input_size = [input_size,] * 2
|
||||
self.height = input_size[0] // patch_size
|
||||
self.width = input_size[1] // patch_size
|
||||
|
||||
self.num_patches = self.height * self.width
|
||||
self.num_masking_patches = int(self.num_patches * mask_ratio)
|
||||
|
||||
self.min_num_patches = min_num_patches
|
||||
self.max_num_patches = self.num_masking_patches if max_num_patches is None else max_num_patches
|
||||
|
||||
max_aspect = max_aspect or 1 / min_aspect
|
||||
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
|
||||
self.height, self.width, self.min_num_patches, self.max_num_patches,
|
||||
self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1])
|
||||
return repr_str
|
||||
|
||||
def get_shape(self):
|
||||
return self.height, self.width
|
||||
|
||||
def _mask(self, mask, max_mask_patches):
|
||||
delta = 0
|
||||
for attempt in range(10):
|
||||
target_area = random.uniform(self.min_num_patches, max_mask_patches)
|
||||
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
if w < self.width and h < self.height:
|
||||
top = random.randint(0, self.height - h)
|
||||
left = random.randint(0, self.width - w)
|
||||
|
||||
num_masked = mask[top: top + h, left: left + w].sum()
|
||||
# Overlap
|
||||
if 0 < h * w - num_masked <= max_mask_patches:
|
||||
for i in range(top, top + h):
|
||||
for j in range(left, left + w):
|
||||
if mask[i, j] == 0:
|
||||
mask[i, j] = 1
|
||||
delta += 1
|
||||
|
||||
if delta > 0:
|
||||
break
|
||||
return delta
|
||||
|
||||
def __call__(self):
|
||||
mask = np.zeros(shape=self.get_shape(), dtype=np.int32)
|
||||
mask_count = 0
|
||||
while mask_count < self.num_masking_patches:
|
||||
max_mask_patches = self.num_masking_patches - mask_count
|
||||
max_mask_patches = min(max_mask_patches, self.max_num_patches)
|
||||
|
||||
delta = self._mask(mask, max_mask_patches)
|
||||
if delta == 0:
|
||||
break
|
||||
else:
|
||||
mask_count += delta
|
||||
|
||||
# maintain a fix number {self.num_masking_patches}
|
||||
if mask_count > self.num_masking_patches:
|
||||
delta = mask_count - self.num_masking_patches
|
||||
mask_x, mask_y = mask.nonzero()
|
||||
to_vis = np.random.choice(mask_x.shape[0], delta, replace=False)
|
||||
mask[mask_x[to_vis], mask_y[to_vis]] = 0
|
||||
|
||||
elif mask_count < self.num_masking_patches:
|
||||
delta = self.num_masking_patches - mask_count
|
||||
mask_x, mask_y = (mask == 0).nonzero()
|
||||
to_mask = np.random.choice(mask_x.shape[0], delta, replace=False)
|
||||
mask[mask_x[to_mask], mask_y[to_mask]] = 1
|
||||
|
||||
assert mask.sum() == self.num_masking_patches, f"mask: {mask}, mask count {mask.sum()}"
|
||||
|
||||
return mask
|
||||
532
lib/datasets/utils/pair_trainsforms.py
Normal file
532
lib/datasets/utils/pair_trainsforms.py
Normal file
@@ -0,0 +1,532 @@
|
||||
import math
|
||||
import numbers
|
||||
import random
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torchvision.transforms as transforms
|
||||
from skimage import io
|
||||
|
||||
|
||||
try:
|
||||
import accimage
|
||||
except ImportError:
|
||||
accimage = None
|
||||
|
||||
import torchvision.transforms.functional as F
|
||||
from torchvision.transforms.functional import _interpolation_modes_from_int, InterpolationMode
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
|
||||
from .dataset_colors import modal_norm_dict
|
||||
|
||||
__all__ = [
|
||||
"Compose",
|
||||
"ToTensor",
|
||||
"Normalize",
|
||||
"RandomHorizontalFlip",
|
||||
"RandomResizedCrop",
|
||||
]
|
||||
|
||||
|
||||
|
||||
class Compose(transforms.Compose):
|
||||
"""Composes several transforms together. This transform does not support torchscript.
|
||||
Please, see the note below.
|
||||
Args:
|
||||
transforms (list of ``Transform`` objects): list of transforms to compose.
|
||||
"""
|
||||
|
||||
def __init__(self, transforms):
|
||||
super().__init__(transforms)
|
||||
|
||||
def __call__(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None):
|
||||
# i = 0
|
||||
for t in self.transforms:
|
||||
# i = i+1
|
||||
# print(f'dataset_name:{dataset_name}')
|
||||
# print(f'step:{i}')
|
||||
# print(f'hr_img shape:{hr_img.shape}')
|
||||
# print(f's2_img shape:{s2_img.shape}')
|
||||
# print(f's1_img shape:{s1_img.shape}')
|
||||
# print(f'tgt shape:{tgt.shape}')
|
||||
|
||||
hr_img, s2_img, s1_img, tgt = t(dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=interpolation1, interpolation2=interpolation2)
|
||||
return hr_img, s2_img, s1_img, tgt
|
||||
|
||||
|
||||
class ToTensor(transforms.ToTensor):
|
||||
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
|
||||
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
|
||||
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
|
||||
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
|
||||
or if the numpy.ndarray has dtype = np.uint8
|
||||
In the other cases, tensors are returned without scaling.
|
||||
.. note::
|
||||
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
|
||||
transforming target image masks. See the `references`_ for implementing the transforms for image masks.
|
||||
.. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None):
|
||||
"""
|
||||
Args:
|
||||
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
|
||||
Returns:
|
||||
Tensor: Converted image.
|
||||
"""
|
||||
|
||||
|
||||
# print(f'hr dtype:{hr_img.dtype}')
|
||||
# print(f's2_img dtype:{s2_img.dtype}')
|
||||
# print(f's1_img dtype:{s1_img.dtype}')
|
||||
# print(f'tgt dtype:{tgt.dtype}')
|
||||
if dataset_name == 'dynamic-mm' or dataset_name == 'guizhou-mm':
|
||||
hr_img = hr_img.astype(np.int32)[:3,:,:]
|
||||
hr_img = hr_img[::-1,:,:].copy()
|
||||
else:
|
||||
hr_img = hr_img.astype(np.int32)
|
||||
tgt = tgt.astype(np.uint8)
|
||||
s1_img = s1_img.astype(np.float32)
|
||||
s2_img = s2_img.astype(np.int16)
|
||||
|
||||
return torch.tensor(hr_img), torch.tensor(s2_img), torch.tensor(s1_img),torch.tensor(tgt)
|
||||
|
||||
|
||||
class Normalize(transforms.Normalize):
|
||||
"""Normalize a tensor image with mean and standard deviation.
|
||||
This transform does not support PIL Image.
|
||||
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
|
||||
channels, this transform will normalize each channel of the input
|
||||
``torch.*Tensor`` i.e.,
|
||||
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
|
||||
.. note::
|
||||
This transform acts out of place, i.e., it does not mutate the input tensor.
|
||||
Args:
|
||||
mean (sequence): Sequence of means for each channel.
|
||||
std (sequence): Sequence of standard deviations for each channel.
|
||||
inplace(bool,optional): Bool to make this operation in-place.
|
||||
"""
|
||||
|
||||
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False):
|
||||
super().__init__(mean, std, inplace)
|
||||
|
||||
def forward(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None):
|
||||
"""
|
||||
Args:
|
||||
tensor (Tensor): Tensor image to be normalized.
|
||||
Returns:
|
||||
Tensor: Normalized Tensor image.
|
||||
"""
|
||||
# TODO 查询对应的mean和std
|
||||
|
||||
# 处理一些mean和std
|
||||
if dataset_name == 'dynamic-mm':
|
||||
hr_std = [1008.4052, 760.9586, 631.4754]
|
||||
hr_mean = [1085.2941, 944.2718, 689.2493]
|
||||
hr_div = 1.
|
||||
else:
|
||||
hr_mean = modal_norm_dict['hr']['mean']
|
||||
hr_std = modal_norm_dict['hr']['std']
|
||||
hr_div = modal_norm_dict['hr']['div']
|
||||
|
||||
if dataset_name == 'l8activefire':
|
||||
# if False:
|
||||
s2_mean = modal_norm_dict['l8']['mean']
|
||||
s2_std = modal_norm_dict['l8']['std']
|
||||
s2_div = modal_norm_dict['l8']['div']
|
||||
else:
|
||||
s2_mean = modal_norm_dict['s2']['mean']
|
||||
s2_std = modal_norm_dict['s2']['std']
|
||||
s2_div = modal_norm_dict['s2']['div']
|
||||
|
||||
s1_mean = modal_norm_dict['s1']['mean']
|
||||
s1_std = modal_norm_dict['s1']['std']
|
||||
s1_div = modal_norm_dict['s1']['div']
|
||||
|
||||
anno_mean = [0.485, 0.456, 0.406]
|
||||
anno_std = [0.229, 0.224, 0.225]
|
||||
ann_div = 255.
|
||||
|
||||
# 存在问题:时间序列这样处理是否会出错
|
||||
|
||||
#mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
|
||||
#std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
|
||||
# print(s2_img.shape)
|
||||
# import pdb; pdb.set_trace()
|
||||
# print(s2_img)
|
||||
try:
|
||||
ch, ts, h, w = s2_img.shape
|
||||
except:
|
||||
print(f's2: {s2_img.shape}, s1: {s1_img.shape}')
|
||||
s2_img = s2_img.view(ch, ts*h, w)
|
||||
s2_img = self.normalize(s2_img.type(torch.float32), s2_mean, s2_std, self.inplace)
|
||||
s2_img = s2_img.view(ch, ts, h, w)
|
||||
|
||||
ch, ts, h, w = s1_img.shape
|
||||
s1_img = s1_img.view(ch, ts*h, w)
|
||||
s1_img = self.normalize(s1_img.type(torch.float32), s1_mean, s1_std, self.inplace)
|
||||
s1_img = s1_img.view(ch, ts, h, w)
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
# print(s2_img.shape, s2_img[:,0,:,:])
|
||||
# print(s1_img.shape, s1_img[:,0,:,:])
|
||||
# print(hr_mean, hr_std, hr_div)
|
||||
return self.normalize(hr_img.type(torch.float32).div_(hr_div), hr_mean, hr_std, self.inplace), \
|
||||
s2_img, \
|
||||
s1_img, \
|
||||
self.normalize(tgt.type(torch.float32).div_(ann_div) , anno_mean, anno_std, self.inplace)
|
||||
|
||||
def normalize(self, tensor, mean, std, inplace):
|
||||
dtype = tensor.dtype
|
||||
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
|
||||
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
|
||||
if (std == 0).any():
|
||||
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
|
||||
mean = mean.view(-1, 1, 1)
|
||||
std = std.view(-1, 1, 1)
|
||||
# print(f'tensor shape: {tensor.shape}')
|
||||
# print(f'mean shape: {mean.shape}')
|
||||
# print(f'std shape: {std.shape}')
|
||||
return tensor.sub_(mean).div_(std)
|
||||
|
||||
|
||||
class RandomResizedCrop(transforms.RandomResizedCrop):
|
||||
"""Crop a random portion of image and resize it to a given size.
|
||||
If the image is torch Tensor, it is expected
|
||||
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
|
||||
A crop of the original image is made: the crop has a random area (H * W)
|
||||
and a random aspect ratio. This crop is finally resized to the given
|
||||
size. This is popularly used to train the Inception networks.
|
||||
Args:
|
||||
size (int or sequence): expected output size of the crop, for each edge. If size is an
|
||||
int instead of sequence like (h, w), a square output size ``(size, size)`` is
|
||||
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
||||
.. note::
|
||||
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
|
||||
scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
|
||||
before resizing. The scale is defined with respect to the area of the original image.
|
||||
ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before
|
||||
resizing.
|
||||
interpolation (InterpolationMode): Desired interpolation enum defined by
|
||||
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
||||
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
|
||||
``InterpolationMode.BICUBIC`` are supported.
|
||||
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
|
||||
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
scale=(0.08, 1.0),
|
||||
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
||||
interpolation=InterpolationMode.BILINEAR,
|
||||
mode='small'
|
||||
):
|
||||
super().__init__(size, scale=scale, ratio=ratio, interpolation=interpolation)
|
||||
self.cnt=0
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None, mode='small'):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image or Tensor): Image to be cropped and resized.
|
||||
Returns:
|
||||
PIL Image or Tensor: Randomly cropped and resized image.
|
||||
"""
|
||||
i, j, h, w = self.get_params(s2_img, self.scale, self.ratio)
|
||||
size_hr = hr_img.shape[-1]
|
||||
size_s2 = s2_img.shape[-1]
|
||||
size_anno = tgt.shape[-1]
|
||||
# 映射到其他模态
|
||||
ratio_s2_hr = size_s2 / size_hr
|
||||
i_hr = int(i / ratio_s2_hr)
|
||||
j_hr = int(j / ratio_s2_hr)
|
||||
h_hr = int(h / ratio_s2_hr)
|
||||
w_hr = int(w / ratio_s2_hr)
|
||||
|
||||
ratio_s2_anno = size_s2 / size_anno
|
||||
i_anno = int(i / ratio_s2_anno)
|
||||
j_anno = int(j / ratio_s2_anno)
|
||||
h_anno = int(h / ratio_s2_anno)
|
||||
w_anno = int(w / ratio_s2_anno)
|
||||
|
||||
if interpolation1 == 'nearest':
|
||||
interpolation1 = InterpolationMode.NEAREST
|
||||
else:
|
||||
interpolation1 = InterpolationMode.BICUBIC
|
||||
if interpolation2 == 'nearest':
|
||||
interpolation2 = InterpolationMode.NEAREST
|
||||
else:
|
||||
interpolation2 = InterpolationMode.BICUBIC
|
||||
# import pdb;pdb.set_trace()
|
||||
if self.scale[0]>0.99 and self.scale[0]<1.0:
|
||||
if self.mode=='small':
|
||||
resized_s2_img = F.resize(s2_img, (16,16), interpolation=InterpolationMode.BICUBIC)
|
||||
resized_hr_img = F.resize(hr_img, (512, 512), interpolation=InterpolationMode.BICUBIC)
|
||||
resized_s1_img = F.resize(s1_img, (16,16), interpolation=InterpolationMode.BICUBIC)
|
||||
resized_tgt = F.resize(tgt, (512,512), interpolation=InterpolationMode.NEAREST)
|
||||
else:
|
||||
resized_s2_img = F.resize(s2_img, (64,64), interpolation=InterpolationMode.BICUBIC)
|
||||
resized_hr_img = F.resize(hr_img, (2048, 2048), interpolation=InterpolationMode.BICUBIC)
|
||||
resized_s1_img = F.resize(s1_img, (64,64), interpolation=InterpolationMode.BICUBIC)
|
||||
resized_tgt = F.resize(tgt, (2048,2048), interpolation=InterpolationMode.NEAREST)
|
||||
return resized_hr_img, resized_s2_img, resized_s1_img, resized_tgt
|
||||
|
||||
if self.mode=='small':
|
||||
resized_s2_img = F.resized_crop(s2_img, i, j, h, w, (16, 16), InterpolationMode.BICUBIC)
|
||||
resized_hr_img = F.resized_crop(hr_img, i_hr, j_hr, h_hr, w_hr, (512, 512), InterpolationMode.BICUBIC)
|
||||
resized_s1_img = F.resized_crop(s1_img, i, j, h, w, (16, 16), InterpolationMode.BICUBIC)
|
||||
resized_tgt = F.resized_crop(tgt, i_anno, j_anno, h_anno, w_anno, (512, 512), InterpolationMode.NEAREST)
|
||||
else:
|
||||
resized_s2_img = F.resized_crop(s2_img, i, j, h, w, (512, 512), InterpolationMode.BICUBIC)
|
||||
resized_hr_img = F.resized_crop(hr_img, i_hr, j_hr, h_hr, w_hr, (2048,2048), InterpolationMode.BICUBIC)
|
||||
resized_s1_img = F.resized_crop(s1_img, i, j, h, w, (512, 512), InterpolationMode.BICUBIC)
|
||||
resized_tgt = F.resized_crop(tgt, i_anno, j_anno, h_anno, w_anno, (2048, 2048), InterpolationMode.NEAREST)
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
# 将resize后的结果保存为concat的img
|
||||
# self.cnt = self.cnt+1
|
||||
# from torchvision.utils import save_image
|
||||
# save_hr = resized_hr_img[:3, :, :] / resized_hr_img[:3, :, :].max()
|
||||
# save_s2 = resized_s2_img[:3,0,:,:] / resized_s2_img[:3,0,:,:].max()
|
||||
# print(f'{save_hr.shape}, {save_s2.shape}')
|
||||
# save_image(save_s2, f'FoundationModel/debug/output2/resized_s2_{self.cnt}.png')
|
||||
# save_image(save_hr, f'FoundationModel/debug/output2/resized_hr_{self.cnt}.png')
|
||||
|
||||
return resized_hr_img, resized_s2_img, resized_s1_img, resized_tgt
|
||||
|
||||
class RandomResizedCropComb(transforms.RandomResizedCrop):
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
scale=(0.08, 1.0),
|
||||
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
||||
interpolation=InterpolationMode.BILINEAR,
|
||||
):
|
||||
super().__init__(size, scale=scale, ratio=ratio, interpolation=interpolation)
|
||||
self.cnt=0
|
||||
|
||||
def forward(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image or Tensor): Image to be cropped and resized.
|
||||
Returns:
|
||||
PIL Image or Tensor: Randomly cropped and resized image.
|
||||
"""
|
||||
i, j, h, w = self.get_params(s2_img, self.scale, self.ratio)
|
||||
# print(f'i, j, h, w: {i, j, h, w}')
|
||||
# print(f's2_img shape: {s2_img.shape}')
|
||||
size_hr = hr_img.shape[-1]
|
||||
size_s2 = s2_img.shape[-1]
|
||||
size_anno = tgt.shape[-1]
|
||||
# 映射到其他模态
|
||||
ratio_s2_hr = size_s2 / size_hr
|
||||
i_hr = int(i / ratio_s2_hr)
|
||||
j_hr = int(j / ratio_s2_hr)
|
||||
h_hr = int(h / ratio_s2_hr)
|
||||
w_hr = int(w / ratio_s2_hr)
|
||||
|
||||
ratio_s2_anno = size_s2 / size_anno
|
||||
i_anno = int(i / ratio_s2_anno)
|
||||
j_anno = int(j / ratio_s2_anno)
|
||||
h_anno = int(h / ratio_s2_anno)
|
||||
w_anno = int(w / ratio_s2_anno)
|
||||
|
||||
if interpolation1 == 'nearest':
|
||||
interpolation1 = InterpolationMode.NEAREST
|
||||
else:
|
||||
interpolation1 = InterpolationMode.BICUBIC
|
||||
if interpolation2 == 'nearest':
|
||||
interpolation2 = InterpolationMode.NEAREST
|
||||
else:
|
||||
interpolation2 = InterpolationMode.BICUBIC
|
||||
|
||||
resized_s2_img = F.resized_crop(s2_img, i, j, h, w, (32, 16), InterpolationMode.BICUBIC)
|
||||
resized_hr_img = F.resized_crop(hr_img, i_hr, j_hr, h_hr, w_hr, (1024, 512), InterpolationMode.BICUBIC)
|
||||
resized_s1_img = F.resized_crop(s1_img, i, j, h, w, (32, 16), InterpolationMode.BICUBIC)
|
||||
resized_tgt = F.resized_crop(tgt, i_anno, j_anno, h_anno, w_anno, (1024, 512), InterpolationMode.NEAREST)
|
||||
|
||||
return resized_hr_img, resized_s2_img, resized_s1_img, resized_tgt
|
||||
|
||||
|
||||
class RandomHorizontalFlip(transforms.RandomHorizontalFlip):
|
||||
"""Horizontally flip the given image randomly with a given probability.
|
||||
If the image is torch Tensor, it is expected
|
||||
to have [..., H, W] shape, where ... means an arbitrary number of leading
|
||||
dimensions
|
||||
Args:
|
||||
p (float): probability of the image being flipped. Default value is 0.5
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5):
|
||||
super().__init__(p=p)
|
||||
|
||||
def forward(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image or Tensor): Image to be flipped.
|
||||
Returns:
|
||||
PIL Image or Tensor: Randomly flipped image.
|
||||
"""
|
||||
if torch.rand(1) < self.p:
|
||||
return F.hflip(hr_img), F.hflip(s2_img), F.hflip(s1_img), F.hflip(tgt)
|
||||
return hr_img, s2_img, s1_img, tgt
|
||||
|
||||
|
||||
class RandomApply(transforms.RandomApply):
|
||||
"""Apply randomly a list of transformations with a given probability.
|
||||
.. note::
|
||||
In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
|
||||
transforms as shown below:
|
||||
>>> transforms = transforms.RandomApply(torch.nn.ModuleList([
|
||||
>>> transforms.ColorJitter(),
|
||||
>>> ]), p=0.3)
|
||||
>>> scripted_transforms = torch.jit.script(transforms)
|
||||
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
|
||||
`lambda` functions or ``PIL.Image``.
|
||||
Args:
|
||||
transforms (sequence or torch.nn.Module): list of transformations
|
||||
p (float): probability
|
||||
"""
|
||||
|
||||
def __init__(self, transforms, p=0.5):
|
||||
super().__init__(transforms, p=p)
|
||||
|
||||
def forward(self, img, tgt, interpolation1=None, interpolation2=None):
|
||||
if self.p < torch.rand(1):
|
||||
return img, tgt
|
||||
for t in self.transforms:
|
||||
img, tgt = t(img, tgt)
|
||||
return img, tgt
|
||||
|
||||
class ColorJitter(transforms.ColorJitter):
|
||||
"""Randomly change the brightness, contrast, saturation and hue of an image.
|
||||
If the image is torch Tensor, it is expected
|
||||
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||
If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
|
||||
Args:
|
||||
brightness (float or tuple of float (min, max)): How much to jitter brightness.
|
||||
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
|
||||
or the given [min, max]. Should be non negative numbers.
|
||||
contrast (float or tuple of float (min, max)): How much to jitter contrast.
|
||||
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
|
||||
or the given [min, max]. Should be non negative numbers.
|
||||
saturation (float or tuple of float (min, max)): How much to jitter saturation.
|
||||
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
|
||||
or the given [min, max]. Should be non negative numbers.
|
||||
hue (float or tuple of float (min, max)): How much to jitter hue.
|
||||
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
|
||||
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
|
||||
To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
|
||||
thus it does not work if you normalize your image to an interval with negative values,
|
||||
or use an interpolation that generates negative values before using this function.
|
||||
"""
|
||||
|
||||
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
||||
super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
|
||||
|
||||
def forward(self, img, tgt, interpolation1=None, interpolation2=None):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image or Tensor): Input image.
|
||||
Returns:
|
||||
PIL Image or Tensor: Color jittered image.
|
||||
"""
|
||||
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
|
||||
self.brightness, self.contrast, self.saturation, self.hue
|
||||
)
|
||||
|
||||
for fn_id in fn_idx:
|
||||
if fn_id == 0 and brightness_factor is not None:
|
||||
img = F.adjust_brightness(img, brightness_factor)
|
||||
elif fn_id == 1 and contrast_factor is not None:
|
||||
img = F.adjust_contrast(img, contrast_factor)
|
||||
elif fn_id == 2 and saturation_factor is not None:
|
||||
img = F.adjust_saturation(img, saturation_factor)
|
||||
elif fn_id == 3 and hue_factor is not None:
|
||||
img = F.adjust_hue(img, hue_factor)
|
||||
return img, tgt
|
||||
|
||||
|
||||
class RandomErasing(transforms.RandomErasing):
|
||||
"""Randomly selects a rectangle region in a torch.Tensor image and erases its pixels.
|
||||
This transform does not support PIL Image.
|
||||
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
|
||||
Args:
|
||||
p: probability that the random erasing operation will be performed.
|
||||
scale: range of proportion of erased area against input image.
|
||||
ratio: range of aspect ratio of erased area.
|
||||
value: erasing value. Default is 0. If a single int, it is used to
|
||||
erase all pixels. If a tuple of length 3, it is used to erase
|
||||
R, G, B channels respectively.
|
||||
If a str of 'random', erasing each pixel with random values.
|
||||
inplace: boolean to make this transform inplace. Default set to False.
|
||||
Returns:
|
||||
Erased Image.
|
||||
Example:
|
||||
>>> transform = transforms.Compose([
|
||||
>>> transforms.RandomHorizontalFlip(),
|
||||
>>> transforms.PILToTensor(),
|
||||
>>> transforms.ConvertImageDtype(torch.float),
|
||||
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
>>> transforms.RandomErasing(),
|
||||
>>> ])
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
|
||||
super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace)
|
||||
|
||||
def forward(self, img, tgt, interpolation1=None, interpolation2=None):
|
||||
"""
|
||||
Args:
|
||||
img (Tensor): Tensor image to be erased.
|
||||
Returns:
|
||||
img (Tensor): Erased Tensor image.
|
||||
"""
|
||||
if torch.rand(1) < self.p:
|
||||
|
||||
# cast self.value to script acceptable type
|
||||
if isinstance(self.value, (int, float)):
|
||||
value = [self.value]
|
||||
elif isinstance(self.value, str):
|
||||
value = None
|
||||
elif isinstance(self.value, tuple):
|
||||
value = list(self.value)
|
||||
else:
|
||||
value = self.value
|
||||
|
||||
if value is not None and not (len(value) in (1, img.shape[-3])):
|
||||
raise ValueError(
|
||||
"If value is a sequence, it should have either a single value or "
|
||||
f"{img.shape[-3]} (number of input channels)"
|
||||
)
|
||||
|
||||
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
|
||||
return F.erase(img, x, y, h, w, v, self.inplace), tgt
|
||||
return img, tgt
|
||||
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
"""Gaussian blur augmentation from SimCLR: https://arxiv.org/abs/2002.05709"""
|
||||
|
||||
def __init__(self, sigma=[.1, 2.]):
|
||||
self.sigma = sigma
|
||||
|
||||
def __call__(self, img, tgt, interpolation1=None, interpolation2=None):
|
||||
sigma = random.uniform(self.sigma[0], self.sigma[1])
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=sigma))
|
||||
return img, tgt
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = f"{self.__class__.__name__}( sigma={self.sigma})"
|
||||
return s
|
||||
|
||||
1558
lib/datasets/utils/transforms.py
Normal file
1558
lib/datasets/utils/transforms.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user