Skip to content

Commit

Permalink
Tweak tree traversal, bump version
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jul 29, 2022
1 parent a1ace0b commit 12e873e
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 18 deletions.
20 changes: 6 additions & 14 deletions jax_dataclasses/_copy_and_mutate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

import jax
from jax import numpy as jnp

from . import _dataclasses
from jax._src.tree_util import _registry # Dangerous!

T = TypeVar("T")

Expand All @@ -21,18 +20,11 @@ def _tree_children(container: Any) -> Sequence[Any]:
"""Grab child nodes of a pytree. This would ideally be implemented using the pytree
registry."""

if isinstance(container, (tuple, list)):
return container
elif isinstance(container, dict):
return tuple(container.values())
elif dataclasses.is_dataclass(container):
out = []
for field in dataclasses.fields(container):
# Mark non-static fields as mutable.
if not field.metadata.get(_dataclasses.FIELD_METADATA_STATIC_MARKER, False):
out.append(getattr(container, field.name))
return out
return ()
registry_entry = _registry.get(type(container))
if registry_entry is not None:
children, _metadata = registry_entry.to_iter(container)
return list(children)
return []


def _mark_mutable(obj: Any, mutable: _Mutability, visited: Set[Any]) -> None:
Expand Down
1 change: 0 additions & 1 deletion jax_dataclasses/_enforced_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def __post_init__(self):

# For each field...
for field in dataclasses.fields(self):

type_hint = hint_from_name[field.name]
value = self.__getattribute__(field.name)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
long_description = fh.read()
setup(
name="jax_dataclasses",
version="1.2.3",
version="1.3.0",
description="Dataclasses + JAX",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def _assert_pytree_allclose(x, y):
jax.tree_multimap(
jax.tree_map(
lambda *arrays: onp.testing.assert_allclose(arrays[0], arrays[1]), x, y
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def _assert_pytree_allclose(x, y):
jax.tree_multimap(
jax.tree_map(
lambda *arrays: onp.testing.assert_allclose(arrays[0], arrays[1]), x, y
)

Expand Down

0 comments on commit 12e873e

Please sign in to comment.