Source code for neodroidvision.utilities.torch_utilities.patches.ndim.sampling

import torch

__all__ = ["mask_patches"]


[docs]def mask_patches(x, prob): """ :param x: :type x: :param prob: :type prob: :return: :rtype: """ prob = torch.randn(x.shape[:2]) < prob x[prob] = torch.zeros(x.shape[2:], dtype=torch.int) return x