diff --git a/Project.toml b/Project.toml index e1e4629..82241d9 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -49,8 +50,10 @@ Lux = "0.5.50" LuxAMDGPU = "0.2.3" LuxCUDA = "0.3.2" LuxCore = "0.1.15" +LuxDeviceUtils = "0.1.21" LuxLib = "0.3.26" LuxTestUtils = "0.1.15" +Markdown = "1.10" Metalhead = "0.9" NNlib = "0.9.17" Pkg = "1.10" diff --git a/src/basis.jl b/src/basis.jl index b1880b6..973da6b 100644 --- a/src/basis.jl +++ b/src/basis.jl @@ -13,7 +13,7 @@ function Base.show(io::IO, basis::GeneralBasisFunction{name}) where {name} print(io, "Basis.$(name)(order=$(basis.n))") end -function (basis::GeneralBasisFunction{name, F})(x::AbstractArray) where {name, F} +@inline function (basis::GeneralBasisFunction{name, F})(x::AbstractArray) where {name, F} return basis.f.(1:(basis.n), _unsqueeze1(x)) end @@ -66,7 +66,8 @@ $F_j(x) = j is even ? cos((j÷2)x) : sin((j÷2)x)$ => $[F_0(x), F_1(x), \dots, F Fourier(n) = GeneralBasisFunction{:Fourier}(__fourier, n) @inline function __fourier(i, x) - s, c = @fastmath sincos(i * x / 2) + # Don't use @fastmasth here, fast mode needs float but Zygote uses Duals for broadcast + s, c = sincos(i * x / 2) return ifelse(iseven(i), c, s) end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index c20b11f..fda83df 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -5,11 +5,12 @@ using PrecompileTools: @recompile_invalidations @recompile_invalidations begin using ArgCheck: @argcheck using ADTypes: AutoForwardDiff, AutoZygote - using ..Boltz: Boltz, _fast_chunk, _should_type_assert, _batchview, _unsqueezeN + using ..Boltz: Boltz, _fast_chunk, _should_type_assert, _stack using ConcreteStructs: @concrete using ChainRulesCore: ChainRulesCore using Lux: Lux, StatefulLuxLayer using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer + using LuxDeviceUtils: get_device, LuxCPUDevice, LuxCUDADevice using Markdown: @doc_str using NNlib: NNlib using Random: AbstractRNG diff --git a/src/layers/tensor_product.jl b/src/layers/tensor_product.jl index a50fe26..a320e04 100644 --- a/src/layers/tensor_product.jl +++ b/src/layers/tensor_product.jl @@ -4,7 +4,7 @@ Constructs the Tensor Product Layer, which takes as input an array of n tensor product basis, $[B_1, B_2, \dots, B_n]$ a data point x, computes -$$z_i = W_{i, :} ⨀ [B_1(x_1) ⨂ B_2(x_2) ⨂ \dots ⨂ B_n(x_n)]$$ +$$z_i = W_{i, :} \odot [B_1(x_1) \otimes B_2(x_2) \otimes \dots \otimes B_n(x_n)]$$ where $W$ is the layer's weight, and returns $[z_1, \dots, z_{out}]$. @@ -14,27 +14,33 @@ where $W$ is the layer's weight, and returns $[z_1, \dots, z_{out}]$. corresponds to the dimension of the input. - `out_dim`: Dimension of the output. - `init_weight`: Initializer for the weight matrix. Defaults to `randn32`. + +!!! warning + + This layer currently only works on CPU and CUDA devices. """ function TensorProductLayer(basis_fns, out_dim::Int; init_weight::F=randn32) where {F} dense = Lux.Dense( prod(Base.Fix2(getproperty, :n), basis_fns) => out_dim; use_bias=false, init_weight) return Lux.@compact(; basis_fns=Tuple(basis_fns), dense, out_dim, dispatch=:TensorProductLayer) do x::AbstractArray # I1 x I2 x ... x T x B + dev = get_device(x) + @argcheck dev isa LuxCPUDevice || dev isa LuxCUDADevice # kron is not widely supported + x_ = Lux._eachslice(x, Val(ndims(x) - 1)) # [I1 x I2 x ... x B] x T @argcheck length(x_) == length(basis_fns) - y_ = mapfoldl(_kron, zip(basis_fns, x_)) do (fn, xᵢ) + y = mapfoldl(_kron, zip(basis_fns, x_)) do (fn, xᵢ) eachcol(reshape(fn(xᵢ), :, prod(size(xᵢ)))) end # [[D₁ x ... x Dₙ] x (I1 x I2 x ... x B)] - @return reshape(dense(stack(y_)), size(x)[1:(end - 2)]..., out_dim, size(x)[end]) + @return reshape(dense(_stack(y)), size(x)[1:(end - 2)]..., out_dim, size(x)[end]) end end # CUDA `kron` exists only for `CuMatrix` so we define `kron` directly by converting to # a matrix -@noinline _kron(a, b) = map(__kron, a, b) -@noinline function __kron(a, b) - @show size(a), size(b) +@inline _kron(a, b) = map(__kron, a, b) +@inline function __kron(a::AbstractVector, b::AbstractVector) return vec(kron(reshape(a, :, 1), reshape(b, 1, :))) end diff --git a/src/utils.jl b/src/utils.jl index 08a4cfc..dfee33d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -44,4 +44,7 @@ type-assert for `x`. @inline _unsqueeze1(x::AbstractArray) = reshape(x, 1, size(x)...) @inline _unsqueezeN(x::AbstractArray) = reshape(x, size(x)..., 1) -@inline _batchview(x::AbstractArray{T, N}) where {T, N} = Lux._eachslice(x, Val(N)) +@inline _catN(x::AbstractArray{T, N}, y::AbstractArray{T, N}) where {T, N} = cat( + x, y; dims=Val(N)) + +@inline _stack(xs) = mapreduce(_unsqueezeN, _catN, xs) diff --git a/test/layer_tests.jl b/test/layer_tests.jl index 3ef4802..32d987d 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -81,3 +81,29 @@ end end end end + +@testitem "Tensor Product Layer" setup=[SharedTestSetup] tags=[:layers] begin + @testset "$(mode)" for (mode, aType, dev, ongpu) in MODES + mode === "AMDGPU" && continue + + @testset "$(basis)" for basis in (Basis.Chebyshev, Basis.Sin, Basis.Cos, + Basis.Fourier, Basis.Legendre, Basis.Polynomial) + tensor_project = Layers.TensorProductLayer([basis(n + 2) for n in 1:3], 4) + ps, st = Lux.setup(Xoshiro(0), tensor_project) |> dev + + x = tanh.(randn(Float32, 2, 4, 5)) |> aType + + @test_throws ArgumentError tensor_project(x, ps, st) + + x = tanh.(randn(Float32, 2, 3, 5)) |> aType + + y, st = tensor_project(x, ps, st) + @test size(y) == (2, 4, 5) + + @jet tensor_project(x, ps, st) + + __f = (x, ps) -> sum(abs2, first(tensor_project(x, ps, st))) + @eval @test_gradients $(__f) $x $ps gpu_testing=$(ongpu) atol=1e-3 rtol=1e-3 skip_tracker=true + end + end +end