init
This commit is contained in:
84
lib/datasets/utils/masking_generator.py
Normal file
84
lib/datasets/utils/masking_generator.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
class MaskingGenerator:
|
||||
def __init__(
|
||||
self, input_size, patch_size, mask_ratio=0.5, min_num_patches=4, max_num_patches=None,
|
||||
min_aspect=0.3, max_aspect=None):
|
||||
if not isinstance(input_size, list):
|
||||
input_size = [input_size,] * 2
|
||||
self.height = input_size[0] // patch_size
|
||||
self.width = input_size[1] // patch_size
|
||||
|
||||
self.num_patches = self.height * self.width
|
||||
self.num_masking_patches = int(self.num_patches * mask_ratio)
|
||||
|
||||
self.min_num_patches = min_num_patches
|
||||
self.max_num_patches = self.num_masking_patches if max_num_patches is None else max_num_patches
|
||||
|
||||
max_aspect = max_aspect or 1 / min_aspect
|
||||
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
|
||||
self.height, self.width, self.min_num_patches, self.max_num_patches,
|
||||
self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1])
|
||||
return repr_str
|
||||
|
||||
def get_shape(self):
|
||||
return self.height, self.width
|
||||
|
||||
def _mask(self, mask, max_mask_patches):
|
||||
delta = 0
|
||||
for attempt in range(10):
|
||||
target_area = random.uniform(self.min_num_patches, max_mask_patches)
|
||||
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
if w < self.width and h < self.height:
|
||||
top = random.randint(0, self.height - h)
|
||||
left = random.randint(0, self.width - w)
|
||||
|
||||
num_masked = mask[top: top + h, left: left + w].sum()
|
||||
# Overlap
|
||||
if 0 < h * w - num_masked <= max_mask_patches:
|
||||
for i in range(top, top + h):
|
||||
for j in range(left, left + w):
|
||||
if mask[i, j] == 0:
|
||||
mask[i, j] = 1
|
||||
delta += 1
|
||||
|
||||
if delta > 0:
|
||||
break
|
||||
return delta
|
||||
|
||||
def __call__(self):
|
||||
mask = np.zeros(shape=self.get_shape(), dtype=np.int32)
|
||||
mask_count = 0
|
||||
while mask_count < self.num_masking_patches:
|
||||
max_mask_patches = self.num_masking_patches - mask_count
|
||||
max_mask_patches = min(max_mask_patches, self.max_num_patches)
|
||||
|
||||
delta = self._mask(mask, max_mask_patches)
|
||||
if delta == 0:
|
||||
break
|
||||
else:
|
||||
mask_count += delta
|
||||
|
||||
# maintain a fix number {self.num_masking_patches}
|
||||
if mask_count > self.num_masking_patches:
|
||||
delta = mask_count - self.num_masking_patches
|
||||
mask_x, mask_y = mask.nonzero()
|
||||
to_vis = np.random.choice(mask_x.shape[0], delta, replace=False)
|
||||
mask[mask_x[to_vis], mask_y[to_vis]] = 0
|
||||
|
||||
elif mask_count < self.num_masking_patches:
|
||||
delta = self.num_masking_patches - mask_count
|
||||
mask_x, mask_y = (mask == 0).nonzero()
|
||||
to_mask = np.random.choice(mask_x.shape[0], delta, replace=False)
|
||||
mask[mask_x[to_mask], mask_y[to_mask]] = 1
|
||||
|
||||
assert mask.sum() == self.num_masking_patches, f"mask: {mask}, mask count {mask.sum()}"
|
||||
|
||||
return mask
|
||||
Reference in New Issue
Block a user