Source code for neodroidvision.utilities.torch_utilities.patches.multiply.ratio_masking

import numpy
import torch

__all__ = ["RatioMaskGenerator"]

from torch import nn
from torch.nn.functional import fold


[docs]class RatioMaskGenerator(nn.Module): """description"""
[docs] def __init__(self, patch_size, mask_ratio): super().__init__() if not isinstance(patch_size, tuple): patch_size = (patch_size,) * 2 self.patch_size = patch_size self.mask_ratio = mask_ratio
def __repr__(self): return f"Maks: total patches {self.num_patches}, mask patches {self.num_mask}" def __call__(self, x): height, width = x.shape[-2:] num_patches = height // self.patch_size[0] * width // self.patch_size[1] num_mask = int(self.mask_ratio * num_patches) mask = numpy.vstack( [ numpy.zeros((num_patches - num_mask, *self.patch_size)), numpy.ones((num_mask, *self.patch_size)), ] ) numpy.random.shuffle(mask) mask = torch.from_numpy(mask) print(mask.shape, x.shape) resized = fold( mask, x.shape[-2:], self.patch_size, stride=self.patch_size, padding=0 ) return x * resized
if __name__ == "__main__": def asidj(): """description""" from cv2 import circle import numpy shuffle = RatioMaskGenerator(20, 0.8) x_ = torch.randn(100, 100, 3).numpy() * 255 # batch, c, h, w, d x_ = circle(x_, (50, 50), 40, (200, 160, 120), -1).astype(numpy.uint8) from matplotlib import pyplot pyplot.imshow(x_) pyplot.show() x_ = torch.FloatTensor(x_).permute(2, 0, 1).contiguous().unsqueeze(0) shuffled = shuffle(x_) pyplot.imshow(shuffled.squeeze(0).permute(1, 2, 0).to(dtype=torch.int)) pyplot.show() asidj()