Skip to content

Commit

Permalink
Correct inverse plan caching logic in test plans
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav-arya committed Jul 2, 2022
1 parent 3e7d412 commit 5821ae4
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions test/testplans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,20 @@ end
function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...) where {T}
return InverseTestPlan{T}(region, size(x))
end

function AbstractFFTs.plan_inv(p::TestPlan{T}) where {T}
unscaled_pinv = InverseTestPlan{T}(p.region, p.sz)
unscaled_pinv.pinv = p
pinv = AbstractFFTs.ScaledPlan(
unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region),
)
N = AbstractFFTs.normalization(T, p.sz, p.region)
unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, N)
pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, N)
return pinv
end
function AbstractFFTs.plan_inv(p::InverseTestPlan{T}) where {T}
unscaled_pinv = TestPlan{T}(p.region, p.sz)
unscaled_pinv.pinv = p
pinv = AbstractFFTs.ScaledPlan(
unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region),
)
return pinv
function AbstractFFTs.plan_inv(pinv::InverseTestPlan{T}) where {T}
unscaled_p = TestPlan{T}(pinv.region, pinv.sz)
N = AbstractFFTs.normalization(T, pinv.sz, pinv.region)
unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, N)
p = AbstractFFTs.ScaledPlan(unscaled_p, N)
return p
end

# Just a helper function since forward and backward are nearly identical
Expand Down Expand Up @@ -118,22 +117,23 @@ function AbstractFFTs.plan_inv(p::TestRPlan{T,N}) where {T,N}
firstdim = first(p.region)::Int
d = p.sz[firstdim]
sz = ntuple(i -> i == firstdim ? d ÷ 2 + 1 : p.sz[i], Val(N))
_N = AbstractFFTs.normalization(T, p.sz, p.region)

unscaled_pinv = InverseTestRPlan{T}(d, p.region, sz)
unscaled_pinv.pinv = p
pinv = AbstractFFTs.ScaledPlan(
unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region),
)
unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, _N)
pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, _N)
return pinv
end
function AbstractFFTs.plan_inv(p::InverseTestRPlan{T,N}) where {T,N}
firstdim = first(p.region)::Int
sz = ntuple(i -> i == firstdim ? p.d : p.sz[i], Val(N))
unscaled_pinv = TestRPlan{T}(p.region, sz)
unscaled_pinv.pinv = p
pinv = AbstractFFTs.ScaledPlan(
unscaled_pinv, AbstractFFTs.normalization(T, sz, p.region),
)
return pinv

function AbstractFFTs.plan_inv(pinv::InverseTestRPlan{T,N}) where {T,N}
firstdim = first(pinv.region)::Int
sz = ntuple(i -> i == firstdim ? pinv.d : pinv.sz[i], Val(N))
_N = AbstractFFTs.normalization(T, sz, pinv.region)

unscaled_p = TestRPlan{T}(pinv.region, sz)
unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, _N)
p = AbstractFFTs.ScaledPlan(unscaled_p, _N)
return p
end

Base.size(p::TestRPlan) = p.sz
Expand Down

0 comments on commit 5821ae4

Please sign in to comment.