Skip to content

Commit

Permalink
Expand basis functions to operate on arbitrary dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 12, 2024
1 parent 39aed1e commit 48e16d7
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Boltz"
uuid = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.3.7"
version = "0.3.8"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
82 changes: 66 additions & 16 deletions src/basis.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
module Basis

using ArgCheck: @argcheck
using ..Boltz: _unsqueeze1
using ChainRulesCore: ChainRulesCore, NoTangent
using ConcreteStructs: @concrete
using LuxDeviceUtils: get_device, LuxCPUDevice
using Markdown: @doc_str

const CRC = ChainRulesCore
Expand All @@ -11,63 +13,103 @@ const CRC = ChainRulesCore
@concrete struct GeneralBasisFunction{name}
f
n::Int
dim::Int
end

function Base.show(io::IO, basis::GeneralBasisFunction{name}) where {name}
print(io, "Basis.$(name)(order=$(basis.n))")
end

@inline function (basis::GeneralBasisFunction{name, F})(x::AbstractArray) where {name, F}
return basis.f.(1:(basis.n), _unsqueeze1(x))
@inline function (basis::GeneralBasisFunction{name, F})(x::AbstractArray,
grid::Union{AbstractRange, AbstractVector}=1:1:(basis.n)) where {name, F}
@argcheck length(grid) == basis.n
if basis.dim == 1 # Fast path where we don't need to materialize the range
return basis.f.(grid, _unsqueeze1(x))
end

@argcheck ndims(x) + 1 basis.dim
new_x_size = ntuple(
i -> i == basis.dim ? 1 : (i < basis.dim ? size(x, i) : size(x, i - 1)),
ndims(x) + 1)
x_new = reshape(x, new_x_size)
if grid isa AbstractRange
dev = get_device(x)
grid = dev isa LuxCPUDevice ? collect(grid) : dev(grid)
end
grid_shape = ntuple(i -> i == basis.dim ? basis.n : 1, ndims(x) + 1)
grid_new = reshape(grid, grid_shape)
return basis.f.(grid_new, x_new)
end

const DIM_KWARG_DOC = " - `dim::Int=1`: The dimension along which the basis functions are applied."

@doc doc"""
Chebyshev(n)
Chebyshev(n; dim::Int=1)
Constructs a Chebyshev basis of the form $[T_{0}(x), T_{1}(x), \dots, T_{n-1}(x)]$ where
$T_j(.)$ is the $j^{th}$ Chebyshev polynomial of the first kind.
## Arguments
- `n`: number of terms in the polynomial expansion.
## Keyword Arguments
$(DIM_KWARG_DOC)
"""
Chebyshev(n) = GeneralBasisFunction{:Chebyshev}(__chebyshev, n)
Chebyshev(n; dim::Int=1) = GeneralBasisFunction{:Chebyshev}(__chebyshev, n, dim)

@inline __chebyshev(i, x) = @fastmath cos(i * acos(x))

@doc doc"""
Sin(n)
Sin(n; dim::Int=1)
Constructs a sine basis of the form $[\sin(x), \sin(2x), \dots, \sin(nx)]$.
## Arguments
- `n`: number of terms in the sine expansion.
## Keyword Arguments
$(DIM_KWARG_DOC)
"""
Sin(n) = GeneralBasisFunction{:Sin}(@fastmath(sin∘*), n)
Sin(n; dim::Int=1) = GeneralBasisFunction{:Sin}(@fastmath(sin∘*), n, dim)

@doc doc"""
Cos(n)
Cos(n; dim::Int=1)
Constructs a cosine basis of the form $[\cos(x), \cos(2x), \dots, \cos(nx)]$.
## Arguments
- `n`: number of terms in the cosine expansion.
## Keyword Arguments
$(DIM_KWARG_DOC)
"""
Cos(n) = GeneralBasisFunction{:Cos}(@fastmath(cos∘*), n)
Cos(n; dim::Int=1) = GeneralBasisFunction{:Cos}(@fastmath(cos∘*), n, dim)

@doc doc"""
Fourier(n)
Fourier(n; dim=1)
Constructs a Fourier basis of the form
$F_j(x) = j is even ? cos((j÷2)x) : sin((j÷2)x)$ => $[F_0(x), F_1(x), \dots, F_n(x)]$.
$$F_j(x) = \begin{cases}
cos\left(\frac{j}{2}x\right) & \text{if } j \text{ is even} \\
sin\left(\frac{j}{2}x\right) & \text{if } j \text{ is odd}
\end{cases}$$
## Arguments
- `n`: number of terms in the Fourier expansion.
## Keyword Arguments
$(DIM_KWARG_DOC)
"""
Fourier(n) = GeneralBasisFunction{:Fourier}(__fourier, n)
Fourier(n; dim::Int=1) = GeneralBasisFunction{:Fourier}(__fourier, n, dim)

@inline @fastmath function __fourier(i, x::AbstractFloat)
s, c = sincos(i * x / 2)
Expand Down Expand Up @@ -96,16 +138,20 @@ end
end

@doc doc"""
Legendre(n)
Legendre(n; dim::Int=1)
Constructs a Legendre basis of the form $[P_{0}(x), P_{1}(x), \dots, P_{n-1}(x)]$ where
$P_j(.)$ is the $j^{th}$ Legendre polynomial.
## Arguments
- `n`: number of terms in the polynomial expansion.
## Keyword Arguments
$(DIM_KWARG_DOC)
"""
Legendre(n) = GeneralBasisFunction{:Legendre}(__legendre_poly, n)
Legendre(n; dim::Int=1) = GeneralBasisFunction{:Legendre}(__legendre_poly, n, dim)

## Source: https://github.com/ranocha/PolynomialBases.jl/blob/master/src/legendre.jl
@inline function __legendre_poly(i, x)
Expand All @@ -124,15 +170,19 @@ Legendre(n) = GeneralBasisFunction{:Legendre}(__legendre_poly, n)
end

@doc doc"""
Polynomial(n)
Polynomial(n; dim::Int=1)
Constructs a Polynomial basis of the form $[1, x, \dots, x^(n-1)]$.
Constructs a Polynomial basis of the form $[1, x, \dots, x^{(n-1)}]$.
## Arguments
- `n`: number of terms in the polynomial expansion.
## Keyword Arguments
$(DIM_KWARG_DOC)
"""
Polynomial(n) = GeneralBasisFunction{:Polynomial}(__polynomial, n)
Polynomial(n; dim::Int=1) = GeneralBasisFunction{:Polynomial}(__polynomial, n, dim)

@inline __polynomial(i, x) = x^(i - 1)

Expand Down
34 changes: 34 additions & 0 deletions test/layer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,37 @@ end
end
end
end

@testitem "Basis Functions" setup=[SharedTestSetup] tags=[:layers] begin
@testset "$(mode)" for (mode, aType, dev, ongpu) in MODES
@testset "$(basis)" for basis in (Basis.Chebyshev, Basis.Sin, Basis.Cos,
Basis.Fourier, Basis.Legendre, Basis.Polynomial)
x = tanh.(randn(Float32, 2, 4)) |> aType
grid = collect(1:3) |> aType

fn = basis(3)
@test size(fn(x)) == (3, 2, 4)
@jet fn(x)
@test size(fn(x, grid)) == (3, 2, 4)
@jet fn(x, grid)

fn = basis(3; dim=2)
@test size(fn(x)) == (2, 3, 4)
@jet fn(x)
@test size(fn(x, grid)) == (2, 3, 4)
@jet fn(x, grid)

fn = basis(3; dim=3)
@test size(fn(x)) == (2, 4, 3)
@jet fn(x)
@test size(fn(x, grid)) == (2, 4, 3)
@jet fn(x, grid)

fn = basis(3; dim=4)
@test_throws ArgumentError fn(x)

grid = 1:5 |> aType
@test_throws ArgumentError fn(x, grid)
end
end
end

0 comments on commit 48e16d7

Please sign in to comment.