import random import math import numpy as np class MaskingGenerator: def __init__( self, input_size, patch_size, mask_ratio=0.5, min_num_patches=4, max_num_patches=None, min_aspect=0.3, max_aspect=None): if not isinstance(input_size, list): input_size = [input_size,] * 2 self.height = input_size[0] // patch_size self.width = input_size[1] // patch_size self.num_patches = self.height * self.width self.num_masking_patches = int(self.num_patches * mask_ratio) self.min_num_patches = min_num_patches self.max_num_patches = self.num_masking_patches if max_num_patches is None else max_num_patches max_aspect = max_aspect or 1 / min_aspect self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) def __repr__(self): repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( self.height, self.width, self.min_num_patches, self.max_num_patches, self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1]) return repr_str def get_shape(self): return self.height, self.width def _mask(self, mask, max_mask_patches): delta = 0 for attempt in range(10): target_area = random.uniform(self.min_num_patches, max_mask_patches) aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) if w < self.width and h < self.height: top = random.randint(0, self.height - h) left = random.randint(0, self.width - w) num_masked = mask[top: top + h, left: left + w].sum() # Overlap if 0 < h * w - num_masked <= max_mask_patches: for i in range(top, top + h): for j in range(left, left + w): if mask[i, j] == 0: mask[i, j] = 1 delta += 1 if delta > 0: break return delta def __call__(self): mask = np.zeros(shape=self.get_shape(), dtype=np.int32) mask_count = 0 while mask_count < self.num_masking_patches: max_mask_patches = self.num_masking_patches - mask_count max_mask_patches = min(max_mask_patches, self.max_num_patches) delta = self._mask(mask, max_mask_patches) if delta == 0: break else: mask_count += delta # maintain a fix number {self.num_masking_patches} if mask_count > self.num_masking_patches: delta = mask_count - self.num_masking_patches mask_x, mask_y = mask.nonzero() to_vis = np.random.choice(mask_x.shape[0], delta, replace=False) mask[mask_x[to_vis], mask_y[to_vis]] = 0 elif mask_count < self.num_masking_patches: delta = self.num_masking_patches - mask_count mask_x, mask_y = (mask == 0).nonzero() to_mask = np.random.choice(mask_x.shape[0], delta, replace=False) mask[mask_x[to_mask], mask_y[to_mask]] = 1 assert mask.sum() == self.num_masking_patches, f"mask: {mask}, mask count {mask.sum()}" return mask