Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compiling Recurrent Models with Reactant #1025

Open
avik-pal opened this issue Nov 4, 2024 · 3 comments · May be fixed by #1026
Open

Compiling Recurrent Models with Reactant #1025

avik-pal opened this issue Nov 4, 2024 · 3 comments · May be fixed by #1026

Comments

@avik-pal
Copy link
Member

avik-pal commented Nov 4, 2024

using Lux, Reactant, Random

model = Recurrence(RNNCell(4 => 4))
ps, st = Lux.setup(Xoshiro(123), model) |> Reactant.to_rarray
x = rand(Float32, 4, 16, 12) |> Reactant.ConcreteRArray

@code_hlo model(x, ps, st)

The issue originates from the following function moving the data to a ConcreteRArray instead of a TracedRArray when run inside a compilation context. I could use @reactant_override to define custom dispatches when used inside Reactant.compile, but I am not sure that should be the recommended usage

function init_rnn_hidden_state(rng::AbstractRNG, rnn, x::AbstractMatrix)
    # TODO: Once we support moving `rng` to the device, we can directly initialize on the
    #       device
    return rnn.init_state(rng, rnn.out_dims, Base.size(x, 2)) |> get_device(x)
end
@avik-pal
Copy link
Member Author

avik-pal commented Nov 4, 2024

Error for full context:

1-element ExceptionStack:
LoadError: MethodError: no method matching Float32(::Reactant.TracedRNumber{Float32})
The type `Float32` exists, but no method is defined for this combination of argument types when trying to construct it.

Closest candidates are:
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:900
  Float32(::IrrationalConstants.Log2π)
   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:113
  Float32(::IrrationalConstants.Halfπ)
   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:113
  ...

Stacktrace:
  [1] convert(::Type{Float32}, x::Reactant.TracedRNumber{Float32})
    @ Base ./number.jl:7
  [2] unsafe_store!(p::Ptr{Float32}, x::Reactant.TracedRNumber{Float32}, i::Int64)
    @ Base ./pointer.jl:180
  [3] setindex!(::ConcreteRArray{Float32, 2}, ::Reactant.TracedRNumber{Float32}, ::Int64, ::Int64)
    @ Reactant /mnt/software/lux/Reactant.jl/src/ConcreteRArray.jl:233
  [4] macro expansion
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:894 [inlined]
  [5] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [6] _generic_matmatmul!(C::ConcreteRArray{Float32, 2}, A::Reactant.TracedRArray{Float32, 2}, B::ConcreteRArray{Float32, 2}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:893
  [7] generic_matmatmul!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:868 [inlined]
  [8] _mul!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:287 [inlined]
  [9] mul!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
 [10] mul!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253 [inlined]
 [11] *
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:114 [inlined]
 [12] muladd(A::Reactant.TracedRArray{Float32, 2}, y::ConcreteRArray{Float32, 2}, z::Reactant.TracedRArray{Float32, 1})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:180
 [13] matmuladd
    @ ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:12 [inlined]
 [14] matmuladd
    @ ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:7 [inlined]
 [15] fused_dense
    @ ~/.julia/packages/LuxLib/I9RHW/src/impl/dense.jl:6 [inlined]
 [16] fused_dense_bias_activation
    @ ~/.julia/packages/LuxLib/I9RHW/src/api/dense.jl:35 [inlined]
 [17] RNNCell
    @ ~/.julia/packages/Lux/atwzZ/src/layers/recurrent.jl:291 [inlined]
 [18] (::RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True})(x::SubArray{Float32, 2, Reactant.TracedRArray{Float32, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}}, false}, ps::@NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, st::@NamedTuple{rng::Xoshiro})
    @ Lux ~/.julia/packages/Lux/atwzZ/src/layers/recurrent.jl:277
 [19] apply
    @ ~/.julia/packages/LuxCore/IBKvY/src/LuxCore.jl:155 [inlined]
 [20] (::Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex})(x::Vector{SubArray{Float32, 2, Reactant.TracedRArray{Float32, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}}, false}}, ps::@NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, st::@NamedTuple{rng::Xoshiro})
    @ Lux ~/.julia/packages/Lux/atwzZ/src/layers/recurrent.jl:118
 [21] apply
    @ ~/.julia/packages/LuxCore/IBKvY/src/LuxCore.jl:155 [inlined]
 [22] Recurrence
    @ ~/.julia/packages/Lux/atwzZ/src/layers/recurrent.jl:114 [inlined]
 [23] #apply#19
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:33 [inlined]
 [24] apply
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:32 [inlined]
 [25] (::Tuple{})(none::Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, none::Tuple{Reactant.TracedRArray{Float32, 3}, @NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}})
    @ Base.Experimental ./<missing>:0
 [26] (::Reactant.var"#26#35"{Bool, typeof(Reactant.apply), Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, ConcreteRArray{Float32, 3}, @NamedTuple{weight_ih::ConcreteRArray{Float32, 2}, weight_hh::ConcreteRArray{Float32, 2}, bias_ih::ConcreteRArray{Float32, 1}, bias_hh::ConcreteRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}, Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Reactant.TracedRArray{Float32, 3}, @NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}})()
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:139
 [27] block!(f::Reactant.var"#26#35"{Bool, typeof(Reactant.apply), Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, ConcreteRArray{Float32, 3}, @NamedTuple{weight_ih::ConcreteRArray{Float32, 2}, weight_hh::ConcreteRArray{Float32, 2}, bias_ih::ConcreteRArray{Float32, 1}, bias_hh::ConcreteRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}, Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Reactant.TracedRArray{Float32, 3}, @NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
 [28] make_mlir_fn(f::Function, args::Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, ConcreteRArray{Float32, 3}, @NamedTuple{weight_ih::ConcreteRArray{Float32, 2}, weight_hh::ConcreteRArray{Float32, 2}, bias_ih::ConcreteRArray{Float32, 1}, bias_hh::ConcreteRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool)
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:112
 [29] make_mlir_fn(f::Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, args::Vector{Any}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool)
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:48
 [30] make_mlir_fn
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:36 [inlined]
 [31] #6
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:270 [inlined]
 [32] block!(f::Reactant.Compiler.var"#6#11"{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Vector{Any}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
 [33] #5
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:269 [inlined]
 [34] mmodule!(f::Reactant.Compiler.var"#5#10"{Reactant.MLIR.IR.Module, Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Vector{Any}}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Module.jl:93
 [35] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, args::Vector{Any}; optimize::Bool)
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:266
 [36] compile_mlir!
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:265 [inlined]
 [37] #2
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:260 [inlined]
 [38] context!(f::Reactant.Compiler.var"#2#3"{@Kwargs{optimize::Bool}, Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Vector{Any}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76
 [39] #compile_mlir#1
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:258 [inlined]
 [40] top-level scope
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:419
 [41] eval
    @ ./boot.jl:430 [inlined]
 [42] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base ./loading.jl:2643
 [43] invokelatest(::Any, ::Any, ::Vararg{Any}; kwargs::@Kwargs{})
    @ Base ./essentials.jl:1055
 [44] invokelatest(::Any, ::Any, ::Vararg{Any})
    @ Base ./essentials.jl:1052
 [45] inlineeval(m::Module, code::String, code_line::Int64, code_column::Int64, file::String; softscope::Bool)
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:271
 [46] (::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:181
 [47] withpath(f::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, path::String)
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:276
 [48] (::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:179
 [49] hideprompt(f::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:38
 [50] #67
    @ ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:150 [inlined]
 [51] with_logstate(f::VSCodeServer.var"#67#72"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, logstate::Base.CoreLogging.LogState)
    @ Base.CoreLogging ./logging/logging.jl:522
 [52] with_logger
    @ ./logging/logging.jl:632 [inlined]
 [53] (::VSCodeServer.var"#66#71"{VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:263
 [54] #invokelatest#2
    @ ./essentials.jl:1055 [inlined]
 [55] invokelatest(::Any)
    @ Base ./essentials.jl:1052
in expression starting at /mnt/software/lux/Reactant.jl/envs/lux/rnn.jl:7

@avik-pal
Copy link
Member Author

avik-pal commented Nov 4, 2024

The other probably nicer way is to just write the init function with a copyto! into an array initialized using similar

@avik-pal avik-pal linked a pull request Nov 4, 2024 that will close this issue
@wsmoses
Copy link
Contributor

wsmoses commented Nov 5, 2024

@avik-pal from the error log above it looks like we don't override the in place mul method (and should).

Specifically:

  [9] mul!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
 [10] mul!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253 [inlined]
 [11] *
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:114 [inlined]
 [12] muladd(A::Reactant.TracedRArray{Float32, 2}, y::ConcreteRArray{Float32, 2}, z::Reactant.TracedRArray{Float32, 1})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:180

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants