Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for jaxtyping #6

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

lucagrementieri
Copy link

@lucagrementieri lucagrementieri commented Jan 7, 2023

This PR adds support for jaxtyping annotations preserving all the features and checks on tensor dimensions.

The PR doesn't update the README, since it could become messy very easily. I'll wait further indications to update the README.

Close #5.

@brentyi
Copy link
Owner

brentyi commented Jan 8, 2023

Thanks! Looks reasonable overall, my main concern is the private jaxtyping imports. I assume there's no way to get around this?

@lucagrementieri
Copy link
Author

lucagrementieri commented Jan 10, 2023

_MetaAbstractArray is the base class of all jaxtyping types and annotations so the check isinstance(type_hint, _MetaAbstractArray) is the best way to identify jaxtyping annotations. Surely there are workarounds to not use it, but they will be more fragile and less elegant.

For _NamedVariadicDim, I think there is no simple workaround because this class is required to support variadic dimensions, like the batch dimension.

@brentyi
Copy link
Owner

brentyi commented Jan 10, 2023

Okay, makes sense! It's definitely not ideal but having support for jaxtyping here seems useful enough to warrant it. I like how we don't have to worry about import hooks or @jaxtyped for the shape checks checking/getting the batch axes.

(cc @patrick-kidger for any warnings, are there any plans to rework the internals of jaxtyping?)

I can handle the rest of the PR. Some TODOs would be:

  • Consider adding a jaxtyping install dependency with a pinned version.
  • Fixing the tests for Python 3.7. (seems unrelated to this PR)
  • Updating the README.
  • Deprecation warnings for the proprietary shape typing interface.
  • Consider renaming EnforcedAnnotationsMixin while we're at it. This is a mouthful.

@patrick-kidger
Copy link

There aren't any current plans. But taking a quick glance at your code, I think this will fail for a type hint of the form tuple[Float[Array, "foo"], ...], i.e. one in which the array is nested within another type hint? (I didn't check that closely though, so I might be wrong.)

Anyway, jaxtyping hints are expected to be validated using a runtime type checker, such as typeguard or beartype. I'd recommend that you simply do the same thing, as they'll handle the details for you: both the nesting above, and avoiding the need to access private jaxtyping functionality.

Side note: if you're working on a project like this then you may find Equinox interesting. In particular equinox.Module is also a dataclass-pytree combo, with most (all?) of the expected bells-and-whistles: serialisation, immutability etc. (I suppose I've not really tested static type checking, as I'm mostly a non-user of that.)

I like the neat syntax of your copy_and_mutate, by the way. (Equinox's equivalent is equinox.tree_at, which is safer but a bit harder to use.)

@brentyi
Copy link
Owner

brentyi commented Jan 10, 2023

Thanks!

I've also been following Equinox; definitely the "how to build pytrees" + tooling compatibility landscapes have improved quite a bit since I started jax_dataclasses in ~late 2020. For now I think mypy compatibility + the Static[] API + copy_and_mutate are still nice enough for me to keep the library around, but sooner or later I should revisit whether the library still makes sense given developments in equinox, flax, etc (especially if a copy_and_mutate-style API is merged into flax google/flax#2735).

I also agree that typeguard or beartype makes sense for asserting that the shapes are correct, but for this PR the main purpose is to support jaxtyping annotations for the dataclass.get_batch_axes() helper that currently works for jax_dataclasses-proprietary shape annotations.

For this we need to figure out which axes in the array shapes correspond to the variadic dimension, which leaves the options of: (a) touching the private bits of jaxtyping, (b) trying to convince @patrick-kidger to expose a public API for reasoning about jaxtyping types*, or (c) not implementing this functionality.

*maybe something like: (jaxtyping type, array) -> labels for each axis in the array. Any chance you're open to something like this? (understand if not)

@patrick-kidger
Copy link

You should be able to replace isinstance(type_hint, _MetaAbstractArray) with issubclass(type_hint, AbstractArray). (Which is public API.)

At that point I can see that you'd want to modify its dimensions. I think the best way to do this would be to submit a PR against jaxtyping that records cls and item here:

https://github.com/google/jaxtyping/blob/59e8fb0d18325f990a9d59ee35e90c04b699cab8/jaxtyping/array_types.py#L400

so that you can then look these up, modify these as desired, and then recreate the type hint through the public jaxtyping API (e.g. cls[item] to recreate the same hint).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use jaxtyping to enrich type annotations
3 participants