Skip to content

Commit

Permalink
Cleanup for Optimal Control Ops (#1045)
Browse files Browse the repository at this point in the history
* Blockwise optimal linear control ops

* Add jax rewrite to eliminate `BilinearSolveDiscreteLyapunov`

* set `solve_discrete_lyapunov` method default to bilinear

* Appease mypy

* restore method dispatching

* Use `pt.vectorize` on base `solve_discrete_lyapunov` case

* Apply JAX rewrite before canonicalization

* Improve tests

* Remove useless warning filters

* Fix local_blockwise_alloc rewrite

The rewrite was squeezing too many dimensions of the alloced value, when this didn't have dummy expand dims to the left.

* Fix float32 tests

* Test against complex inputs

* Appease ViPy (Vieira-py type checking)

* Remove condition from `TensorLike` import

* Infer dtype from `node.outputs.type.dtype`

* Remove unused mypy ignore

* Don't manually set dtype of output

Revert change to `_solve_discrete_lyapunov`

* Set dtype of Op outputs

---------

Co-authored-by: ricardoV94 <[email protected]>
  • Loading branch information
jessegrabowski and ricardoV94 authored Oct 24, 2024
1 parent dae731d commit fffb84c
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 124 deletions.
4 changes: 2 additions & 2 deletions pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def local_blockwise_alloc(fgraph, node):
value, *shape = inp.owner.inputs

# Check what to do with the value of the Alloc
squeezed_value = _squeeze_left(value, batch_ndim)
missing_ndim = len(shape) - value.type.ndim
missing_ndim = inp.type.ndim - value.type.ndim
squeezed_value = _squeeze_left(value, (batch_ndim - missing_ndim))
if (
(((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:])
!= inp.type.broadcastable[batch_ndim:]
Expand Down
23 changes: 23 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

from pytensor import Variable
from pytensor import tensor as pt
from pytensor.compile import optdb
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
in2out,
node_rewriter,
)
from pytensor.scalar.basic import Mul
Expand Down Expand Up @@ -45,9 +47,11 @@
Cholesky,
Solve,
SolveBase,
_bilinear_solve_discrete_lyapunov,
block_diag,
cholesky,
solve,
solve_discrete_lyapunov,
solve_triangular,
)

Expand Down Expand Up @@ -966,3 +970,22 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
non_eye_input = pt.shape_padaxis(non_eye_input, -2)

return [eye_input * (non_eye_input**0.5)]


@node_rewriter([_bilinear_solve_discrete_lyapunov])
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
"""
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
"""
A, B = (cast(TensorVariable, x) for x in node.inputs)
result = solve_discrete_lyapunov(A, B, method="direct")

return [result]


optdb.register(
"jax_bilinaer_lyapunov_to_direct",
in2out(jax_bilinaer_lyapunov_to_direct),
"jax",
position=0.9, # Run before canonicalization
)
194 changes: 127 additions & 67 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing
import warnings
from functools import reduce
from typing import TYPE_CHECKING, Literal, cast
from typing import Literal, cast

import numpy as np
import scipy.linalg
Expand All @@ -11,7 +11,7 @@
import pytensor.tensor as pt
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor import as_tensor_variable
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm
from pytensor.tensor.blockwise import Blockwise
Expand All @@ -21,9 +21,6 @@
from pytensor.tensor.variable import TensorVariable


if TYPE_CHECKING:
from pytensor.tensor import TensorLike

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -777,7 +774,16 @@ def perform(self, node, inputs, outputs):


class SolveContinuousLyapunov(Op):
"""
Solves a continuous Lyapunov equation, :math:`AX + XA^H = B`, for :math:`X.
Continuous time Lyapunov equations are special cases of Sylvester equations, :math:`AX + XB = C`, and can be solved
efficiently using the Bartels-Stewart algorithm. For more details, see the docstring for
scipy.linalg.solve_continuous_lyapunov
"""

__props__ = ()
gufunc_signature = "(m,m),(m,m)->(m,m)"

def make_node(self, A, B):
A = as_tensor_variable(A)
Expand All @@ -792,7 +798,8 @@ def perform(self, node, inputs, output_storage):
(A, B) = inputs
X = output_storage[0]

X[0] = scipy.linalg.solve_continuous_lyapunov(A, B)
out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
Expand All @@ -813,7 +820,41 @@ def grad(self, inputs, output_grads):
return [A_bar, Q_bar]


_solve_continuous_lyapunov = Blockwise(SolveContinuousLyapunov())


def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
"""
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
Parameters
----------
A: TensorLike
Square matrix of shape ``N x N``.
Q: TensorLike
Square matrix of shape ``N x N``.
Returns
-------
X: TensorVariable
Square matrix of shape ``N x N``
"""

return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))


class BilinearSolveDiscreteLyapunov(Op):
"""
Solves a discrete lyapunov equation, :math:`AXA^H - X = Q`, for :math:`X.
The solution is computed by first transforming the discrete-time problem into a continuous-time form. The continuous
time lyapunov is a special case of a Sylvester equation, and can be efficiently solved. For more details, see the
docstring for scipy.linalg.solve_discrete_lyapunov
"""

gufunc_signature = "(m,m),(m,m)->(m,m)"

def make_node(self, A, B):
A = as_tensor_variable(A)
B = as_tensor_variable(B)
Expand All @@ -827,7 +868,10 @@ def perform(self, node, inputs, output_storage):
(A, B) = inputs
X = output_storage[0]

X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear")
out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
out_dtype
)

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
Expand All @@ -849,83 +893,83 @@ def grad(self, inputs, output_grads):
return [A_bar, Q_bar]


_solve_continuous_lyapunov = SolveContinuousLyapunov()
_solve_bilinear_direct_lyapunov = cast(typing.Callable, BilinearSolveDiscreteLyapunov())
_bilinear_solve_discrete_lyapunov = Blockwise(BilinearSolveDiscreteLyapunov())


def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
A_ = as_tensor_variable(A)
Q_ = as_tensor_variable(Q)
def _direct_solve_discrete_lyapunov(
A: TensorVariable, Q: TensorVariable
) -> TensorVariable:
r"""
Directly solve the discrete Lyapunov equation :math:`A X A^H - X = Q` using the kronecker method of Magnus and
Neudecker.
This involves constructing and inverting an intermediate matrix :math:`A \otimes A`, with shape :math:`N^2 x N^2`.
As a result, this method scales poorly with the size of :math:`N`, and should be avoided for large :math:`N`.
"""

if "complex" in A_.type.dtype:
AA = kron(A_, A_.conj())
if A.type.dtype.startswith("complex"):
AxA = kron(A, A.conj())
else:
AA = kron(A_, A_)
AxA = kron(A, A)

eye = pt.eye(AxA.shape[-1])

X = solve(pt.eye(AA.shape[0]) - AA, Q_.ravel())
return cast(TensorVariable, reshape(X, Q_.shape))
vec_Q = Q.ravel()
vec_X = solve(eye - AxA, vec_Q, b_ndim=1)

return cast(TensorVariable, reshape(vec_X, A.shape))


def solve_discrete_lyapunov(
A: "TensorLike", Q: "TensorLike", method: Literal["direct", "bilinear"] = "direct"
A: TensorLike,
Q: TensorLike,
method: Literal["direct", "bilinear"] = "bilinear",
) -> TensorVariable:
"""Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`.
Parameters
----------
A
Square matrix of shape N x N; must have the same shape as Q
Q
Square matrix of shape N x N; must have the same shape as A
method
Solver method used, one of ``"direct"`` or ``"bilinear"``. ``"direct"``
solves the problem directly via matrix inversion. This has a pure
PyTensor implementation and can thus be cross-compiled to supported
backends, and should be preferred when ``N`` is not large. The direct
method scales poorly with the size of ``N``, and the bilinear can be
A: TensorLike
Square matrix of shape N x N
Q: TensorLike
Square matrix of shape N x N
method: str, one of ``"direct"`` or ``"bilinear"``
Solver method used, . ``"direct"`` solves the problem directly via matrix inversion. This has a pure
PyTensor implementation and can thus be cross-compiled to supported backends, and should be preferred when
``N`` is not large. The direct method scales poorly with the size of ``N``, and the bilinear can be
used in these cases.
Returns
-------
Square matrix of shape ``N x N``, representing the solution to the
Lyapunov equation
X: TensorVariable
Square matrix of shape ``N x N``. Solution to the Lyapunov equation
"""
if method not in ["direct", "bilinear"]:
raise ValueError(
f'Parameter "method" must be one of "direct" or "bilinear", found {method}'
)

if method == "direct":
return _direct_solve_discrete_lyapunov(A, Q)
if method == "bilinear":
return cast(TensorVariable, _solve_bilinear_direct_lyapunov(A, Q))


def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
"""Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
Parameters
----------
A
Square matrix of shape ``N x N``; must have the same shape as `Q`.
Q
Square matrix of shape ``N x N``; must have the same shape as `A`.
A = as_tensor_variable(A)
Q = as_tensor_variable(Q)

Returns
-------
Square matrix of shape ``N x N``, representing the solution to the
Lyapunov equation
if method == "direct":
signature = BilinearSolveDiscreteLyapunov.gufunc_signature
X = pt.vectorize(_direct_solve_discrete_lyapunov, signature=signature)(A, Q)
return cast(TensorVariable, X)

"""
elif method == "bilinear":
return cast(TensorVariable, _bilinear_solve_discrete_lyapunov(A, Q))

return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
else:
raise ValueError(f"Unknown method {method}")


class SolveDiscreteARE(pt.Op):
class SolveDiscreteARE(Op):
__props__ = ("enforce_Q_symmetric",)
gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"

def __init__(self, enforce_Q_symmetric=False):
def __init__(self, enforce_Q_symmetric: bool = False):
self.enforce_Q_symmetric = enforce_Q_symmetric

def make_node(self, A, B, Q, R):
Expand All @@ -946,9 +990,8 @@ def perform(self, node, inputs, output_storage):
if self.enforce_Q_symmetric:
Q = 0.5 * (Q + Q.T)

X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(
node.outputs[0].type.dtype
)
out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
Expand All @@ -960,14 +1003,16 @@ def grad(self, inputs, output_grads):
(dX,) = output_grads
X = self(A, B, Q, R)

K_inner = R + pt.linalg.matrix_dot(B.T, X, B)
K_inner_inv = pt.linalg.solve(K_inner, pt.eye(R.shape[0]))
K = matrix_dot(K_inner_inv, B.T, X, A)
K_inner = R + matrix_dot(B.T, X, B)

# K_inner is guaranteed to be symmetric, because X and R are symmetric
K_inner_inv_BT = solve(K_inner, B.T, assume_a="sym")
K = matrix_dot(K_inner_inv_BT, X, A)

A_tilde = A - B.dot(K)

dX_symm = 0.5 * (dX + dX.T)
S = solve_discrete_lyapunov(A_tilde, dX_symm).astype(dX.type.dtype)
S = solve_discrete_lyapunov(A_tilde, dX_symm)

A_bar = 2 * matrix_dot(X, A_tilde, S)
B_bar = -2 * matrix_dot(X, A_tilde, S, K.T)
Expand All @@ -977,30 +1022,45 @@ def grad(self, inputs, output_grads):
return [A_bar, B_bar, Q_bar, R_bar]


def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
def solve_discrete_are(
A: TensorLike,
B: TensorLike,
Q: TensorLike,
R: TensorLike,
enforce_Q_symmetric: bool = False,
) -> TensorVariable:
"""
Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
Discrete-time Algebraic Riccati equations arise in the context of optimal control and filtering problems, as the
solution to Linear-Quadratic Regulators (LQR), Linear-Quadratic-Guassian (LQG) control problems, and as the
steady-state covariance of the Kalman Filter.
Such problems typically have many solutions, but we are generally only interested in the unique *stabilizing*
solution. This stable solution, if it exists, will be returned by this function.
Parameters
----------
A: ArrayLike
A: TensorLike
Square matrix of shape M x M
B: ArrayLike
B: TensorLike
Square matrix of shape M x M
Q: ArrayLike
Q: TensorLike
Symmetric square matrix of shape M x M
R: ArrayLike
R: TensorLike
Square matrix of shape N x N
enforce_Q_symmetric: bool
If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry
Returns
-------
X: pt.matrix
X: TensorVariable
Square matrix of shape M x M, representing the solution to the DARE
"""

return cast(TensorVariable, SolveDiscreteARE(enforce_Q_symmetric)(A, B, Q, R))
return cast(
TensorVariable, Blockwise(SolveDiscreteARE(enforce_Q_symmetric))(A, B, Q, R)
)


def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:
Expand Down
Loading

0 comments on commit fffb84c

Please sign in to comment.