"""
Matplotlib plotting backend for MPT.
Implement the PlotterProtocol.
"""
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from typing import Optional
import numpy as np
from mpt4py.base import Matrix, Vector
from mpt4py.tolerances import tolerances
from .plot import PolyhedralPlottingData, PlotterProtocol
[docs]
class MatplotlibPlotter(PlotterProtocol):
def __init__(self, ax: Optional[Axes] = None):
if isinstance(ax, Axes):
self.ax = ax
self.fig = ax.figure
else:
self.fig = plt.figure()
self.ax = self.fig.add_subplot(111)
self.is_3d = False # to be determined from plotting data
[docs]
def plot_convexhull(self, points: Matrix, rays: Optional[Matrix] = None,
color: Optional[str] = 'blue',
opacity: Optional[float] = 1.,
show_edges: Optional[bool] = True,
edge_width: Optional[float] = 1.,
edge_color: Optional[str] = 'black',
show_vertices: Optional[bool] = False,
vertex_size: Optional[float] = 20.,
vertex_color: Optional[str] = 'black',
**kwargs):
rays = np.empty((0, points.shape[1])) if rays is None else np.array(rays)
if np.size(rays):
raise NotImplementedError("Rays are not supported in matplotlib plotting.")
data = PolyhedralPlottingData(points, rays)
polyhedron_connectivity = []
for facet in data.facets:
polyhedron_connectivity.extend([len(facet), *facet])
self.polyhedron_connectivity = [len(polyhedron_connectivity)+1, len(data.facets), *polyhedron_connectivity]
self.is_3d = True if data.points.shape[1] == 3 else False
if self.is_3d:
if self.ax is None:
self.ax = self.fig.add_subplot(111, projection='3d')
elif not isinstance(self.ax, Axes3D):
raise RuntimeError("The polyhedron is 3D but the axes is 2D. Consider using a 2D axes or projecting the polyhedron to 2D.")
else:
if self.ax is None:
self.ax = self.fig.add_subplot(111)
elif isinstance(self.ax, Axes3D):
# the polyhedron is 2D but the axes is 3D, lift the polyhedron to 3D for plotting by adding a zero z-coordinate
self.is_3d = True
data = PolyhedralPlottingData(np.hstack([points, np.zeros((points.shape[0], 1))]), np.hstack([rays, np.zeros((rays.shape[0], 1))]))
# raise RuntimeError("The polyhedron is 2D but the axes is 3D. Consider using a 3D axes or lifting the polyhedron to 3D.")
# Plot vertices. This should be done before plotting facets since non-transparent facets should cover the vertices
if show_vertices:
if self.is_3d:
self.ax.scatter(data.vertices[:, 0], data.vertices[:, 1], data.vertices[:, 2], s=vertex_size, marker='o', color=vertex_color)
else:
self.ax.scatter(data.vertices[:, 0], data.vertices[:, 1], s=vertex_size, marker='o', color=vertex_color)
# Plot facets and edges
import matplotlib.colors as mcolors
face_color_rgba = mcolors.to_rgba(color, alpha=opacity)
edge_color_rgba = mcolors.to_rgba(edge_color, alpha=1.0) if show_edges else 'none'
if self.is_3d:
poly3d = [data.points[facet] for facet in data.facets]
polycollection = Poly3DCollection(poly3d, facecolor=face_color_rgba, edgecolor=edge_color_rgba, alpha=opacity, linewidth=edge_width)
self.ax.add_collection3d(polycollection, **kwargs)
else:
polygon = plt.Polygon(data.vertices, facecolor=face_color_rgba, edgecolor=edge_color, linewidth=edge_width, **kwargs)
self.ax.add_patch(polygon)
# adding a patch with ax.add_patch() does not automatically rescale the axes to fit the patch geometry.
# (Unlike ax.plot, ax.scatter, ax.fill which update limits immediately.)
# update axes limits to include the new patch
self.ax.relim() # recompute limits based on artists
self.ax.autoscale() # autoscale view to those limits
# [optionally] self.ax.fill(data.vertices[:,0], data.vertices[:,1], color=color, edgecolor=edge_color, alpha=opacity, linewidth=edge_width, **kwargs) # filled polygon
[docs]
def plot_ellipsoid(self,
P: Matrix,
c: Optional[Vector] = None,
r: float = 1.0,
color: Optional[str] = 'lightblue',
opacity: Optional[float] = 1.,
show_edges: Optional[bool] = False,
line_width: Optional[float] = 2,
edge_color: Optional[str] = 'black'):
assert P.ndim == 2 and P.shape[0] == P.shape[1], "P must be a square matrix."
assert P.shape[0] in [2, 3], "P must be a 2x2 or 3x3 matrix."
c = np.zeros(P.shape[0]) if c is None else np.array(c)
assert c.shape[0] == P.shape[0], "c must have the same dimension as P."
self.is_3d = True if P.shape[0] == 3 else False
eigvals, eigvecs = np.linalg.eigh(P) # P = Q * Lambda * Q^T
radii = [r / np.sqrt(eigval) if abs(eigval) > tolerances['zero'] else 0.0 for eigval in eigvals] # deal with lower-dime ellipsoids
num = 100
if self.is_3d:
rx, ry, rz = radii
u = np.linspace(0, 2*np.pi, num)
v = np.linspace(0, np.pi, num)
x_sphere = rx * np.outer(np.cos(u), np.sin(v))
y_sphere = ry * np.outer(np.sin(u), np.sin(v))
z_sphere = rz * np.outer(np.ones_like(u), np.cos(v))
xyz_sphere = np.stack([x_sphere, y_sphere, z_sphere], axis=0)
xyz = eigvecs @ xyz_sphere.reshape(3, -1) + c[:, np.newaxis] # rotation and translation, shape (3, num*num)
xyz = xyz.reshape(3, num, num)
self.fig = plt.figure() if self.fig is None else self.fig
self.ax = self.fig.add_subplot(111, projection='3d') if self.ax is None else self.ax
plt.sca(self.ax)
self.ax.plot_surface(xyz[0], xyz[1], xyz[2], rstride=4, cstride=4, color=color, alpha=opacity, linewidth=line_width, edgecolor=edge_color)
else:
rx, ry = radii
u = np.linspace(0, 2*np.pi, num)
xy_circle = np.stack([np.cos(u), np.sin(u)], axis=0)
xy = eigvecs @ np.diag(radii) @ xy_circle + c.reshape(2, 1)
self.fig = plt.figure() if self.fig is None else self.fig
self.ax = self.fig.add_subplot(111) if self.ax is None else self.ax
if show_edges:
self.ax.fill(xy[0], xy[1], facecolor=color, alpha=opacity, linewidth=line_width, edgecolor=edge_color)
else:
self.ax.fill(xy[0], xy[1], facecolor=color, alpha=opacity, edgecolor='none')
[docs]
def show(self):
if self.ax is not None:
plt.show()
else:
raise ValueError("No plot to show.")