Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chain rules for function calls without dims #83

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 81 additions & 50 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,121 +1,152 @@
# ffts
function ChainRulesCore.frule((_, Δx, _), ::typeof(fft), x::AbstractArray, dims)
y = fft(x, dims)
Δy = fft(Δx, dims)
# we explicitly handle both unprovided and provided dims arguments in all rules, which
# results in some additional complexity here but means no assumptions are made on what
# signatures downstream implementations support.
function ChainRulesCore.frule(Δargs, ::typeof(fft), x::AbstractArray, dims=nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not happy about this PR because it means the signature of the AD rules is different from the signatures of fft etc. - we do not support dims = nothing in any of these methods.

Copy link
Contributor Author

@gaurav-arya gaurav-arya Mar 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A default positional argument simply expands to separate dispatches on the signatures fft(x, dims) and fft(x). The dims=nothing is just a way of sharing logic in these cases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I would not say the signatures are different?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is: You can call frule(.., fft, x, nothing) but you cannot call fft(x, nothing). This breaks the correspondence between the primal function and the rules, and makes the signatures inconsistent.

There is no clean way to share code as long as fft(x) and fft(x, dims) are completely separate. Introducing fft(x) = fft(x, 1:ndims(x)) or fft(x) = fft(x, nothing), and demanding that downstream packages implement fft(x, dims) only would solve these issues. Otherwise you have to copy the code or use something like @eval to do it for you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a fair point, I didn't realize the nothing case. Sharing code would be easy enough with a shared helper function, e.g. replacing my current function with something like _fft_rrule and calling it in both cases, so that all the dispatches are correct. If you're opposed to that, I can look into how to modify src/definitions.jl to support your solution.

Δx = Δargs[2]
dims_args = (dims === nothing) ? () : (dims,)
y = fft(x, dims_args...)
Δy = fft(Δx, dims_args...)
return y, Δy
end
function ChainRulesCore.rrule(::typeof(fft), x::AbstractArray, dims)
y = fft(x, dims)
function ChainRulesCore.rrule(::typeof(fft), x::AbstractArray, dims=nothing)
dims_args = (dims === nothing) ? () : (dims,)
y = fft(x, dims_args...)
project_x = ChainRulesCore.ProjectTo(x)
function fft_pullback(ȳ)
x̄ = project_x(bfft(ChainRulesCore.unthunk(ȳ), dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
x̄ = project_x(bfft(ChainRulesCore.unthunk(ȳ), dims_args...))
dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),)
return ChainRulesCore.NoTangent(), x̄, dims_args_tangent...
end
return y, fft_pullback
end

function ChainRulesCore.frule((_, Δx, _), ::typeof(rfft), x::AbstractArray{<:Real}, dims)
y = rfft(x, dims)
Δy = rfft(Δx, dims)
function ChainRulesCore.frule(Δargs, ::typeof(rfft), x::AbstractArray{<:Real}, dims=nothing)
Δx = Δargs[2]
dims_args = (dims === nothing) ? () : (dims,)
y = rfft(x, dims_args...)
Δy = rfft(Δx, dims_args...)
return y, Δy
end
function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
y = rfft(x, dims)
function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims=nothing)
dims_args = (dims === nothing) ? () : (dims,)
true_dims = (dims === nothing) ? (1:ndims(x)) : dims
y = rfft(x, dims_args...)

# compute scaling factors
halfdim = first(dims)
halfdim = first(true_dims)
d = size(x, halfdim)
n = size(y, halfdim)
scale = reshape(
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
ntuple(i -> i == first(true_dims) ? n : 1, Val(ndims(x))),
)

project_x = ChainRulesCore.ProjectTo(x)
function rfft_pullback(ȳ)
x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims_args...))
dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),)
return ChainRulesCore.NoTangent(), x̄, dims_args_tangent...
end
return y, rfft_pullback
end

function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dims)
y = ifft(x, dims)
Δy = ifft(Δx, dims)
function ChainRulesCore.frule(Δargs, ::typeof(ifft), x::AbstractArray, dims=nothing)
Δx = Δargs[2]
args = (dims === nothing) ? () : (dims,)
y = ifft(x, args...)
Δy = ifft(Δx, args...)
return y, Δy
end
function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims)
y = ifft(x, dims)
invN = normalization(y, dims)
function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims=nothing)
dims_args = (dims === nothing) ? () : (dims,)
true_dims = (dims === nothing) ? (1:ndims(x)) : dims
y = ifft(x, dims_args...)
invN = normalization(y, true_dims)
project_x = ChainRulesCore.ProjectTo(x)
function ifft_pullback(ȳ)
x̄ = project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
x̄ = project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims_args...))
dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),)
return ChainRulesCore.NoTangent(), x̄, dims_args_tangent...
end
return y, ifft_pullback
end

function ChainRulesCore.frule((_, Δx, _, _), ::typeof(irfft), x::AbstractArray, d::Int, dims)
y = irfft(x, d, dims)
Δy = irfft(Δx, d, dims)
function ChainRulesCore.frule(Δargs, ::typeof(irfft), x::AbstractArray, d::Int, dims=nothing)
Δx = Δargs[2]
dims_args = (dims === nothing) ? () : (dims,)
y = irfft(x, d, dims_args...)
Δy = irfft(Δx, d, dims_args...)
return y, Δy
end
function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
y = irfft(x, d, dims)
function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims=nothing)
dims_args = (dims === nothing) ? () : (dims,)
true_dims = (dims === nothing) ? (1:ndims(x)) : dims
y = irfft(x, d, dims_args...)

# compute scaling factors
halfdim = first(dims)
halfdim = first(true_dims)
n = size(x, halfdim)
invN = normalization(y, dims)
invN = normalization(y, true_dims)
twoinvN = 2 * invN
scale = reshape(
[i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n],
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
ntuple(i -> i == first(true_dims) ? n : 1, Val(ndims(x))),
)

project_x = ChainRulesCore.ProjectTo(x)
function irfft_pullback(ȳ)
x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims_args...))
dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),)
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), dims_args_tangent...
end
return y, irfft_pullback
end

function ChainRulesCore.frule((_, Δx, _), ::typeof(bfft), x::AbstractArray, dims)
y = bfft(x, dims)
Δy = bfft(Δx, dims)
function ChainRulesCore.frule(Δargs, ::typeof(bfft), x::AbstractArray, dims=nothing)
Δx = Δargs[2]
dims_args = (dims === nothing) ? () : (dims,)
y = bfft(x, dims_args...)
Δy = bfft(Δx, dims_args...)
return y, Δy
end
function ChainRulesCore.rrule(::typeof(bfft), x::AbstractArray, dims)
y = bfft(x, dims)
function ChainRulesCore.rrule(::typeof(bfft), x::AbstractArray, dims=nothing)
dims_args = (dims === nothing) ? () : (dims,)
y = bfft(x, dims_args...)
project_x = ChainRulesCore.ProjectTo(x)
function bfft_pullback(ȳ)
x̄ = project_x(fft(ChainRulesCore.unthunk(ȳ), dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
x̄ = project_x(fft(ChainRulesCore.unthunk(ȳ), dims_args...))
dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),)
return ChainRulesCore.NoTangent(), x̄, dims_args_tangent...
end
return y, bfft_pullback
end

function ChainRulesCore.frule((_, Δx, _, _), ::typeof(brfft), x::AbstractArray, d::Int, dims)
y = brfft(x, d, dims)
Δy = brfft(Δx, d, dims)
function ChainRulesCore.frule(Δargs, ::typeof(brfft), x::AbstractArray, d::Int, dims=nothing)
Δx = Δargs[2]
dims_args = (dims === nothing) ? () : (dims,)
y = brfft(x, d, dims_args...)
Δy = brfft(Δx, d, dims_args...)
return y, Δy
end
function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims)
y = brfft(x, d, dims)
function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims=nothing)
dims_args = (dims === nothing) ? () : (dims,)
true_dims = (dims === nothing) ? (1:ndims(x)) : dims
y = brfft(x, d, dims_args...)

# compute scaling factors
halfdim = first(dims)
halfdim = first(true_dims)
n = size(x, halfdim)
scale = reshape(
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
ntuple(i -> i == first(true_dims) ? n : 1, Val(ndims(x))),
)

project_x = ChainRulesCore.ProjectTo(x)
function brfft_pullback(ȳ)
x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims_args...))
dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),)
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), dims_args_tangent...
end
return y, brfft_pullback
end
Expand Down
38 changes: 21 additions & 17 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,14 @@ end
@testset "ChainRules" begin
@testset "shift functions" begin
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
# type inference checks of `rrule` fail on old Julia versions
# for higher-dimensional arrays:
# https://github.com/JuliaMath/AbstractFFTs.jl/pull/58#issuecomment-916530016
check_inferred = ndims(x) < 3 || VERSION >= v"1.6"

for dims in ((), 1, 2, (1,2), 1:2)
any(d > ndims(x) for d in dims) && continue

# type inference checks of `rrule` fail on old Julia versions
# for higher-dimensional arrays:
# https://github.com/JuliaMath/AbstractFFTs.jl/pull/58#issuecomment-916530016
check_inferred = ndims(x) < 3 || VERSION >= v"1.6"

test_frule(AbstractFFTs.fftshift, x, dims)
test_rrule(AbstractFFTs.fftshift, x, dims; check_inferred=check_inferred)

Expand All @@ -237,23 +237,27 @@ end
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
N = ndims(x)
complex_x = complex.(x)
for dims in unique((1, 1:N, N))
for dims in unique((1, 1:N, N, nothing))
# if dims=nothing, test handling of default dims argument
dims_args = (dims === nothing) ? () : (dims,)
true_dims = (dims === nothing) ? (1:N) : dims

for f in (fft, ifft, bfft)
test_frule(f, x, dims)
test_rrule(f, x, dims)
test_frule(f, complex_x, dims)
test_rrule(f, complex_x, dims)
test_frule(f, x, dims_args...)
test_rrule(f, x, dims_args...)
test_frule(f, complex_x, dims_args...)
test_rrule(f, complex_x, dims_args...)
end

test_frule(rfft, x, dims)
test_rrule(rfft, x, dims)
test_frule(rfft, x, dims_args...)
test_rrule(rfft, x, dims_args...)

for f in (irfft, brfft)
for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2)
test_frule(f, x, d, dims)
test_rrule(f, x, d, dims)
test_frule(f, complex_x, d, dims)
test_rrule(f, complex_x, d, dims)
for d in (2 * size(x, first(true_dims)) - 1, 2 * size(x, first(true_dims)) - 2)
test_frule(f, x, d, dims_args...)
test_rrule(f, x, d, dims_args...)
test_frule(f, complex_x, d, dims_args...)
test_rrule(f, complex_x, d, dims_args...)
end
end
end
Expand Down