Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

differentiating rfft on CuArray leads to error #95

Open
ziyiyin97 opened this issue Mar 27, 2023 · 0 comments
Open

differentiating rfft on CuArray leads to error #95

ziyiyin97 opened this issue Mar 27, 2023 · 0 comments

Comments

@ziyiyin97
Copy link

Redirected from FluxML/Zygote.jl#1406

gradient (based on rrule) on rfft (in FFTW.jl, doing Fast Fourier Transform for real-valued entities) leads to error. On the other hand, fft on CuArray, or rfft on CPU array, both run fine.

Pointed out by @ToucheSir, this comes from the fact that the AbstractFFTs rrule for rfft is unconditionally creating a CPU array and using it in https://github.com/JuliaMath/AbstractFFTs.jl/blob/v1.3.1/ext/AbstractFFTsChainRulesCoreExt.jl#L33-L40.

(@v1.8) pkg> activate --temp
  Activating new project at `/tmp/jl_yR7NT1`

(jl_yR7NT1) pkg> add CUDA, Flux, FFTW
    Updating registry at `~/.julia/registries/General.toml`
   Resolving package versions...
   Installed Optimisers ─ v0.2.17
   Installed CUDA ─────── v4.1.1
    Updating `/tmp/jl_yR7NT1/Project.toml`
  [052768ef] + CUDA v4.1.1
  [7a1cc6ca] + FFTW v1.6.0
  [587475ba] + Flux v0.13.14
    Updating `/tmp/jl_yR7NT1/Manifest.toml`
  [621f4979] + AbstractFFTs v1.3.1
  [7d9f7c33] + Accessors v0.1.28
  [79e6a3ab] + Adapt v3.6.1
  [dce04be8] + ArgCheck v2.3.0
  [a9b6321e] + Atomix v0.1.0
  [ab4f0b2a] + BFloat16s v0.4.2
  [198e06fe] + BangBang v0.3.37
  [9718e550] + Baselet v0.1.1
  [fa961155] + CEnum v0.4.2
  [052768ef] + CUDA v4.1.1
  [1af6417a] + CUDA_Runtime_Discovery v0.1.1
  [082447d4] + ChainRules v1.48.0
  [d360d2e6] + ChainRulesCore v1.15.7
  [9e997f8a] + ChangesOfVariables v0.1.6
  [bbf7d656] + CommonSubexpressions v0.3.0
  [34da2185] + Compat v4.6.1
  [a33af91c] + CompositionsBase v0.1.1
  [187b0558] + ConstructionBase v1.5.1
  [6add18c4] + ContextVariablesX v0.1.3
  [9a962f9c] + DataAPI v1.14.0
  [864edb3b] + DataStructures v0.18.13
  [e2d170a0] + DataValueInterfaces v1.0.0
  [244e2a9f] + DefineSingletons v0.1.2
  [163ba53b] + DiffResults v1.1.0
  [b552c78f] + DiffRules v1.13.0
  [ffbed154] + DocStringExtensions v0.9.3
  [e2ba6199] + ExprTools v0.1.9
  [7a1cc6ca] + FFTW v1.6.0
  [cc61a311] + FLoops v0.2.1
  [b9860ae5] + FLoopsBase v0.1.1
  [1a297f60] + FillArrays v0.13.10
  [587475ba] + Flux v0.13.14
  [9c68100b] + FoldsThreads v0.1.1
  [f6369f11] + ForwardDiff v0.10.35
  [069b7b12] + FunctionWrappers v1.1.3
  [d9f16b24] + Functors v0.4.3
  [0c68f7d7] + GPUArrays v8.6.5
  [46192b85] + GPUArraysCore v0.1.4
  [61eb1bfa] + GPUCompiler v0.18.0
  [7869d1d1] + IRTools v0.4.9
  [22cec73e] + InitialValues v0.3.1
  [3587e190] + InverseFunctions v0.1.8
  [92d709cd] + IrrationalConstants v0.2.2
  [82899510] + IteratorInterfaceExtensions v1.0.0
  [692b3bcd] + JLLWrappers v1.4.1
  [b14d175d] + JuliaVariables v0.2.4
  [63c18a36] + KernelAbstractions v0.9.1
  [929cbde3] + LLVM v4.17.1
  [2ab3a3ac] + LogExpFunctions v0.3.23
  [d8e11817] + MLStyle v0.4.17
  [f1d291b0] + MLUtils v0.4.1
  [1914dd2f] + MacroTools v0.5.10
  [128add7d] + MicroCollections v0.1.4
  [e1d29d7a] + Missings v1.1.0
  [872c559c] + NNlib v0.8.19
  [a00861dc] + NNlibCUDA v0.2.7
  [77ba4419] + NaNMath v1.0.2
  [71a1bf82] + NameResolution v0.1.5
  [0b1bfda6] + OneHotArrays v0.2.3
  [3bd65402] + Optimisers v0.2.17
  [bac558e1] + OrderedCollections v1.4.1
  [21216c6a] + Preferences v1.3.0
  [8162dcfd] + PrettyPrint v0.2.0
  [33c8b6b6] + ProgressLogging v0.1.4
  [74087812] + Random123 v1.6.0
  [e6cf234a] + RandomNumbers v1.5.3
  [c1ae055f] + RealDot v0.1.0
  [189a3867] + Reexport v1.2.2
  [ae029012] + Requires v1.3.0
  [efcf1570] + Setfield v1.1.1
  [605ecd9f] + ShowCases v0.1.0
  [699a6c99] + SimpleTraits v0.9.4
  [66db9d55] + SnoopPrecompile v1.0.3
  [a2af1166] + SortingAlgorithms v1.1.0
  [276daf66] + SpecialFunctions v2.2.0
  [171d559e] + SplittablesBase v0.1.15
  [90137ffa] + StaticArrays v1.5.19
  [1e83bf80] + StaticArraysCore v1.4.0
  [82ae8749] + StatsAPI v1.5.0
  [2913bbd2] + StatsBase v0.33.21
  [09ab397b] + StructArrays v0.6.15
  [3783bdb8] + TableTraits v1.0.1
  [bd369af6] + Tables v1.10.1
  [a759f4b9] + TimerOutputs v0.5.22
  [28d57a85] + Transducers v0.4.75
  [013be700] + UnsafeAtomics v0.2.1
  [d80eeb9a] + UnsafeAtomicsLLVM v0.1.0
  [e88e6eb3] + Zygote v0.6.59
  [700de1a5] + ZygoteRules v0.2.3
  [02a925ec] + cuDNN v1.0.2
⌅ [4ee394cb] + CUDA_Driver_jll v0.4.0+2
  [76a88914] + CUDA_Runtime_jll v0.4.0+2
  [62b44479] + CUDNN_jll v8.8.1+0
  [f5851436] + FFTW_jll v3.3.10+0
  [1d5cc7b8] + IntelOpenMP_jll v2018.0.3+2
⌅ [dad2f222] + LLVMExtra_jll v0.0.18+0
  [856f044c] + MKL_jll v2022.2.0+0
  [efe28fd5] + OpenSpecFun_jll v0.5.5+0
  [0dad84c5] + ArgTools v1.1.1
  [56f22d72] + Artifacts
  [2a0f44e3] + Base64
  [ade2ca70] + Dates
  [8bb1440f] + DelimitedFiles
  [8ba89e20] + Distributed
  [f43a241f] + Downloads v1.6.0
  [7b1f6079] + FileWatching
  [9fa8497b] + Future
  [b77e0a4c] + InteractiveUtils
  [4af54fe1] + LazyArtifacts
  [b27032c2] + LibCURL v0.6.3
  [76f85450] + LibGit2
  [8f399da3] + Libdl
  [37e2e46d] + LinearAlgebra
  [56ddb016] + Logging
  [d6f4376e] + Markdown
  [a63ad114] + Mmap
  [ca575930] + NetworkOptions v1.2.0
  [44cfe95a] + Pkg v1.8.0
  [de0858da] + Printf
  [3fa0cd96] + REPL
  [9a3f8284] + Random
  [ea8e919c] + SHA v0.7.0
  [9e88b42a] + Serialization
  [6462fe0b] + Sockets
  [2f01184e] + SparseArrays
  [10745b16] + Statistics
  [fa267f1f] + TOML v1.0.0
  [a4e569a6] + Tar v1.10.1
  [8dfed614] + Test
  [cf7118a7] + UUIDs
  [4ec0a83e] + Unicode
  [e66e0078] + CompilerSupportLibraries_jll v1.0.1+0
  [deac9b47] + LibCURL_jll v7.84.0+0
  [29816b5a] + LibSSH2_jll v1.10.2+0
  [c8ffd9c3] + MbedTLS_jll v2.28.0+0
  [14a3606d] + MozillaCACerts_jll v2022.2.1
  [4536629a] + OpenBLAS_jll v0.3.20+0
  [05823500] + OpenLibm_jll v0.8.1+0
  [83775a58] + Zlib_jll v1.2.12+3
  [8e850b90] + libblastrampoline_jll v5.1.1+0
  [8e850ede] + nghttp2_jll v1.48.0+0
  [3f19e933] + p7zip_jll v17.4.0+0
        Info Packages marked with ⌅ have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated -m`
Precompiling project...
  5 dependencies successfully precompiled in 54 seconds. 100 already precompiled.

julia> using CUDA, FFTW, Flux

julia> x = CUDA.randn(3)
3-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
 0.19325934
 0.55793864
 0.08928435

julia> gradient(()->sum(abs.(fft(x))), Flux.params(x)) # this works
Grads(...)

julia> gradient(()->sum(abs.(rfft(x))), Flux.params(x))
ERROR: GPU compilation of broadcast_kernel(CUDA.CuKernelContext, CuDeviceVector{ComplexF32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(/), Tuple{Base.Broadcast.Extruded{CuDeviceVector{ComplexF32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Vector{Int64}, Tuple{Bool}, Tuple{Int64}}}}, Int64) in world 32592 failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(/), Tuple{Base.Broadcast.Extruded{CuDeviceVector{ComplexF32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Vector{Int64}, Tuple{Bool}, Tuple{Int64}}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Extruded{CuDeviceVector{ComplexF32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Vector{Int64}, Tuple{Bool}, Tuple{Int64}}} which is not isbits.
    .2 is of type Base.Broadcast.Extruded{Vector{Int64}, Tuple{Bool}, Tuple{Int64}} which is not isbits.
      .x is of type Vector{Int64} which is not isbits.


Stacktrace:
  [1] check_invocation(job::GPUCompiler.CompilerJob)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/anMCs/src/validation.jl:101
  [2] macro expansion
    @ ~/.julia/packages/GPUCompiler/anMCs/src/driver.jl:154 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/TimerOutputs/LHjFw/src/TimerOutput.jl:253 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/GPUCompiler/anMCs/src/driver.jl:152 [inlined]
  [5] emit_julia(job::GPUCompiler.CompilerJob; validate::Bool)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/anMCs/src/utils.jl:83
  [6] emit_julia
    @ ~/.julia/packages/GPUCompiler/anMCs/src/utils.jl:77 [inlined]
  [7] compile(job::GPUCompiler.CompilerJob, ctx::LLVM.Context)
    @ CUDA ~/.julia/packages/CUDA/N71Iw/src/compiler/compilation.jl:105
  [8] #203
    @ ~/.julia/packages/CUDA/N71Iw/src/compiler/compilation.jl:100 [inlined]
  [9] JuliaContext(f::CUDA.var"#203#204"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/anMCs/src/driver.jl:76
 [10] compile
    @ ~/.julia/packages/CUDA/N71Iw/src/compiler/compilation.jl:99 [inlined]
 [11] actual_compilation(cache::Dict{UInt64, Any}, key::UInt64, cfg::GPUCompiler.CompilerConfig{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, ft::Type, tt::Type, world::UInt64, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/anMCs/src/cache.jl:184
 [12] cached_compilation(cache::Dict{UInt64, Any}, cfg::GPUCompiler.CompilerConfig{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, ft::Type, tt::Type, compiler::Function, linker::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/anMCs/src/cache.jl:163
 [13] macro expansion
    @ ~/.julia/packages/CUDA/N71Iw/src/compiler/execution.jl:310 [inlined]
 [14] macro expansion
    @ ./lock.jl:223 [inlined]
 [15] cufunction(f::GPUArrays.var"#broadcast_kernel#28", tt::Type{Tuple{CUDA.CuKernelContext, CuDeviceVector{ComplexF32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(/), Tuple{Base.Broadcast.Extruded{CuDeviceVector{ComplexF32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Vector{Int64}, Tuple{Bool}, Tuple{Int64}}}}, Int64}}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ CUDA ~/.julia/packages/CUDA/N71Iw/src/compiler/execution.jl:306
 [16] cufunction
    @ ~/.julia/packages/CUDA/N71Iw/src/compiler/execution.jl:303 [inlined]
 [17] macro expansion
    @ ~/.julia/packages/CUDA/N71Iw/src/compiler/execution.jl:104 [inlined]
 [18] #launch_heuristic#244
    @ ~/.julia/packages/CUDA/N71Iw/src/gpuarrays.jl:17 [inlined]
 [19] _copyto!
    @ ~/.julia/packages/GPUArrays/XR4WO/src/host/broadcast.jl:65 [inlined]
 [20] copyto!
    @ ~/.julia/packages/GPUArrays/XR4WO/src/host/broadcast.jl:46 [inlined]
 [21] copy
    @ ~/.julia/packages/GPUArrays/XR4WO/src/host/broadcast.jl:37 [inlined]
 [22] materialize
    @ ./broadcast.jl:860 [inlined]
 [23] (::AbstractFFTs.AbstractFFTsChainRulesCoreExt.var"#rfft_pullback#6"{UnitRange{Int64}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Vector{Int64}, Int64})(ȳ::CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer})
    @ AbstractFFTs.AbstractFFTsChainRulesCoreExt ~/.julia/packages/AbstractFFTs/0uOAT/ext/AbstractFFTsChainRulesCoreExt.jl:40
 [24] ZBack
    @ ~/.julia/packages/Zygote/TSj5C/src/compiler/chainrules.jl:211 [inlined]
 [25] Pullback
    @ ~/.julia/packages/AbstractFFTs/0uOAT/src/definitions.jl:62 [inlined]
 [26] Pullback
    @ ./REPL[8]:1 [inlined]
 [27] (::Zygote.Pullback{Tuple{var"#5#6"}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(rfft), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(ndims), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.ZBack{AbstractFFTs.AbstractFFTsChainRulesCoreExt.var"#rfft_pullback#6"{UnitRange{Int64}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Vector{Int64}, Int64}}, Zygote.ZBack{ChainRules.var"#:_pullback#275"{Tuple{Int64, Int64}}}}}, Zygote.var"#4160#back#1438"{Zygote.var"#1434#1437"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(abs), CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1169"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2841#back#683"{Zygote.var"#map_back#677"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4128#back#1421"{Zygote.var"#bc_fwd_back#1409"{1, CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer}, Tuple{CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Val{1}}}}}, Zygote.var"#1982#back#200"{typeof(identity)}}}, Zygote.var"#1955#back#190"{Zygote.var"#186#189"{Zygote.Context{true}, GlobalRef, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TSj5C/src/compiler/interface2.jl:0
 [28] (::Zygote.var"#118#119"{Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, Zygote.Pullback{Tuple{var"#5#6"}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(rfft), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(ndims), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.ZBack{AbstractFFTs.AbstractFFTsChainRulesCoreExt.var"#rfft_pullback#6"{UnitRange{Int64}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Vector{Int64}, Int64}}, Zygote.ZBack{ChainRules.var"#:_pullback#275"{Tuple{Int64, Int64}}}}}, Zygote.var"#4160#back#1438"{Zygote.var"#1434#1437"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(abs), CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1169"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2841#back#683"{Zygote.var"#map_back#677"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4128#back#1421"{Zygote.var"#bc_fwd_back#1409"{1, CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer}, Tuple{CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Val{1}}}}}, Zygote.var"#1982#back#200"{typeof(identity)}}}, Zygote.var"#1955#back#190"{Zygote.var"#186#189"{Zygote.Context{true}, GlobalRef, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, Zygote.Context{true}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TSj5C/src/compiler/interface.jl:389
 [29] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/TSj5C/src/compiler/interface.jl:97
 [30] top-level scope
    @ REPL[8]:1
 [31] top-level scope
    @ ~/.julia/packages/CUDA/N71Iw/src/initialization.jl:163
ziyiyin97 added a commit to ziyiyin97/AbstractFFTs.jl that referenced this issue Mar 28, 2023
ziyiyin97 added a commit to ziyiyin97/AbstractFFTs.jl that referenced this issue Mar 28, 2023
ziyiyin97 added a commit to ziyiyin97/AbstractFFTs.jl that referenced this issue Mar 28, 2023
ziyiyin97 added a commit to ziyiyin97/AbstractFFTs.jl that referenced this issue Mar 29, 2023
ziyiyin97 added a commit to ziyiyin97/AbstractFFTs.jl that referenced this issue Mar 30, 2023
ziyiyin97 added a commit to ziyiyin97/AbstractFFTs.jl that referenced this issue Mar 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant