import os
import numpy
from matplotlib import cm, patches, pyplot
__all__ = ["visualise_3d_gmm", "visualise_2D_gmm"]
def plot_sphere(
w=0, c=(0, 0, 0), r=(1, 1, 1), sub_divisions=10, ax=None, sigma_multiplier=3
):
"""
plot a sphere surface
Input:
c: 3 elements list, sphere center
r: 3 element list, sphere original scale in each axis ( allowing to draw elipsoids)
sub_divisions: scalar, number of subdivisions (subdivision^2 points sampled on the surface)
ax: optional pyplot axis object to plot the sphere in.
sigma_multiplier: sphere additional scale (choosing an std value when plotting gaussians)
Output:
ax: pyplot axis object
"""
if ax is None:
fig = pyplot.figure()
ax = fig.add_subplot(111, projection="3d")
pi = numpy.pi
cos = numpy.cos
sin = numpy.sin
phi, theta = numpy.mgrid[
0.0 : pi : complex(0, sub_divisions), 0.0 : 2.0 * pi : complex(0, sub_divisions)
]
x = sigma_multiplier * r[0] * sin(phi) * cos(theta) + c[0]
y = sigma_multiplier * r[1] * sin(phi) * sin(theta) + c[1]
z = sigma_multiplier * r[2] * cos(phi) + c[2]
cmap = cm.ScalarMappable()
cmap.set_cmap("jet")
c = cmap.to_rgba(w)
ax.plot_surface(x, y, z, color=c, alpha=0.2, linewidth=1)
return ax
[docs]def visualise_3d_gmm(points, w, mu, std_dev, export=False):
"""
plots points and their corresponding gmm model in 3D
Input:
points: N X 3, sampled points
w: n_components, gmm weights
mu: 3 X n_components, gmm means
std_dev: 3 X n_components, gmm standard deviation (assuming diagonal covariance matrix)
Output:
None
"""
n_components = mu.shape[1]
N = int(numpy.round(points.shape[0] / n_components))
# Visualize data
fig = pyplot.figure(figsize=(8, 8))
axes = fig.add_subplot(111, projection="3d")
axes.set_xlim([-1, 1])
axes.set_ylim([-1, 1])
axes.set_zlim([-1, 1])
pyplot.set_cmap("Set1")
colors = cm.Set1(numpy.linspace(0, 1, n_components))
for i in range(n_components):
idx = range(i * N, (i + 1) * N)
axes.scatter(
points[idx, 0], points[idx, 1], points[idx, 2], alpha=0.3, c=colors[i]
)
plot_sphere(w=w[i], c=mu[:, i], r=std_dev[:, i], ax=axes)
pyplot.title("3D GMM")
axes.set_xlabel("X")
axes.set_ylabel("Y")
axes.set_zlabel("Z")
axes.view_init(35.246, 45)
if export:
if not os.path.exists("images/"):
os.mkdir("images/")
pyplot.savefig("images/3D_GMM_demonstration.png", dpi=100, format="png")
pyplot.show()
[docs]def visualise_2D_gmm(points, w, mu, std_dev, export=False):
"""
plots points and their corresponding gmm model in 2D
Input:
points: N X 2, sampled points
w: n_components, gmm weights
mu: 2 X n_components, gmm means
std_dev: 2 X n_components, gmm standard deviation (assuming diagonal covariance matrix)
Output:
None
"""
n_components = mu.shape[1]
N = int(numpy.round(points.shape[0] / n_components))
# Visualize data
fig = pyplot.figure(figsize=(8, 8))
axes = pyplot.gca()
axes.set_xlim([-1, 1])
axes.set_ylim([-1, 1])
pyplot.set_cmap("Set1")
colors = cm.Set1(numpy.linspace(0, 1, n_components))
for i in range(n_components):
idx = range(i * N, (i + 1) * N)
pyplot.scatter(points[idx, 0], points[idx, 1], alpha=0.3, c=colors[i])
for j in range(8):
axes.add_patch(
patches.Ellipse(
mu[:, i],
width=(j + 1) * std_dev[0, i],
height=(j + 1) * std_dev[1, i],
fill=False,
color=[0.0, 0.0, 1.0, 1.0 / (0.5 * j + 1)],
)
)
pyplot.title("GMM")
pyplot.xlabel("X")
pyplot.ylabel("Y")
if export:
if not os.path.exists("images/"):
os.mkdir("images/")
pyplot.savefig("images/2D_GMM_demonstration.png", dpi=100, format="png")
pyplot.show()