Source code for mpt4py.geometry.union.polyunion

import scipy.sparse as sp
import numpy as np
from typing import Optional, Union
from matplotlib.axes import Axes
from pyvista import Plotter
from mpt4py.geometry.polyhedron import Polyhedron
from mpt4py.geometry.union.union import ConvexSetUnion
from itertools import combinations


[docs] class PolyUnion(ConvexSetUnion): """Represent a union of polyhedra (represented in the same dimension).""" def __init__(self, *polyhedra: Polyhedron, convex: Optional[bool] = None, overlaps: Optional[bool] = None, connected: Optional[bool] = None, bounded: Optional[bool] = None, fulldim: Optional[bool] = None): if not all(isinstance(polyhedron, Polyhedron) for polyhedron in polyhedra): raise TypeError("All elements must be instances of Polyhedron.") super().__init__(*polyhedra) assert isinstance(convex, bool) or convex is None, "convex must be a boolean or None" assert isinstance(overlaps, bool) or overlaps is None, "overlaps must be a boolean or None" assert isinstance(connected, bool) or connected is None, "connected must be a boolean or None" assert isinstance(bounded, bool) or bounded is None, "bounded must be a boolean or None" assert isinstance(fulldim, bool) or fulldim is None, "fulldim must be a boolean or None" # This lower angular matrix stores the overlapping relationship of the polyhera. # e.g., if we have 3 polyhedra, where 2 overlaps with 0, 1, but 0 and 1 do not overlap # | 0 1 2 # ------------ # 0 | x # 1 | 0 x # 2 | 1 1 x overlap_matrix = np.ones((len(self)-1, len(self)-1), dtype=int) overlap_matrix = sp.tril(overlap_matrix, format='lil') if overlaps is None: overlap_matrix = -1 * overlap_matrix elif overlaps is False: overlap_matrix = 0 * overlap_matrix self._internal = {"is_overlapping": overlaps, "is_connected": connected, "is_convex": convex, "is_full_dim": [fulldim for _ in range(self.__len__())], "is_bounded": [bounded for _ in range(self.__len__())]} @property def is_convex(self) -> bool: r"""Determine if the union of polyhedra is convex. .. warning:: This method is very computationally demanding and is suitable for unions with small number of polyhedra. """ raise NotImplementedError @property def is_bounded(self) -> bool: """Determine if the union is built from bounded polyhedra.""" if None in self._internal["is_bounded"]: self._internal["is_bounded"] = [cvxset.is_bounded for cvxset in self._cvxsets] return all(self._internal["is_bounded"]) @property def is_connected(self) -> bool: """Determine if the union of polyhedra form a connected union. """ raise NotImplementedError @property def is_full_dim(self) -> bool: """Determine if the union is built from full-dimensional polyhedra. """ if None in self._internal["is_full_dim"]: self._internal["is_full_dim"] = [cvxset.is_full_dim for cvxset in self._cvxsets] return all(self._internal["is_full_dim"]) @property def is_overlapping(self) -> bool: """Determine if the union of polyhedra is overlapping. .. note:: This function considers following two cases to detect overlaps: 1. If two full-dimensional polyhedra overlap, then the intersection of these polyhedra must be full-dimensional. 2. If low-dimensional and full-dimensional polyhedra overlap, then the intersection of these polyhedra must not be empty. .. warning:: This method is computationally demanding and is suitable for unions with small number of polyhedra. """ if self._internal["is_overlapping"] is None: if self.__len__() <= 1: self._internal["is_overlapping"] = False else: self._internal["is_overlapping"] = False for set1, set2 in combinations(self._cvxsets, 2): if set1.is_full_dim and set2.is_full_dim: if set1.intersect(set2).is_full_dim: self._internal["is_overlapping"] = True break else: if not set1.intersect(set2).is_empty: self._internal["is_overlapping"] = True break return self._internal["is_overlapping"]
[docs] def remove(self, index: int): """Remove the polyhedron at the specified index.""" if index < 0 or index >= len(self._cvxsets): raise IndexError("Index out of range.") del self._cvxsets[index] self._internal["is_bounded"].pop(index) self._internal["is_full_dim"].pop(index) self._internal["is_overlapping"] = None self._internal["is_connected"] = None
[docs] def add(self, polyhedron: Polyhedron): """Add a polyhedron to the union.""" if not isinstance(polyhedron, Polyhedron): raise TypeError("The element must be an instance of Polyhedron.") self._cvxsets.append(polyhedron) self._internal["is_bounded"].append(polyhedron.is_bounded) self._internal["is_full_dim"].append(polyhedron.is_full_dim) self._internal["is_overlapping"] = None self._internal["is_connected"] = None
[docs] def convex_hull(self) -> Polyhedron: """Compute the convex hull for union of polyhedra. The convex hull of the union of polyhedra is defined as the minimal convex set that contains all polyhedra. Returns: Polyhedron: the convex hull of the union of polyhedra. """ raise NotImplementedError
[docs] def merge(self): """ Simplify the union of polyhedra by merging the neighboring polyhedra if their union is convex. The algorithm cycles through the regions and checks if any two regions form a convex union. If so, the algorithm combines them in one region, and continues checking the remaining regions. To improve the solution, multiple merging loops can be enabled in options. """ raise NotImplementedError
[docs] def plot(self, ax: Union[Axes, Plotter], **kwargs): r""" Plot the Polyhedron. .. Note:: Require V-representation. Will be computed if necessary. """ if self.dim not in [2, 3]: raise ValueError("Can only plot 2D or 3D polyhedra.") if 'color' not in kwargs: import matplotlib.pyplot as plt color_order = plt.rcParams['axes.prop_cycle'].by_key()['color'] if isinstance(ax, Axes): from mpt4py.geometry.visualization.plot_matplotlib import MatplotlibPlotter plotter_protocal = MatplotlibPlotter(ax) elif isinstance(ax, Plotter): from mpt4py.geometry.visualization.plot_pyvista import PyvistaPlotter plotter_protocal = PyvistaPlotter(ax) else: raise NotImplementedError("Unsupported backend. Supported backends are matplotlib and pyvista.") for i, polyhedron in enumerate(self._cvxsets): if 'color' in kwargs: plotter_protocal.plot_convexhull(polyhedron.V, polyhedron.R, **kwargs) else: plotter_protocal.plot_convexhull(polyhedron.V, polyhedron.R, color=color_order[i % len(color_order)], **kwargs)
[docs] def fplot(self, ax: Union[Axes, Plotter], func_name: Optional[str] = None, **kwargs): """ Plot the functions associated with the polyunion. """ if self.dim != 2: raise NotImplementedError("Function plotting is only implemented for 2D polyunions.") # TODO: also need the output of the function to be dimension 1 if 'color' not in kwargs: import matplotlib.pyplot as plt color_order = plt.rcParams['axes.prop_cycle'].by_key()['color'] for i, polyhedron in enumerate(self._cvxsets): if 'color' in kwargs: polyhedron.fplot(ax, func_name, **kwargs) else: polyhedron.fplot(ax, func_name, color=color_order[i % len(color_order)], **kwargs)
# def toC(self, filename="polyunion_data.c", var_name="polyunion"): # """ # Generate C code for the union of polyhedra. # """ # string = "#define NUM_REGION {}\n".format(self.__len__()) # # string_A = '' # string_b = '' # string_m = '' # number of constraints # for poly in self._cvxsets: # assert isinstance(poly, Polyhedron) # string_A += poly.A.flatten() # # with open(filename, "w") as f: # f.write("#include <stdbool.h>\n") # f.write("#include <math.h>\n\n") # # f.write("// check if Ax <= b holds\n") # f.write("bool point_in_polyhedron(const double* x, const double* A, const double* b, int m, int n) {\n") # f.write(" for (int i = 0; i < m; ++i) {\n") # f.write(" double sum = 0;\n") # f.write(" for (int j = 0; j < n; ++j) {\n") # f.write(" sum += A[i*n + j] * x[j];\n") # f.write(" }\n") # f.write(" if (sum > b[i] + 1e-9) return false;\n") # f.write(" }\n") # f.write(" return true;\n") # f.write("}\n\n") # # poly_count = len(self._cvxsets) # # for i, poly in enumerate(zip(range(self.__len__()), self._cvxsets)): # assert isinstance(poly, Polyhedron) # A_flat = poly.A.flatten() # m = poly.A.shape[0] # f.write(f"// Polyhedron {i}\n") # f.write(f"const double {var_name}_A_{i}[] = {{ " + ", ".join(map(str, A_flat)) + " };\n") # f.write(f"const double {var_name}_b_{i}[] = {{ " + ", ".join(map(str, b)) + " };\n\n") # # f.write(f"int find_polyhedron(const double* x) {{\n") # for i, (A, b) in enumerate(polyunion): # m = A.shape[0] # f.write(f" if (point_in_polyhedron(x, {var_name}_A_{i}, {var_name}_b_{i}, {m}, {dim})) return {i};\n") # f.write(" return -1; // not found\n") # f.write("}\n") # # print(f"C code written to {filename}")