Source code for neodroidvision.utilities.torch_utilities.layers.torch_layers

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Tuple

import torch

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 08/10/2019
           """

from torch.nn import Module

__all__ = ["MinMaxNorm", "Reshape"]


[docs]class MinMaxNorm(Module): """description"""
[docs] def __init__(self, min_value: float = 0, max_value: float = 1): """ :param min_value: :param max_value:""" super().__init__() self.min_value = min_value self.max_value = max_value
def __call__(self, tensor: torch.Tensor) -> torch.Tensor: """ :param tensor: :return:""" min_tensor = tensor.min() tensor -= min_tensor max_tensor = tensor.max() tensor /= max_tensor return tensor * (self.max_value - self.min_value) + self.min_value
[docs]class Reshape(Module): """ Reshaping Layer"""
[docs] def __init__(self, new_size: Tuple[int, ...]): """ :param new_size:""" super().__init__() self.new_size = new_size
def __call__(self, img: torch.Tensor) -> torch.Tensor: """ :param img: :return:""" return torch.reshape(img, self.new_size)