Skip to content

Commit

Permalink
More robust forward references, fix spurious print
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Aug 8, 2022
1 parent 398a8f6 commit 6faa039
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 17 deletions.
79 changes: 64 additions & 15 deletions jax_dataclasses/_dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import collections
import dataclasses
import inspect
import sys
from typing import Dict, List, Optional, Type, TypeVar

from jax import tree_util
from typing_extensions import Annotated, get_type_hints
from typing_extensions import Annotated

try:
# Attempt to import flax for serialization. The exception handling lets us drop
Expand Down Expand Up @@ -53,12 +56,60 @@ def deprecated_static_field(*args, **kwargs):
return dataclasses.field(*args, **kwargs)


def _register_pytree_dataclass(cls: Type[T]) -> Type[T]:
"""Register a dataclass as a flax-serializable pytree container.
class _UnresolvableForwardReference:
def __class_getitem__(cls, item) -> Type[_UnresolvableForwardReference]:
"""__getitem__ passthrough, for supporting generics."""
return _UnresolvableForwardReference


def _get_type_hints_partial(obj, include_extras=False):
"""Adapted from typing.get_type_hints(), but aimed at suppressing errors from not
(yet) resolvable forward references. Only used for detecting `jdc.Static[]`
annotations.
For example:
@jdc.pytree_dataclass
class A:
x: B
y: jdc.Static[bool]
@jdc.pytree_dataclass
class B:
x: jnp.ndarray
Args:
cls (Type[T]): Dataclass to wrap.
Note that the type annotations of `A` need to be parsed by the `pytree_dataclass`
decorator in order to register the static field, but `B` is not yet defined when the
decorator is run. We don't actually care about the details of the `B` annotation, so
we replace it in our annotation dictionary with a dummy value.
Differences:
1. `include_extras` must be True.
2. Only supports types.
3. Doesn't throw an error when a name is not found. Instead, replaces the value
with `_UnresolvableForwardReference`.
"""
assert include_extras
assert isinstance(obj, type)

hints = {}
for base in reversed(obj.__mro__):
# Replace any unresolvable names with _UnresolvableForwardReference.
base_globals = collections.defaultdict(lambda: _UnresolvableForwardReference)
base_globals.update(sys.modules[base.__module__].__dict__)

ann = base.__dict__.get("__annotations__", {})
for name, value in ann.items():
if value is None:
value = type(None)
if isinstance(value, str):
value = eval(value, base_globals)
hints[name] = value
return hints


def _register_pytree_dataclass(cls: Type[T]) -> Type[T]:
"""Register a dataclass as a flax-serializable pytree container."""

assert dataclasses.is_dataclass(cls)

Expand All @@ -67,22 +118,20 @@ def _register_pytree_dataclass(cls: Type[T]) -> Type[T]:
child_node_field_names: List[str] = []
static_field_names: List[str] = []

# We use get_type_hints() instead of field.type to make sure that forward references
# are resolved.
# We don't directly use field.type for postponed evaluation; we want to make sure
# that our types are interpreted as proper types and not as (string) forward
# references.
#
# We make a copy of the caller's local namespace, and add the registered class to it
# in advance: this prevents cyclic references from breaking get_type_hints().
caller_localns = dict(inspect.stack()[1][0].f_locals)
if cls.__name__ not in caller_localns:
caller_localns[cls.__name__] = None
type_from_name = get_type_hints(cls, localns=caller_localns, include_extras=True)
# Note that there are ocassionally situations where the @jdc.pytree_dataclass
# decorator is called before a referenced type is defined; to suppress this error,
# we resolve missing names to our subscriptible placeohlder object.
type_from_name = _get_type_hints_partial(cls, include_extras=True)

for field in dataclasses.fields(cls):
if not field.init:
continue

field_type = type_from_name[field.name]
print(field_type)

# Two ways to mark a field as static: either via the Static[] type or
# jdc.static_field().
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
long_description = fh.read()
setup(
name="jax_dataclasses",
version="1.4.1",
version="1.4.2",
description="Dataclasses + JAX",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
34 changes: 33 additions & 1 deletion tests/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from __future__ import annotations

from typing import Generic, TypeVar

import jax
import numpy as onp
import pytest
Expand Down Expand Up @@ -115,7 +117,7 @@ def test_no_init():
class A:
field1: float
field2: float = jdc.field()
field3: bool = jdc.static_field(init=False)
field3: jdc.Static[bool] = jdc.field(init=False)

def __post_init__(self):
object.__setattr__(self, "field3", False)
Expand All @@ -125,3 +127,33 @@ def construct_A(a: float) -> A:
return A(field1=a, field2=a * 2.0)

assert construct_A(5.0).field3 is False


def test_static_field_forward_ref():
@jdc.pytree_dataclass
class A:
field1: float
field2: float
field3: jdc.Static[Container[bool]]

T = TypeVar("T")

@jdc.pytree_dataclass
class Container(Generic[T]):
x: T

@jax.jit
def jitted_op(obj: A) -> float:
if obj.field3.x:
return obj.field1 + obj.field2
else:
return obj.field1 - obj.field2

with pytest.raises(ValueError):
# Cannot map over pytrees with different treedefs
_assert_pytree_allclose(
A(1.0, 2.0, Container(False)), A(1.0, 2.0, Container(True))
)

_assert_pytree_allclose(jitted_op(A(5.0, 3.0, Container(True))), 8.0)
_assert_pytree_allclose(jitted_op(A(5.0, 3.0, Container(False))), 2.0)

0 comments on commit 6faa039

Please sign in to comment.