Source code for neodroidvision.utilities.torch_utilities.patches.masking

import torch

__all__ = ["StochasticMaskGenerator", "RatioMaskGenerator"]

from torch import nn
from torch.nn.functional import unfold, fold


[docs]class StochasticMaskGenerator(nn.Module): """description"""
[docs] def __init__(self, patch_size, prob): super().__init__() self.patch_size = patch_size self.prob = prob
def __call__(self, x): unfolded = unfold( x, kernel_size=self.patch_size, stride=self.patch_size, padding=0 ) unfolded[..., torch.randn(unfolded.shape[-1]) < self.prob] = 0 folded = fold( unfolded, x.shape[-2:], kernel_size=self.patch_size, stride=self.patch_size, padding=0, ) return folded
[docs]class RatioMaskGenerator(nn.Module): """description"""
[docs] def __init__(self, patch_size, ratio): super().__init__() self.patch_size = patch_size self.ratio = ratio
def __call__(self, x): unfolded = unfold( x, kernel_size=self.patch_size, stride=self.patch_size, padding=0 ) masked_num = int(unfolded.shape[-1] * self.ratio) mask = torch.hstack( (torch.ones(masked_num), torch.zeros(unfolded.shape[-1] - masked_num)) ).to(dtype=torch.bool) index = torch.randperm(mask.shape[0]) unfolded[..., mask[index]] = 0 folded = fold( unfolded, x.shape[-2:], kernel_size=self.patch_size, stride=self.patch_size, padding=0, ) return folded
if __name__ == "__main__": def asidj(): """description""" from cv2 import circle import numpy shuffle = StochasticMaskGenerator(16, 0.8) x_ = torch.randn(100, 100, 3).numpy() * 255 # batch, c, h, w, d x_ = circle(x_, (50, 50), 40, (200, 160, 120), -1).astype(numpy.uint8) from matplotlib import pyplot pyplot.imshow(x_) pyplot.show() x_ = torch.FloatTensor(x_).permute(2, 0, 1).contiguous().unsqueeze(0) shuffled = shuffle(x_) pyplot.imshow(shuffled.squeeze(0).permute(1, 2, 0).to(dtype=torch.int)) pyplot.show() def asid22j(): """description""" from cv2 import circle import numpy shuffle = RatioMaskGenerator(16, 0.8) x_ = torch.randn(100, 100, 3).numpy() * 255 # batch, c, h, w, d x_ = circle(x_, (50, 50), 40, (200, 160, 120), -1).astype(numpy.uint8) from matplotlib import pyplot pyplot.imshow(x_) pyplot.show() x_ = torch.FloatTensor(x_).permute(2, 0, 1).contiguous().unsqueeze(0) shuffled = shuffle(x_) pyplot.imshow(shuffled.squeeze(0).permute(1, 2, 0).to(dtype=torch.int)) pyplot.show() # asidj() asid22j()