From 398a8f64ae72e046db0630eef929188aeb7f1ed3 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Thu, 4 Aug 2022 04:05:56 -0700 Subject: [PATCH] Fix issues with forward references, cycles --- jax_dataclasses/__init__.py | 3 +-- jax_dataclasses/_dataclasses.py | 20 +++++++++++++++++--- setup.py | 2 +- tests/test_dataclass.py | 2 ++ 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/jax_dataclasses/__init__.py b/jax_dataclasses/__init__.py index 6e287f6..617e3b4 100644 --- a/jax_dataclasses/__init__.py +++ b/jax_dataclasses/__init__.py @@ -17,10 +17,9 @@ # Dataclass transforms serve a similar purpose, but are currently only supported in # pyright and pylance. # https://github.com/microsoft/pyright/blob/master/specs/dataclass_transforms.md - from dataclasses import dataclass as pytree_dataclass - # `static_field()` is deprecated, but not a lot of code to support, so leaving it # for now... + from dataclasses import dataclass as pytree_dataclass from dataclasses import field as static_field else: from ._dataclasses import pytree_dataclass diff --git a/jax_dataclasses/_dataclasses.py b/jax_dataclasses/_dataclasses.py index 21e95d6..10632de 100644 --- a/jax_dataclasses/_dataclasses.py +++ b/jax_dataclasses/_dataclasses.py @@ -1,8 +1,9 @@ import dataclasses +import inspect from typing import Dict, List, Optional, Type, TypeVar from jax import tree_util -from typing_extensions import Annotated +from typing_extensions import Annotated, get_type_hints try: # Attempt to import flax for serialization. The exception handling lets us drop @@ -66,15 +67,28 @@ def _register_pytree_dataclass(cls: Type[T]) -> Type[T]: child_node_field_names: List[str] = [] static_field_names: List[str] = [] + # We use get_type_hints() instead of field.type to make sure that forward references + # are resolved. + # + # We make a copy of the caller's local namespace, and add the registered class to it + # in advance: this prevents cyclic references from breaking get_type_hints(). + caller_localns = dict(inspect.stack()[1][0].f_locals) + if cls.__name__ not in caller_localns: + caller_localns[cls.__name__] = None + type_from_name = get_type_hints(cls, localns=caller_localns, include_extras=True) + for field in dataclasses.fields(cls): if not field.init: continue + field_type = type_from_name[field.name] + print(field_type) + # Two ways to mark a field as static: either via the Static[] type or # jdc.static_field(). if ( - hasattr(field.type, "__metadata__") - and JDC_STATIC_MARKER in field.type.__metadata__ + hasattr(field_type, "__metadata__") + and JDC_STATIC_MARKER in field_type.__metadata__ ): static_field_names.append(field.name) continue diff --git a/setup.py b/setup.py index 658f4cb..5a79ccd 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ long_description = fh.read() setup( name="jax_dataclasses", - version="1.4.0", + version="1.4.1", description="Dataclasses + JAX", long_description=long_description, long_description_content_type="text/markdown", diff --git a/tests/test_dataclass.py b/tests/test_dataclass.py index a9e4f70..0fba3c5 100644 --- a/tests/test_dataclass.py +++ b/tests/test_dataclass.py @@ -2,6 +2,8 @@ static fields, etc. """ +from __future__ import annotations + import jax import numpy as onp import pytest