import numpy as np
import cvxpy as cp
from typing import List, Tuple, Optional, Union
import warnings
from mpt4py.geometry import Polyhedron
from mpt4py.systems import LTISystem
from mpt4py.controllers.base_controller import ControllerBase
from mpt4py.base import Matrix, Vector
from mpt4py.util import is_positive_definite, is_positive_semidefinite
[docs]
class MPCController(ControllerBase):
"""
Model Predictive Controller.
"""
def __init__(self,
sys: LTISystem,
# prediction horizon
N: Optional[int] = None,
T: Optional[float] = None,
dt: Optional[float] = None
) -> None:
super().__init__(sys=sys)
self._sys = sys
# TODO: discrete if the given system is in continuous-time. Be careful about the Q, R, you might need to multiply by dt
# if sys.isctime(strict=True):
# provided_count = 0
# provided_count += (1 if N is not None else 0)
# provided_count += (1 if dt is not None else 0)
# provided_count += (1 if T is not None else 0)
# if provided_count < 2:
# raise ValueError("For a continuous-time system, at least 2 among T, N and dt must be provided.")
# if provided_count == 3:
# if N != np.floor(T / dt):
# raise ValueError(
# "For a continuous-time system, if T, N and dt are all provided, they must satisfy T = N*dt.")
# sys_disc_time = LTISystem(sys.sample(dt))
# sys_disc_time.Q = self.Q * dt
# sys_disc_time.R = self.R * dt
# sys_disc_time.Qf = self.Qf * dt
# self._sys = sys_disc_time
# self._dt = dt
# self._N = N
# self._T = T
if not sys.isdtime():
raise NotImplementedError("In current version, the system must be in discrete time.")
if sys.isdtime(strict=True):
# if the sampling time is provided, ignore it and just use sys.dt
if dt is not None:
warnings.warn(
"The system is already in discrete-time. The specified sampling time dt={} will be omitted".format(
dt))
self._dt = self._sys.dt
if T is not None:
warnings.warn(
"The system is in discrete-time. The specified prediction horizon T={} will be omitted".format(
dt))
if N is None:
raise ValueError("For a discrete-time system, N is required.")
elif not (isinstance(N, (int, np.integer)) and N > 0):
raise ValueError("N must be an positive integer.")
self._N = N
self._ocp = None
def _setup_ocp(self) -> None:
"""
Set up the optimization problem of the MPC controller.
"""
if self._sys.isctime():
self._sys = self._sys.sample(self._dt)
self._x = cp.Variable((self.nx, self._N + 1), name='x')
self._u = cp.Variable((self.nu, self._N), name='u')
self._x0 = cp.Parameter((self.nx,), name='x0')
cost = 0
constraints = []
constraints.append(self._x[:,0] == self._x0)
for i in range(self.N):
cost += cp.quad_form(self._x[:,i] - self.x_ref, self.Q)
cost += cp.quad_form(self._u[:,i] - self._sys.u_ref, self.R)
constraints.append(self._x[:,i+1] == self.A @ self._x[:,i] + self.B @ self._u[:,i])
if self.state_constraints is not None:
constraints.append(self.state_constraints.A @ self._x[:,i] <= self.state_constraints.b)
# TODO: shall we also consider the equality constraints?
if self.input_constraints is not None:
constraints.append(self.input_constraints.A @ self._u[:,i] <= self.input_constraints.b)
# Terminal cost
cost += cp.quad_form(self._sys.x_ref - self._sys.x_ref, self.Qf)
# Terminal constraint
if self.terminal_constraints is not None:
constraints.append(self.terminal_constraints.A @ self._x[:,-1] <= self.terminal_constraints.b)
elif self.state_constraints is not None:
constraints.append(self.state_constraints.A @ self._x[:,-1] <= self.state_constraints.b)
self._ocp = cp.Problem(cp.Minimize(cost), constraints)
[docs]
def evaluate(self, x0: Vector) -> Tuple[Vector, str]:
"""
solve the mpc problem for a given initial state x0
"""
if self._ocp is None:
self._setup_ocp()
self._x0.value = x0
self._ocp.solve()
if not self._ocp.status == cp.OPTIMAL:
raise RuntimeError('The solver returns status {}'.format(self._ocp.status))
return self._u[:,0].value
@property
def nx(self) -> int:
return self._sys.nx
@property
def nu(self) -> int:
return self._sys.nu
@property
def A(self) -> Matrix:
return self._sys.A
@property
def B(self) -> Matrix:
return self._sys.B
@property
def N(self) -> int:
return self._N
@N.setter
def N(self, N):
if not isinstance(N, int):
raise TypeError("N must be a positive integer.")
self._N = N
@property
def Q(self) -> Matrix:
return self._sys.Q
@Q.setter
def Q(self, Q):
if np.ndim(Q) != 2:
raise ValueError("Q must be a matrix.")
if not is_positive_semidefinite(Q):
raise ValueError("Q must be a positive semi-definite matrix.")
self._sys.Q = Q
@property
def R(self) -> Matrix:
return self._sys.R
@R.setter
def R(self, R):
if np.ndim(R) != 2:
raise ValueError("R must be a matrix.")
if not is_positive_semidefinite(R):
raise ValueError("R must be a positive semi-definite matrix.")
self._sys.R = R
@property
def Qf(self) -> Matrix:
return self._sys.Qf
@Qf.setter
def Qf(self, Qf):
if np.ndim(Qf) != 2:
raise ValueError("Qf must be a matrix.")
if not is_positive_semidefinite(Qf):
raise ValueError("Qf must be a positive semi-definite matrix.")
self._sys.Qf = Qf
@property
def x_ref(self) -> Vector:
return self._sys.x_ref
@x_ref.setter
def x_ref(self, x_ref):
self._sys.x_ref = x_ref
@property
def u_ref(self) -> Vector:
return self._sys.u_ref
@u_ref.setter
def u_ref(self, u_ref):
self._sys.u_ref = u_ref
@property
def state_constraints(self) -> Polyhedron:
return self._sys.state_constraints
@state_constraints.setter
def state_constraints(self, state_set: Polyhedron):
if not isinstance(state_set, Polyhedron):
raise TypeError("state_constraints must be a Polyhedron.")
# TODO: shall we require X to be non-empty or full-dimensional?
if state_set.is_empty:
warnings.warn("The state constraints results in an empty set.")
self._sys.state_constraints = state_set
@property
def input_constraints(self) -> Polyhedron:
return self._sys.input_constraints
@input_constraints.setter
def input_constraints(self, input_set: Polyhedron):
if not isinstance(input_set, Polyhedron):
raise TypeError("input_constraints must be a Polyhedron.")
if input_set.is_empty:
warnings.warn("The input constraints results in an empty set.")
self._sys.input_constraints = input_set
@property
def terminal_constraints(self) -> Polyhedron:
return self._sys.terminal_constraints
@terminal_constraints.setter
def terminal_constraints(self, terminal_set: Polyhedron):
if not isinstance(terminal_set, Polyhedron):
raise TypeError("terminal_constraints must be a Polyhedron.")
if terminal_set.is_empty:
warnings.warn("The terminal constraints results in an empty set.")
self._sys.terminal_constraints = terminal_set
@property
def x_opt_traj(self):
if self._ocp is not None and self._ocp.status == cp.OPTIMAL:
return self._x.value
else:
return None
@property
def u_opt_traj(self):
if self._ocp is not None and self._ocp.status == cp.OPTIMAL:
return self._u.value
else:
return None
[docs]
def to_explicit(self):
"""
Convert the MPC problem to an explicit form
"""
raise NotImplementedError