Source code for neodroidvision.utilities.visualisation.plot_kernel

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 29/07/2020
           """

__all__ = ["plot_kernels"]

from enum import Enum

from matplotlib import pyplot
from sorcery import assigned_names
from torch import Tensor


class CmapEnum(Enum):  # TODO: Add more
    (gray, binary, viridis) = assigned_names()


class InterpolationEnum(Enum):
    (
        none,
        antialiased,
        nearest,
        bilinear,
        bicubic,
        spline16,
        spline36,
        hanning,
        hamming,
        hermite,
        kaiser,
        quadric,
        catrom,
        gaussian,
        bessel,
        mitchell,
        sinc,
        lanczos,
        blackman,
    ) = assigned_names()


[docs]def plot_kernels( tensor: Tensor, number_cols: int = 5, m_interpolation: InterpolationEnum = InterpolationEnum.bilinear, ) -> None: """ Function to visualize the kernels. Arguments: tensor: number_cols: number of columns to be displayed m_interpolation: interpolation methods matplotlib. See in: https://matplotlib.org/gallery/images_contours_and_fields/interpolation_methods.html""" number_kernels = tensor.shape[0] number_rows = 1 + number_kernels // number_cols fig = pyplot.figure(figsize=(number_cols, number_rows)) for i in range(number_kernels): ax1 = fig.add_subplot(number_rows, number_cols, i + 1) ax1.imshow( tensor[i][0, :, :], interpolation=m_interpolation.value, cmap=CmapEnum.gray.value, ) ax1.axis("off")