Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use jaxtyping to enrich type annotations #5

Open
lucagrementieri opened this issue Nov 17, 2022 · 5 comments · May be fixed by #6
Open

Use jaxtyping to enrich type annotations #5

lucagrementieri opened this issue Nov 17, 2022 · 5 comments · May be fixed by #6

Comments

@lucagrementieri
Copy link

I just discovered the jaxtyping library and I think it could be an interesting alternative to the current typing system proposed by jax_dataclasses.

jaxtyping supports variable-size axes and symbolic expressions in terms of other variable-size axes, see https://github.com/google/jaxtyping/blob/main/API.md and it has very few requirements.

Do you think that it could be added to jax_dataclasses?

@brentyi
Copy link
Owner

brentyi commented Nov 17, 2022

I've also been following jaxtyping and wouldn't be opposed to deprecating the shape annotation syntax and standardizing on theirs!

Observations on my end:

  • I haven't verified myself but for shape/datatype assertions, I think jaxtyping should work out-of-the-box with @jdc.pytree_dataclass. (it seems like @jaxtyped has some recent dataclass-related additions that might help?)
  • The main additional feature that the current jdc.EnforcedAnnotationsMixin gives you is a .get_batch_axes() function. It seems like jaxtyping doesn't have any public API for accessing shape annotations, so bringing this functionality over would require manually parsing the annotations.

Let me know if you more thoughts on this, or bandwidth for contributing. 🙂

@lucagrementieri
Copy link
Author

Surely the first point should be verified.
For the second point, I did some experiments and I think that typing.get_type_hints is sufficient to access shape annotations.
I'll come back to you after more experiments and maybe with a pull request.

@brentyi
Copy link
Owner

brentyi commented Nov 18, 2022

Looking forward to it, thanks @lucagrementieri!

@lucagrementieri
Copy link
Author

lucagrementieri commented Nov 25, 2022

From my experiments I think it's possible to support all the features using jaxtyping with no manual parsing!

I think my pull request should be ready for next week.

@lucagrementieri
Copy link
Author

I finally had time to push my PR #6 ! Sorry for the delay!

@lucagrementieri lucagrementieri linked a pull request Jan 7, 2023 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants