Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rrule for rfft and ifft for CuArray #96

Closed
wants to merge 4 commits into from

Conversation

ziyiyin97
Copy link

Fix #95

Right now there is no test on CuArray so this fix cannot be tested easily. Any suggestion?

@codecov
Copy link

codecov bot commented Mar 28, 2023

Codecov Report

Patch coverage: 100.00% and project coverage change: +0.97 🎉

Comparison is base (a25656d) 87.08% compared to head (a22c168) 88.05%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master      #96      +/-   ##
==========================================
+ Coverage   87.08%   88.05%   +0.97%     
==========================================
  Files           3        3              
  Lines         209      226      +17     
==========================================
+ Hits          182      199      +17     
  Misses         27       27              
Impacted Files Coverage Δ
ext/AbstractFFTsChainRulesCoreExt.jl 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@ziyiyin97 ziyiyin97 force-pushed the master branch 2 times, most recently from 55d2262 to 2a724d0 Compare March 29, 2023 01:07
@ziyiyin97 ziyiyin97 changed the title Fix rrule for rfft Fix rrule for rfft and ifft for CuArray Mar 29, 2023
Copy link
Contributor

@gaurav-arya gaurav-arya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding testing, #78 wants to make AbstractFFTsTestUtils a separate package for downstream packages to use. Once its a separate dependency, we could probably lump in a ChainRulesTestUtils dependency there so test these chain rules, and CUDA can then use AbstractFFTsTestUtils in its tests. cc @devmotion

For now, I think we can only test this change in this package if there's a way to make thie code error without using a GPU array?

@@ -37,7 +37,7 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)

project_x = ChainRulesCore.ProjectTo(x)
function rfft_pullback(ȳ)
x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims))
x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ typeof(x)(scale), d, dims))
Copy link
Contributor

@gaurav-arya gaurav-arya Mar 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better to make scale the appropriate array type at point of construction, rather than converting it here, to avoid unnecessary allocations?

@ziyiyin97
Copy link
Author

Thanks for your suggestion. I addressed the scale issue and rebased. Could you propose a test? I currently do not find any error on CPU arrays unfortunately

@gaurav-arya
Copy link
Contributor

How about an OffsetArray? Looks like broadcasting would fail without the fix:

julia> using OffsetArrays
julia> a = OffsetArray([1,2,3], 2:4)
3-element OffsetArray(::Vector{Int64}, 2:4) with eltype Int64 with indices 2:4:
 1
 2
 3

julia> b = [1,2,3]
3-element Vector{Int64}:
 1
 2
 3

julia> a ./ b
ERROR: DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 3 and 3
Stacktrace:
 [1] _bcs1
   @ ./broadcast.jl:529 [inlined]
 [2] _bcs
   @ ./broadcast.jl:523 [inlined]
 [3] broadcast_shape
   @ ./broadcast.jl:517 [inlined]
 [4] combine_axes
   @ ./broadcast.jl:512 [inlined]
 [5] instantiate
   @ ./broadcast.jl:294 [inlined]
 [6] materialize(bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{OffsetVector{Int64, Vector{Int64}}, Vector{Int64}}})
   @ Base.Broadcast ./broadcast.jl:873
 [7] top-level scope
   @ REPL[27]:1

Seems like you might just be able to add an offset array e.g. OffsetArray(randn(3), 2:4) to the list of arrays tested in the chain rules, and verify the test fails before the fix and passes now.

(Unfortunately, one can imagine a situation where your fix here actually fails, e.g. if x is some sort of one hot array type, the construction of scale will fail. But getting things working for a GPU seems like a much more important first step.)

@@ -30,10 +30,10 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
halfdim = first(dims)
d = size(x, halfdim)
n = size(y, halfdim)
scale = reshape(
scale = typeof(x)(reshape(
Copy link
Contributor

@gaurav-arya gaurav-arya Mar 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, would typeof(y) be better here? Since scale is broadcasted against y's adjoint.

(Since fft and friends don't preserve OffsetArrays in their output, this unfortunately means that the test I proposed won't work anymore: we'd need a case where fft does not produce a Vector, and I don't know of any other than GPU arrays. Maybe it's OK to satisfy ourselves with the existing tests passing for now?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's not too much trouble, an OffsetArrays test could still be good at add though: it would at least catch that it should be typeof(y) rather than typeof(x) here.

@ziyiyin97
Copy link
Author

I did

x = OffsetArray(randn(3), 2:4)
test_rrule(rfft, x, 1) # errors
test_rrule: rfft on OffsetVector{Int64, Vector{Int64}},Int64: Test Failed at /Users/ziyiyin/.julia/packages/ChainRulesTestUtils/lERVj/src/testers.jl:314
  Expression: ad_cotangent isa NoTangent
   Evaluated: [-8.2, -6.698780366869516, 5.598780366869516] isa NoTangent
Stacktrace:
 [1] macro expansion
   @ /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/Test/src/Test.jl:464 [inlined]
 [2] _test_cotangent(::NoTangent, ad_cotangent::Any, ::NoTangent; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/lERVj/src/testers.jl:314
Test Summary:                                                | Pass  Fail  Total  Time
test_rrule: rfft on OffsetVector{Int64, Vector{Int64}},Int64 |    6     1      7  0.0s
ERROR: Some tests did not pass: 6 passed, 1 failed, 0 errored, 0 broken.

Seems nasty and I am not quite sure how the chainrules testing package works in this case. Suggestions?

@gaurav-arya
Copy link
Contributor

gaurav-arya commented Mar 30, 2023

It might be fixable with some type piracy to FiniteDifferences.to_vec, but I'm not sure if it's worth the effort given that this is an imperfect test of the CuArray case anyway.

Should also note that the typeof(y)(...) is an imperfect solution, e.g. I think it would fail in the multidimensional case if y is in theory a static array with shape encoded in type. I don't think this happens in practice though (FFTW's outputs are always vectors, and y is only a different type for the GPU case afaik). Constructing a CPU array first and then converting it into typeof(y) also feels imperfect; I don't know if it has any performance consequences for GPU programming.

If this fixes the CUDA behaviour locally, it looks good to me to merge as is and improve later as needed; hopefully we would get CUDA tests of this behaviour in the near future. Someone with merge rights would need to sign off though.

@ziyiyin97
Copy link
Author

ziyiyin97 commented Mar 30, 2023

Yes it works on my end locally.

               _
   _       _ _(_)_     |  Documentation: https://docs.julialang.org
  (_)     | (_) (_)    |
   _ _   _| |_  __ _   |  Type "?" for help, "]?" for Pkg help.
  | | | | | | |/ _` |  |
  | | |_| | | | (_| |  |  Version 1.8.5 (2023-01-08)
 _/ |\__'_|_|_|\__'_|  |  Official https://julialang.org/ release
|__/                   |

(@v1.8) pkg> activate --temp
  Activating new project at `/tmp/jl_S6w38G`

(jl_S6w38G) pkg> add https://github.com/ziyiyin97/AbstractFFTs.jl.git
    Updating git-repo `https://github.com/ziyiyin97/AbstractFFTs.jl.git`
    Updating registry at `~/.julia/registries/General.toml`
   Resolving package versions...
    Updating `/tmp/jl_S6w38G/Project.toml`
  [621f4979] + AbstractFFTs v1.3.1 `https://github.com/ziyiyin97/AbstractFFTs.jl.git#master`
    Updating `/tmp/jl_S6w38G/Manifest.toml`
  [621f4979] + AbstractFFTs v1.3.1 `https://github.com/ziyiyin97/AbstractFFTs.jl.git#master`
  [d360d2e6] + ChainRulesCore v1.15.7
  [34da2185] + Compat v4.6.1
  [56f22d72] + Artifacts
  [ade2ca70] + Dates
  [8f399da3] + Libdl
  [37e2e46d] + LinearAlgebra
  [de0858da] + Printf
  [9a3f8284] + Random
  [ea8e919c] + SHA v0.7.0
  [9e88b42a] + Serialization
  [2f01184e] + SparseArrays
  [cf7118a7] + UUIDs
  [4ec0a83e] + Unicode
  [e66e0078] + CompilerSupportLibraries_jll v1.0.1+0
  [4536629a] + OpenBLAS_jll v0.3.20+0
  [8e850b90] + libblastrampoline_jll v5.1.1+0
Precompiling project...
  1 dependency successfully precompiled in 1 seconds. 5 already precompiled.

(jl_S6w38G) pkg> add CUDA, Flux, FFTW
   Resolving package versions...
   Installed CUDA_Driver_jll ──────── v0.5.0+0
   Installed CUDA_Runtime_Discovery ─ v0.2.0
   Installed CUDA_Runtime_jll ─────── v0.5.0+0
   Installed OrderedCollections ───── v1.6.0
   Installed CUDA ─────────────────── v4.1.2
    Updating `/tmp/jl_S6w38G/Project.toml`
  [052768ef] + CUDA v4.1.2
  [7a1cc6ca] + FFTW v1.6.0
  [587475ba] + Flux v0.13.14
    Updating `/tmp/jl_S6w38G/Manifest.toml`
  [7d9f7c33] + Accessors v0.1.28
  [79e6a3ab] + Adapt v3.6.1
  [dce04be8] + ArgCheck v2.3.0
  [a9b6321e] + Atomix v0.1.0
  [ab4f0b2a] + BFloat16s v0.4.2
  [198e06fe] + BangBang v0.3.37
  [9718e550] + Baselet v0.1.1
  [fa961155] + CEnum v0.4.2
  [052768ef] + CUDA v4.1.2
  [1af6417a] + CUDA_Runtime_Discovery v0.2.0
  [082447d4] + ChainRules v1.48.0
  [9e997f8a] + ChangesOfVariables v0.1.6
  [bbf7d656] + CommonSubexpressions v0.3.0
  [a33af91c] + CompositionsBase v0.1.1
  [187b0558] + ConstructionBase v1.5.1
  [6add18c4] + ContextVariablesX v0.1.3
  [9a962f9c] + DataAPI v1.14.0
  [864edb3b] + DataStructures v0.18.13
  [e2d170a0] + DataValueInterfaces v1.0.0
  [244e2a9f] + DefineSingletons v0.1.2
  [163ba53b] + DiffResults v1.1.0
  [b552c78f] + DiffRules v1.13.0
  [ffbed154] + DocStringExtensions v0.9.3
  [e2ba6199] + ExprTools v0.1.9
  [7a1cc6ca] + FFTW v1.6.0
  [cc61a311] + FLoops v0.2.1
  [b9860ae5] + FLoopsBase v0.1.1
  [1a297f60] + FillArrays v0.13.11
  [587475ba] + Flux v0.13.14
  [9c68100b] + FoldsThreads v0.1.1
  [f6369f11] + ForwardDiff v0.10.35
  [069b7b12] + FunctionWrappers v1.1.3
  [d9f16b24] + Functors v0.4.4
  [0c68f7d7] + GPUArrays v8.6.5
  [46192b85] + GPUArraysCore v0.1.4
  [61eb1bfa] + GPUCompiler v0.18.0
  [7869d1d1] + IRTools v0.4.9
  [22cec73e] + InitialValues v0.3.1
  [3587e190] + InverseFunctions v0.1.8
  [92d709cd] + IrrationalConstants v0.2.2
  [82899510] + IteratorInterfaceExtensions v1.0.0
  [692b3bcd] + JLLWrappers v1.4.1
  [b14d175d] + JuliaVariables v0.2.4
  [63c18a36] + KernelAbstractions v0.9.1
⌅ [929cbde3] + LLVM v4.17.1
  [2ab3a3ac] + LogExpFunctions v0.3.23
  [d8e11817] + MLStyle v0.4.17
  [f1d291b0] + MLUtils v0.4.1
  [1914dd2f] + MacroTools v0.5.10
  [128add7d] + MicroCollections v0.1.4
  [e1d29d7a] + Missings v1.1.0
  [872c559c] + NNlib v0.8.19
  [a00861dc] + NNlibCUDA v0.2.7
  [77ba4419] + NaNMath v1.0.2
  [71a1bf82] + NameResolution v0.1.5
  [0b1bfda6] + OneHotArrays v0.2.3
  [3bd65402] + Optimisers v0.2.17
  [bac558e1] + OrderedCollections v1.6.0
  [21216c6a] + Preferences v1.3.0
  [8162dcfd] + PrettyPrint v0.2.0
  [33c8b6b6] + ProgressLogging v0.1.4
  [74087812] + Random123 v1.6.0
  [e6cf234a] + RandomNumbers v1.5.3
  [c1ae055f] + RealDot v0.1.0
  [189a3867] + Reexport v1.2.2
  [ae029012] + Requires v1.3.0
  [efcf1570] + Setfield v1.1.1
  [605ecd9f] + ShowCases v0.1.0
  [699a6c99] + SimpleTraits v0.9.4
  [66db9d55] + SnoopPrecompile v1.0.3
  [a2af1166] + SortingAlgorithms v1.1.0
  [276daf66] + SpecialFunctions v2.2.0
  [171d559e] + SplittablesBase v0.1.15
  [90137ffa] + StaticArrays v1.5.19
  [1e83bf80] + StaticArraysCore v1.4.0
  [82ae8749] + StatsAPI v1.6.0
  [2913bbd2] + StatsBase v0.33.21
  [09ab397b] + StructArrays v0.6.15
  [3783bdb8] + TableTraits v1.0.1
  [bd369af6] + Tables v1.10.1
  [a759f4b9] + TimerOutputs v0.5.22
  [28d57a85] + Transducers v0.4.75
  [013be700] + UnsafeAtomics v0.2.1
⌃ [d80eeb9a] + UnsafeAtomicsLLVM v0.1.1
  [e88e6eb3] + Zygote v0.6.59
  [700de1a5] + ZygoteRules v0.2.3
  [02a925ec] + cuDNN v1.0.2
  [4ee394cb] + CUDA_Driver_jll v0.5.0+0
  [76a88914] + CUDA_Runtime_jll v0.5.0+0
  [62b44479] + CUDNN_jll v8.8.1+0
  [f5851436] + FFTW_jll v3.3.10+0
  [1d5cc7b8] + IntelOpenMP_jll v2018.0.3+2
⌅ [dad2f222] + LLVMExtra_jll v0.0.18+0
  [856f044c] + MKL_jll v2022.2.0+0
  [efe28fd5] + OpenSpecFun_jll v0.5.5+0
  [0dad84c5] + ArgTools v1.1.1
  [2a0f44e3] + Base64
  [8bb1440f] + DelimitedFiles
  [8ba89e20] + Distributed
  [f43a241f] + Downloads v1.6.0
  [7b1f6079] + FileWatching
  [9fa8497b] + Future
  [b77e0a4c] + InteractiveUtils
  [4af54fe1] + LazyArtifacts
  [b27032c2] + LibCURL v0.6.3
  [76f85450] + LibGit2
  [56ddb016] + Logging
  [d6f4376e] + Markdown
  [a63ad114] + Mmap
  [ca575930] + NetworkOptions v1.2.0
  [44cfe95a] + Pkg v1.8.0
  [3fa0cd96] + REPL
  [6462fe0b] + Sockets
  [10745b16] + Statistics
  [fa267f1f] + TOML v1.0.0
  [a4e569a6] + Tar v1.10.1
  [8dfed614] + Test
  [deac9b47] + LibCURL_jll v7.84.0+0
  [29816b5a] + LibSSH2_jll v1.10.2+0
  [c8ffd9c3] + MbedTLS_jll v2.28.0+0
  [14a3606d] + MozillaCACerts_jll v2022.2.1
  [05823500] + OpenLibm_jll v0.8.1+0
  [83775a58] + Zlib_jll v1.2.12+3
  [8e850ede] + nghttp2_jll v1.48.0+0
  [3f19e933] + p7zip_jll v17.4.0+0
        Info Packages marked with ⌃ and ⌅ have new versions available, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated -m`
Precompiling project...
  19 dependencies successfully precompiled in 56 seconds. 86 already precompiled.

julia> using CUDA, FFTW, Flux
[ Info: Precompiling FFTW [7a1cc6ca-52ef-59f5-83cd-3a7055c09341]

julia> x = CUDA.randn(3)
3-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
 -0.08069154
  0.49658614
 -0.9882421

julia> gradient(()->sum(abs.(rfft(x))), Flux.params(x))
Grads(...)

julia> y = rfft(x)
2-element CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}:
 -0.57234746f0 + 0.0f0im
  0.16513643f0 - 1.2858989f0im

julia> gradient(()->sum(abs.(irfft(y, 3))), Flux.params(x))
Grads(...)

julia> gradient(()->sum(abs.(brfft(y, 3))), Flux.params(x))
Grads(...)

@ziyiyin97
Copy link
Author

Could you please help me notify the reviewers with merge rights? Thanks again

@gaurav-arya
Copy link
Contributor

gaurav-arya commented Apr 3, 2023

One worry I have is whether the way scale is being constructed as a CPU array, converted to GPU, and then broadcasted would tank performance on the GPU. A simple test seems to suggest that it could lead to a significant overhead:

using CUDA

N = 100000000
x = CuArray(rand(N))

# f1 is similar to our current approach
function f1(x)
    y = typeof(x)([i == 7 ? 2 : 1 for i in 1:N])
    return x ./ y
end

# f2 does not explicitly construct a scale array
function f2(x)
    y = copy(x)
    y[7] *= 2
    return y
end

@time f1(x) # 0.633765 seconds (105.59 k allocations: 1.496 GiB, 3.71% gc time, 6.02% compilation time)
@time f2(x) # 0.007896 seconds (4.60 k allocations: 224.639 KiB, 73.60% compilation time)

An alternative might be to write out the division in the pullback without broadcasting, similar to f2 above. E.g. here's my attempt to write the RFFT pullback in a way that might be more GPU friendly:

function rfft_pullback(ȳ)
      dY = ChainRulesCore.unthunk(ȳ)
      dY_scaled = similar(dY)
      dY_scaled .= dY
      dY_scaled ./= 2
      selectdim(dY_scaled, halfdim, 1) .*= 2
      if 2 * (n - 1) == d
          selectdim(dY_scaled, halfdim, n) .*= 2
      end= project_x(brfft(dY_scaled, d, dims))
      return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
  end

I'm not at all experienced with CUDA array programming in Julia to be sure about this. It would also be nice if there were a way of keeping the "style" of the original broadcasting solution without having to allocate a CPU array. Pinging @maleadt for any more info:)

@gaurav-arya
Copy link
Contributor

gaurav-arya commented Apr 3, 2023

Here's an approach that keeps the original broadcasting style, but makes the scale array in what is hopefully a GPU friendly way:

# Make scaling array in a GPU friendly way
scale = similar(y, ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))))
scale .= 2
selectdim(scale, halfdim, 1) .= 1 
if 2 * (n - 1) == d
    selectdim(scale, halfdim, n) .= 1 
end

My worry is that the current approach in this PR would be slow on the GPU since it makes a large CPU array allocation within the rule -- @ziyiyin97 do you think the same, and if so could you use replace the construction of the scale array with something like one of the two proposed solutions?

cc'ing @devmotion as the author of the original scaling code.

@ziyiyin97
Copy link
Author

Sorry for the late response. I think the solution you proposed sounds great! I've replicated your approach to rfft, irfft, brfft. Let me know how it looks to you now.

@ziyiyin97
Copy link
Author

ziyiyin97 commented Apr 11, 2023

It seems that Julia 1.0 does not like it
ERROR: LoadError: LoadError: syntax: invalid assignment location "selectdim(dY, halfdim, 1)"

Comment on lines 36 to 40
dY = ChainRulesCore.unthunk(ȳ) ./ 2
selectdim(dY, halfdim, 1) .*= 2
if 2 * (n - 1) == d
selectdim(dY, halfdim, n) .*= 2
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems both inefficient (many unnecessary computations) and it breaks non-mutating arrays.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Any suggestion? I used to do the code block below in the previous commit but @gaurav-arya also made a good point that this type conversion might be slow in certain cases.

scale = typeof(y)(reshape(
        [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
        ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
    ))

Copy link
Contributor

@gaurav-arya gaurav-arya Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to write this code to be GPU compatible, so avoiding broadcasting with CPU arrays. That's why I divide the whole array by 2, then multiply some slices of the array by 2. I wasn't sure if what I did was the right approach, so I'd appreciate any feedback on how to write it better:)

Regarding the speed issue, I benchmarked the following code before and after this PR:

using FFTW
using ChainRulesCore

function tobenchmark(x, dims)
    y, pb = rrule(rfft, x, dims)
    return pb(y)
end

julia> @btime tobenchmark(rand(1000, 1000), 1:2);
13.913 ms (71 allocations: 45.85 MiB) [BEFORE]
13.897 ms (78 allocations: 45.84 MiB) [AFTER]

Regarding the mutable array issue, that's why I used similar in the code I originally suggested, which is semantically guaranteed to return a mutable array. I agree it's not a perfect solution for the immutable array case (perhaps using Adapt.jl or ArrayInterface.jl could help with that). But also, note that this about the type of the output array rather than the input, and afaik there is no existing case in the ecosystem where the output array is immutable: FFTW converts all CPU arrays to vectors. So it's not perfect, but it did seem to fix the CUDA case which the previous approach didn't support (and with similar it would even be correct for a hypothetical static array, although admittedly not an ideal approach) -- hopefully that helps explain my reasoning :)

@gaurav-arya
Copy link
Contributor

gaurav-arya commented Apr 11, 2023

@ziyiyin97 regarding the invalid assigment location, the following workaround seemed to work for me on Julia 1.0:

julia> x = rand(2,2)
2×2 Array{Float64,2}:
 0.000300982  0.405891
 0.903893     0.814312

julia> v = selectdim(x, 1, 2); # place view in a separate variable (workaround for Julia <1.2)

julia> v .+= 1
2-element view(::Array{Float64,2}, 2, :) with eltype Float64:
 1.9038934352514685
 1.8143118255443202

It looks like it's a bug that was fixed only on Julia 1.2: https://discourse.julialang.org/t/invalid-assignment-location-on-function-call-returning-view/23346/5. It only seems to appear for the .*= in your cases, the .= lines are safe.

dX_scaled = similar(dX)
dX_scaled .= dX
dX_scaled .*= 2
v = selectdim(dX_scaled, halfdim, 1)
Copy link
Contributor

@gaurav-arya gaurav-arya Apr 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a comment above this line saying something like like # assign view to a separate variable before assignment, to support Julia <1.2?

dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims) .* 2
selectdim(dX, halfdim, 1) ./= 2
dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)
# apply scaling
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe also add a comment saying something like # below approach is for ensuring GPU compatibility, see PR #96?

@@ -33,10 +33,12 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
project_x = ChainRulesCore.ProjectTo(x)
function rfft_pullback(ȳ)
dY = ChainRulesCore.unthunk(ȳ)
# apply scaling
# apply scaling; below approach is for GPU CuArray compatibility, see PR #96
dY_scaled = similar(dY)
dY_scaled .= dY
dY_scaled ./= 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could actually fuse this line with the previous one I think, e.g. dY_scaled .= dY ./ 2 and similarly for the other scalings

@ziyiyin97
Copy link
Author

Any other comments?

@gaurav-arya
Copy link
Contributor

I don't have any other comments, but this would need an approval from someone with merge rights, if @stevengj or @devmotion have time for a quick look?

@ziyiyin97
Copy link
Author

Any update?

@devmotion
Copy link
Member

#105 seems simpler.

@ziyiyin97
Copy link
Author

I think this conversion approach was discussed (and opposed) earlier in here #96 (comment) ... but I'm fine with it. I can close this PR if #105 is merged. Feel free to let me know.

@stevengj
Copy link
Member

Closed by #105

@stevengj stevengj closed this Jun 27, 2023
@gaurav-arya
Copy link
Contributor

gaurav-arya commented Aug 8, 2023

@stevengj @devmotion, could we revisit this? It does not suffer from either of the issues observed in #112, that were introduced by #105. I'd be happy to reopen a new PR for review if you think this is the right path. If there's a better solution, that would of course be wonderful too 🙂

@devmotion
Copy link
Member

I wonder if we could just use convert(typeof(y), ...) instead? This could (or rather should) even be computed outside of the pullback. I guess we always expect that ybar is compatible with y and this should avoid issues with "strange" pullback inputs such as Zygote.OneElement.

@gaurav-arya
Copy link
Contributor

I think that should fix the Zygote issue, which is the more important issue, so I'd be happy with that as a stopgap. I don't think it would handle the subarray issue though.

@gaurav-arya
Copy link
Contributor

gaurav-arya commented Aug 9, 2023

Actually, it looks like a reasonable solution, since it just relies on the output of the FFT function being convertable to. I've implemented it in #114.

Edit: This would fix reverse rules for non-plans (issue #115), but not reverse rules for plans (issue #112), because in the current design we delegate to adjoint_mul, which would still suffer from #112, and the trick of using the type of the primal output is not applicable within adjoint_mul. Adapt.jl is probably the right solution for solving #112 too, but I'll leave it to your judgement whether to add that dependency.

@devmotion
Copy link
Member

I think a solution without new dependencies would be preferable (hence I didn't suggest Adapt) but maybe it's not possible.

@gaurav-arya
Copy link
Contributor

gaurav-arya commented Aug 9, 2023

Ok. I think #114 is the right way forward as a first step, to first fix downstream. #112 at least is not going to be a regression in any case, because the rules for real plans in Zygote are currently incorrect

Edit: As for a solution without dependencies, the similar-based one here is a valid one. But I like it a bit less now than the OOP / more functional broadcasting solution, since it violates the principle of not unnecessarily using mutating code in non mutating operations, see https://github.com/SciML/SciMLStyle#functions-should-either-attempt-to-be-non-allocating-and-reuse-caches-or-treat-inputs-as-immutable)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

differentiating rfft on CuArray leads to error
4 participants