Skip to content

Commit

Permalink
Fix issues with forward references, cycles
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Aug 4, 2022
1 parent cba8ac9 commit 398a8f6
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 6 deletions.
3 changes: 1 addition & 2 deletions jax_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
# Dataclass transforms serve a similar purpose, but are currently only supported in
# pyright and pylance.
# https://github.com/microsoft/pyright/blob/master/specs/dataclass_transforms.md
from dataclasses import dataclass as pytree_dataclass

# `static_field()` is deprecated, but not a lot of code to support, so leaving it
# for now...
from dataclasses import dataclass as pytree_dataclass
from dataclasses import field as static_field
else:
from ._dataclasses import pytree_dataclass
Expand Down
20 changes: 17 additions & 3 deletions jax_dataclasses/_dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import dataclasses
import inspect
from typing import Dict, List, Optional, Type, TypeVar

from jax import tree_util
from typing_extensions import Annotated
from typing_extensions import Annotated, get_type_hints

try:
# Attempt to import flax for serialization. The exception handling lets us drop
Expand Down Expand Up @@ -66,15 +67,28 @@ def _register_pytree_dataclass(cls: Type[T]) -> Type[T]:
child_node_field_names: List[str] = []
static_field_names: List[str] = []

# We use get_type_hints() instead of field.type to make sure that forward references
# are resolved.
#
# We make a copy of the caller's local namespace, and add the registered class to it
# in advance: this prevents cyclic references from breaking get_type_hints().
caller_localns = dict(inspect.stack()[1][0].f_locals)
if cls.__name__ not in caller_localns:
caller_localns[cls.__name__] = None
type_from_name = get_type_hints(cls, localns=caller_localns, include_extras=True)

for field in dataclasses.fields(cls):
if not field.init:
continue

field_type = type_from_name[field.name]
print(field_type)

# Two ways to mark a field as static: either via the Static[] type or
# jdc.static_field().
if (
hasattr(field.type, "__metadata__")
and JDC_STATIC_MARKER in field.type.__metadata__
hasattr(field_type, "__metadata__")
and JDC_STATIC_MARKER in field_type.__metadata__
):
static_field_names.append(field.name)
continue
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
long_description = fh.read()
setup(
name="jax_dataclasses",
version="1.4.0",
version="1.4.1",
description="Dataclasses + JAX",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
2 changes: 2 additions & 0 deletions tests/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
static fields, etc.
"""

from __future__ import annotations

import jax
import numpy as onp
import pytest
Expand Down

0 comments on commit 398a8f6

Please sign in to comment.