From 12e873e9120331257f124c315de193235fbc4236 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Thu, 28 Jul 2022 20:59:56 -0700 Subject: [PATCH] Tweak tree traversal, bump version --- jax_dataclasses/_copy_and_mutate.py | 20 ++++++-------------- jax_dataclasses/_enforced_annotations.py | 1 - setup.py | 2 +- tests/test_dataclass.py | 2 +- tests/test_serialization.py | 2 +- 5 files changed, 9 insertions(+), 18 deletions(-) diff --git a/jax_dataclasses/_copy_and_mutate.py b/jax_dataclasses/_copy_and_mutate.py index 8004f11..a30a450 100644 --- a/jax_dataclasses/_copy_and_mutate.py +++ b/jax_dataclasses/_copy_and_mutate.py @@ -5,8 +5,7 @@ import jax from jax import numpy as jnp - -from . import _dataclasses +from jax._src.tree_util import _registry # Dangerous! T = TypeVar("T") @@ -21,18 +20,11 @@ def _tree_children(container: Any) -> Sequence[Any]: """Grab child nodes of a pytree. This would ideally be implemented using the pytree registry.""" - if isinstance(container, (tuple, list)): - return container - elif isinstance(container, dict): - return tuple(container.values()) - elif dataclasses.is_dataclass(container): - out = [] - for field in dataclasses.fields(container): - # Mark non-static fields as mutable. - if not field.metadata.get(_dataclasses.FIELD_METADATA_STATIC_MARKER, False): - out.append(getattr(container, field.name)) - return out - return () + registry_entry = _registry.get(type(container)) + if registry_entry is not None: + children, _metadata = registry_entry.to_iter(container) + return list(children) + return [] def _mark_mutable(obj: Any, mutable: _Mutability, visited: Set[Any]) -> None: diff --git a/jax_dataclasses/_enforced_annotations.py b/jax_dataclasses/_enforced_annotations.py index 5334b15..a0f756d 100644 --- a/jax_dataclasses/_enforced_annotations.py +++ b/jax_dataclasses/_enforced_annotations.py @@ -81,7 +81,6 @@ def __post_init__(self): # For each field... for field in dataclasses.fields(self): - type_hint = hint_from_name[field.name] value = self.__getattribute__(field.name) diff --git a/setup.py b/setup.py index 07b56ff..acf80fe 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ long_description = fh.read() setup( name="jax_dataclasses", - version="1.2.3", + version="1.3.0", description="Dataclasses + JAX", long_description=long_description, long_description_content_type="text/markdown", diff --git a/tests/test_dataclass.py b/tests/test_dataclass.py index 859ad5f..c4d3864 100644 --- a/tests/test_dataclass.py +++ b/tests/test_dataclass.py @@ -10,7 +10,7 @@ def _assert_pytree_allclose(x, y): - jax.tree_multimap( + jax.tree_map( lambda *arrays: onp.testing.assert_allclose(arrays[0], arrays[1]), x, y ) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 012db42..d7847c6 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -9,7 +9,7 @@ def _assert_pytree_allclose(x, y): - jax.tree_multimap( + jax.tree_map( lambda *arrays: onp.testing.assert_allclose(arrays[0], arrays[1]), x, y )