Skip to content

Commit

Permalink
Merge branch 'master' into adjoint
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Jun 30, 2023
2 parents fe3b06a + b5109aa commit 266c88f
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 43 deletions.
7 changes: 7 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/" # Location of package manifests
schedule:
interval: "weekly"
17 changes: 4 additions & 13 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,23 @@ jobs:
version:
- '1.0'
- '1'
# - 'nightly'
- 'nightly'
os:
- ubuntu-latest
- macOS-latest
- windows-latest
arch:
- x64
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: actions/cache@v1
env:
cache-name: cache-artifacts
with:
path: ~/.julia/artifacts
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
restore-keys: |
${{ runner.os }}-test-${{ env.cache-name }}-
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v1
- uses: codecov/codecov-action@v3
with:
file: lcov.info
6 changes: 3 additions & 3 deletions .github/workflows/Documenter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ jobs:
name: Documentation
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/julia-docdeploy@latest
- uses: actions/checkout@v3
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-docdeploy@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
7 changes: 4 additions & 3 deletions .github/workflows/IntegrationTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@ jobs:
os: [ubuntu-latest]
package:
- {user: JuliaMath, repo: FFTW.jl}
- {user: JuliaApproximation, repo: FastTransforms.jl}
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.julia-version }}
arch: x64
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/julia-buildpkg@v1
- name: Clone Downstream
uses: actions/checkout@v2
uses: actions/checkout@v3
with:
repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
path: downstream
Expand Down
40 changes: 40 additions & 0 deletions .github/workflows/Invalidations.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: Invalidations

on:
pull_request:

concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: always.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
evaluate:
# Only run on PRs to the default branch.
# In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch
if: github.base_ref == github.event.repository.default_branch
runs-on: ubuntu-latest
steps:
- uses: julia-actions/setup-julia@v1
with:
version: '1'
- uses: actions/checkout@v3
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-invalidations@v1
id: invs_pr

- uses: actions/checkout@v3
with:
ref: ${{ github.event.repository.default_branch }}
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-invalidations@v1
id: invs_default

- name: Report invalidation counts
run: |
echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY
echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY
- name: Check whether the number of invalidations increased
if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total
run: exit 1
11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
name = "AbstractFFTs"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.2.1"
version = "1.3.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[extensions]
AbstractFFTsChainRulesCoreExt = "ChainRulesCore"

[compat]
ChainRulesCore = "1"
julia = "^1.0"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["ChainRulesTestUtils", "FiniteDifferences", "Random", "Test", "Unitful"]
test = ["ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Random", "Test", "Unitful"]
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

A general framework for fast Fourier transforms (FFTs) in Julia.

[![Travis](https://travis-ci.org/JuliaMath/AbstractFFTs.jl.svg?branch=master)](https://travis-ci.org/JuliaMath/AbstractFFTs.jl)
[![Coveralls](https://coveralls.io/repos/github/JuliaMath/AbstractFFTs.jl/badge.svg?branch=master)](https://coveralls.io/github/JuliaMath/AbstractFFTs.jl?branch=master)
[![GHA](https://github.com/JuliaMath/AbstractFFTs.jl/workflows/CI/badge.svg)](https://github.com/JuliaMath/AbstractFFTs.jl/actions?query=workflow%3ACI+branch%3Amaster)
[![Codecov](http://codecov.io/github/JuliaMath/AbstractFFTs.jl/coverage.svg?branch=master)](http://codecov.io/github/JuliaMath/AbstractFFTs.jl?branch=master)

Documentation:
[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://JuliaMath.github.io/AbstractFFTs.jl/stable)
[![](https://img.shields.io/badge/docs-latest-blue.svg)](https://JuliaMath.github.io/AbstractFFTs.jl/latest)
[![](https://img.shields.io/badge/docs-latest-blue.svg)](https://JuliaMath.github.io/AbstractFFTs.jl/dev)

This package is mainly not intended to be used directly.
Instead, developers of packages that implement FFTs (such as [FFTW.jl](https://github.com/JuliaMath/FFTW.jl) or [FastTransforms.jl](https://github.com/JuliaApproximation/FastTransforms.jl))
Expand All @@ -17,3 +17,4 @@ This allows multiple FFT packages to co-exist with the same underlying `fft(x)`
## Developer information

To define a new FFT implementation in your own module, see [defining a new implementation](https://juliamath.github.io/AbstractFFTs.jl/stable/implementations/#Defining-a-new-implementation).

31 changes: 21 additions & 10 deletions src/chainrules.jl → ext/AbstractFFTsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# ffts
module AbstractFFTsChainRulesCoreExt

using AbstractFFTs
import ChainRulesCore

function ChainRulesCore.frule((_, Δx, _), ::typeof(fft), x::AbstractArray, dims)
y = fft(x, dims)
Δy = fft(Δx, dims)
Expand Down Expand Up @@ -33,7 +37,9 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)

project_x = ChainRulesCore.ProjectTo(x)
function rfft_pullback(ȳ)
= project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims))
ybar = ChainRulesCore.unthunk(ȳ)
_scale = convert(typeof(ybar),scale)
= project_x(brfft(ybar ./ _scale, d, dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
end
return y, rfft_pullback
Expand All @@ -46,7 +52,7 @@ function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dim
end
function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims)
y = ifft(x, dims)
invN = normalization(y, dims)
invN = AbstractFFTs.normalization(y, dims)
project_x = ChainRulesCore.ProjectTo(x)
function ifft_pullback(ȳ)
= project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims))
Expand All @@ -66,7 +72,7 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
# compute scaling factors
halfdim = first(dims)
n = size(x, halfdim)
invN = normalization(y, dims)
invN = AbstractFFTs.normalization(y, dims)
twoinvN = 2 * invN
scale = reshape(
[i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n],
Expand All @@ -75,7 +81,9 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)

project_x = ChainRulesCore.ProjectTo(x)
function irfft_pullback(ȳ)
= project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims))
ybar = ChainRulesCore.unthunk(ȳ)
_scale = convert(typeof(ybar),scale)
= project_x(_scale .* rfft(real.(ybar), dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
end
return y, irfft_pullback
Expand Down Expand Up @@ -152,12 +160,12 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims)
end

# plans
function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::Plan, x::AbstractArray)
function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray)
y = P * x
Δy = P * Δx
return y, Δy
end
function ChainRulesCore.rrule(::typeof(*), P::Plan, x::AbstractArray)
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray)
y = P * x
project_x = ChainRulesCore.ProjectTo(x)
Pt = P'
Expand All @@ -168,22 +176,25 @@ function ChainRulesCore.rrule(::typeof(*), P::Plan, x::AbstractArray)
return y, mul_plan_pullback
end

function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::ScaledPlan, x::AbstractArray)
function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray)
y = P * x
Δy = P * Δx .+ (ΔP.scale / P.scale) .* y
return y, Δy
end
function ChainRulesCore.rrule(::typeof(*), P::ScaledPlan, x::AbstractArray)
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray)
y = P * x
Pt = P'
scale = P.scale
project_x = ChainRulesCore.ProjectTo(x)
project_scale = ChainRulesCore.ProjectTo(scale)
function mul_scaledplan_pullback(ȳ)
= ChainRulesCore.@thunk(project_x(Pt * ȳ))
scale_tangent = ChainRulesCore.@thunk(project_scale(dot(y, ȳ) / conj(scale)))
scale_tangent = ChainRulesCore.@thunk(project_scale(AbstractFFTs.dot(y, ȳ) / conj(scale)))
plan_tangent = ChainRulesCore.Tangent{typeof(P)}(;p=ChainRulesCore.NoTangent(), scale=scale_tangent)
return ChainRulesCore.NoTangent(), plan_tangent, x̄
end
return y, mul_scaledplan_pullback
end

end # module

7 changes: 4 additions & 3 deletions src/AbstractFFTs.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
module AbstractFFTs

import ChainRulesCore

export fft, ifft, bfft, fft!, ifft!, bfft!,
plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!,
rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft,
fftdims, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq

include("definitions.jl")
include("chainrules.jl")

if !isdefined(Base, :get_extension)
include("../ext/AbstractFFTsChainRulesCoreExt.jl")
end

end # module
12 changes: 6 additions & 6 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ _to1(::Tuple, x) = copy1(eltype(x), x)
for f in (:fft, :bfft, :ifft, :fft!, :bfft!, :ifft!, :rfft)
pf = Symbol("plan_", f)
@eval begin
$f(x::AbstractArray) = (y = to1(x); $pf(y) * y)
$f(x::AbstractArray) = $f(x, 1:ndims(x))
$f(x::AbstractArray, region) = (y = to1(x); $pf(y, region) * y)
$pf(x::AbstractArray; kws...) = (y = to1(x); $pf(y, 1:ndims(y); kws...))
end
Expand Down Expand Up @@ -208,9 +208,9 @@ bfft!
for f in (:fft, :bfft, :ifft)
pf = Symbol("plan_", f)
@eval begin
$f(x::AbstractArray{<:Real}, region=1:ndims(x)) = $f(complexfloat(x), region)
$f(x::AbstractArray{<:Real}, region) = $f(complexfloat(x), region)
$pf(x::AbstractArray{<:Real}, region; kws...) = $pf(complexfloat(x), region; kws...)
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region=1:ndims(x)) = $f(complexfloat(x), region)
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region) = $f(complexfloat(x), region)
$pf(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region; kws...) = $pf(complexfloat(x), region; kws...)
end
end
Expand Down Expand Up @@ -299,16 +299,16 @@ LinearAlgebra.mul!(y::AbstractArray, p::ScaledPlan, x::AbstractArray) =
for f in (:brfft, :irfft)
pf = Symbol("plan_", f)
@eval begin
$f(x::AbstractArray, d::Integer) = $pf(x, d) * x
$f(x::AbstractArray, d::Integer) = $f(x, d, 1:ndims(x))
$f(x::AbstractArray, d::Integer, region) = $pf(x, d, region) * x
$pf(x::AbstractArray, d::Integer;kws...) = $pf(x, d, 1:ndims(x);kws...)
end
end

for f in (:brfft, :irfft)
@eval begin
$f(x::AbstractArray{<:Real}, d::Integer, region=1:ndims(x)) = $f(complexfloat(x), d, region)
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region=1:ndims(x)) = $f(complexfloat(x), d, region)
$f(x::AbstractArray{<:Real}, d::Integer, region) = $f(complexfloat(x), d, region)
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region) = $f(complexfloat(x), d, region)
end
end

Expand Down
27 changes: 27 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,33 @@ end
end
end

# Test that dims defaults to 1:ndims for fft-like functions
@testset "Default dims" begin
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
N = ndims(x)
complex_x = complex.(x)
@test fft(x) fft(x, 1:N)
@test ifft(x) ifft(x, 1:N)
@test bfft(x) bfft(x, 1:N)
@test rfft(x) rfft(x, 1:N)
d = 2 * size(x, 1) - 1
@test irfft(x, d) irfft(x, d, 1:N)
@test brfft(x, d) brfft(x, d, 1:N)
end
end

@testset "Complex float promotion" begin
for x in (rand(-5:5, 3), rand(-5:5, 3, 4), rand(-5:5, 3, 4, 5))
N = ndims(x)
@test fft(x) fft(complex.(x)) fft(complex.(float.(x)))
@test ifft(x) ifft(complex.(x)) ifft(complex.(float.(x)))
@test bfft(x) bfft(complex.(x)) bfft(complex.(float.(x)))
d = 2 * size(x, 1) - 1
@test irfft(x, d) irfft(complex.(x), d) irfft(complex.(float.(x)), d)
@test brfft(x, d) brfft(complex.(x), d) brfft(complex.(float.(x)), d)
end
end

@testset "ChainRules" begin
@testset "shift functions" begin
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
Expand Down

0 comments on commit 266c88f

Please sign in to comment.