Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for jaxtyping #6

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 81 additions & 29 deletions jax_dataclasses/_enforced_annotations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import dataclasses
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, _AnnotatedAlias

try:
from jaxtyping.array_types import _MetaAbstractArray, _NamedVariadicDim
except ImportError:
_MetaAbstractArray = type(None)
_NamedVariadicDim = type(None)

from jax import numpy as jnp
from typing_extensions import TypeGuard
Expand Down Expand Up @@ -92,39 +98,71 @@ def __post_init__(self) -> None:
child_batch_axes_list.append(child_batch_axes)
continue

# Check for metadata from `typing.Annotated` value! Skip if no annotation.
if not hasattr(type_hint, "__metadata__"):
continue
metadata: Tuple[Any, ...] = type_hint.__metadata__
assert (
len(metadata) <= 2
), "We expect <= 2 metadata items; only shape and dtype are expected."

# Check data type.
metadata_dtype = tuple(filter(_is_dtype, metadata))
if len(metadata_dtype) > 0 and hasattr(value, "dtype"):
(dtype,) = metadata_dtype
assert jnp.issubdtype(
value.dtype, dtype
), f"Mismatched dtype, expected {dtype} but got {value.dtype}."

# Shape checks.
metadata_shape = tuple(filter(_is_expected_shape, metadata))
shape: Optional[Tuple[int, ...]] = None
if isinstance(value, (int, float)):
shape = ()
elif hasattr(value, "shape"):
shape = value.shape
if len(metadata_shape) > 0 and shape is not None:
# Get expected shape, sans batch axes.
(expected_shape,) = metadata_shape
field_batch_axes = _check_batch_axes(shape, expected_shape)
if isinstance(type_hint, _AnnotatedAlias):
if not hasattr(type_hint, "__metadata__"):
continue
metadata: Tuple[Any, ...] = type_hint.__metadata__
assert (
len(metadata) <= 2
), "We expect <= 2 metadata items; only shape and dtype are expected."

# Check data type.
metadata_dtype = tuple(filter(_is_dtype, metadata))
if len(metadata_dtype) > 0 and hasattr(value, "dtype"):
(dtype,) = metadata_dtype
assert jnp.issubdtype(
value.dtype, dtype
), f"Mismatched dtype, expected {dtype} but got {value.dtype}."

# Shape checks.
metadata_shape = tuple(filter(_is_expected_shape, metadata))
shape: Optional[Tuple[int, ...]] = None
if isinstance(value, (int, float)):
shape = ()
elif hasattr(value, "shape"):
shape = value.shape
if len(metadata_shape) > 0 and shape is not None:
# Get expected shape, sans batch axes.
(expected_shape,) = metadata_shape
field_batch_axes = _check_batch_axes_annotated(
shape, expected_shape
)
if batch_axes is None:
batch_axes = field_batch_axes
else:
assert (
batch_axes == field_batch_axes
), f"Batch axis mismatch: {batch_axes} and {field_batch_axes}."
elif isinstance(type_hint, _MetaAbstractArray):
if type_hint.index_variadic == 0 and len(type_hint.dims) == 1:
assert isinstance(value, type_hint)
continue

if type_hint.index_variadic is None:
batched_type_hint = type(
f"Batch{type_hint.__name__}",
(type_hint,),
{
"index_variadic": 0,
"dims": [_NamedVariadicDim(object(), False)]
+ type_hint.dims,
},
)
else:
batched_type_hint = type_hint
assert isinstance(value, batched_type_hint)

field_batch_axes = _check_batch_axes_jaxtyping(
value.shape, batched_type_hint
)
if batch_axes is None:
batch_axes = field_batch_axes
else:
assert (
batch_axes == field_batch_axes
), f"Batch axis mismatch: {batch_axes} and {field_batch_axes}."
else:
continue

# Check child batch axes: any batch axes present in the parent should be present
# in the children as well.
Expand All @@ -149,7 +187,7 @@ def get_batch_axes(self) -> Tuple[int, ...]:
return batch_axes


def _check_batch_axes(
def _check_batch_axes_annotated(
shape: Tuple[int, ...],
expected_shape: ExpectedShape,
) -> Tuple[int, ...]:
Expand Down Expand Up @@ -180,3 +218,17 @@ def _check_batch_axes(
batch_axes = shape

return batch_axes


def _check_batch_axes_jaxtyping(
shape: Tuple[int, ...],
expected_shape: ExpectedShape,
) -> Tuple[int, ...]:
batch_index = expected_shape.index_variadic
prefix = batch_index
suffix = len(expected_shape.dims) - batch_index - 1
if suffix == 0:
batch_axes = shape[prefix:]
else:
batch_axes = shape[prefix:-suffix]
return batch_axes
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
extras_require={
"testing": [
"flax", # Used for serialization tests.
"jaxtyping",
"pytest",
"pytest-cov",
]
Expand Down
Loading