Skip to content

Commit

Permalink
Add Tensor Product Layer Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 9, 2024
1 parent 2372e67 commit f10ba33
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 10 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Expand Down Expand Up @@ -49,8 +50,10 @@ Lux = "0.5.50"
LuxAMDGPU = "0.2.3"
LuxCUDA = "0.3.2"
LuxCore = "0.1.15"
LuxDeviceUtils = "0.1.21"
LuxLib = "0.3.26"
LuxTestUtils = "0.1.15"
Markdown = "1.10"
Metalhead = "0.9"
NNlib = "0.9.17"
Pkg = "1.10"
Expand Down
5 changes: 3 additions & 2 deletions src/basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function Base.show(io::IO, basis::GeneralBasisFunction{name}) where {name}
print(io, "Basis.$(name)(order=$(basis.n))")

Check warning on line 13 in src/basis.jl

View check run for this annotation

Codecov / codecov/patch

src/basis.jl#L12-L13

Added lines #L12 - L13 were not covered by tests
end

function (basis::GeneralBasisFunction{name, F})(x::AbstractArray) where {name, F}
@inline function (basis::GeneralBasisFunction{name, F})(x::AbstractArray) where {name, F}
return basis.f.(1:(basis.n), _unsqueeze1(x))
end

Expand Down Expand Up @@ -66,7 +66,8 @@ $F_j(x) = j is even ? cos((j÷2)x) : sin((j÷2)x)$ => $[F_0(x), F_1(x), \dots, F
Fourier(n) = GeneralBasisFunction{:Fourier}(__fourier, n)

@inline function __fourier(i, x)

Check warning on line 68 in src/basis.jl

View check run for this annotation

Codecov / codecov/patch

src/basis.jl#L68

Added line #L68 was not covered by tests
s, c = @fastmath sincos(i * x / 2)
# Don't use @fastmasth here, fast mode needs float but Zygote uses Duals for broadcast
s, c = sincos(i * x / 2)
return ifelse(iseven(i), c, s)

Check warning on line 71 in src/basis.jl

View check run for this annotation

Codecov / codecov/patch

src/basis.jl#L70-L71

Added lines #L70 - L71 were not covered by tests
end

Expand Down
3 changes: 2 additions & 1 deletion src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ using PrecompileTools: @recompile_invalidations
@recompile_invalidations begin
using ArgCheck: @argcheck
using ADTypes: AutoForwardDiff, AutoZygote
using ..Boltz: Boltz, _fast_chunk, _should_type_assert, _batchview, _unsqueezeN
using ..Boltz: Boltz, _fast_chunk, _should_type_assert, _stack
using ConcreteStructs: @concrete
using ChainRulesCore: ChainRulesCore
using Lux: Lux, StatefulLuxLayer
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using LuxDeviceUtils: get_device, LuxCPUDevice, LuxCUDADevice
using Markdown: @doc_str
using NNlib: NNlib
using Random: AbstractRNG
Expand Down
18 changes: 12 additions & 6 deletions src/layers/tensor_product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Constructs the Tensor Product Layer, which takes as input an array of n tensor product
basis, $[B_1, B_2, \dots, B_n]$ a data point x, computes
$$z_i = W_{i, :} [B_1(x_1) B_2(x_2) \dots B_n(x_n)]$$
$$z_i = W_{i, :} \odot [B_1(x_1) \otimes B_2(x_2) \otimes \dots \otimes B_n(x_n)]$$
where $W$ is the layer's weight, and returns $[z_1, \dots, z_{out}]$.
Expand All @@ -14,27 +14,33 @@ where $W$ is the layer's weight, and returns $[z_1, \dots, z_{out}]$.
corresponds to the dimension of the input.
- `out_dim`: Dimension of the output.
- `init_weight`: Initializer for the weight matrix. Defaults to `randn32`.
!!! warning
This layer currently only works on CPU and CUDA devices.
"""
function TensorProductLayer(basis_fns, out_dim::Int; init_weight::F=randn32) where {F}
dense = Lux.Dense(
prod(Base.Fix2(getproperty, :n), basis_fns) => out_dim; use_bias=false, init_weight)
return Lux.@compact(; basis_fns=Tuple(basis_fns), dense,
out_dim, dispatch=:TensorProductLayer) do x::AbstractArray # I1 x I2 x ... x T x B
dev = get_device(x)
@argcheck dev isa LuxCPUDevice || dev isa LuxCUDADevice # kron is not widely supported

Check warning on line 28 in src/layers/tensor_product.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/tensor_product.jl#L27-L28

Added lines #L27 - L28 were not covered by tests

x_ = Lux._eachslice(x, Val(ndims(x) - 1)) # [I1 x I2 x ... x B] x T
@argcheck length(x_) == length(basis_fns)

Check warning on line 31 in src/layers/tensor_product.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/tensor_product.jl#L30-L31

Added lines #L30 - L31 were not covered by tests

y_ = mapfoldl(_kron, zip(basis_fns, x_)) do (fn, xᵢ)
y = mapfoldl(_kron, zip(basis_fns, x_)) do (fn, xᵢ)
eachcol(reshape(fn(xᵢ), :, prod(size(xᵢ))))

Check warning on line 34 in src/layers/tensor_product.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/tensor_product.jl#L33-L34

Added lines #L33 - L34 were not covered by tests
end # [[D₁ x ... x Dₙ] x (I1 x I2 x ... x B)]

@return reshape(dense(stack(y_)), size(x)[1:(end - 2)]..., out_dim, size(x)[end])
@return reshape(dense(_stack(y)), size(x)[1:(end - 2)]..., out_dim, size(x)[end])

Check warning on line 37 in src/layers/tensor_product.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/tensor_product.jl#L37

Added line #L37 was not covered by tests
end
end

# CUDA `kron` exists only for `CuMatrix` so we define `kron` directly by converting to
# a matrix
@noinline _kron(a, b) = map(__kron, a, b)
@noinline function __kron(a, b)
@show size(a), size(b)
@inline _kron(a, b) = map(__kron, a, b)
@inline function __kron(a::AbstractVector, b::AbstractVector)
return vec(kron(reshape(a, :, 1), reshape(b, 1, :)))
end
5 changes: 4 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,7 @@ type-assert for `x`.
@inline _unsqueeze1(x::AbstractArray) = reshape(x, 1, size(x)...)
@inline _unsqueezeN(x::AbstractArray) = reshape(x, size(x)..., 1)

@inline _batchview(x::AbstractArray{T, N}) where {T, N} = Lux._eachslice(x, Val(N))
@inline _catN(x::AbstractArray{T, N}, y::AbstractArray{T, N}) where {T, N} = cat(
x, y; dims=Val(N))

@inline _stack(xs) = mapreduce(_unsqueezeN, _catN, xs)
26 changes: 26 additions & 0 deletions test/layer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,29 @@ end
end
end
end

@testitem "Tensor Product Layer" setup=[SharedTestSetup] tags=[:layers] begin
@testset "$(mode)" for (mode, aType, dev, ongpu) in MODES
mode === "AMDGPU" && continue

@testset "$(basis)" for basis in (Basis.Chebyshev, Basis.Sin, Basis.Cos,
Basis.Fourier, Basis.Legendre, Basis.Polynomial)
tensor_project = Layers.TensorProductLayer([basis(n + 2) for n in 1:3], 4)
ps, st = Lux.setup(Xoshiro(0), tensor_project) |> dev

x = tanh.(randn(Float32, 2, 4, 5)) |> aType

@test_throws ArgumentError tensor_project(x, ps, st)

x = tanh.(randn(Float32, 2, 3, 5)) |> aType

y, st = tensor_project(x, ps, st)
@test size(y) == (2, 4, 5)

@jet tensor_project(x, ps, st)

__f = (x, ps) -> sum(abs2, first(tensor_project(x, ps, st)))
@eval @test_gradients $(__f) $x $ps gpu_testing=$(ongpu) atol=1e-3 rtol=1e-3 skip_tracker=true
end
end
end

0 comments on commit f10ba33

Please sign in to comment.