Skip to content

Commit

Permalink
add more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ziyiyin97 committed Apr 12, 2023
1 parent dc2cfab commit e4da59e
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions ext/AbstractFFTsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
project_x = ChainRulesCore.ProjectTo(x)
function rfft_pullback(ȳ)
dY = ChainRulesCore.unthunk(ȳ)
# apply scaling
# apply scaling; below approach is for GPU CuArray compatibility, see PR #96
dY_scaled = similar(dY)
dY_scaled .= dY
dY_scaled ./= 2
# assign view to a separate variable before assignment, to support Julia <1.2
# see https://github.com/JuliaLang/julia/issues/31295
v = selectdim(dY_scaled, halfdim, 1)
v .*= 2
if 2 * (n - 1) == d
Expand Down Expand Up @@ -80,10 +82,12 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
project_x = ChainRulesCore.ProjectTo(x)
function irfft_pullback(ȳ)
dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)
# apply scaling
# apply scaling; below approach is for GPU CuArray compatibility, see PR #96
dX_scaled = similar(dX)
dX_scaled .= dX
dX_scaled .*= invN .* 2
# assign view to a separate variable before assignment, to support Julia <1.2
# see https://github.com/JuliaLang/julia/issues/31295
v = selectdim(dX_scaled, halfdim, 1)
v ./= 2
if 2 * (n - 1) == d
Expand Down Expand Up @@ -125,10 +129,12 @@ function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims)
project_x = ChainRulesCore.ProjectTo(x)
function brfft_pullback(ȳ)
dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)
# apply scaling
# apply scaling; below approach is for GPU CuArray compatibility, see PR #96
dX_scaled = similar(dX)
dX_scaled .= dX
dX_scaled .*= 2
# assign view to a separate variable before assignment, to support Julia <1.2
# see https://github.com/JuliaLang/julia/issues/31295
v = selectdim(dX_scaled, halfdim, 1)
v ./= 2
if 2 * (n - 1) == d
Expand Down

0 comments on commit e4da59e

Please sign in to comment.