Skip to content

Commit

Permalink
Transform helpers cleanup (#285)
Browse files Browse the repository at this point in the history
* Start transforms cleanup

* Add back `sample_uniform()`

* Add basic transform tests

* Ensure dtype consistency

* More tests, GH action

* Add pytest dependency

* Add hypothesis

* Minor fixes
  • Loading branch information
brentyi committed Sep 16, 2024
1 parent a1eec96 commit 2ad8a57
Show file tree
Hide file tree
Showing 13 changed files with 782 additions and 216 deletions.
28 changes: 28 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: pytest

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install uv
uv pip install --system -e ".[dev,examples]"
- name: Test with pytest
run: |
pytest
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ dev = [
"pyright>=1.1.308",
"ruff==0.6.2",
"pre-commit==3.3.2",
"pytest",
"hypothesis[numpy]",
]
examples = [
"torch>=1.13.1",
Expand Down
54 changes: 36 additions & 18 deletions src/viser/transforms/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
class MatrixLieGroup(abc.ABC):
"""Interface definition for matrix Lie groups."""

# Class properties.
# > These will be set in `_utils.register_lie_group()`.

matrix_dim: ClassVar[int]
"""Dimension of square matrix output from `.as_matrix()`."""

Expand All @@ -36,6 +33,19 @@ def __init__(
"""Construct a group object from its underlying parameters."""
raise NotImplementedError()

def __init_subclass__(
cls,
matrix_dim: int = 0,
parameters_dim: int = 0,
tangent_dim: int = 0,
space_dim: int = 0,
) -> None:
"""Set class properties for subclasses. We default to dummy values."""
cls.matrix_dim = matrix_dim
cls.parameters_dim = parameters_dim
cls.tangent_dim = tangent_dim
cls.space_dim = space_dim

# Shared implementations.

@overload
Expand Down Expand Up @@ -66,11 +76,14 @@ def __matmul__(

@classmethod
@abc.abstractmethod
def identity(cls, batch_axes: Tuple[int, ...] = ()) -> Self:
def identity(
cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64
) -> Self:
"""Returns identity element.
Args:
batch_axes: Any leading batch axes for the output transform.
dtype: Datatype for the output.
Returns:
Identity element.
Expand Down Expand Up @@ -172,20 +185,25 @@ def normalize(self) -> Self:
Normalized group member.
"""

# @classmethod
# @abc.abstractmethod
# def sample_uniform(cls, key: onp.ndarray, batch_axes: Tuple[int, ...] = ()) -> Self:
# """Draw a uniform sample from the group. Translations (if applicable) are in the
# range [-1, 1].
#
# Args:
# key: PRNG key, as returned by `jax.random.PRNGKey()`.
# batch_axes: Any leading batch axes for the output transforms. Each
# sampled transform will be different.
#
# Returns:
# Sampled group member.
# """
@classmethod
@abc.abstractmethod
def sample_uniform(
cls,
rng: onp.random.Generator,
batch_axes: Tuple[int, ...] = (),
dtype: onpt.DTypeLike = onp.float64,
) -> Self:
"""Draw a uniform sample from the group. Translations (if applicable) are in the
range [-1, 1].
Args:
rng: numpy generator object.
batch_axes: Any leading batch axes for the output transforms. Each
sampled transform will be different.
Returns:
Sampled group member.
"""

@final
def get_batch_axes(self) -> Tuple[int, ...]:
Expand Down
87 changes: 41 additions & 46 deletions src/viser/transforms/_se2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import dataclasses
from typing import Tuple, cast

Expand All @@ -7,17 +9,17 @@

from . import _base, hints
from ._so2 import SO2
from .utils import broadcast_leading_axes, get_epsilon, register_lie_group
from .utils import broadcast_leading_axes, get_epsilon


@register_lie_group(
@dataclasses.dataclass(frozen=True)
class SE2(
_base.SEBase[SO2],
matrix_dim=3,
parameters_dim=4,
tangent_dim=3,
space_dim=2,
)
@dataclasses.dataclass(frozen=True)
class SE2(_base.SEBase[SO2]):
):
"""Special Euclidean group for proper rigid transforms in 2D. Broadcasting
rules are the same as for numpy.
Expand All @@ -39,7 +41,7 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(unit_complex={unit_complex}, xy={xy})"

@staticmethod
def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> "SE2":
def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> SE2:
"""Construct a transformation from standard 2D pose parameters.
This is not the same as integrating over a length-3 twist.
Expand All @@ -56,7 +58,7 @@ def from_rotation_and_translation(
cls,
rotation: SO2,
translation: onpt.NDArray[onp.floating],
) -> "SE2":
) -> SE2:
assert translation.shape[-1:] == (2,)
rotation, translation = broadcast_leading_axes((rotation, translation))
return SE2(
Expand All @@ -77,16 +79,18 @@ def translation(self) -> onpt.NDArray[onp.floating]:

@classmethod
@override
def identity(cls, batch_axes: Tuple[int, ...] = ()) -> "SE2":
def identity(
cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64
) -> SE2:
return SE2(
unit_complex_xy=onp.broadcast_to(
onp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4)
onp.array([1.0, 0.0, 0.0, 0.0], dtype=dtype), (*batch_axes, 4)
)
)

@classmethod
@override
def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> "SE2":
def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SE2:
assert matrix.shape[-2:] == (3, 3) or matrix.shape[-2:] == (2, 3)
# Currently assumes bottom row is [0, 0, 1].
return SE2.from_rotation_and_translation(
Expand Down Expand Up @@ -123,7 +127,7 @@ def as_matrix(self) -> onpt.NDArray[onp.floating]:

@classmethod
@override
def exp(cls, tangent: onpt.NDArray[onp.floating]) -> "SE2":
def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SE2:
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L558
# Also see:
Expand All @@ -146,21 +150,15 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> "SE2":
)

theta_sq = theta**2
sin_over_theta = cast(
onp.ndarray,
onp.where(
use_taylor,
1.0 - theta_sq / 6.0,
onp.sin(safe_theta) / safe_theta,
),
sin_over_theta = onp.where(
use_taylor,
1.0 - theta_sq / 6.0,
onp.sin(safe_theta) / safe_theta,
)
one_minus_cos_over_theta = cast(
onp.ndarray,
onp.where(
use_taylor,
0.5 * theta - theta * theta_sq / 24.0,
(1.0 - onp.cos(safe_theta)) / safe_theta,
),
one_minus_cos_over_theta = onp.where(
use_taylor,
0.5 * theta - theta * theta_sq / 24.0,
(1.0 - onp.cos(safe_theta)) / safe_theta,
)

V = onp.stack(
Expand All @@ -172,9 +170,12 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> "SE2":
],
axis=-1,
).reshape((*tangent.shape[:-1], 2, 2))

return SE2.from_rotation_and_translation(
rotation=SO2.from_radians(theta),
translation=onp.einsum("...ij,...j->...i", V, tangent[..., :2]),
translation=onp.einsum("...ij,...j->...i", V, tangent[..., :2]).astype(
tangent.dtype
),
)

@override
Expand Down Expand Up @@ -224,10 +225,10 @@ def log(self) -> onpt.NDArray[onp.floating]:
],
axis=-1,
)
return tangent
return tangent.astype(self.unit_complex_xy.dtype)

@override
def adjoint(self: "SE2") -> onpt.NDArray[onp.floating]:
def adjoint(self: SE2) -> onpt.NDArray[onp.floating]:
cos, sin, x, y = onp.moveaxis(self.unit_complex_xy, -1, 0)
return onp.stack(
[
Expand All @@ -244,21 +245,15 @@ def adjoint(self: "SE2") -> onpt.NDArray[onp.floating]:
axis=-1,
).reshape((*self.get_batch_axes(), 3, 3))

# @classmethod
# @override
# def sample_uniform(
# cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = ()
# ) -> "SE2":
# key0, key1 = jax.random.split(key)
# return SE2.from_rotation_and_translation(
# rotation=SO2.sample_uniform(key0, batch_axes=batch_axes),
# translation=jax.random.uniform(
# key=key1,
# shape=(
# *batch_axes,
# 2,
# ),
# minval=-1.0,
# maxval=1.0,
# ),
# )
@classmethod
@override
def sample_uniform(
cls,
rng: onp.random.Generator,
batch_axes: Tuple[int, ...] = (),
dtype: onpt.DTypeLike = onp.float64,
) -> SE2:
return SE2.from_rotation_and_translation(
SO2.sample_uniform(rng, batch_axes=batch_axes, dtype=dtype),
rng.uniform(low=-1.0, high=1.0, size=(*batch_axes, 2)).astype(dtype),
)
Loading

0 comments on commit 2ad8a57

Please sign in to comment.