-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DifferentiationInterface testing #769
Comments
|
Thanks! In the file you linked, the source of truth seems to be Zygote? What would you use to validate the Zygote gradients themselves? |
Currently I compute with Zygote and then test against other backends based on the device
Tracker and Zygote hit very different code paths in LuxLib (Zygote is the optimized one with often handwritten rules). In case of conflict/mismatch, the general assumption is that Tracker (on GPU) or FiniteDifferences (on CPU) is the source of truth. |
I have been meaning to try out FiniteDiff to validate the GPU gradients but haven't had the time to set it up. |
Hi @avik-pal, JuliaDiff/DifferentiationInterface.jl#372 has a first version of Lux tests with ComponentArrays.jl encoding. A few points remain unclear to me:
|
I meant FiniteDifferences.jl originally. But you are looking at the new release where I migrated to FiniteDiff 😅 How do you handle the flattening and unflattening of ps for finite differences? The only code I found is this one, should I copy it inside DITest?
No states cannot be mutated. (It is a bug in Lux if it happens for any of the layers). 1 particular case to be careful about is |
How do you choose FiniteDiff parameters like the epsilon? Keep the package defaults? I'm looking at test failures in JuliaDiff/DifferentiationInterface.jl#372 which I think are due to the numerical errors in finite differencing, but I don't know if I should take even higher Are there layers that contain an |
Yes keep the default. Normalization layers are tricky to test with Finite Differencing especially because of the reason you cited. In those cases, I rely on comparing Zygote with any of the other AD backends. For example, using Tracker hits generic codepaths and using Zygote hits optimized codepaths with custom rrules, so the assumption is that Tracker without custom rules gets the gradient correct. Alternatively for smaller systems comparing against ForwardDiff is also an option.
The |
Have you ever encountered this error with Tracker + ComponentArrays? The tests pass with Zygote so I'm trying to add more backends |
Yes you should call Tracker.param on the ComponentArray directly instead of |
Hi @avik-pal!
I'm heading towards multi-argument and non-array support in DI, and I'd like to start testing Lux layers. For this I would need two things:
Do you think you could help me out?
The text was updated successfully, but these errors were encountered: