From 20458e42a26640e17a8d0414f296ec26f754b268 Mon Sep 17 00:00:00 2001 From: Hendrik Ranocha Date: Fri, 4 Nov 2022 17:13:24 +0100 Subject: [PATCH 01/14] Create Invalidations.yml (#79) This is based on https://github.com/julia-actions/julia-invalidations. Adding such checks came up in https://discourse.julialang.org/t/potential-performance-regressions-in-julia-1-8-for-special-un-precompiled-type-dispatches-and-how-to-fix-them/86359. I suggest to add this check here since this package is widely used as a dependency. --- .github/workflows/Invalidations.yml | 40 +++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 .github/workflows/Invalidations.yml diff --git a/.github/workflows/Invalidations.yml b/.github/workflows/Invalidations.yml new file mode 100644 index 0000000..770a9b1 --- /dev/null +++ b/.github/workflows/Invalidations.yml @@ -0,0 +1,40 @@ +name: Invalidations + +on: + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: always. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + evaluate: + # Only run on PRs to the default branch. + # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch + if: github.base_ref == github.event.repository.default_branch + runs-on: ubuntu-latest + steps: + - uses: julia-actions/setup-julia@v1 + with: + version: '1' + - uses: actions/checkout@v3 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_pr + + - uses: actions/checkout@v3 + with: + ref: ${{ github.event.repository.default_branch }} + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_default + + - name: Report invalidation counts + run: | + echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + - name: Check whether the number of invalidations increased + if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total + run: exit 1 From 7d698db082f4ea1898afe40fbf2755e915ee5ff7 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Fri, 4 Nov 2022 09:26:26 -0700 Subject: [PATCH 02/14] Update badges and use LaTeX for math - We're using GHA + Codecov rather than Travis + Coveralls - GitHub now renders LaTeX math in Markdown --- README.md | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 5b33c59..4a884c1 100644 --- a/README.md +++ b/README.md @@ -2,12 +2,12 @@ A general framework for fast Fourier transforms (FFTs) in Julia. -[![Travis](https://travis-ci.org/JuliaMath/AbstractFFTs.jl.svg?branch=master)](https://travis-ci.org/JuliaMath/AbstractFFTs.jl) -[![Coveralls](https://coveralls.io/repos/github/JuliaMath/AbstractFFTs.jl/badge.svg?branch=master)](https://coveralls.io/github/JuliaMath/AbstractFFTs.jl?branch=master) +[![GHA](https://github.com/JuliaMath/AbstractFFTs.jl/workflows/CI/badge.svg)](https://github.com/JuliaMath/AbstractFFTs.jl/actions?query=workflow%3ACI+branch%3Amaster) +[![Codecov](http://codecov.io/github/JuliaMath/AbstractFFTs.jl/coverage.svg?branch=master)](http://codecov.io/github/JuliaMath/AbstractFFTs.jl?branch=master) Documentation: [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://JuliaMath.github.io/AbstractFFTs.jl/stable) -[![](https://img.shields.io/badge/docs-latest-blue.svg)](https://JuliaMath.github.io/AbstractFFTs.jl/latest) +[![](https://img.shields.io/badge/docs-latest-blue.svg)](https://JuliaMath.github.io/AbstractFFTs.jl/dev) This package is mainly not intended to be used directly. Instead, developers of packages that implement FFTs (such as [FFTW.jl](https://github.com/JuliaMath/FFTW.jl) or [FastTransforms.jl](https://github.com/JuliaApproximation/FastTransforms.jl)) @@ -36,5 +36,6 @@ To define a new FFT implementation in your own module, you should * You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs. -The normalization convention for your FFT should be that it computes yₖ = ∑ⱼ xⱼ exp(-2πi jk/n) for a transform of -length n, and the "backwards" (unnormalized inverse) transform computes the same thing but with exp(+2πi jk/n). +The normalization convention for your FFT should be that it computes $y_k = \sum_j \exp\(-2 \pi i \cdot \frac{j k}{n}\)$ +for a transform of length $n$, and the "backwards" (unnormalized inverse) transform computes the same thing but with +$\exp\(+2 \pi i \cdot \frac{j k}{n}\)$. From b2dd69cc19b4dfaeed377fd9f81d7631459b079d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 7 Mar 2023 22:36:43 +0100 Subject: [PATCH 03/14] Make ChainRulesCore a weak dependency on Julia >= 1.9 (#85) * Make ChainRulesCore a weak dependency on Julia >= 1.9 * Qualify `normalization` * Check on nightly if extension works correctly --- .github/workflows/CI.yml | 2 +- Project.toml | 11 +++++++++-- .../AbstractFFTsChainRulesCoreExt.jl | 12 +++++++++--- src/AbstractFFTs.jl | 7 ++++--- 4 files changed, 23 insertions(+), 9 deletions(-) rename src/chainrules.jl => ext/AbstractFFTsChainRulesCoreExt.jl (96%) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 8d43117..1444ca5 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -15,7 +15,7 @@ jobs: version: - '1.0' - '1' -# - 'nightly' + - 'nightly' os: - ubuntu-latest - macOS-latest diff --git a/Project.toml b/Project.toml index a639c5d..572c7a2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,20 +1,27 @@ name = "AbstractFFTs" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.2.1" +version = "1.3.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + +[extensions] +AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + [compat] ChainRulesCore = "1" julia = "^1.0" [extras] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["ChainRulesTestUtils", "Random", "Test", "Unitful"] +test = ["ChainRulesCore", "ChainRulesTestUtils", "Random", "Test", "Unitful"] diff --git a/src/chainrules.jl b/ext/AbstractFFTsChainRulesCoreExt.jl similarity index 96% rename from src/chainrules.jl rename to ext/AbstractFFTsChainRulesCoreExt.jl index 97d4d22..f0c788e 100644 --- a/src/chainrules.jl +++ b/ext/AbstractFFTsChainRulesCoreExt.jl @@ -1,4 +1,8 @@ -# ffts +module AbstractFFTsChainRulesCoreExt + +using AbstractFFTs +import ChainRulesCore + function ChainRulesCore.frule((_, Δx, _), ::typeof(fft), x::AbstractArray, dims) y = fft(x, dims) Δy = fft(Δx, dims) @@ -46,7 +50,7 @@ function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dim end function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims) y = ifft(x, dims) - invN = normalization(y, dims) + invN = AbstractFFTs.normalization(y, dims) project_x = ChainRulesCore.ProjectTo(x) function ifft_pullback(ȳ) x̄ = project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims)) @@ -66,7 +70,7 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) # compute scaling factors halfdim = first(dims) n = size(x, halfdim) - invN = normalization(y, dims) + invN = AbstractFFTs.normalization(y, dims) twoinvN = 2 * invN scale = reshape( [i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n], @@ -150,3 +154,5 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims) end return y, ifftshift_pullback end + +end # module diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index 56d7123..00f6dc2 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -1,13 +1,14 @@ module AbstractFFTs -import ChainRulesCore - export fft, ifft, bfft, fft!, ifft!, bfft!, plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!, rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft, fftdims, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq include("definitions.jl") -include("chainrules.jl") + +if !isdefined(Base, :get_extension) + include("../ext/AbstractFFTsChainRulesCoreExt.jl") +end end # module From 1e3df24dc91cc77ada2e3847f8281a8fa787b7ad Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Wed, 8 Mar 2023 07:14:17 -0500 Subject: [PATCH 04/14] Ensure all fft-like functions fallback to version with region when region not provided (#84) * Ensure all fft-like functions fallback to version with region when region not provided * Add testset for default dims * Add tests for complex float promotion * Test complex float promotion for fft,ifft,bfft too --- src/definitions.jl | 12 ++++++------ test/runtests.jl | 27 +++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index 4532650..1cf542b 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -59,7 +59,7 @@ _to1(::Tuple, x) = copy1(eltype(x), x) for f in (:fft, :bfft, :ifft, :fft!, :bfft!, :ifft!, :rfft) pf = Symbol("plan_", f) @eval begin - $f(x::AbstractArray) = (y = to1(x); $pf(y) * y) + $f(x::AbstractArray) = $f(x, 1:ndims(x)) $f(x::AbstractArray, region) = (y = to1(x); $pf(y, region) * y) $pf(x::AbstractArray; kws...) = (y = to1(x); $pf(y, 1:ndims(y); kws...)) end @@ -207,9 +207,9 @@ bfft! for f in (:fft, :bfft, :ifft) pf = Symbol("plan_", f) @eval begin - $f(x::AbstractArray{<:Real}, region=1:ndims(x)) = $f(complexfloat(x), region) + $f(x::AbstractArray{<:Real}, region) = $f(complexfloat(x), region) $pf(x::AbstractArray{<:Real}, region; kws...) = $pf(complexfloat(x), region; kws...) - $f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region=1:ndims(x)) = $f(complexfloat(x), region) + $f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region) = $f(complexfloat(x), region) $pf(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region; kws...) = $pf(complexfloat(x), region; kws...) end end @@ -297,7 +297,7 @@ LinearAlgebra.mul!(y::AbstractArray, p::ScaledPlan, x::AbstractArray) = for f in (:brfft, :irfft) pf = Symbol("plan_", f) @eval begin - $f(x::AbstractArray, d::Integer) = $pf(x, d) * x + $f(x::AbstractArray, d::Integer) = $f(x, d, 1:ndims(x)) $f(x::AbstractArray, d::Integer, region) = $pf(x, d, region) * x $pf(x::AbstractArray, d::Integer;kws...) = $pf(x, d, 1:ndims(x);kws...) end @@ -305,8 +305,8 @@ end for f in (:brfft, :irfft) @eval begin - $f(x::AbstractArray{<:Real}, d::Integer, region=1:ndims(x)) = $f(complexfloat(x), d, region) - $f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region=1:ndims(x)) = $f(complexfloat(x), d, region) + $f(x::AbstractArray{<:Real}, d::Integer, region) = $f(complexfloat(x), d, region) + $f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region) = $f(complexfloat(x), d, region) end end diff --git a/test/runtests.jl b/test/runtests.jl index 4d402c5..9cb528a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -213,6 +213,33 @@ end @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10 end +# Test that dims defaults to 1:ndims for fft-like functions +@testset "Default dims" begin + for x in (randn(3), randn(3, 4), randn(3, 4, 5)) + N = ndims(x) + complex_x = complex.(x) + @test fft(x) ≈ fft(x, 1:N) + @test ifft(x) ≈ ifft(x, 1:N) + @test bfft(x) ≈ bfft(x, 1:N) + @test rfft(x) ≈ rfft(x, 1:N) + d = 2 * size(x, 1) - 1 + @test irfft(x, d) ≈ irfft(x, d, 1:N) + @test brfft(x, d) ≈ brfft(x, d, 1:N) + end +end + +@testset "Complex float promotion" begin + for x in (rand(-5:5, 3), rand(-5:5, 3, 4), rand(-5:5, 3, 4, 5)) + N = ndims(x) + @test fft(x) ≈ fft(complex.(x)) ≈ fft(complex.(float.(x))) + @test ifft(x) ≈ ifft(complex.(x)) ≈ ifft(complex.(float.(x))) + @test bfft(x) ≈ bfft(complex.(x)) ≈ bfft(complex.(float.(x))) + d = 2 * size(x, 1) - 1 + @test irfft(x, d) ≈ irfft(complex.(x), d) ≈ irfft(complex.(float.(x)), d) + @test brfft(x, d) ≈ brfft(complex.(x), d) ≈ brfft(complex.(float.(x)), d) + end +end + @testset "ChainRules" begin @testset "shift functions" begin for x in (randn(3), randn(3, 4), randn(3, 4, 5)) From 79789f2250aed0d70d2e1766667cc4dde7b43896 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 8 Mar 2023 13:14:44 +0100 Subject: [PATCH 05/14] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 572c7a2..8e7206c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AbstractFFTs" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.3.0" +version = "1.3.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From a52b2ab54e1e5e3ad2b4c6f3a1b0890ed06f0a1c Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 11 Mar 2023 19:51:47 -0500 Subject: [PATCH 06/14] Add integration test for FastTransforms (#86) * Add integration test for FastTransforms * Revert temporary fix --------- Co-authored-by: David Widmann --- .github/workflows/IntegrationTest.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index 34d8ca4..9e75160 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -18,6 +18,7 @@ jobs: os: [ubuntu-latest] package: - {user: JuliaMath, repo: FFTW.jl} + - {user: JuliaApproximation, repo: FastTransforms.jl} steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 From 8c5712b8eab7e588672f2f06db37322cade67b27 Mon Sep 17 00:00:00 2001 From: Hendrik Ranocha Date: Tue, 14 Mar 2023 19:08:51 +0100 Subject: [PATCH 07/14] enable dependabot for GitHub actions (#89) --- .github/dependabot.yml | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..700707c --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,7 @@ +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" # Location of package manifests + schedule: + interval: "weekly" From a226e0c12486a6724fcf268b44707b69ce8a3b12 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 15 Mar 2023 12:29:12 -0400 Subject: [PATCH 08/14] Bump codecov/codecov-action from 1 to 3 (#90) Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 1 to 3. - [Release notes](https://github.com/codecov/codecov-action/releases) - [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/codecov/codecov-action/compare/v1...v3) --- updated-dependencies: - dependency-name: codecov/codecov-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 1444ca5..6b9ef2b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -41,6 +41,6 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v1 + - uses: codecov/codecov-action@v3 with: file: lcov.info From a25656dfabf0f6c5067bc3f90a591c242da4b9be Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 16 Mar 2023 16:32:41 +0100 Subject: [PATCH 09/14] Update actions (#93) --- .github/workflows/CI.yml | 13 ++----------- .github/workflows/Documenter.yml | 6 +++--- .github/workflows/IntegrationTest.yml | 6 +++--- 3 files changed, 8 insertions(+), 17 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 6b9ef2b..7395a98 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -23,21 +23,12 @@ jobs: arch: - x64 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v1 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- + - uses: julia-actions/cache@v1 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 diff --git a/.github/workflows/Documenter.yml b/.github/workflows/Documenter.yml index 8675804..b9c4408 100644 --- a/.github/workflows/Documenter.yml +++ b/.github/workflows/Documenter.yml @@ -10,9 +10,9 @@ jobs: name: Documentation runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: julia-actions/julia-buildpkg@latest - - uses: julia-actions/julia-docdeploy@latest + - uses: actions/checkout@v3 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-docdeploy@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index 9e75160..59837f6 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -20,14 +20,14 @@ jobs: - {user: JuliaMath, repo: FFTW.jl} - {user: JuliaApproximation, repo: FastTransforms.jl} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.julia-version }} arch: x64 - - uses: julia-actions/julia-buildpkg@latest + - uses: julia-actions/julia-buildpkg@v1 - name: Clone Downstream - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} path: downstream From 09d703ae926943e08a5c3a449060883134824e6d Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Mon, 24 Apr 2023 12:24:38 -0400 Subject: [PATCH 10/14] correct formula (closes #100) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 4a884c1..fedb8c2 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,6 @@ To define a new FFT implementation in your own module, you should * You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs. -The normalization convention for your FFT should be that it computes $y_k = \sum_j \exp\(-2 \pi i \cdot \frac{j k}{n}\)$ +The normalization convention for your FFT should be that it computes $y_k = \sum_j \exp\(-2 \pi i \cdot \frac{j k}{n}\) x_j$ for a transform of length $n$, and the "backwards" (unnormalized inverse) transform computes the same thing but with $\exp\(+2 \pi i \cdot \frac{j k}{n}\)$. From 3a3f0e4ebc3a9ab81bfb655003759ffa9c6bd1bb Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Mon, 19 Jun 2023 09:47:40 -0400 Subject: [PATCH 11/14] put chainrules core in a requires block, rm it from deps (#107) * put chainrules core in a requires block, rm it from deps * test conditional loading of CRC --- Project.toml | 14 +++++++------- src/AbstractFFTs.jl | 12 ++++++++++-- test/runtests.jl | 14 +++++++++++++- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 8e7206c..111b84f 100644 --- a/Project.toml +++ b/Project.toml @@ -3,19 +3,16 @@ uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" version = "1.3.1" [deps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[weakdeps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - -[extensions] -AbstractFFTsChainRulesCoreExt = "ChainRulesCore" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" [compat] ChainRulesCore = "1" julia = "^1.0" +[extensions] +AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" @@ -25,3 +22,6 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] test = ["ChainRulesCore", "ChainRulesTestUtils", "Random", "Test", "Unitful"] + +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index 00f6dc2..77e194b 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -7,8 +7,16 @@ export fft, ifft, bfft, fft!, ifft!, bfft!, include("definitions.jl") -if !isdefined(Base, :get_extension) - include("../ext/AbstractFFTsChainRulesCoreExt.jl") +@static if !isdefined(Base, :get_extension) + import Requires +end + +@static if !isdefined(Base, :get_extension) + function __init__() + Requires.@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" begin + include("../ext/AbstractFFTsChainRulesCoreExt.jl") + end + end end end # module diff --git a/test/runtests.jl b/test/runtests.jl index 9cb528a..c41debe 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,6 @@ using AbstractFFTs using AbstractFFTs: Plan -using ChainRulesTestUtils using LinearAlgebra using Random @@ -241,6 +240,19 @@ end end @testset "ChainRules" begin + + if isdefined(Base, :get_extension) + CRCEXT = Base.get_extension(AbstractFFTs, :AbstractFFTsChainRulesCoreExt) + @test isnothing(CRCEXT) + end + + using ChainRulesTestUtils + + if isdefined(Base, :get_extension) + CRCEXT = Base.get_extension(AbstractFFTs, :AbstractFFTsChainRulesCoreExt) + @test !isnothing(CRCEXT) + end + @testset "shift functions" begin for x in (randn(3), randn(3, 4), randn(3, 4, 5)) for dims in ((), 1, 2, (1,2), 1:2) From 27c37a07560bd8fca09ba2329e9816dca3990f9a Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 27 Jun 2023 14:18:09 -0400 Subject: [PATCH 12/14] For conversion of scale vector in adjoint (#105) * For conversion of scale vector in adjoint It always defines an `Array` which can fail on the GPU. This forces it to be the right type. One could also use `adapt` here, but since the element type promotion would have to occur anyways in the subsequent broadcast it seems you might as well convert all at once. * Update AbstractFFTsChainRulesCoreExt.jl --- ext/AbstractFFTsChainRulesCoreExt.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ext/AbstractFFTsChainRulesCoreExt.jl b/ext/AbstractFFTsChainRulesCoreExt.jl index f0c788e..d58f5fa 100644 --- a/ext/AbstractFFTsChainRulesCoreExt.jl +++ b/ext/AbstractFFTsChainRulesCoreExt.jl @@ -37,7 +37,9 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) project_x = ChainRulesCore.ProjectTo(x) function rfft_pullback(ȳ) - x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims)) + ybar = ChainRulesCore.unthunk(ȳ) + _scale = convert(typeof(ybar),scale) + x̄ = project_x(brfft(ybar ./ _scale, d, dims)) return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() end return y, rfft_pullback @@ -79,7 +81,9 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) project_x = ChainRulesCore.ProjectTo(x) function irfft_pullback(ȳ) - x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) + ybar = ChainRulesCore.unthunk(ȳ) + _scale = convert(typeof(ybar),scale) + x̄ = project_x(_scale .* rfft(real.(ybar), dims)) return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() end return y, irfft_pullback From 48a5d1c6a2eb3befa9fe7a7b4b24fea1da399b4a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 27 Jun 2023 20:18:54 +0200 Subject: [PATCH 13/14] Revert "put chainrules core in a requires block, rm it from deps (#107)" (#108) This reverts commit 3a3f0e4ebc3a9ab81bfb655003759ffa9c6bd1bb. --- Project.toml | 14 +++++++------- src/AbstractFFTs.jl | 12 ++---------- test/runtests.jl | 14 +------------- 3 files changed, 10 insertions(+), 30 deletions(-) diff --git a/Project.toml b/Project.toml index 111b84f..8e7206c 100644 --- a/Project.toml +++ b/Project.toml @@ -3,16 +3,19 @@ uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" version = "1.3.1" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" -[compat] -ChainRulesCore = "1" -julia = "^1.0" +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" [extensions] AbstractFFTsChainRulesCoreExt = "ChainRulesCore" +[compat] +ChainRulesCore = "1" +julia = "^1.0" + [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" @@ -22,6 +25,3 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] test = ["ChainRulesCore", "ChainRulesTestUtils", "Random", "Test", "Unitful"] - -[weakdeps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index 77e194b..00f6dc2 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -7,16 +7,8 @@ export fft, ifft, bfft, fft!, ifft!, bfft!, include("definitions.jl") -@static if !isdefined(Base, :get_extension) - import Requires -end - -@static if !isdefined(Base, :get_extension) - function __init__() - Requires.@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" begin - include("../ext/AbstractFFTsChainRulesCoreExt.jl") - end - end +if !isdefined(Base, :get_extension) + include("../ext/AbstractFFTsChainRulesCoreExt.jl") end end # module diff --git a/test/runtests.jl b/test/runtests.jl index c41debe..9cb528a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using AbstractFFTs using AbstractFFTs: Plan +using ChainRulesTestUtils using LinearAlgebra using Random @@ -240,19 +241,6 @@ end end @testset "ChainRules" begin - - if isdefined(Base, :get_extension) - CRCEXT = Base.get_extension(AbstractFFTs, :AbstractFFTsChainRulesCoreExt) - @test isnothing(CRCEXT) - end - - using ChainRulesTestUtils - - if isdefined(Base, :get_extension) - CRCEXT = Base.get_extension(AbstractFFTs, :AbstractFFTsChainRulesCoreExt) - @test !isnothing(CRCEXT) - end - @testset "shift functions" begin for x in (randn(3), randn(3, 4), randn(3, 4, 5)) for dims in ((), 1, 2, (1,2), 1:2) From b5109aab2d0d610a0963a65e6012907605b14bec Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 27 Jun 2023 20:34:50 +0200 Subject: [PATCH 14/14] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8e7206c..498fac8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AbstractFFTs" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.3.1" +version = "1.3.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"