From fffb84c16af14bf4cdf08bb25c97874d4a7000ea Mon Sep 17 00:00:00 2001 From: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> Date: Thu, 24 Oct 2024 13:45:05 +0200 Subject: [PATCH] Cleanup for Optimal Control Ops (#1045) * Blockwise optimal linear control ops * Add jax rewrite to eliminate `BilinearSolveDiscreteLyapunov` * set `solve_discrete_lyapunov` method default to bilinear * Appease mypy * restore method dispatching * Use `pt.vectorize` on base `solve_discrete_lyapunov` case * Apply JAX rewrite before canonicalization * Improve tests * Remove useless warning filters * Fix local_blockwise_alloc rewrite The rewrite was squeezing too many dimensions of the alloced value, when this didn't have dummy expand dims to the left. * Fix float32 tests * Test against complex inputs * Appease ViPy (Vieira-py type checking) * Remove condition from `TensorLike` import * Infer dtype from `node.outputs.type.dtype` * Remove unused mypy ignore * Don't manually set dtype of output Revert change to `_solve_discrete_lyapunov` * Set dtype of Op outputs --------- Co-authored-by: ricardoV94 --- pytensor/tensor/rewriting/blockwise.py | 4 +- pytensor/tensor/rewriting/linalg.py | 23 +++ pytensor/tensor/slinalg.py | 194 ++++++++++++++++--------- tests/link/jax/test_slinalg.py | 25 ++++ tests/tensor/test_slinalg.py | 179 ++++++++++++++++------- 5 files changed, 301 insertions(+), 124 deletions(-) diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index b62e6a73e7..97046bffe2 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -127,8 +127,8 @@ def local_blockwise_alloc(fgraph, node): value, *shape = inp.owner.inputs # Check what to do with the value of the Alloc - squeezed_value = _squeeze_left(value, batch_ndim) - missing_ndim = len(shape) - value.type.ndim + missing_ndim = inp.type.ndim - value.type.ndim + squeezed_value = _squeeze_left(value, (batch_ndim - missing_ndim)) if ( (((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:]) != inp.type.broadcastable[batch_ndim:] diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 798d590d7f..a2418147cf 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -4,9 +4,11 @@ from pytensor import Variable from pytensor import tensor as pt +from pytensor.compile import optdb from pytensor.graph import Apply, FunctionGraph from pytensor.graph.rewriting.basic import ( copy_stack_trace, + in2out, node_rewriter, ) from pytensor.scalar.basic import Mul @@ -45,9 +47,11 @@ Cholesky, Solve, SolveBase, + _bilinear_solve_discrete_lyapunov, block_diag, cholesky, solve, + solve_discrete_lyapunov, solve_triangular, ) @@ -966,3 +970,22 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): non_eye_input = pt.shape_padaxis(non_eye_input, -2) return [eye_input * (non_eye_input**0.5)] + + +@node_rewriter([_bilinear_solve_discrete_lyapunov]) +def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply): + """ + Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX + """ + A, B = (cast(TensorVariable, x) for x in node.inputs) + result = solve_discrete_lyapunov(A, B, method="direct") + + return [result] + + +optdb.register( + "jax_bilinaer_lyapunov_to_direct", + in2out(jax_bilinaer_lyapunov_to_direct), + "jax", + position=0.9, # Run before canonicalization +) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 0f2ef5c740..802ca6e543 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -2,7 +2,7 @@ import typing import warnings from functools import reduce -from typing import TYPE_CHECKING, Literal, cast +from typing import Literal, cast import numpy as np import scipy.linalg @@ -11,7 +11,7 @@ import pytensor.tensor as pt from pytensor.graph.basic import Apply from pytensor.graph.op import Op -from pytensor.tensor import as_tensor_variable +from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import basic as ptb from pytensor.tensor import math as ptm from pytensor.tensor.blockwise import Blockwise @@ -21,9 +21,6 @@ from pytensor.tensor.variable import TensorVariable -if TYPE_CHECKING: - from pytensor.tensor import TensorLike - logger = logging.getLogger(__name__) @@ -777,7 +774,16 @@ def perform(self, node, inputs, outputs): class SolveContinuousLyapunov(Op): + """ + Solves a continuous Lyapunov equation, :math:`AX + XA^H = B`, for :math:`X. + + Continuous time Lyapunov equations are special cases of Sylvester equations, :math:`AX + XB = C`, and can be solved + efficiently using the Bartels-Stewart algorithm. For more details, see the docstring for + scipy.linalg.solve_continuous_lyapunov + """ + __props__ = () + gufunc_signature = "(m,m),(m,m)->(m,m)" def make_node(self, A, B): A = as_tensor_variable(A) @@ -792,7 +798,8 @@ def perform(self, node, inputs, output_storage): (A, B) = inputs X = output_storage[0] - X[0] = scipy.linalg.solve_continuous_lyapunov(A, B) + out_dtype = node.outputs[0].type.dtype + X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype) def infer_shape(self, fgraph, node, shapes): return [shapes[0]] @@ -813,7 +820,41 @@ def grad(self, inputs, output_grads): return [A_bar, Q_bar] +_solve_continuous_lyapunov = Blockwise(SolveContinuousLyapunov()) + + +def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable: + """ + Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`. + + Parameters + ---------- + A: TensorLike + Square matrix of shape ``N x N``. + Q: TensorLike + Square matrix of shape ``N x N``. + + Returns + ------- + X: TensorVariable + Square matrix of shape ``N x N`` + + """ + + return cast(TensorVariable, _solve_continuous_lyapunov(A, Q)) + + class BilinearSolveDiscreteLyapunov(Op): + """ + Solves a discrete lyapunov equation, :math:`AXA^H - X = Q`, for :math:`X. + + The solution is computed by first transforming the discrete-time problem into a continuous-time form. The continuous + time lyapunov is a special case of a Sylvester equation, and can be efficiently solved. For more details, see the + docstring for scipy.linalg.solve_discrete_lyapunov + """ + + gufunc_signature = "(m,m),(m,m)->(m,m)" + def make_node(self, A, B): A = as_tensor_variable(A) B = as_tensor_variable(B) @@ -827,7 +868,10 @@ def perform(self, node, inputs, output_storage): (A, B) = inputs X = output_storage[0] - X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear") + out_dtype = node.outputs[0].type.dtype + X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype( + out_dtype + ) def infer_shape(self, fgraph, node, shapes): return [shapes[0]] @@ -849,46 +893,56 @@ def grad(self, inputs, output_grads): return [A_bar, Q_bar] -_solve_continuous_lyapunov = SolveContinuousLyapunov() -_solve_bilinear_direct_lyapunov = cast(typing.Callable, BilinearSolveDiscreteLyapunov()) +_bilinear_solve_discrete_lyapunov = Blockwise(BilinearSolveDiscreteLyapunov()) -def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable: - A_ = as_tensor_variable(A) - Q_ = as_tensor_variable(Q) +def _direct_solve_discrete_lyapunov( + A: TensorVariable, Q: TensorVariable +) -> TensorVariable: + r""" + Directly solve the discrete Lyapunov equation :math:`A X A^H - X = Q` using the kronecker method of Magnus and + Neudecker. + + This involves constructing and inverting an intermediate matrix :math:`A \otimes A`, with shape :math:`N^2 x N^2`. + As a result, this method scales poorly with the size of :math:`N`, and should be avoided for large :math:`N`. + """ - if "complex" in A_.type.dtype: - AA = kron(A_, A_.conj()) + if A.type.dtype.startswith("complex"): + AxA = kron(A, A.conj()) else: - AA = kron(A_, A_) + AxA = kron(A, A) + + eye = pt.eye(AxA.shape[-1]) - X = solve(pt.eye(AA.shape[0]) - AA, Q_.ravel()) - return cast(TensorVariable, reshape(X, Q_.shape)) + vec_Q = Q.ravel() + vec_X = solve(eye - AxA, vec_Q, b_ndim=1) + + return cast(TensorVariable, reshape(vec_X, A.shape)) def solve_discrete_lyapunov( - A: "TensorLike", Q: "TensorLike", method: Literal["direct", "bilinear"] = "direct" + A: TensorLike, + Q: TensorLike, + method: Literal["direct", "bilinear"] = "bilinear", ) -> TensorVariable: """Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`. Parameters ---------- - A - Square matrix of shape N x N; must have the same shape as Q - Q - Square matrix of shape N x N; must have the same shape as A - method - Solver method used, one of ``"direct"`` or ``"bilinear"``. ``"direct"`` - solves the problem directly via matrix inversion. This has a pure - PyTensor implementation and can thus be cross-compiled to supported - backends, and should be preferred when ``N`` is not large. The direct - method scales poorly with the size of ``N``, and the bilinear can be + A: TensorLike + Square matrix of shape N x N + Q: TensorLike + Square matrix of shape N x N + method: str, one of ``"direct"`` or ``"bilinear"`` + Solver method used, . ``"direct"`` solves the problem directly via matrix inversion. This has a pure + PyTensor implementation and can thus be cross-compiled to supported backends, and should be preferred when + ``N`` is not large. The direct method scales poorly with the size of ``N``, and the bilinear can be used in these cases. Returns ------- - Square matrix of shape ``N x N``, representing the solution to the - Lyapunov equation + X: TensorVariable + Square matrix of shape ``N x N``. Solution to the Lyapunov equation """ if method not in ["direct", "bilinear"]: @@ -896,36 +950,26 @@ def solve_discrete_lyapunov( f'Parameter "method" must be one of "direct" or "bilinear", found {method}' ) - if method == "direct": - return _direct_solve_discrete_lyapunov(A, Q) - if method == "bilinear": - return cast(TensorVariable, _solve_bilinear_direct_lyapunov(A, Q)) - - -def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable: - """Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`. - - Parameters - ---------- - A - Square matrix of shape ``N x N``; must have the same shape as `Q`. - Q - Square matrix of shape ``N x N``; must have the same shape as `A`. + A = as_tensor_variable(A) + Q = as_tensor_variable(Q) - Returns - ------- - Square matrix of shape ``N x N``, representing the solution to the - Lyapunov equation + if method == "direct": + signature = BilinearSolveDiscreteLyapunov.gufunc_signature + X = pt.vectorize(_direct_solve_discrete_lyapunov, signature=signature)(A, Q) + return cast(TensorVariable, X) - """ + elif method == "bilinear": + return cast(TensorVariable, _bilinear_solve_discrete_lyapunov(A, Q)) - return cast(TensorVariable, _solve_continuous_lyapunov(A, Q)) + else: + raise ValueError(f"Unknown method {method}") -class SolveDiscreteARE(pt.Op): +class SolveDiscreteARE(Op): __props__ = ("enforce_Q_symmetric",) + gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)" - def __init__(self, enforce_Q_symmetric=False): + def __init__(self, enforce_Q_symmetric: bool = False): self.enforce_Q_symmetric = enforce_Q_symmetric def make_node(self, A, B, Q, R): @@ -946,9 +990,8 @@ def perform(self, node, inputs, output_storage): if self.enforce_Q_symmetric: Q = 0.5 * (Q + Q.T) - X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype( - node.outputs[0].type.dtype - ) + out_dtype = node.outputs[0].type.dtype + X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype) def infer_shape(self, fgraph, node, shapes): return [shapes[0]] @@ -960,14 +1003,16 @@ def grad(self, inputs, output_grads): (dX,) = output_grads X = self(A, B, Q, R) - K_inner = R + pt.linalg.matrix_dot(B.T, X, B) - K_inner_inv = pt.linalg.solve(K_inner, pt.eye(R.shape[0])) - K = matrix_dot(K_inner_inv, B.T, X, A) + K_inner = R + matrix_dot(B.T, X, B) + + # K_inner is guaranteed to be symmetric, because X and R are symmetric + K_inner_inv_BT = solve(K_inner, B.T, assume_a="sym") + K = matrix_dot(K_inner_inv_BT, X, A) A_tilde = A - B.dot(K) dX_symm = 0.5 * (dX + dX.T) - S = solve_discrete_lyapunov(A_tilde, dX_symm).astype(dX.type.dtype) + S = solve_discrete_lyapunov(A_tilde, dX_symm) A_bar = 2 * matrix_dot(X, A_tilde, S) B_bar = -2 * matrix_dot(X, A_tilde, S, K.T) @@ -977,30 +1022,45 @@ def grad(self, inputs, output_grads): return [A_bar, B_bar, Q_bar, R_bar] -def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable: +def solve_discrete_are( + A: TensorLike, + B: TensorLike, + Q: TensorLike, + R: TensorLike, + enforce_Q_symmetric: bool = False, +) -> TensorVariable: """ Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`. + Discrete-time Algebraic Riccati equations arise in the context of optimal control and filtering problems, as the + solution to Linear-Quadratic Regulators (LQR), Linear-Quadratic-Guassian (LQG) control problems, and as the + steady-state covariance of the Kalman Filter. + + Such problems typically have many solutions, but we are generally only interested in the unique *stabilizing* + solution. This stable solution, if it exists, will be returned by this function. + Parameters ---------- - A: ArrayLike + A: TensorLike Square matrix of shape M x M - B: ArrayLike + B: TensorLike Square matrix of shape M x M - Q: ArrayLike + Q: TensorLike Symmetric square matrix of shape M x M - R: ArrayLike + R: TensorLike Square matrix of shape N x N enforce_Q_symmetric: bool If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry Returns ------- - X: pt.matrix + X: TensorVariable Square matrix of shape M x M, representing the solution to the DARE """ - return cast(TensorVariable, SolveDiscreteARE(enforce_Q_symmetric)(A, B, Q, R)) + return cast( + TensorVariable, Blockwise(SolveDiscreteARE(enforce_Q_symmetric))(A, B, Q, R) + ) def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index 827666d37f..3320eb9e73 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -1,3 +1,6 @@ +from functools import partial +from typing import Literal + import numpy as np import pytest @@ -194,3 +197,25 @@ def test_jax_eigvalsh(lower): None, ], ) + + +@pytest.mark.parametrize("method", ["direct", "bilinear"]) +@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"]) +def test_jax_solve_discrete_lyapunov( + method: Literal["direct", "bilinear"], shape: tuple[int] +): + A = pt.tensor(name="A", shape=shape) + B = pt.tensor(name="B", shape=shape) + out = pt_slinalg.solve_discrete_lyapunov(A, B, method=method) + out_fg = FunctionGraph([A, B], [out]) + + atol = rtol = 1e-8 if config.floatX == "float64" else 1e-3 + compare_jax_and_py( + out_fg, + [ + np.random.normal(size=shape).astype(config.floatX), + np.random.normal(size=shape).astype(config.floatX), + ], + jax_mode="JAX", + assert_fn=partial(np.testing.assert_allclose, atol=atol, rtol=rtol), + ) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 28a0210278..3d4b6697b8 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -1,5 +1,6 @@ import functools import itertools +from typing import Literal import numpy as np import pytest @@ -514,75 +515,133 @@ def test_expm_grad_3(): utt.verify_grad(expm, [A], rng=rng) -def test_solve_discrete_lyapunov_via_direct_real(): - N = 5 +def recover_Q(A, X, continuous=True): + if continuous: + return A @ X + X @ A.conj().T + else: + return X - A @ X @ A.conj().T + + +vec_recover_Q = np.vectorize(recover_Q, signature="(m,m),(m,m),()->(m,m)") + + +@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"]) +@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"]) +@pytest.mark.parametrize("method", ["direct", "bilinear"]) +def test_solve_discrete_lyapunov( + use_complex, shape: tuple[int], method: Literal["direct", "bilinear"] +): rng = np.random.default_rng(utt.fetch_seed()) - a = pt.dmatrix("a") - q = pt.dmatrix("q") - f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")]) + dtype = config.floatX + if use_complex: + precision = int(dtype[-2:]) # 64 or 32 + dtype = f"complex{int(2 * precision)}" + + A1, A2 = rng.normal(size=(2, *shape)) + Q1, Q2 = rng.normal(size=(2, *shape)) + + if use_complex: + A = A1 + 1j * A2 + Q = Q1 + 1j * Q2 + else: + A = A1 + Q = Q1 + + A, Q = A.astype(dtype), Q.astype(dtype) + + a = pt.tensor(name="a", shape=shape, dtype=dtype) + q = pt.tensor(name="q", shape=shape, dtype=dtype) - A = rng.normal(size=(N, N)) - Q = rng.normal(size=(N, N)) + x = solve_discrete_lyapunov(a, q, method=method) + f = function([a, q], x) X = f(A, Q) - assert np.allclose(A @ X @ A.T - X + Q, 0.0) + Q_recovered = vec_recover_Q(A, X, continuous=False) - utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng) + atol = rtol = 1e-4 if config.floatX == "float32" else 1e-8 + np.testing.assert_allclose(Q_recovered, Q, atol=atol, rtol=rtol) -@pytest.mark.filterwarnings("ignore::UserWarning") -def test_solve_discrete_lyapunov_via_direct_complex(): - # Conj doesn't have C-op; filter the warning. +@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"]) +@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"]) +@pytest.mark.parametrize("method", ["direct", "bilinear"]) +def test_solve_discrete_lyapunov_gradient( + use_complex, shape: tuple[int], method: Literal["direct", "bilinear"] +): + if config.floatX == "float32": + pytest.skip(reason="Not enough precision in float32 to get a good gradient") + if use_complex: + pytest.skip(reason="Complex numbers are not supported in the gradient test") - N = 5 rng = np.random.default_rng(utt.fetch_seed()) - a = pt.zmatrix() - q = pt.zmatrix() - f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")]) + A = rng.normal(size=shape).astype(config.floatX) + Q = rng.normal(size=shape).astype(config.floatX) - A = rng.normal(size=(N, N)) + rng.normal(size=(N, N)) * 1j - Q = rng.normal(size=(N, N)) - X = f(A, Q) - np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12) - - # TODO: the .conj() method currently does not have a gradient; add this test when gradients are implemented. - # utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng) + utt.verify_grad( + functools.partial(solve_discrete_lyapunov, method=method), + pt=[A, Q], + rng=rng, + ) -def test_solve_discrete_lyapunov_via_bilinear(): - N = 5 +@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"]) +@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"]) +def test_solve_continuous_lyapunov(shape: tuple[int], use_complex: bool): + dtype = config.floatX + if use_complex and dtype == "float32": + pytest.skip( + "Not enough precision in complex64 to do schur decomposition " + "(ill-conditioned matrix errors arise)" + ) rng = np.random.default_rng(utt.fetch_seed()) - a = pt.dmatrix() - q = pt.dmatrix() - f = function([a, q], [solve_discrete_lyapunov(a, q, method="bilinear")]) - A = rng.normal(size=(N, N)) - Q = rng.normal(size=(N, N)) + if use_complex: + precision = int(dtype[-2:]) # 64 or 32 + dtype = f"complex{int(2 * precision)}" - X = f(A, Q) + A1, A2 = rng.normal(size=(2, *shape)) + Q1, Q2 = rng.normal(size=(2, *shape)) - np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12) - utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng) + if use_complex: + A = A1 + 1j * A2 + Q = Q1 + 1j * Q2 + else: + A = A1 + Q = Q1 + A, Q = A.astype(dtype), Q.astype(dtype) -def test_solve_continuous_lyapunov(): - N = 5 - rng = np.random.default_rng(utt.fetch_seed()) - a = pt.dmatrix() - q = pt.dmatrix() - f = function([a, q], [solve_continuous_lyapunov(a, q)]) + a = pt.tensor(name="a", shape=shape, dtype=dtype) + q = pt.tensor(name="q", shape=shape, dtype=dtype) + x = solve_continuous_lyapunov(a, q) + + f = function([a, q], x) - A = rng.normal(size=(N, N)) - Q = rng.normal(size=(N, N)) X = f(A, Q) - Q_recovered = A @ X + X @ A.conj().T + Q_recovered = vec_recover_Q(A, X, continuous=True) + + atol = rtol = 1e-2 if config.floatX == "float32" else 1e-8 + np.testing.assert_allclose(Q_recovered.squeeze(), Q, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"]) +@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"]) +def test_solve_continuous_lyapunov_grad(shape: tuple[int], use_complex): + if config.floatX == "float32": + pytest.skip(reason="Not enough precision in float32 to get a good gradient") + if use_complex: + pytest.skip(reason="Complex numbers are not supported in the gradient test") + + rng = np.random.default_rng(utt.fetch_seed()) + A = rng.normal(size=shape).astype(config.floatX) + Q = rng.normal(size=shape).astype(config.floatX) - np.testing.assert_allclose(Q_recovered.squeeze(), Q) utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng) -def test_solve_discrete_are_forward(): +@pytest.mark.parametrize("add_batch_dim", [False, True]) +def test_solve_discrete_are_forward(add_batch_dim): # TEST CASE 4 : darex #1 -- taken from Scipy tests a, b, q, r = ( np.array([[4, 3], [-4.5, -3.5]]), @@ -590,29 +649,39 @@ def test_solve_discrete_are_forward(): np.array([[9, 6], [6, 4]]), np.array([[1]]), ) - a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r]) + if add_batch_dim: + a, b, q, r = (np.stack([x] * 5) for x in [a, b, q, r]) - x = solve_discrete_are(a, b, q, r).eval() - res = a.T.dot(x.dot(a)) - x + q - res -= ( - a.conj() - .T.dot(x.dot(b)) - .dot(np.linalg.solve(r + b.conj().T.dot(x.dot(b)), b.T).dot(x.dot(a))) - ) + a, b, q, r = (pt.as_tensor_variable(x).astype(config.floatX) for x in [a, b, q, r]) + + x = solve_discrete_are(a, b, q, r) + + def eval_fun(a, b, q, r, x): + term_1 = a.T @ x @ a + term_2 = a.T @ x @ b + term_3 = pt.linalg.solve(r + b.T @ x @ b, b.T) @ x @ a + + return term_1 - x - term_2 @ term_3 + q + + res = pt.vectorize(eval_fun, "(m,m),(m,n),(m,m),(n,n),(m,m)->(m,m)")(a, b, q, r, x) + res_np = res.eval() atol = 1e-4 if config.floatX == "float32" else 1e-12 - np.testing.assert_allclose(res, np.zeros_like(res), atol=atol) + np.testing.assert_allclose(res_np, np.zeros_like(res_np), atol=atol) -def test_solve_discrete_are_grad(): +@pytest.mark.parametrize("add_batch_dim", [False, True]) +def test_solve_discrete_are_grad(add_batch_dim): a, b, q, r = ( np.array([[4, 3], [-4.5, -3.5]]), np.array([[1], [-1]]), np.array([[9, 6], [6, 4]]), np.array([[1]]), ) - a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r]) + if add_batch_dim: + a, b, q, r = (np.stack([x] * 5) for x in [a, b, q, r]) + a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r]) rng = np.random.default_rng(utt.fetch_seed()) # TODO: Is there a "theoretically motivated" value to use here? I pulled 1e-4 out of a hat