From 2ad8a575075de5081adb87204d829140d87b7d39 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sun, 15 Sep 2024 21:40:43 -0700 Subject: [PATCH] Transform helpers cleanup (#285) * Start transforms cleanup * Add back `sample_uniform()` * Add basic transform tests * Ensure dtype consistency * More tests, GH action * Add pytest dependency * Add hypothesis * Minor fixes --- .github/workflows/pytest.yml | 28 +++++ pyproject.toml | 2 + src/viser/transforms/_base.py | 54 +++++--- src/viser/transforms/_se2.py | 87 ++++++------- src/viser/transforms/_se3.py | 54 ++++---- src/viser/transforms/_so2.py | 51 ++++---- src/viser/transforms/_so3.py | 155 +++++++++++------------ src/viser/transforms/utils/__init__.py | 5 +- src/viser/transforms/utils/_utils.py | 25 +--- tests/test_transforms_axioms.py | 103 +++++++++++++++ tests/test_transforms_bijective.py | 167 +++++++++++++++++++++++++ tests/test_transforms_ops.py | 128 +++++++++++++++++++ tests/utils.py | 139 ++++++++++++++++++++ 13 files changed, 782 insertions(+), 216 deletions(-) create mode 100644 .github/workflows/pytest.yml create mode 100644 tests/test_transforms_axioms.py create mode 100644 tests/test_transforms_bijective.py create mode 100644 tests/test_transforms_ops.py create mode 100644 tests/utils.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 000000000..d34aafd47 --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,28 @@ +name: pytest + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8", "3.8", "3.9", "3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install uv + uv pip install --system -e ".[dev,examples]" + - name: Test with pytest + run: | + pytest diff --git a/pyproject.toml b/pyproject.toml index d82a2478a..7dd76da91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,8 @@ dev = [ "pyright>=1.1.308", "ruff==0.6.2", "pre-commit==3.3.2", + "pytest", + "hypothesis[numpy]", ] examples = [ "torch>=1.13.1", diff --git a/src/viser/transforms/_base.py b/src/viser/transforms/_base.py index f3660a095..523a926cb 100644 --- a/src/viser/transforms/_base.py +++ b/src/viser/transforms/_base.py @@ -9,9 +9,6 @@ class MatrixLieGroup(abc.ABC): """Interface definition for matrix Lie groups.""" - # Class properties. - # > These will be set in `_utils.register_lie_group()`. - matrix_dim: ClassVar[int] """Dimension of square matrix output from `.as_matrix()`.""" @@ -36,6 +33,19 @@ def __init__( """Construct a group object from its underlying parameters.""" raise NotImplementedError() + def __init_subclass__( + cls, + matrix_dim: int = 0, + parameters_dim: int = 0, + tangent_dim: int = 0, + space_dim: int = 0, + ) -> None: + """Set class properties for subclasses. We default to dummy values.""" + cls.matrix_dim = matrix_dim + cls.parameters_dim = parameters_dim + cls.tangent_dim = tangent_dim + cls.space_dim = space_dim + # Shared implementations. @overload @@ -66,11 +76,14 @@ def __matmul__( @classmethod @abc.abstractmethod - def identity(cls, batch_axes: Tuple[int, ...] = ()) -> Self: + def identity( + cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64 + ) -> Self: """Returns identity element. Args: batch_axes: Any leading batch axes for the output transform. + dtype: Datatype for the output. Returns: Identity element. @@ -172,20 +185,25 @@ def normalize(self) -> Self: Normalized group member. """ - # @classmethod - # @abc.abstractmethod - # def sample_uniform(cls, key: onp.ndarray, batch_axes: Tuple[int, ...] = ()) -> Self: - # """Draw a uniform sample from the group. Translations (if applicable) are in the - # range [-1, 1]. - # - # Args: - # key: PRNG key, as returned by `jax.random.PRNGKey()`. - # batch_axes: Any leading batch axes for the output transforms. Each - # sampled transform will be different. - # - # Returns: - # Sampled group member. - # """ + @classmethod + @abc.abstractmethod + def sample_uniform( + cls, + rng: onp.random.Generator, + batch_axes: Tuple[int, ...] = (), + dtype: onpt.DTypeLike = onp.float64, + ) -> Self: + """Draw a uniform sample from the group. Translations (if applicable) are in the + range [-1, 1]. + + Args: + rng: numpy generator object. + batch_axes: Any leading batch axes for the output transforms. Each + sampled transform will be different. + + Returns: + Sampled group member. + """ @final def get_batch_axes(self) -> Tuple[int, ...]: diff --git a/src/viser/transforms/_se2.py b/src/viser/transforms/_se2.py index 8952684f9..e68b4208e 100644 --- a/src/viser/transforms/_se2.py +++ b/src/viser/transforms/_se2.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses from typing import Tuple, cast @@ -7,17 +9,17 @@ from . import _base, hints from ._so2 import SO2 -from .utils import broadcast_leading_axes, get_epsilon, register_lie_group +from .utils import broadcast_leading_axes, get_epsilon -@register_lie_group( +@dataclasses.dataclass(frozen=True) +class SE2( + _base.SEBase[SO2], matrix_dim=3, parameters_dim=4, tangent_dim=3, space_dim=2, -) -@dataclasses.dataclass(frozen=True) -class SE2(_base.SEBase[SO2]): +): """Special Euclidean group for proper rigid transforms in 2D. Broadcasting rules are the same as for numpy. @@ -39,7 +41,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(unit_complex={unit_complex}, xy={xy})" @staticmethod - def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> "SE2": + def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> SE2: """Construct a transformation from standard 2D pose parameters. This is not the same as integrating over a length-3 twist. @@ -56,7 +58,7 @@ def from_rotation_and_translation( cls, rotation: SO2, translation: onpt.NDArray[onp.floating], - ) -> "SE2": + ) -> SE2: assert translation.shape[-1:] == (2,) rotation, translation = broadcast_leading_axes((rotation, translation)) return SE2( @@ -77,16 +79,18 @@ def translation(self) -> onpt.NDArray[onp.floating]: @classmethod @override - def identity(cls, batch_axes: Tuple[int, ...] = ()) -> "SE2": + def identity( + cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64 + ) -> SE2: return SE2( unit_complex_xy=onp.broadcast_to( - onp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4) + onp.array([1.0, 0.0, 0.0, 0.0], dtype=dtype), (*batch_axes, 4) ) ) @classmethod @override - def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> "SE2": + def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SE2: assert matrix.shape[-2:] == (3, 3) or matrix.shape[-2:] == (2, 3) # Currently assumes bottom row is [0, 0, 1]. return SE2.from_rotation_and_translation( @@ -123,7 +127,7 @@ def as_matrix(self) -> onpt.NDArray[onp.floating]: @classmethod @override - def exp(cls, tangent: onpt.NDArray[onp.floating]) -> "SE2": + def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SE2: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L558 # Also see: @@ -146,21 +150,15 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> "SE2": ) theta_sq = theta**2 - sin_over_theta = cast( - onp.ndarray, - onp.where( - use_taylor, - 1.0 - theta_sq / 6.0, - onp.sin(safe_theta) / safe_theta, - ), + sin_over_theta = onp.where( + use_taylor, + 1.0 - theta_sq / 6.0, + onp.sin(safe_theta) / safe_theta, ) - one_minus_cos_over_theta = cast( - onp.ndarray, - onp.where( - use_taylor, - 0.5 * theta - theta * theta_sq / 24.0, - (1.0 - onp.cos(safe_theta)) / safe_theta, - ), + one_minus_cos_over_theta = onp.where( + use_taylor, + 0.5 * theta - theta * theta_sq / 24.0, + (1.0 - onp.cos(safe_theta)) / safe_theta, ) V = onp.stack( @@ -172,9 +170,12 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> "SE2": ], axis=-1, ).reshape((*tangent.shape[:-1], 2, 2)) + return SE2.from_rotation_and_translation( rotation=SO2.from_radians(theta), - translation=onp.einsum("...ij,...j->...i", V, tangent[..., :2]), + translation=onp.einsum("...ij,...j->...i", V, tangent[..., :2]).astype( + tangent.dtype + ), ) @override @@ -224,10 +225,10 @@ def log(self) -> onpt.NDArray[onp.floating]: ], axis=-1, ) - return tangent + return tangent.astype(self.unit_complex_xy.dtype) @override - def adjoint(self: "SE2") -> onpt.NDArray[onp.floating]: + def adjoint(self: SE2) -> onpt.NDArray[onp.floating]: cos, sin, x, y = onp.moveaxis(self.unit_complex_xy, -1, 0) return onp.stack( [ @@ -244,21 +245,15 @@ def adjoint(self: "SE2") -> onpt.NDArray[onp.floating]: axis=-1, ).reshape((*self.get_batch_axes(), 3, 3)) - # @classmethod - # @override - # def sample_uniform( - # cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = () - # ) -> "SE2": - # key0, key1 = jax.random.split(key) - # return SE2.from_rotation_and_translation( - # rotation=SO2.sample_uniform(key0, batch_axes=batch_axes), - # translation=jax.random.uniform( - # key=key1, - # shape=( - # *batch_axes, - # 2, - # ), - # minval=-1.0, - # maxval=1.0, - # ), - # ) + @classmethod + @override + def sample_uniform( + cls, + rng: onp.random.Generator, + batch_axes: Tuple[int, ...] = (), + dtype: onpt.DTypeLike = onp.float64, + ) -> SE2: + return SE2.from_rotation_and_translation( + SO2.sample_uniform(rng, batch_axes=batch_axes, dtype=dtype), + rng.uniform(low=-1.0, high=1.0, size=(*batch_axes, 2)).astype(dtype), + ) diff --git a/src/viser/transforms/_se3.py b/src/viser/transforms/_se3.py index 5ce5a3ea3..2bc0b187b 100644 --- a/src/viser/transforms/_se3.py +++ b/src/viser/transforms/_se3.py @@ -9,7 +9,7 @@ from . import _base from ._so3 import SO3 -from .utils import broadcast_leading_axes, get_epsilon, register_lie_group +from .utils import broadcast_leading_axes, get_epsilon def _skew(omega: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: @@ -23,14 +23,14 @@ def _skew(omega: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: ).reshape((*omega.shape[:-1], 3, 3)) -@register_lie_group( +@dataclasses.dataclass(frozen=True) +class SE3( + _base.SEBase[SO3], matrix_dim=4, parameters_dim=7, tangent_dim=6, space_dim=3, -) -@dataclasses.dataclass(frozen=True) -class SE3(_base.SEBase[SO3]): +): """Special Euclidean group for proper rigid transforms in 3D. Broadcasting rules are the same as for numpy. @@ -76,10 +76,13 @@ def translation(self) -> onpt.NDArray[onp.floating]: @classmethod @override - def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SE3: + def identity( + cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64 + ) -> SE3: return SE3( wxyz_xyz=onp.broadcast_to( - onp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), (*batch_axes, 7) + onp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=dtype), + (*batch_axes, 7), ) ) @@ -97,7 +100,7 @@ def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SE3: @override def as_matrix(self) -> onpt.NDArray[onp.floating]: - out = onp.zeros((*self.get_batch_axes(), 4, 4)) + out = onp.zeros((*self.get_batch_axes(), 4, 4), dtype=self.wxyz_xyz.dtype) out[..., :3, :3] = self.rotation().as_matrix() out[..., :3, 3] = self.translation() out[..., 3, 3] = 1.0 @@ -154,7 +157,9 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SE3: return SE3.from_rotation_and_translation( rotation=rotation, - translation=onp.einsum("...ij,...j->...i", V, tangent[..., :3]), + translation=onp.einsum("...ij,...j->...i", V, tangent[..., :3]).astype( + tangent.dtype + ), ) @override @@ -200,7 +205,7 @@ def log(self) -> onpt.NDArray[onp.floating]: ) return onp.concatenate( [onp.einsum("...ij,...j->...i", V_inv, self.translation()), omega], axis=-1 - ) + ).astype(self.wxyz_xyz.dtype) @override def adjoint(self) -> onpt.NDArray[onp.floating]: @@ -212,21 +217,24 @@ def adjoint(self) -> onpt.NDArray[onp.floating]: axis=-1, ), onp.concatenate( - [onp.zeros((*self.get_batch_axes(), 3, 3)), R], axis=-1 + [onp.zeros((*self.get_batch_axes(), 3, 3), dtype=R.dtype), R], + axis=-1, ), ], axis=-2, ) - # @classmethod - # @override - # def sample_uniform( - # cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = () - # ) -> SE3: - # key0, key1 = jax.random.split(key) - # return SE3.from_rotation_and_translation( - # rotation=SO3.sample_uniform(key0, batch_axes=batch_axes), - # translation=jax.random.uniform( - # key=key1, shape=(*batch_axes, 3), minval=-1.0, maxval=1.0 - # ), - # ) + @classmethod + @override + def sample_uniform( + cls, + rng: onp.random.Generator, + batch_axes: Tuple[int, ...] = (), + dtype: onpt.DTypeLike = onp.float64, + ) -> SE3: + return SE3.from_rotation_and_translation( + rotation=SO3.sample_uniform(rng, batch_axes=batch_axes, dtype=dtype), + translation=rng.uniform(low=-1.0, high=1.0, size=(*batch_axes, 3)).astype( + dtype=dtype + ), + ) diff --git a/src/viser/transforms/_so2.py b/src/viser/transforms/_so2.py index db4421245..a6b9c5161 100644 --- a/src/viser/transforms/_so2.py +++ b/src/viser/transforms/_so2.py @@ -8,17 +8,17 @@ from typing_extensions import override from . import _base, hints -from .utils import broadcast_leading_axes, register_lie_group +from .utils import broadcast_leading_axes -@register_lie_group( +@dataclasses.dataclass(frozen=True) +class SO2( + _base.SOBase, matrix_dim=2, parameters_dim=2, tangent_dim=1, space_dim=2, -) -@dataclasses.dataclass(frozen=True) -class SO2(_base.SOBase): +): """Special orthogonal group for 2D rotations. Broadcasting rules are the same as for numpy. @@ -53,10 +53,13 @@ def as_radians(self) -> onpt.NDArray[onp.floating]: @classmethod @override - def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SO2: + def identity( + cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64 + ) -> SO2: return SO2( unit_complex=onp.stack( - [onp.ones(batch_axes), onp.zeros(batch_axes)], axis=-1 + [onp.ones(batch_axes, dtype=dtype), onp.zeros(batch_axes, dtype=dtype)], + axis=-1, ) ) @@ -64,7 +67,7 @@ def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SO2: @override def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SO2: assert matrix.shape[-2:] == (2, 2) - return SO2(unit_complex=onp.asarray(matrix[..., :, 0])) + return SO2(unit_complex=onp.array(matrix[..., :, 0])) # Accessors. @@ -74,7 +77,7 @@ def as_matrix(self) -> onpt.NDArray[onp.floating]: out = onp.stack( [ # [cos, -sin], - cos_sin * onp.array([1, -1]), + cos_sin * onp.array([1, -1], dtype=cos_sin.dtype), # [sin, cos], cos_sin[..., ::-1], ], @@ -119,11 +122,13 @@ def log(self) -> onpt.NDArray[onp.floating]: @override def adjoint(self) -> onpt.NDArray[onp.floating]: - return onp.ones((*self.get_batch_axes(), 1, 1)) + return onp.ones((*self.get_batch_axes(), 1, 1), dtype=self.unit_complex.dtype) @override def inverse(self) -> SO2: - return SO2(unit_complex=self.unit_complex * onp.array([1, -1])) + unit_complex = self.unit_complex.copy() + unit_complex[..., 1] *= -1 + return SO2(unit_complex) @override def normalize(self) -> SO2: @@ -132,14 +137,16 @@ def normalize(self) -> SO2: / onp.linalg.norm(self.unit_complex, axis=-1, keepdims=True) ) - # @classmethod - # @override - # def sample_uniform( - # cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = () - # ) -> SO2: - # out = SO2.from_radians( - # jax.random.uniform( - # key=key, shape=batch_axes, minval=0.0, maxval=2.0 * onp.pi) - # ) - # assert out.get_batch_axes() == batch_axes - # return out + @classmethod + @override + def sample_uniform( + cls, + rng: onp.random.Generator, + batch_axes: Tuple[int, ...] = (), + dtype: onpt.DTypeLike = onp.float64, + ) -> SO2: + out = SO2.from_radians( + rng.uniform(0.0, 2.0 * onp.pi, size=batch_axes).astype(dtype=dtype) + ) + assert out.get_batch_axes() == batch_axes + return out diff --git a/src/viser/transforms/_so3.py b/src/viser/transforms/_so3.py index c1436604f..c8a7deb63 100644 --- a/src/viser/transforms/_so3.py +++ b/src/viser/transforms/_so3.py @@ -8,7 +8,7 @@ from typing_extensions import override from . import _base, hints -from .utils import broadcast_leading_axes, get_epsilon, register_lie_group +from .utils import broadcast_leading_axes, get_epsilon @dataclasses.dataclass(frozen=True) @@ -20,14 +20,14 @@ class RollPitchYaw: yaw: onpt.NDArray[onp.floating] -@register_lie_group( +@dataclasses.dataclass(frozen=True) +class SO3( + _base.SOBase, matrix_dim=3, parameters_dim=4, tangent_dim=3, space_dim=3, -) -@dataclasses.dataclass(frozen=True) -class SO3(_base.SOBase): +): """Special orthogonal group for 3D rotations. Broadcasting rules are the same as for numpy. @@ -173,9 +173,13 @@ def compute_yaw_radians(self) -> onpt.NDArray[onp.floating]: @classmethod @override - def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SO3: + def identity( + cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64 + ) -> SO3: return SO3( - wxyz=onp.broadcast_to(onp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4)) + wxyz=onp.broadcast_to( + onp.array([1.0, 0.0, 0.0, 0.0], dtype=dtype), (*batch_axes, 4) + ) ) @classmethod @@ -260,26 +264,7 @@ def case3(m): onp.where(cond1[..., None], case0_q, case1_q), onp.where(cond2[..., None], case2_q, case3_q), ) - - # We can also choose to branch, but this is slower. - # t, q = jax.lax.cond( - # matrix[2, 2] < 0, - # true_fun=lambda matrix: jax.lax.cond( - # matrix[0, 0] > matrix[1, 1], - # true_fun=case0, - # false_fun=case1, - # operand=matrix, - # ), - # false_fun=lambda matrix: jax.lax.cond( - # matrix[0, 0] < -matrix[1, 1], - # true_fun=case2, - # false_fun=case3, - # operand=matrix, - # ), - # operand=matrix, - # ) - - return SO3(wxyz=q * 0.5 / onp.sqrt(t[..., None])) + return SO3(wxyz=(q * 0.5 / onp.sqrt(t[..., None])).astype(matrix.dtype)) # Accessors. @@ -288,20 +273,24 @@ def as_matrix(self) -> onpt.NDArray[onp.floating]: norm_sq = onp.sum(onp.square(self.wxyz), axis=-1, keepdims=True) q = self.wxyz * onp.sqrt(2.0 / norm_sq) # (*, 4) q_outer = onp.einsum("...i,...j->...ij", q, q) # (*, 4, 4) - return onp.stack( - [ - 1.0 - q_outer[..., 2, 2] - q_outer[..., 3, 3], - q_outer[..., 1, 2] - q_outer[..., 3, 0], - q_outer[..., 1, 3] + q_outer[..., 2, 0], - q_outer[..., 1, 2] + q_outer[..., 3, 0], - 1.0 - q_outer[..., 1, 1] - q_outer[..., 3, 3], - q_outer[..., 2, 3] - q_outer[..., 1, 0], - q_outer[..., 1, 3] - q_outer[..., 2, 0], - q_outer[..., 2, 3] + q_outer[..., 1, 0], - 1.0 - q_outer[..., 1, 1] - q_outer[..., 2, 2], - ], - axis=-1, - ).reshape(*q.shape[:-1], 3, 3) + return ( + onp.stack( + [ + 1.0 - q_outer[..., 2, 2] - q_outer[..., 3, 3], + q_outer[..., 1, 2] - q_outer[..., 3, 0], + q_outer[..., 1, 3] + q_outer[..., 2, 0], + q_outer[..., 1, 2] + q_outer[..., 3, 0], + 1.0 - q_outer[..., 1, 1] - q_outer[..., 3, 3], + q_outer[..., 2, 3] - q_outer[..., 1, 0], + q_outer[..., 1, 3] - q_outer[..., 2, 0], + q_outer[..., 2, 3] + q_outer[..., 1, 0], + 1.0 - q_outer[..., 1, 1] - q_outer[..., 2, 2], + ], + axis=-1, + ) + .reshape(*q.shape[:-1], 3, 3) + .astype(self.wxyz.dtype) + ) @override def parameters(self) -> onpt.NDArray[onp.floating]: @@ -316,7 +305,8 @@ def apply(self, target: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating # Compute using quaternion multiplys. padded_target = onp.concatenate( - [onp.zeros((*self.get_batch_axes(), 1)), target], axis=-1 + [onp.zeros((*self.get_batch_axes(), 1), dtype=target.dtype), target], + axis=-1, ) return (self @ SO3(wxyz=padded_target) @ self.inverse()).wxyz[..., 1:] @@ -357,14 +347,16 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SO3: theta_squared, ) ) - safe_half_theta = 0.5 * safe_theta + # Fun fact: when safe_theta is a `float32` _scalar_, this + # multiplication will promote `safe_half_theta` to `float64`. We'll + # cast at the end to make sure our input/output dtypes match. + safe_half_theta = 0.5 * safe_theta real_factor = onp.where( use_taylor, 1.0 - theta_squared / 8.0 + theta_pow_4 / 384.0, onp.cos(safe_half_theta), ) - imaginary_factor = onp.where( use_taylor, 0.5 - theta_squared / 48.0 + theta_pow_4 / 3840.0, @@ -378,7 +370,7 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SO3: imaginary_factor[..., None] * tangent, ], axis=-1, - ) + ).astype(tangent.dtype) ) @override @@ -409,12 +401,11 @@ def log(self) -> onpt.NDArray[onp.floating]: 2.0 / w_safe - 2.0 / 3.0 * norm_sq / w_safe**3, onp.where( onp.abs(w) < get_epsilon(w.dtype), - onp.where(w > 0, 1.0, -1.0) * onp.pi / norm_safe, + onp.where(w > 0, 1.0, -1.0).astype(dtype=w.dtype) * onp.pi / norm_safe, 2.0 * atan_n_over_w / norm_safe, ), ) - - return atan_factor[..., None] * self.wxyz[..., 1:] # type: ignore + return (atan_factor[..., None] * self.wxyz[..., 1:]).astype(self.wxyz.dtype) @override def adjoint(self) -> onpt.NDArray[onp.floating]: @@ -423,40 +414,44 @@ def adjoint(self) -> onpt.NDArray[onp.floating]: @override def inverse(self) -> SO3: # Negate complex terms. - return SO3(wxyz=self.wxyz * onp.array([1, -1, -1, -1])) + wxyz = self.wxyz.copy() + wxyz[..., 1:] *= -1 + return SO3(wxyz) @override def normalize(self) -> SO3: return SO3(wxyz=self.wxyz / onp.linalg.norm(self.wxyz, axis=-1, keepdims=True)) - # @classmethod - # @override - # def sample_uniform( - # cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = () - # ) -> SO3: - # # Uniformly sample over S^3. - # # > Reference: http://planning.cs.uiuc.edu/node198.html - # u1, u2, u3 = onp.moveaxis( - # jax.random.uniform( - # key=key, - # shape=(*batch_axes, 3), - # minval=onp.zeros(3), - # maxval=onp.array([1.0, 2.0 * onp.pi, 2.0 * onp.pi]), - # ), - # -1, - # 0, - # ) - # a = onp.sqrt(1.0 - u1) - # b = onp.sqrt(u1) - # - # return SO3( - # wxyz=onp.stack( - # [ - # a * onp.sin(u2), - # a * onp.cos(u2), - # b * onp.sin(u3), - # b * onp.cos(u3), - # ], - # axis=-1, - # ) - # ) + @classmethod + @override + def sample_uniform( + cls, + rng: onp.random.Generator, + batch_axes: Tuple[int, ...] = (), + dtype: onpt.DTypeLike = onp.float64, + ) -> SO3: + # Uniformly sample over S^3. + # > Reference: http://planning.cs.uiuc.edu/node198.html + u1, u2, u3 = onp.moveaxis( + rng.uniform( + low=onp.zeros(3), + high=onp.array([1.0, 2.0 * onp.pi, 2.0 * onp.pi]), + size=(*batch_axes, 3), + ), + -1, + 0, + ) + a = onp.sqrt(1.0 - u1) + b = onp.sqrt(u1) + + return SO3( + wxyz=onp.stack( + [ + a * onp.sin(u2), + a * onp.cos(u2), + b * onp.sin(u3), + b * onp.cos(u3), + ], + axis=-1, + ).astype(dtype=dtype) + ) diff --git a/src/viser/transforms/utils/__init__.py b/src/viser/transforms/utils/__init__.py index 11980074b..3ecb41d21 100644 --- a/src/viser/transforms/utils/__init__.py +++ b/src/viser/transforms/utils/__init__.py @@ -1,3 +1,2 @@ -from ._utils import broadcast_leading_axes, get_epsilon, register_lie_group - -__all__ = ["get_epsilon", "register_lie_group", "broadcast_leading_axes"] +from ._utils import broadcast_leading_axes as broadcast_leading_axes +from ._utils import get_epsilon as get_epsilon diff --git a/src/viser/transforms/utils/_utils.py b/src/viser/transforms/utils/_utils.py index a25e0ac10..8ec773729 100644 --- a/src/viser/transforms/utils/_utils.py +++ b/src/viser/transforms/utils/_utils.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Callable, Tuple, Type, TypeVar, Union, cast +from typing import TYPE_CHECKING, Tuple, TypeVar, Union, cast import numpy as onp @@ -26,29 +26,6 @@ def get_epsilon(dtype: onp.dtype) -> float: assert False -def register_lie_group( - *, - matrix_dim: int, - parameters_dim: int, - tangent_dim: int, - space_dim: int, -) -> Callable[[Type[T]], Type[T]]: - """Decorator for registering Lie group dataclasses. - - Sets dimensionality class variables, and marks all methods for JIT compilation. - """ - - def _wrap(cls: Type[T]) -> Type[T]: - # Register dimensions as class attributes. - cls.matrix_dim = matrix_dim - cls.parameters_dim = parameters_dim - cls.tangent_dim = tangent_dim - cls.space_dim = space_dim - return cls - - return _wrap - - TupleOfBroadcastable = TypeVar( "TupleOfBroadcastable", bound="Tuple[Union[MatrixLieGroup, onp.ndarray], ...]", diff --git a/tests/test_transforms_axioms.py b/tests/test_transforms_axioms.py new file mode 100644 index 000000000..8ec6d3398 --- /dev/null +++ b/tests/test_transforms_axioms.py @@ -0,0 +1,103 @@ +"""Tests for group axioms. + +https://proofwiki.org/wiki/Definition:Group_Axioms +""" + +from typing import Tuple, Type + +import numpy as onp +import numpy.typing as onpt +from utils import ( + assert_arrays_close, + assert_transforms_close, + general_group_test, + sample_transform, +) + +import viser.transforms as vtf + + +@general_group_test +def test_closure( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check closure property.""" + transform_a = sample_transform(Group, batch_axes, dtype) + transform_b = sample_transform(Group, batch_axes, dtype) + + composed = transform_a @ transform_b + assert_transforms_close(composed, composed.normalize()) + composed = transform_b @ transform_a + assert_transforms_close(composed, composed.normalize()) + composed = Group.multiply(transform_a, transform_b) + assert_transforms_close(composed, composed.normalize()) + composed = Group.multiply(transform_b, transform_a) + assert_transforms_close(composed, composed.normalize()) + + +@general_group_test +def test_identity( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check identity property.""" + transform = sample_transform(Group, batch_axes, dtype) + identity = Group.identity(batch_axes, dtype=dtype) + assert_transforms_close(transform, identity @ transform) + assert_transforms_close(transform, transform @ identity) + assert_arrays_close( + transform.as_matrix(), + onp.einsum("...ij,...jk->...ik", identity.as_matrix(), transform.as_matrix()), + ) + assert_arrays_close( + transform.as_matrix(), + onp.einsum("...ij,...jk->...ik", transform.as_matrix(), identity.as_matrix()), + ) + + +@general_group_test +def test_inverse( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check inverse property.""" + transform = sample_transform(Group, batch_axes, dtype) + identity = Group.identity(batch_axes, dtype=dtype) + assert_transforms_close(identity, transform @ transform.inverse()) + assert_transforms_close(identity, transform.inverse() @ transform) + assert_transforms_close(identity, Group.multiply(transform, transform.inverse())) + assert_transforms_close(identity, Group.multiply(transform.inverse(), transform)) + assert_arrays_close( + onp.broadcast_to( + onp.eye(Group.matrix_dim, dtype=dtype), + (*batch_axes, Group.matrix_dim, Group.matrix_dim), + ), + onp.einsum( + "...ij,...jk->...ik", + transform.as_matrix(), + transform.inverse().as_matrix(), + ), + ) + assert_arrays_close( + onp.broadcast_to( + onp.eye(Group.matrix_dim, dtype=dtype), + (*batch_axes, Group.matrix_dim, Group.matrix_dim), + ), + onp.einsum( + "...ij,...jk->...ik", + transform.inverse().as_matrix(), + transform.as_matrix(), + ), + ) + + +@general_group_test +def test_associative( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check associative property.""" + transform_a = sample_transform(Group, batch_axes, dtype) + transform_b = sample_transform(Group, batch_axes, dtype) + transform_c = sample_transform(Group, batch_axes, dtype) + assert_transforms_close( + (transform_a @ transform_b) @ transform_c, + transform_a @ (transform_b @ transform_c), + ) diff --git a/tests/test_transforms_bijective.py b/tests/test_transforms_bijective.py new file mode 100644 index 000000000..a9f0c835c --- /dev/null +++ b/tests/test_transforms_bijective.py @@ -0,0 +1,167 @@ +"""Tests for general operation definitions.""" + +from typing import Tuple, Type + +import numpy as onp +import numpy.typing as onpt +from hypothesis import given, settings +from hypothesis import strategies as st +from utils import ( + assert_arrays_close, + assert_transforms_close, + general_group_test, + sample_transform, +) + +import viser.transforms as vtf + + +@general_group_test +def test_sample_uniform_valid( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check that sample_uniform() returns valid group members.""" + T = sample_transform( + Group, batch_axes, dtype + ) # Calls sample_uniform under the hood. + assert_transforms_close(T, T.normalize()) + + +@settings(deadline=None) +@given(_random_module=st.random_module()) +def test_so2_from_to_radians_bijective(_random_module): + """Check that we can convert from and to radians.""" + radians = onp.random.uniform(low=-onp.pi, high=onp.pi) + assert_arrays_close(vtf.SO2.from_radians(radians).as_radians(), radians) + + +@settings(deadline=None) +@given(_random_module=st.random_module()) +def test_so3_xyzw_bijective(_random_module): + """Check that we can convert between xyzw and wxyz quaternions.""" + T = sample_transform(vtf.SO3, (), dtype=onp.float64) + assert_transforms_close(T, vtf.SO3.from_quaternion_xyzw(T.as_quaternion_xyzw())) + + +@settings(deadline=None) +@given(_random_module=st.random_module()) +def test_so3_rpy_bijective(_random_module): + """Check that we can convert between quaternions and Euler angles.""" + T = sample_transform(vtf.SO3, (), dtype=onp.float64) + rpy = T.as_rpy_radians() + assert_transforms_close(T, vtf.SO3.from_rpy_radians(rpy.roll, rpy.pitch, rpy.yaw)) + + +@general_group_test +def test_log_exp_bijective( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check 1-to-1 mapping for log <=> exp operations.""" + transform = sample_transform(Group, batch_axes, dtype) + + assert transform.parameters().dtype == dtype + tangent = transform.log() + assert tangent.dtype == dtype + assert tangent.shape == (*batch_axes, Group.tangent_dim) + + exp_transform = Group.exp(tangent) + assert_transforms_close(transform, exp_transform) + assert_arrays_close(tangent, exp_transform.log()) + + +@general_group_test +def test_inverse_bijective( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check inverse of inverse.""" + transform = sample_transform(Group, batch_axes, dtype) + assert_transforms_close(transform, transform.inverse().inverse()) + + +@general_group_test +def test_matrix_bijective( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check that we can convert to and from matrices.""" + transform = sample_transform(Group, batch_axes, dtype) + assert_transforms_close(transform, Group.from_matrix(transform.as_matrix())) + + +@general_group_test +def test_adjoint( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check adjoint definition.""" + transform = sample_transform(Group, batch_axes, dtype) + omega = onp.random.randn(*batch_axes, Group.tangent_dim).astype(dtype) + assert (transform @ Group.exp(omega)).parameters().dtype == dtype + assert_transforms_close( + transform @ Group.exp(omega), + Group.exp(onp.einsum("...ij,...j->...i", transform.adjoint(), omega)) + @ transform, + ) + + +@general_group_test +def test_repr( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Smoke test for __repr__ implementations.""" + transform = sample_transform(Group, batch_axes, dtype) + print(transform) + + +@general_group_test +def test_apply( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check group action interfaces.""" + T_w_b = sample_transform(Group, batch_axes, dtype) + p_b = onp.random.randn(*batch_axes, Group.space_dim).astype(dtype) + + if Group.matrix_dim == Group.space_dim: + assert_arrays_close( + T_w_b @ p_b, + T_w_b.apply(p_b), + onp.einsum("...ij,...j->...i", T_w_b.as_matrix(), p_b), + ) + else: + # Homogeneous coordinates. + assert Group.matrix_dim == Group.space_dim + 1 + assert_arrays_close( + T_w_b @ p_b, + T_w_b.apply(p_b), + onp.einsum( + "...ij,...j->...i", + T_w_b.as_matrix(), + onp.concatenate([p_b, onp.ones_like(p_b[..., :1])], axis=-1), + )[..., :-1], + ) + + +@general_group_test +def test_multiply( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check multiply interfaces.""" + T_w_b = sample_transform(Group, batch_axes, dtype) + T_b_a = sample_transform(Group, batch_axes, dtype) + assert_arrays_close( + onp.einsum( + "...ij,...jk->...ik", T_w_b.as_matrix(), T_w_b.inverse().as_matrix() + ), + onp.broadcast_to( + onp.eye(Group.matrix_dim, dtype=dtype), + (*batch_axes, Group.matrix_dim, Group.matrix_dim), + ), + ) + assert_arrays_close( + onp.einsum( + "...ij,...jk->...ik", T_w_b.as_matrix(), onp.linalg.inv(T_w_b.as_matrix()) + ), + onp.broadcast_to( + onp.eye(Group.matrix_dim, dtype=dtype), + (*batch_axes, Group.matrix_dim, Group.matrix_dim), + ), + ) + assert_transforms_close(T_w_b @ T_b_a, Group.multiply(T_w_b, T_b_a)) diff --git a/tests/test_transforms_ops.py b/tests/test_transforms_ops.py new file mode 100644 index 000000000..6321b8896 --- /dev/null +++ b/tests/test_transforms_ops.py @@ -0,0 +1,128 @@ +"""Tests for general operation definitions.""" + +from typing import Tuple, Type + +import numpy as onp +import numpy.typing as onpt +from utils import ( + assert_arrays_close, + assert_transforms_close, + general_group_test, + sample_transform, +) + +import viser.transforms as vtf + + +@general_group_test +def test_sample_uniform_valid( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check that sample_uniform() returns valid group members.""" + T = sample_transform( + Group, batch_axes, dtype + ) # Calls sample_uniform under the hood. + assert_transforms_close(T, T.normalize()) + + +@general_group_test +def test_log_exp_bijective( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check 1-to-1 mapping for log <=> exp operations.""" + transform = sample_transform(Group, batch_axes, dtype) + + tangent = transform.log() + assert tangent.shape == (*batch_axes, Group.tangent_dim) + + exp_transform = Group.exp(tangent) + assert_transforms_close(transform, exp_transform) + assert_arrays_close(tangent, exp_transform.log()) + + +@general_group_test +def test_inverse_bijective( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check inverse of inverse.""" + transform = sample_transform(Group, batch_axes, dtype) + assert_transforms_close(transform, transform.inverse().inverse()) + + +@general_group_test +def test_matrix_bijective( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check that we can convert to and from matrices.""" + transform = sample_transform(Group, batch_axes, dtype) + assert_transforms_close(transform, Group.from_matrix(transform.as_matrix())) + + +@general_group_test +def test_adjoint( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check adjoint definition.""" + transform = sample_transform(Group, batch_axes, dtype) + omega = onp.random.randn(*batch_axes, Group.tangent_dim).astype(dtype=dtype) + assert_transforms_close( + transform @ Group.exp(omega), + Group.exp(onp.einsum("...ij,...j->...i", transform.adjoint(), omega)) + @ transform, + ) + + +@general_group_test +def test_repr( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Smoke test for __repr__ implementations.""" + transform = sample_transform(Group, batch_axes, dtype) + print(transform) + + +@general_group_test +def test_apply( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check group action interfaces.""" + T_w_b = sample_transform(Group, batch_axes, dtype) + p_b = onp.random.randn(*batch_axes, Group.space_dim).astype(dtype) + + if Group.matrix_dim == Group.space_dim: + assert_arrays_close( + T_w_b @ p_b, + T_w_b.apply(p_b), + onp.einsum("...ij,...j->...i", T_w_b.as_matrix(), p_b), + ) + else: + # Homogeneous coordinates. + assert Group.matrix_dim == Group.space_dim + 1 + assert_arrays_close( + T_w_b @ p_b, + T_w_b.apply(p_b), + onp.einsum( + "...ij,...j->...i", + T_w_b.as_matrix(), + onp.concatenate([p_b, onp.ones_like(p_b[..., :1])], axis=-1), + )[..., :-1], + ) + + +@general_group_test +def test_multiply( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check multiply interfaces.""" + T_w_b = sample_transform(Group, batch_axes, dtype) + T_b_a = sample_transform(Group, batch_axes, dtype) + assert_arrays_close( + onp.einsum( + "...ij,...jk->...ik", T_w_b.as_matrix(), onp.linalg.inv(T_w_b.as_matrix()) + ), + onp.broadcast_to( + onp.eye(Group.matrix_dim, dtype=dtype), + (*batch_axes, Group.matrix_dim, Group.matrix_dim), + ), + ) + assert_transforms_close(T_w_b @ T_b_a, Group.multiply(T_w_b, T_b_a)) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 000000000..ed7346e1e --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,139 @@ +import functools +import random +from typing import Any, Callable, Tuple, Type, TypeVar, Union, cast + +import numpy as onp +import numpy.typing as onpt +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st + +import viser.transforms as vtf + +T = TypeVar("T", bound=vtf.MatrixLieGroup) + + +def sample_transform( + Group: Type[T], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +) -> T: + """Sample a random transform from a group.""" + seed = random.getrandbits(32) + strategy = random.randint(0, 2) + + if strategy == 0: + # Uniform sampling. + return cast( + T, + Group.sample_uniform( + onp.random.default_rng(seed), batch_axes=batch_axes, dtype=dtype + ), + ) + elif strategy == 1: + # Sample from normally-sampled tangent vector. + return cast( + T, + Group.exp( + onp.random.randn(*batch_axes, Group.tangent_dim).astype(dtype=dtype) + ), + ) + elif strategy == 2: + # Sample near identity. + return cast( + T, + Group.exp( + onp.random.randn(*batch_axes, Group.tangent_dim).astype(dtype=dtype) + * 1e-7 + ), + ) + else: + assert False + + +def general_group_test( + f: Callable[[Type[vtf.MatrixLieGroup], Tuple[int, ...], onpt.DTypeLike], None], + max_examples: int = 15, +) -> Callable[[Type[vtf.MatrixLieGroup], Tuple[int, ...], onpt.DTypeLike, Any], None]: + """Decorator for defining tests that run on all group types.""" + + # Disregard unused argument. + def f_wrapped( + Group: Type[vtf.MatrixLieGroup], + batch_axes: Tuple[int, ...], + dtype: onpt.DTypeLike, + _random_module, + ) -> None: + f(Group, batch_axes, dtype) + + # Disable timing check (first run requires JIT tracing and will be slower). + f_wrapped = settings(deadline=None, max_examples=max_examples)(f_wrapped) + + # Add _random_module parameter. + f_wrapped = given(_random_module=st.random_module())(f_wrapped) + + # Parametrize tests with each group type. + f_wrapped = pytest.mark.parametrize( + "Group", + [ + vtf.SO2, + vtf.SE2, + vtf.SO3, + vtf.SE3, + ], + )(f_wrapped) + + # Parametrize tests with each group type. + f_wrapped = pytest.mark.parametrize( + "batch_axes", + [ + (), + (1,), + (3, 1, 2, 1), + ], + )(f_wrapped) + + # Parametrize tests with each group type. + f_wrapped = pytest.mark.parametrize( + "dtype", + [onp.float32, onp.float64], + )(f_wrapped) + return f_wrapped + + +general_group_test_faster = functools.partial(general_group_test, max_examples=5) + + +def assert_transforms_close(a: vtf.MatrixLieGroup, b: vtf.MatrixLieGroup): + """Make sure two transforms are equivalent.""" + # Check matrix representation. + assert_arrays_close(a.as_matrix(), b.as_matrix()) + + # Flip signs for quaternions. + # We use `jnp.asarray` here in case inputs are onp arrays and don't support `.at()`. + p1 = a.parameters().copy() + p2 = b.parameters().copy() + if isinstance(a, vtf.SO3): + p1 = p1 * onp.sign(onp.sum(p1, axis=-1, keepdims=True)) + p2 = p2 * onp.sign(onp.sum(p2, axis=-1, keepdims=True)) + elif isinstance(a, vtf.SE3): + p1[..., :4] *= onp.sign(onp.sum(p1[..., :4], axis=-1, keepdims=True)) + p2[..., :4] *= onp.sign(onp.sum(p2[..., :4], axis=-1, keepdims=True)) + + # Make sure parameters are equal. + assert_arrays_close(p1, p2) + + +def assert_arrays_close(*arrays: Union[onpt.NDArray[onp.float64], float]): + """Make sure two arrays are close. (and not NaN)""" + for array1, array2 in zip(arrays[:-1], arrays[1:]): + assert onp.asarray(array1).dtype == onp.asarray(array2).dtype + + if isinstance(array1, (float, int)) or array1.dtype == onp.float64: + rtol = 1e-7 + atol = 1e-7 + else: + rtol = 1e-3 + atol = 1e-3 + + onp.testing.assert_allclose(array1, array2, rtol=rtol, atol=atol) + assert not onp.any(onp.isnan(array1)) + assert not onp.any(onp.isnan(array2))