-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for jaxtyping #6
base: main
Are you sure you want to change the base?
Conversation
Thanks! Looks reasonable overall, my main concern is the private |
For |
Okay, makes sense! It's definitely not ideal but having support for (cc @patrick-kidger for any warnings, are there any plans to rework the internals of I can handle the rest of the PR. Some TODOs would be:
|
There aren't any current plans. But taking a quick glance at your code, I think this will fail for a type hint of the form Anyway, jaxtyping hints are expected to be validated using a runtime type checker, such as typeguard or beartype. I'd recommend that you simply do the same thing, as they'll handle the details for you: both the nesting above, and avoiding the need to access private jaxtyping functionality. Side note: if you're working on a project like this then you may find Equinox interesting. In particular I like the neat syntax of your |
Thanks! I've also been following Equinox; definitely the "how to build pytrees" + tooling compatibility landscapes have improved quite a bit since I started I also agree that For this we need to figure out which axes in the array shapes correspond to the variadic dimension, which leaves the options of: (a) touching the private bits of jaxtyping, (b) trying to convince @patrick-kidger to expose a public API for reasoning about jaxtyping types*, or (c) not implementing this functionality. *maybe something like: (jaxtyping type, array) -> labels for each axis in the array. Any chance you're open to something like this? (understand if not) |
You should be able to replace At that point I can see that you'd want to modify its dimensions. I think the best way to do this would be to submit a PR against jaxtyping that records so that you can then look these up, modify these as desired, and then recreate the type hint through the public jaxtyping API (e.g. |
608f0f2
to
5767638
Compare
This PR adds support for jaxtyping annotations preserving all the features and checks on tensor dimensions.
The PR doesn't update the README, since it could become messy very easily. I'll wait further indications to update the README.
Close #5.