Skip to content

Commit

Permalink
Add Tensor Product Layer
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 9, 2024
1 parent e545bd0 commit 2372e67
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 6 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
5 changes: 4 additions & 1 deletion src/Boltz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ include("utils.jl")
include("initialize.jl")
include("patch.jl")

# Basis Functions
include("basis.jl")

# Layers
include("layers/Layers.jl")

Expand All @@ -32,6 +35,6 @@ include("vision/Vision.jl")
# deprecated
include("deprecated.jl")

export Layers, Vision
export Basis, Layers, Vision

end
114 changes: 114 additions & 0 deletions src/basis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
module Basis

using ..Boltz: _unsqueeze1
using ConcreteStructs: @concrete
using Markdown: @doc_str

@concrete struct GeneralBasisFunction{name}
f
n::Int
end

function Base.show(io::IO, basis::GeneralBasisFunction{name}) where {name}
print(io, "Basis.$(name)(order=$(basis.n))")
end

function (basis::GeneralBasisFunction{name, F})(x::AbstractArray) where {name, F}
return basis.f.(1:(basis.n), _unsqueeze1(x))
end

@doc doc"""
Chebyshev(n)
Constructs a Chebyshev basis of the form $[T_{0}(x), T_{1}(x), \dots, T_{n-1}(x)]$ where
$T_j(.)$ is the $j^{th}$ Chebyshev polynomial of the first kind.
## Arguments
- `n`: number of terms in the polynomial expansion.
"""
Chebyshev(n) = GeneralBasisFunction{:Chebyshev}(__chebyshev, n)

@inline __chebyshev(i, x) = @fastmath cos(i * acos(x))

@doc doc"""
Sin(n)
Constructs a sine basis of the form $[\sin(x), \sin(2x), \dots, \sin(nx)]$.
## Arguments
- `n`: number of terms in the sine expansion.
"""
Sin(n) = GeneralBasisFunction{:Sin}(@fastmath(sin∘*), n)

@doc doc"""
Cos(n)
Constructs a cosine basis of the form $[\cos(x), \cos(2x), \dots, \cos(nx)]$.
## Arguments
- `n`: number of terms in the cosine expansion.
"""
Cos(n) = GeneralBasisFunction{:Cos}(@fastmath(cos∘*), n)

@doc doc"""
Fourier(n)
Constructs a Fourier basis of the form
$F_j(x) = j is even ? cos((j÷2)x) : sin((j÷2)x)$ => $[F_0(x), F_1(x), \dots, F_n(x)]$.
## Arguments
- `n`: number of terms in the Fourier expansion.
"""
Fourier(n) = GeneralBasisFunction{:Fourier}(__fourier, n)

@inline function __fourier(i, x)
s, c = @fastmath sincos(i * x / 2)
return ifelse(iseven(i), c, s)
end

@doc doc"""
Legendre(n)
Constructs a Legendre basis of the form $[P_{0}(x), P_{1}(x), \dots, P_{n-1}(x)]$ where
$P_j(.)$ is the $j^{th}$ Legendre polynomial.
## Arguments
- `n`: number of terms in the polynomial expansion.
"""
Legendre(n) = GeneralBasisFunction{:Legendre}(__legendre_poly, n)

## Source: https://github.com/ranocha/PolynomialBases.jl/blob/master/src/legendre.jl
@inline function __legendre_poly(i, x)
p = i - 1
a = one(x)
b = x

p 0 && return a
p == 1 && return b

for j in 2:p
a, b = b, @fastmath(((2j - 1) * x * b - (j - 1) * a)/j)
end

return b
end

@doc doc"""
Polynomial(n)
Constructs a Polynomial basis of the form $[1, x, \dots, x^(n-1)]$.
## Arguments
- `n`: number of terms in the polynomial expansion.
"""
Polynomial(n) = GeneralBasisFunction{:Polynomial}(__polynomial, n)

@inline __polynomial(i, x) = x^(i - 1)

end
5 changes: 4 additions & 1 deletion src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ using PrecompileTools: @recompile_invalidations
@recompile_invalidations begin
using ArgCheck: @argcheck
using ADTypes: AutoForwardDiff, AutoZygote
using ..Boltz: _fast_chunk, _should_type_assert
using ..Boltz: Boltz, _fast_chunk, _should_type_assert, _batchview, _unsqueezeN
using ConcreteStructs: @concrete
using ChainRulesCore: ChainRulesCore
using Lux: Lux, StatefulLuxLayer
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using Markdown: @doc_str
using NNlib: NNlib
using Random: AbstractRNG
using WeightInitializers: zeros32, randn32
Expand All @@ -30,5 +31,7 @@ include("encoder.jl")
include("embeddings.jl")
include("hamiltonian.jl")
include("mlp.jl")
include("spline.jl")
include("tensor_product.jl")

end
9 changes: 5 additions & 4 deletions src/layers/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ end

function HamiltonianNN{FST}(model; autodiff=nothing) where {FST}
if autodiff === nothing # Select best possible backend
autodiff = Lux._is_extension_loaded(Val(:Zygote)) ? AutoZygote() :
Lux._is_extension_loaded(Val(:ForwardDiff)) ? AutoForwardDiff() : nothing
autodiff = Boltz._is_extension_loaded(Val(:Zygote)) ? AutoZygote() :
Boltz._is_extension_loaded(Val(:ForwardDiff)) ? AutoForwardDiff() :
nothing
elseif autodiff isa AutoForwardDiff
autodiff = Lux._is_extension_loaded(Val(:ForwardDiff)) ? autodiff : nothing
autodiff = Boltz._is_extension_loaded(Val(:ForwardDiff)) ? autodiff : nothing
elseif autodiff isa AutoZygote
autodiff = Lux._is_extension_loaded(Val(:Zygote)) ? autodiff : nothing
autodiff = Boltz._is_extension_loaded(Val(:Zygote)) ? autodiff : nothing
else
throw(ArgumentError("Invalid autodiff backend: $(autodiff). Available options: \
`AutoForwardDiff`, or `AutoZygote`."))
Expand Down
1 change: 1 addition & 0 deletions src/layers/spline.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

40 changes: 40 additions & 0 deletions src/layers/tensor_product.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
@doc doc"""
TensorProductLayer(model, out_dim::Int; init_weight = randn32)
Constructs the Tensor Product Layer, which takes as input an array of n tensor product
basis, $[B_1, B_2, \dots, B_n]$ a data point x, computes
$$z_i = W_{i, :} ⨀ [B_1(x_1) ⨂ B_2(x_2) ⨂ \dots ⨂ B_n(x_n)]$$
where $W$ is the layer's weight, and returns $[z_1, \dots, z_{out}]$.
## Arguments
- `basis_fns`: Array of TensorProductBasis $[B_1(n_1), \dots, B_k(n_k)]$, where $k$
corresponds to the dimension of the input.
- `out_dim`: Dimension of the output.
- `init_weight`: Initializer for the weight matrix. Defaults to `randn32`.
"""
function TensorProductLayer(basis_fns, out_dim::Int; init_weight::F=randn32) where {F}
dense = Lux.Dense(
prod(Base.Fix2(getproperty, :n), basis_fns) => out_dim; use_bias=false, init_weight)
return Lux.@compact(; basis_fns=Tuple(basis_fns), dense,
out_dim, dispatch=:TensorProductLayer) do x::AbstractArray # I1 x I2 x ... x T x B
x_ = Lux._eachslice(x, Val(ndims(x) - 1)) # [I1 x I2 x ... x B] x T
@argcheck length(x_) == length(basis_fns)

y_ = mapfoldl(_kron, zip(basis_fns, x_)) do (fn, xᵢ)
eachcol(reshape(fn(xᵢ), :, prod(size(xᵢ))))
end # [[D₁ x ... x Dₙ] x (I1 x I2 x ... x B)]

@return reshape(dense(stack(y_)), size(x)[1:(end - 2)]..., out_dim, size(x)[end])
end
end

# CUDA `kron` exists only for `CuMatrix` so we define `kron` directly by converting to
# a matrix
@noinline _kron(a, b) = map(__kron, a, b)
@noinline function __kron(a, b)
@show size(a), size(b)
return vec(kron(reshape(a, :, 1), reshape(b, 1, :)))
end
5 changes: 5 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,8 @@ type-assert for `x`.
"""
@inline _should_type_assert(::AbstractArray{T}) where {T} = isbitstype(T)
@inline _should_type_assert(x) = true

@inline _unsqueeze1(x::AbstractArray) = reshape(x, 1, size(x)...)
@inline _unsqueezeN(x::AbstractArray) = reshape(x, size(x)..., 1)

@inline _batchview(x::AbstractArray{T, N}) where {T, N} = Lux._eachslice(x, Val(N))

0 comments on commit 2372e67

Please sign in to comment.