Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ziyiyin97 committed Mar 30, 2023
1 parent a25656d commit e0e667f
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions ext/AbstractFFTsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
halfdim = first(dims)
d = size(x, halfdim)
n = size(y, halfdim)
scale = reshape(
scale = typeof(y)(reshape(
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
)
))

project_x = ChainRulesCore.ProjectTo(x)
function rfft_pullback(ȳ)
Expand Down Expand Up @@ -72,10 +72,10 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
n = size(x, halfdim)
invN = AbstractFFTs.normalization(y, dims)
twoinvN = 2 * invN
scale = reshape(
scale = typeof(y)(reshape(
[i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n],
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
)
))

project_x = ChainRulesCore.ProjectTo(x)
function irfft_pullback(ȳ)
Expand Down Expand Up @@ -111,10 +111,10 @@ function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims)
# compute scaling factors
halfdim = first(dims)
n = size(x, halfdim)
scale = reshape(
scale = typeof(y)(reshape(
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
)
))

project_x = ChainRulesCore.ProjectTo(x)
function brfft_pullback(ȳ)
Expand Down

0 comments on commit e0e667f

Please sign in to comment.