Source code for neodroidvision.utilities.torch_utilities.transforms.blocky

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "heider"
__doc__ = r"""

           Created on 5/5/22
           """

import glob
import random
from collections.abc import Sequence

import torch
import torch.nn.functional
from PIL import Image
from torch import ceil
from torch.nn.functional import interpolate
from torchvision.transforms.functional import to_pil_image, to_tensor

from neodroidvision.utilities.torch_utilities.transforms.image_transforms import (
    BaseTorchTransform,
    tuplemulti,
)


[docs]class RandomBlock(BaseTorchTransform): """ Randomly blocks out parts of an image Select mode from constant or rgb (random RGB values) Color is relevant in constant mode. Expects float (0..1) or a sequence of length C """
[docs] def __init__(self, mode="constant", color=0, max_size=(1.0, 1.0)): super().__init__() self.mode = mode self.color = torch.tensor(color).unsqueeze(-1).unsqueeze(-1) # ensure max_size is a tuple/list if not isinstance(max_size, Sequence): max_size = (max_size, max_size) assert len(max_size) == 2, "max_size must have length 2 or be a scalar" assert ( 0.0 < max_size[0] <= 1.0 and 0.0 < max_size[1] <= 1.0 ), f"max_size must be greater than zero and less than or equal to 1, got {max_size}" self.max_size = max_size
[docs] def forward(self, img): """ :param img: :type img: :return: :rtype: """ img, retmode = super().forward(img) # generate random slice indices bh = torch.randint(int(round(img.shape[-2] * self.max_size[0])), (1,)).item() bw = torch.randint(int(round(img.shape[-1] * self.max_size[1])), (1,)).item() y = torch.randint(img.shape[-2] - bh, (1,)).item() x = torch.randint(img.shape[-1] - bw, (1,)).item() # constant uses color argument if self.mode == "constant": img[:, y : y + bh, x : x + bh] = self.color elif self.mode == "rgb": # assumes colour channels are channel 0 c = torch.rand((img.shape[0], 1, 1)) img[:, y : y + bh, x : x + bh] = c else: raise ValueError( f"self.mode is expected to be 'rgb' or 'constant', got '{self.mode}'" ) if retmode == "image": return to_pil_image(img) return img
[docs]class BlockyMask(BaseTorchTransform): """ Randomly blocks out parts of an image Select mode from constant or rgb (random RGB values) Color is relevant in constant mode. Expects float (0..1) or a sequence of length C """
[docs] def __init__(self, mode="constant", color=0, max_size=(1.0, 1.0)): super().__init__() self.mode = mode self.color = torch.tensor(color).unsqueeze(-1).unsqueeze(-1) # ensure max_size is a tuple/list if not isinstance(max_size, Sequence): max_size = (max_size, max_size) assert len(max_size) == 2, "max_size must have length 2 or be a scalar" assert ( 0.0 < max_size[0] <= 1.0 and 0.0 < max_size[1] <= 1.0 ), f"max_size must be greater than zero and less than or equal to 1, got {max_size}" self.max_size = max_size
[docs] def forward(self, img): """ :param img: :type img: :return: :rtype: """ img, retmode = super().forward(img) # generate random slice indices bh = torch.randint(int(round(img.shape[-2] * self.max_size[0])), (1,)).item() bw = torch.randint(int(round(img.shape[-1] * self.max_size[1])), (1,)).item() y = torch.randint(img.shape[-2] - bh, (1,)).item() x = torch.randint(img.shape[-1] - bw, (1,)).item() # constant uses color argument if self.mode == "constant": img[:, y : y + bh, x : x + bh] = self.color elif self.mode == "rgb": # assumes colour channels are channel 0 c = torch.rand((img.shape[0], 1, 1)) img[:, y : y + bh, x : x + bh] = c else: raise ValueError( f"self.mode is expected to be 'rgb' or 'constant', got '{self.mode}'" ) if retmode == "image": return to_pil_image(img) return img
[docs]class RandomImageBlocking(BaseTorchTransform): """ Inserts random .png images on top of the input image Expects RGBA images, but will use RGB images without warning """
[docs] def __init__(self, root_dir, apply=1.0, max_size=(1.0, 1.0), rotate=True): """ :param root_dir: path to root directory of png images :param apply: float, a value between 0 and 1 resembles the chance of apllying the transform :param max_size: expects float or (float,float) - resembles the max size of the blocking image relative to the input image """ super().__init__() self.images = glob.glob(root_dir + "**/*.png", recursive=True) self.apply = apply self.rotate = rotate if not isinstance(max_size, Sequence): max_size = (max_size, max_size) assert len(max_size) == 2, "max_size must have length 2 or be a scalar" assert ( 0.0 < max_size[0] <= 1.0 and 0.0 < max_size[1] <= 1.0 ), f"max_size must be greater than zero and less than or equal to 1, got {max_size}" self.max_size = max_size
[docs] def forward(self, img): """ Randomly applies blocking images, likely PNGs, to input image :param img: input image, expects PIL.Image, numpy.ndarray, or torch.Tensor """ # generates a single float value between 0..1 and evaluates against apply if torch.rand(1) > self.apply: return img img, retmode = super().forward(img) # load image block = Image.open(random.choice(self.images)) mode = block.mode # should it be randomly rotated? if self.rotate: block.rotate(torch.randint(360, (1,)).item()) block = to_tensor(block).unsqueeze(0) # scale is max_size * random float between 0..1 # it is uniformly distributed between 0..max_size scale = tuplemulti(torch.rand(2).tolist(), self.max_size) block = interpolate(block, tuplemulti(img.shape[-2:], scale, op=ceil)).squeeze() # determine start positions y = torch.randint(img.shape[-2] - block.shape[-2], (1,)).item() x = torch.randint(img.shape[-1] - block.shape[-1], (1,)).item() # expected if mode == "RGBA": A = block[-1] block = block[:-1] img[:, y : y + block.shape[-2], x : x + block.shape[-1]] = ( img[:, y : y + block.shape[-2], x : x + block.shape[-1]] * (1 - A) + block * A ) # accepted elif mode == "RGB": img[:, y : y + block.shape[-2], x : x + block.shape[-1]] = block else: raise ValueError( f"Image.mode of blocking image is expected to be RGB(A), got {block.mode}" ) if retmode == "image": return to_pil_image(img) return img