diff --git a/hj_reachability/sets.py b/hj_reachability/sets.py index b9eb01c..0581189 100644 --- a/hj_reachability/sets.py +++ b/hj_reachability/sets.py @@ -3,6 +3,8 @@ from flax import struct import jax.numpy as jnp +from hj_reachability import utils + from typing import Any Array = Any @@ -61,7 +63,7 @@ class Ball(BoundedSet): def extreme_point(self, direction: Array) -> Array: """Computes the point `x` in the set such that the dot product `x @ direction` is greatest.""" - return self.center + self.radius * direction / jnp.linalg.norm(direction) + return self.center + self.radius * utils.unit_vector(direction) @property def bounding_box(self) -> "Box": diff --git a/hj_reachability/sets_test.py b/hj_reachability/sets_test.py new file mode 100644 index 0000000..ef6be9a --- /dev/null +++ b/hj_reachability/sets_test.py @@ -0,0 +1,32 @@ +from absl.testing import absltest +import jax +import numpy as np + +from hj_reachability import sets + + +class SetsTest(absltest.TestCase): + + def setUp(self): + np.random.seed(0) + + def test_box(self): + box = sets.Box(np.ones(3), 2 * np.ones(3)) + np.testing.assert_allclose(box.extreme_point(np.array([1, -1, 1])), np.array([2, 1, 2])) + self.assertTrue(np.all(np.isfinite(box.extreme_point(np.zeros(3))))) + self.assertEqual(box.bounding_box, box) + np.testing.assert_allclose(box.max_magnitudes, 2 * np.ones(3)) + self.assertEqual(box.ndim, 3) + + def test_ball(self): + ball = sets.Ball(np.ones(3), np.sqrt(3)) + np.testing.assert_allclose(ball.extreme_point(np.array([1, -1, 1])), np.array([2, 0, 2]), atol=1e-6) + self.assertTrue(np.all(np.isfinite(ball.extreme_point(np.zeros(3))))) + jax.tree_map(np.testing.assert_allclose, ball.bounding_box, + sets.Box((1 - np.sqrt(3)) * np.ones(3), (1 + np.sqrt(3)) * np.ones(3))) + np.testing.assert_allclose(ball.max_magnitudes, (1 + np.sqrt(3)) * np.ones(3)) + self.assertEqual(ball.ndim, 3) + + +if __name__ == "__main__": + absltest.main() diff --git a/hj_reachability/utils.py b/hj_reachability/utils.py index 2629309..11631dc 100644 --- a/hj_reachability/utils.py +++ b/hj_reachability/utils.py @@ -1,6 +1,7 @@ import functools import jax +import jax.numpy as jnp import numpy as np from typing import Any, Callable, Iterable, List, Mapping, Optional, TypeVar, Union @@ -56,3 +57,10 @@ def get_axis_sequence(axis_array: np.ndarray) -> List: vmap_kwargs = jax.tree_util.tree_transpose(jax.tree_util.tree_structure(multivmap_kwargs), axis_sequence_structure, jax.tree_map(get_axis_sequence, multivmap_kwargs)) return functools.reduce(lambda f, kwargs: jax.vmap(f, **kwargs), vmap_kwargs, fun) + + +def unit_vector(x): + """Normalizes a vector `x`, returning a unit vector in the same direction, or a zero vector if `x` is zero.""" + norm2 = jnp.sum(jnp.square(x)) + iszero = norm2 < jnp.finfo(jnp.zeros(()).dtype).eps**2 + return jnp.where(iszero, jnp.zeros_like(x), x / jnp.sqrt(jnp.where(iszero, 1, norm2))) diff --git a/hj_reachability/utils_test.py b/hj_reachability/utils_test.py new file mode 100644 index 0000000..adfdc94 --- /dev/null +++ b/hj_reachability/utils_test.py @@ -0,0 +1,40 @@ +from absl.testing import absltest +import jax +import jax.numpy as jnp +import numpy as np + +from hj_reachability import utils + + +class UtilsTest(absltest.TestCase): + + def setUp(self): + np.random.seed(0) + + def test_multivmap(self): + a = np.random.random((3, 4, 5, 6)) + np.testing.assert_allclose(utils.multivmap(jnp.max, np.array([0, 1]))(a), np.max(a, (2, 3))) + np.testing.assert_allclose(utils.multivmap(jnp.max, np.array([0, 1, 2]))(a), np.max(a, -1)) + np.testing.assert_allclose(utils.multivmap(jnp.max, np.array([0, 1, 3]), np.array([0, 1, 2]))(a), np.max(a, 2)) + np.testing.assert_allclose( + utils.multivmap(jnp.max, np.array([1, 0, 2]), np.array([0, 1, 2]))(a), + np.max(a, 3).swapaxes(0, 1)) + np.testing.assert_allclose( + utils.multivmap(jnp.max, np.array([3, 2]), np.array([0, 1]))(a), + np.max(a, (0, 1)).swapaxes(0, 1)) + + def test_unit_vector(self): + unsafe_unit_vector = lambda x: x / jnp.linalg.norm(x, axis=-1, keepdims=True) + for d in range(1, 4): + np.testing.assert_array_equal(utils.unit_vector(np.zeros(d)), np.zeros(d)) + self.assertTrue(np.all(np.isfinite(jax.jacobian(utils.unit_vector)(np.zeros(d))))) + self.assertTrue(np.all(np.isnan(jax.jacobian(unsafe_unit_vector)(np.zeros(d))))) + a = np.random.random((100, d)) + np.testing.assert_allclose(jax.vmap(utils.unit_vector)(a), unsafe_unit_vector(a), atol=1e-6) + np.testing.assert_allclose(jax.vmap(jax.jacobian(utils.unit_vector))(a), + jax.vmap(jax.jacobian(unsafe_unit_vector))(a), + atol=1e-6) + + +if __name__ == "__main__": + absltest.main()