Skip to content

Commit

Permalink
Introduce jdc.Static[]
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Aug 4, 2022
1 parent b1f28a9 commit cba8ac9
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 29 deletions.
33 changes: 21 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
* [Overview](#overview)
* [Installation](#installation)
* [Core interface](#core-interface)
* [Static fields](#static-fields)
* [Mutations](#mutations)
* [Shape and data-type annotations](#shape-and-data-type-annotations)
* [Alternatives](#alternatives)
Expand All @@ -35,8 +36,7 @@ Heavily influenced by some great existing work (the obvious one being

### Installation

The latest version of `jax_dataclasses` requires Python>=3.7. Python 3.6 will
work as well, but is missing support for shape annotations.
In Python >=3.7:

```bash
pip install jax_dataclasses
Expand All @@ -51,21 +51,30 @@ import jax_dataclasses as jdc
### Core interface

`jax_dataclasses` is meant to provide a drop-in replacement for
`dataclasses.dataclass`:

- <code>jdc.<strong>pytree_dataclass</strong></code> has the same interface as
`dataclasses.dataclass`, but also registers the target class as a pytree
container.
- <code>jdc.<strong>static_field</strong></code> has the same interface as
`dataclasses.field`, but will also mark the field as static. In a pytree node,
static fields will be treated as part of the treedef instead of as a child of
the node; all fields that are not explicitly marked static should contain
arrays or child nodes.
`dataclasses.dataclass`: <code>jdc.<strong>pytree_dataclass</strong></code> has
the same interface as `dataclasses.dataclass`, but also registers the target
class as a pytree node.

We also provide several aliases:
`jdc.[field, asdict, astuples, is_dataclass, replace]` are all identical to
their counterparts in the standard dataclasses library.

### Static fields

To mark a field as static (in this context: constant at compile-time), we can
wrap its type with <code>jdc.<strong>Static[]</strong></code>:

```python
@jdc.pytree_dataclass
class A:
a: jnp.ndarray
b: jdc.Static[bool]
```

In a pytree node, static fields will be treated as part of the treedef instead
of as a child of the node; all fields that are not explicitly marked static
should contain arrays or child nodes.

### Mutations

All dataclasses are automatically marked as frozen and thus immutable (even when
Expand Down
12 changes: 9 additions & 3 deletions jax_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@
# Note that mypy will not follow aliases, so `from dataclasses import dataclass` is
# preferred over `dataclass = dataclasses.dataclass`.
#
# For the future, dataclass transforms may also be worth looking into:
# Dataclass transforms serve a similar purpose, but are currently only supported in
# pyright and pylance.
# https://github.com/microsoft/pyright/blob/master/specs/dataclass_transforms.md
from dataclasses import dataclass as pytree_dataclass

# `static_field()` is deprecated, but not a lot of code to support, so leaving it
# for now...
from dataclasses import field as static_field
else:
from ._dataclasses import pytree_dataclass, static_field
from ._dataclasses import pytree_dataclass
from ._dataclasses import deprecated_static_field as static_field

from ._dataclasses import Static
from ._enforced_annotations import EnforcedAnnotationsMixin

__all__ = [
Expand All @@ -32,6 +38,6 @@
"replace",
"copy_and_mutate",
"pytree_dataclass",
"static_field",
"Static",
"EnforcedAnnotationsMixin",
]
1 change: 0 additions & 1 deletion jax_dataclasses/_copy_and_mutate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import enum
from typing import Any, ContextManager, Sequence, Set, TypeVar

import jax
from jax import numpy as jnp
from jax import tree_util
from jax._src.tree_util import _registry # Dangerous!
Expand Down
33 changes: 25 additions & 8 deletions jax_dataclasses/_dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import dataclasses
from typing import Dict, List, Optional, Type, TypeVar

import jax
from jax import tree_util
from typing_extensions import Annotated

try:
# Attempt to import flax for serialization. The exception handling lets us drop
Expand All @@ -16,7 +16,14 @@
T = TypeVar("T")


FIELD_METADATA_STATIC_MARKER = "__jax_dataclasses_static_field__"
JDC_STATIC_MARKER = "__jax_dataclasses_static_field__"


# Stolen from here: https://github.com/google/jax/issues/10476
InnerT = TypeVar("InnerT")
Static = Annotated[InnerT, JDC_STATIC_MARKER]
"""Annotates a type as static in the sense of JAX; in a pytree, fields marked as such
should be hashable and are treated as part of the treedef and not as a child node."""


def pytree_dataclass(cls: Optional[Type] = None, **kwargs):
Expand All @@ -36,11 +43,11 @@ def wrap(cls):
return wrap(cls)


def static_field(*args, **kwargs):
"""Substitute for dataclasses.field, which also marks a field as static."""
def deprecated_static_field(*args, **kwargs):
"""Deprecated, prefer `Static[]` on the type annotation instead."""

kwargs["metadata"] = kwargs.get("metadata", {})
kwargs["metadata"][FIELD_METADATA_STATIC_MARKER] = True
kwargs["metadata"][JDC_STATIC_MARKER] = True

return dataclasses.field(*args, **kwargs)

Expand All @@ -62,10 +69,20 @@ def _register_pytree_dataclass(cls: Type[T]) -> Type[T]:
for field in dataclasses.fields(cls):
if not field.init:
continue
if field.metadata.get(FIELD_METADATA_STATIC_MARKER, False):

# Two ways to mark a field as static: either via the Static[] type or
# jdc.static_field().
if (
hasattr(field.type, "__metadata__")
and JDC_STATIC_MARKER in field.type.__metadata__
):
static_field_names.append(field.name)
else:
child_node_field_names.append(field.name)
continue
if field.metadata.get(JDC_STATIC_MARKER, False):
static_field_names.append(field.name)
continue

child_node_field_names.append(field.name)

# Define flatten, unflatten operations: this simple converts our dataclass to a list
# of fields.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
long_description = fh.read()
setup(
name="jax_dataclasses",
version="1.3.0",
version="1.4.0",
description="Dataclasses + JAX",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
26 changes: 24 additions & 2 deletions tests/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,30 @@ def test_static_field():
@jdc.pytree_dataclass
class A:
field1: float
field2: float = jdc.field()
field3: bool = jdc.static_field()
field2: float
field3: jdc.Static[bool]

@jax.jit
def jitted_op(obj: A) -> float:
if obj.field3:
return obj.field1 + obj.field2
else:
return obj.field1 - obj.field2

with pytest.raises(ValueError):
# Cannot map over pytrees with different treedefs
_assert_pytree_allclose(A(1.0, 2.0, False), A(1.0, 2.0, True))

_assert_pytree_allclose(jitted_op(A(5.0, 3.0, True)), 8.0)
_assert_pytree_allclose(jitted_op(A(5.0, 3.0, False)), 2.0)


def test_static_field_deprecated():
@jdc.pytree_dataclass
class A:
field1: float
field2: float
field3: bool = jdc.static_field() # type: ignore

@jax.jit
def jitted_op(obj: A) -> float:
Expand Down
3 changes: 1 addition & 2 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"""

import flax
import jax
import numpy as onp
from jax import tree_util

Expand All @@ -20,7 +19,7 @@ def test_serialization():
class A:
field1: int
field2: int
field3: bool = jdc.static_field()
field3: jdc.Static[bool]

obj = A(field1=5, field2=3, field3=True)

Expand Down

0 comments on commit cba8ac9

Please sign in to comment.