Source code for neodroidvision.utilities.visualisation.encoder_utilities

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pathlib import Path
from typing import Optional, Tuple

import numpy
import torch
from PIL import Image
from draugr.torch_utilities import global_torch_device
from torch.nn.functional import one_hot

__author__ = "Christian Heider Nielsen"
__doc__ = r""" description """

from numpy import ndarray

from warg import Number

__all__ = [
    "compile_encoding_image",
    "sample_2d_latent_vectors",
    "plot_conditioned_manifold",
    "plot_manifold",
]


[docs]def compile_encoding_image( images: ndarray, size: Tuple, resize_factor: Number = 1.0 ) -> ndarray: """ :param images: :param size: :param resize_factor: :return:""" h, w = images.shape[1], images.shape[2] h_ = int(h * resize_factor) w_ = int(w * resize_factor) r = [] if len(images.shape) > 3: r = images.shape[3:] img = numpy.zeros((h_ * size[0], w_ * size[1], *r)) for idx, image in enumerate(images): i = int(idx % size[1]) j = int(idx / size[1]) image_ = numpy.array( Image.fromarray(image).resize((w_, h_), resample=Image.BICUBIC) ) img[j * h_ : j * h_ + h_, i * w_ : i * w_ + w_] = image_ return img
[docs]def sample_2d_latent_vectors( encoding_space: Number, n_img_x: int, n_img_y: int ) -> torch.FloatTensor: """ :param encoding_space: :param n_img_x: :param n_img_y: :return:""" return torch.FloatTensor( [ numpy.rollaxis( numpy.mgrid[ encoding_space : -encoding_space : n_img_y * 1j, encoding_space : -encoding_space : n_img_x * 1j, ], 0, 3, ).reshape([-1, 2]) ] )
[docs]def plot_conditioned_manifold( model: torch.nn.Module, condition: torch.Tensor, *, out_path: Path = None, n_img_x: int = 20, n_img_y: int = 20, img_h: int = 28, img_w: int = 28, sample_range: Number = 1, device: Optional[torch.device] = global_torch_device(), ) -> None: """ :param model: :type model: :param condition: :type condition: :param out_path: :type out_path: :param n_img_x: :type n_img_x: :param n_img_y: :type n_img_y: :param img_h: :type img_h: :param img_w: :type img_w: :param sample_range: :type sample_range: :param device: :type device: """ condition_vector = torch.arange(0, 10, device=device).long().unsqueeze(1) sample = model.sample( one_hot(condition_vector, 10).to(device=device), num=condition_vector.size(0), )
# TODO: FINISH
[docs]def plot_manifold( model: torch.nn.Module, *, out_path: Path = None, n_img_x: int = 20, n_img_y: int = 20, img_h: int = 28, img_w: int = 28, sample_range: Number = 1, device: Optional[torch.device] = global_torch_device(), ) -> None: """ :param device: :type device: :param model: :param out_path: :param n_img_x: :param n_img_y: :param img_h: :param img_w: :param sample_range: :return:""" vectors = sample_2d_latent_vectors(sample_range, n_img_x, n_img_y).to(device) encodings = torch.sigmoid(model(vectors)).to("cpu") images = encodings.reshape(n_img_x * n_img_y, img_h, img_w, -1).numpy() images *= 255 images = numpy.uint8(images) compiled = compile_encoding_image(images, (n_img_y, n_img_x)) if out_path: from imageio import imwrite imwrite(str(out_path), compiled) return compiled