Skip to content

Commit

Permalink
fixed NaN when calculating the gradient of SO3.log() #9 (#10)
Browse files Browse the repository at this point in the history
* fixed nan #9

* Add test + isort / black

---------

Co-authored-by: Brent Yi <[email protected]>
  • Loading branch information
Ending2015a and brentyi committed Mar 10, 2023
1 parent ad93513 commit 6698903
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@
]
autoapi_add_toctree_entry = False


# Generate name aliases
def _gen_name_aliases():
"""Generate a name alias dictionary, which maps private names to ones in the public
Expand Down Expand Up @@ -319,6 +320,7 @@ def __init__(self, obj, **kwargs):

_override_function_documenter()


# Apply our inheritance alias to autoapi attribute annotations
def _override_attribute_documenter():
import autoapi
Expand Down
2 changes: 1 addition & 1 deletion examples/se3_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
import time
from typing import List, Literal, Tuple, Union

import tyro
import jax
import jax_dataclasses as jdc
import matplotlib.pyplot as plt
import optax
import tyro
from jax import numpy as jnp
from typing_extensions import assert_never

Expand Down
4 changes: 2 additions & 2 deletions jaxlie/_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,14 +371,14 @@ def log(self) -> jnp.ndarray:
norm_sq,
)
)

w_safe = jnp.where(use_taylor, w, 1.0)
atan_n_over_w = jnp.arctan2(
jnp.where(w < 0, -norm_safe, norm_safe),
jnp.abs(w),
)
atan_factor = jnp.where(
use_taylor,
2.0 / w - 2.0 / 3.0 * norm_sq / w**3,
2.0 / w_safe - 2.0 / 3.0 * norm_sq / w_safe**3,
jnp.where(
jnp.abs(w) < get_epsilon(w.dtype),
jnp.where(w > 0, 1.0, -1.0) * jnp.pi / norm_safe,
Expand Down
15 changes: 15 additions & 0 deletions tests/test_autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,21 @@ def _exp(Group: Type[jaxlie.MatrixLieGroup], generator: jnp.ndarray) -> jnp.ndar
return cast(jnp.ndarray, Group.exp(generator).parameters())


def test_so3_nan():
"""Make sure we don't get NaNs from division when w == 0.
https://github.com/brentyi/jaxlie/issues/9"""

@jax.jit
@jax.grad
def func(x):
return jaxlie.SO3.exp(x).log().sum()

for omega in jnp.eye(3) * jnp.pi:
a = jnp.array(omega, dtype=jnp.float32)
assert all(onp.logical_not(onp.isnan(func(a))))


@general_group_test
def test_exp_random(Group: Type[jaxlie.MatrixLieGroup]):
"""Check that exp Jacobians are consistent, with randomly sampled transforms."""
Expand Down

0 comments on commit 6698903

Please sign in to comment.