From e4da59e96220b40c0196900338f3b9ed727bd4c1 Mon Sep 17 00:00:00 2001 From: Ziyi Yin Date: Wed, 12 Apr 2023 11:25:29 -0400 Subject: [PATCH] add more comments --- ext/AbstractFFTsChainRulesCoreExt.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ext/AbstractFFTsChainRulesCoreExt.jl b/ext/AbstractFFTsChainRulesCoreExt.jl index 119c079..016f347 100644 --- a/ext/AbstractFFTsChainRulesCoreExt.jl +++ b/ext/AbstractFFTsChainRulesCoreExt.jl @@ -33,10 +33,12 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) project_x = ChainRulesCore.ProjectTo(x) function rfft_pullback(ȳ) dY = ChainRulesCore.unthunk(ȳ) - # apply scaling + # apply scaling; below approach is for GPU CuArray compatibility, see PR #96 dY_scaled = similar(dY) dY_scaled .= dY dY_scaled ./= 2 + # assign view to a separate variable before assignment, to support Julia <1.2 + # see https://github.com/JuliaLang/julia/issues/31295 v = selectdim(dY_scaled, halfdim, 1) v .*= 2 if 2 * (n - 1) == d @@ -80,10 +82,12 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) project_x = ChainRulesCore.ProjectTo(x) function irfft_pullback(ȳ) dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims) - # apply scaling + # apply scaling; below approach is for GPU CuArray compatibility, see PR #96 dX_scaled = similar(dX) dX_scaled .= dX dX_scaled .*= invN .* 2 + # assign view to a separate variable before assignment, to support Julia <1.2 + # see https://github.com/JuliaLang/julia/issues/31295 v = selectdim(dX_scaled, halfdim, 1) v ./= 2 if 2 * (n - 1) == d @@ -125,10 +129,12 @@ function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims) project_x = ChainRulesCore.ProjectTo(x) function brfft_pullback(ȳ) dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims) - # apply scaling + # apply scaling; below approach is for GPU CuArray compatibility, see PR #96 dX_scaled = similar(dX) dX_scaled .= dX dX_scaled .*= 2 + # assign view to a separate variable before assignment, to support Julia <1.2 + # see https://github.com/JuliaLang/julia/issues/31295 v = selectdim(dX_scaled, halfdim, 1) v ./= 2 if 2 * (n - 1) == d