Source code for neodroidvision.utilities.visualisation.encoding_space
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = """ description """
from pathlib import Path
from typing import Sequence, Union
import numpy
from matplotlib import pyplot
from matplotlib.colors import Colormap, LinearSegmentedColormap
from numpy import ndarray
from warg import Number
__all__ = ["discrete_cmap", "scatter_plot_encoding_space"]
[docs]def discrete_cmap(
N: int, base_cmap: Union[Colormap, str, None] = None
) -> LinearSegmentedColormap:
"""Create an N-bin discrete colormap from the specified input map"""
# Note that if base_cmap is a string or None, you can simply do
# return pyplot.cm.get_cmap(base_cmap, N)
# The following works for string, None, or a colormap instance:
base = pyplot.cm.get_cmap(base_cmap)
color_list = base(numpy.linspace(0, 1, N))
cmap_name = base.name + str(N)
return base.from_list(cmap_name, color_list, N)
[docs]def scatter_plot_encoding_space(
out_path: Path,
mean: ndarray,
log_var: ndarray,
labels: Sequence,
encoding_space_range: Number = 1,
min_size_constant: Number = 2,
N: int = 10,
):
"""
:param out_path:
:param mean:
:param log_var:
:param labels:
:param encoding_space_range:
:param min_size_constant:
:param N:
:return:"""
sizes = numpy.abs(log_var.mean(-1)) + min_size_constant
fig = pyplot.figure(figsize=(8, 6))
pyplot.scatter(
mean[:, 0],
mean[:, 1],
s=sizes,
c=labels,
marker="o",
edgecolor="none",
cmap=discrete_cmap(N, "jet"),
)
pyplot.colorbar(ticks=range(N))
axes = pyplot.gca()
axes.set_xlim([-encoding_space_range, encoding_space_range])
axes.set_ylim([-encoding_space_range, encoding_space_range])
pyplot.grid(True)
pyplot.savefig(out_path)
return fig