Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chain rules for function calls without dims #83

Closed
wants to merge 3 commits into from

Conversation

gaurav-arya
Copy link
Contributor

Addresses issue with existing chain rules observed in FluxML/Zygote.jl#1386

@codecov
Copy link

codecov bot commented Mar 6, 2023

Codecov Report

Patch coverage: 100.00% and project coverage change: +3.57 🎉

Comparison is base (7d698db) 84.13% compared to head (13cf3af) 87.71%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master      #83      +/-   ##
==========================================
+ Coverage   84.13%   87.71%   +3.57%     
==========================================
  Files           2        2              
  Lines         208      236      +28     
==========================================
+ Hits          175      207      +32     
+ Misses         33       29       -4     
Impacted Files Coverage Δ
src/chainrules.jl 100.00% <100.00%> (ø)
src/definitions.jl 72.11% <0.00%> (+3.84%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@devmotion
Copy link
Member

Could we document that downstream packages have to implement the two-argument methods but not the ones without dims? That seems better to me than adding more rules.

Generally, the approach in the PR won't work anyway if a package has only implemented the one-argument version.

@gaurav-arya
Copy link
Contributor Author

gaurav-arya commented Mar 6, 2023

I'm not following how your suggested approach could mean that the extra rules here aren't needed?

One approach could be to replace this line:

$f(x::AbstractArray) = (y = to1(x); $pf(y) * y)

with

 $f(x::AbstractArray) =  $f(x::AbstractArray, 1:ndims(x))

Then, we wouldn't need the extra rule for no dims, so long as downstream packages never implement $f(x::AbstractArray) directly as you say. Is that what you're suggesting?

@gaurav-arya
Copy link
Contributor Author

I actually went in the opposite direction and generalized the chain rules to directly work with and without a dims argument, which addresses

Generally, the approach in the PR won't work anyway if a package has only implemented the one-argument version.

It makes the rules here a bit more complex, but now no assumptions whatsoever are made on what signatures downstream implementations support, so this is arguably the most robust solution.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this PR is still suboptimal and a better design would be desirable. With the latest changes now the signatures of the rules differs from the signatures of fft etc.

I think the cleanest solution is to only work with versions of fft etc. that implement dims and forward fft(x) etc. to the two-argument version. Otherwise we have to copy all rules and just remove dims everywhere. I think we should avoid such a code duplication.

# we explicitly handle both unprovided and provided dims arguments in all rules, which
# results in some additional complexity here but means no assumptions are made on what
# signatures downstream implementations support.
function ChainRulesCore.frule(Δargs, ::typeof(fft), x::AbstractArray, dims=nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not happy about this PR because it means the signature of the AD rules is different from the signatures of fft etc. - we do not support dims = nothing in any of these methods.

Copy link
Contributor Author

@gaurav-arya gaurav-arya Mar 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A default positional argument simply expands to separate dispatches on the signatures fft(x, dims) and fft(x). The dims=nothing is just a way of sharing logic in these cases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I would not say the signatures are different?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is: You can call frule(.., fft, x, nothing) but you cannot call fft(x, nothing). This breaks the correspondence between the primal function and the rules, and makes the signatures inconsistent.

There is no clean way to share code as long as fft(x) and fft(x, dims) are completely separate. Introducing fft(x) = fft(x, 1:ndims(x)) or fft(x) = fft(x, nothing), and demanding that downstream packages implement fft(x, dims) only would solve these issues. Otherwise you have to copy the code or use something like @eval to do it for you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a fair point, I didn't realize the nothing case. Sharing code would be easy enough with a shared helper function, e.g. replacing my current function with something like _fft_rrule and calling it in both cases, so that all the dispatches are correct. If you're opposed to that, I can look into how to modify src/definitions.jl to support your solution.

@gaurav-arya
Copy link
Contributor Author

See my response to your comment -- I don't really agree that the signatures are different, and even explicitly writing out separate rules for fft(x) and fft(x, dims) would not require code reuse with a simple helper function.

Also, I see it as an inherit benefit to avoid modifying src/definitions.jl when possible in as old and fundamental a package such as this one.

@gaurav-arya
Copy link
Contributor Author

So it's possible to avoid code copying and get the dispatches right if one makes a helper function e.g. _fft_frule that accepts a has_dims of Val{false} or Val{true} and behaves based on that. But you're right it gets complicated, so I agree the best solution is to only handle the dims implementation...

@gaurav-arya gaurav-arya closed this Mar 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants