Source code for neodroidvision.utilities.torch_utilities.patches.shuffle

import torch
from torch import nn
from torch.nn.functional import fold, unfold


[docs]class ShufflePatches(nn.Module): """description"""
[docs] def __init__(self, patch_size): super().__init__() self.patch_size = patch_size
def __call__(self, x): unfolded = unfold( x, kernel_size=self.patch_size, stride=self.patch_size, padding=0 ) permuted = torch.cat( [b_[:, torch.randperm(b_.shape[-1])][None, ...] for b_ in unfolded], dim=0 ) folded = fold( permuted, x.shape[-2:], kernel_size=self.patch_size, stride=self.patch_size, padding=0, ) return folded
if __name__ == "__main__": def asidj(): """description""" from cv2 import circle import numpy shuffle = ShufflePatches(16) x_ = torch.randn(100, 100, 3).numpy() * 255 # batch, c, h, w, d x_ = circle(x_, (50, 50), 40, (200, 160, 120), -1).astype(numpy.uint8) from matplotlib import pyplot pyplot.imshow(x_) pyplot.show() x_ = torch.FloatTensor(x_).permute(2, 0, 1).contiguous().unsqueeze(0) shuffled = shuffle(x_) pyplot.imshow(shuffled.squeeze(0).permute(1, 2, 0).to(dtype=torch.int)) pyplot.show() asidj()