Skip to content

Commit

Permalink
Correct inverse plan logic (#69)
Browse files Browse the repository at this point in the history
* Correct inverse plan caching logic in test plans

* Use inv rather than plan_inv in scaled plan
  • Loading branch information
gaurav-arya authored Jul 3, 2022
1 parent 3e7d412 commit 7d34bf2
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ plan_ifft(x::AbstractArray, region; kws...) =
plan_ifft!(x::AbstractArray, region; kws...) =
ScaledPlan(plan_bfft!(x, region; kws...), normalization(x, region))

plan_inv(p::ScaledPlan) = ScaledPlan(plan_inv(p.p), inv(p.scale))
plan_inv(p::ScaledPlan) = ScaledPlan(inv(p.p), inv(p.scale))

LinearAlgebra.mul!(y::AbstractArray, p::ScaledPlan, x::AbstractArray) =
LinearAlgebra.lmul!(p.scale, LinearAlgebra.mul!(y, p.p, x))
Expand Down
48 changes: 24 additions & 24 deletions test/testplans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,20 @@ end
function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...) where {T}
return InverseTestPlan{T}(region, size(x))
end

function AbstractFFTs.plan_inv(p::TestPlan{T}) where {T}
unscaled_pinv = InverseTestPlan{T}(p.region, p.sz)
unscaled_pinv.pinv = p
pinv = AbstractFFTs.ScaledPlan(
unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region),
)
N = AbstractFFTs.normalization(T, p.sz, p.region)
unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, N)
pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, N)
return pinv
end
function AbstractFFTs.plan_inv(p::InverseTestPlan{T}) where {T}
unscaled_pinv = TestPlan{T}(p.region, p.sz)
unscaled_pinv.pinv = p
pinv = AbstractFFTs.ScaledPlan(
unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region),
)
return pinv
function AbstractFFTs.plan_inv(pinv::InverseTestPlan{T}) where {T}
unscaled_p = TestPlan{T}(pinv.region, pinv.sz)
N = AbstractFFTs.normalization(T, pinv.sz, pinv.region)
unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, N)
p = AbstractFFTs.ScaledPlan(unscaled_p, N)
return p
end

# Just a helper function since forward and backward are nearly identical
Expand Down Expand Up @@ -118,22 +117,23 @@ function AbstractFFTs.plan_inv(p::TestRPlan{T,N}) where {T,N}
firstdim = first(p.region)::Int
d = p.sz[firstdim]
sz = ntuple(i -> i == firstdim ? d ÷ 2 + 1 : p.sz[i], Val(N))
_N = AbstractFFTs.normalization(T, p.sz, p.region)

unscaled_pinv = InverseTestRPlan{T}(d, p.region, sz)
unscaled_pinv.pinv = p
pinv = AbstractFFTs.ScaledPlan(
unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region),
)
unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, _N)
pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, _N)
return pinv
end
function AbstractFFTs.plan_inv(p::InverseTestRPlan{T,N}) where {T,N}
firstdim = first(p.region)::Int
sz = ntuple(i -> i == firstdim ? p.d : p.sz[i], Val(N))
unscaled_pinv = TestRPlan{T}(p.region, sz)
unscaled_pinv.pinv = p
pinv = AbstractFFTs.ScaledPlan(
unscaled_pinv, AbstractFFTs.normalization(T, sz, p.region),
)
return pinv

function AbstractFFTs.plan_inv(pinv::InverseTestRPlan{T,N}) where {T,N}
firstdim = first(pinv.region)::Int
sz = ntuple(i -> i == firstdim ? pinv.d : pinv.sz[i], Val(N))
_N = AbstractFFTs.normalization(T, sz, pinv.region)

unscaled_p = TestRPlan{T}(pinv.region, sz)
unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, _N)
p = AbstractFFTs.ScaledPlan(unscaled_p, _N)
return p
end

Base.size(p::TestRPlan) = p.sz
Expand Down

0 comments on commit 7d34bf2

Please sign in to comment.