-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compiling Recurrent Models with Reactant #1025
Comments
Error for full context: 1-element ExceptionStack:
LoadError: MethodError: no method matching Float32(::Reactant.TracedRNumber{Float32})
The type `Float32` exists, but no method is defined for this combination of argument types when trying to construct it.
Closest candidates are:
(::Type{T})(::T) where T<:Number
@ Core boot.jl:900
Float32(::IrrationalConstants.Log2π)
@ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:113
Float32(::IrrationalConstants.Halfπ)
@ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:113
...
Stacktrace:
[1] convert(::Type{Float32}, x::Reactant.TracedRNumber{Float32})
@ Base ./number.jl:7
[2] unsafe_store!(p::Ptr{Float32}, x::Reactant.TracedRNumber{Float32}, i::Int64)
@ Base ./pointer.jl:180
[3] setindex!(::ConcreteRArray{Float32, 2}, ::Reactant.TracedRNumber{Float32}, ::Int64, ::Int64)
@ Reactant /mnt/software/lux/Reactant.jl/src/ConcreteRArray.jl:233
[4] macro expansion
@ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:894 [inlined]
[5] macro expansion
@ ./simdloop.jl:77 [inlined]
[6] _generic_matmatmul!(C::ConcreteRArray{Float32, 2}, A::Reactant.TracedRArray{Float32, 2}, B::ConcreteRArray{Float32, 2}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
@ LinearAlgebra ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:893
[7] generic_matmatmul!
@ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:868 [inlined]
[8] _mul!
@ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:287 [inlined]
[9] mul!
@ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
[10] mul!
@ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253 [inlined]
[11] *
@ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:114 [inlined]
[12] muladd(A::Reactant.TracedRArray{Float32, 2}, y::ConcreteRArray{Float32, 2}, z::Reactant.TracedRArray{Float32, 1})
@ LinearAlgebra ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:180
[13] matmuladd
@ ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:12 [inlined]
[14] matmuladd
@ ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:7 [inlined]
[15] fused_dense
@ ~/.julia/packages/LuxLib/I9RHW/src/impl/dense.jl:6 [inlined]
[16] fused_dense_bias_activation
@ ~/.julia/packages/LuxLib/I9RHW/src/api/dense.jl:35 [inlined]
[17] RNNCell
@ ~/.julia/packages/Lux/atwzZ/src/layers/recurrent.jl:291 [inlined]
[18] (::RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True})(x::SubArray{Float32, 2, Reactant.TracedRArray{Float32, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}}, false}, ps::@NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, st::@NamedTuple{rng::Xoshiro})
@ Lux ~/.julia/packages/Lux/atwzZ/src/layers/recurrent.jl:277
[19] apply
@ ~/.julia/packages/LuxCore/IBKvY/src/LuxCore.jl:155 [inlined]
[20] (::Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex})(x::Vector{SubArray{Float32, 2, Reactant.TracedRArray{Float32, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}}, false}}, ps::@NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, st::@NamedTuple{rng::Xoshiro})
@ Lux ~/.julia/packages/Lux/atwzZ/src/layers/recurrent.jl:118
[21] apply
@ ~/.julia/packages/LuxCore/IBKvY/src/LuxCore.jl:155 [inlined]
[22] Recurrence
@ ~/.julia/packages/Lux/atwzZ/src/layers/recurrent.jl:114 [inlined]
[23] #apply#19
@ /mnt/software/lux/Reactant.jl/src/utils.jl:33 [inlined]
[24] apply
@ /mnt/software/lux/Reactant.jl/src/utils.jl:32 [inlined]
[25] (::Tuple{})(none::Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, none::Tuple{Reactant.TracedRArray{Float32, 3}, @NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}})
@ Base.Experimental ./<missing>:0
[26] (::Reactant.var"#26#35"{Bool, typeof(Reactant.apply), Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, ConcreteRArray{Float32, 3}, @NamedTuple{weight_ih::ConcreteRArray{Float32, 2}, weight_hh::ConcreteRArray{Float32, 2}, bias_ih::ConcreteRArray{Float32, 1}, bias_hh::ConcreteRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}, Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Reactant.TracedRArray{Float32, 3}, @NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}})()
@ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:139
[27] block!(f::Reactant.var"#26#35"{Bool, typeof(Reactant.apply), Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, ConcreteRArray{Float32, 3}, @NamedTuple{weight_ih::ConcreteRArray{Float32, 2}, weight_hh::ConcreteRArray{Float32, 2}, bias_ih::ConcreteRArray{Float32, 1}, bias_hh::ConcreteRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}, Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Reactant.TracedRArray{Float32, 3}, @NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
[28] make_mlir_fn(f::Function, args::Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, ConcreteRArray{Float32, 3}, @NamedTuple{weight_ih::ConcreteRArray{Float32, 2}, weight_hh::ConcreteRArray{Float32, 2}, bias_ih::ConcreteRArray{Float32, 1}, bias_hh::ConcreteRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool)
@ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:112
[29] make_mlir_fn(f::Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, args::Vector{Any}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool)
@ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:48
[30] make_mlir_fn
@ /mnt/software/lux/Reactant.jl/src/utils.jl:36 [inlined]
[31] #6
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:270 [inlined]
[32] block!(f::Reactant.Compiler.var"#6#11"{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Vector{Any}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
[33] #5
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:269 [inlined]
[34] mmodule!(f::Reactant.Compiler.var"#5#10"{Reactant.MLIR.IR.Module, Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Vector{Any}}, blk::Reactant.MLIR.IR.Module)
@ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Module.jl:93
[35] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, args::Vector{Any}; optimize::Bool)
@ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:266
[36] compile_mlir!
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:265 [inlined]
[37] #2
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:260 [inlined]
[38] context!(f::Reactant.Compiler.var"#2#3"{@Kwargs{optimize::Bool}, Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Vector{Any}}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76
[39] #compile_mlir#1
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:258 [inlined]
[40] top-level scope
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:419
[41] eval
@ ./boot.jl:430 [inlined]
[42] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
@ Base ./loading.jl:2643
[43] invokelatest(::Any, ::Any, ::Vararg{Any}; kwargs::@Kwargs{})
@ Base ./essentials.jl:1055
[44] invokelatest(::Any, ::Any, ::Vararg{Any})
@ Base ./essentials.jl:1052
[45] inlineeval(m::Module, code::String, code_line::Int64, code_column::Int64, file::String; softscope::Bool)
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:271
[46] (::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:181
[47] withpath(f::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, path::String)
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:276
[48] (::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:179
[49] hideprompt(f::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:38
[50] #67
@ ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:150 [inlined]
[51] with_logstate(f::VSCodeServer.var"#67#72"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, logstate::Base.CoreLogging.LogState)
@ Base.CoreLogging ./logging/logging.jl:522
[52] with_logger
@ ./logging/logging.jl:632 [inlined]
[53] (::VSCodeServer.var"#66#71"{VSCodeServer.ReplRunCodeRequestParams})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:263
[54] #invokelatest#2
@ ./essentials.jl:1055 [inlined]
[55] invokelatest(::Any)
@ Base ./essentials.jl:1052
in expression starting at /mnt/software/lux/Reactant.jl/envs/lux/rnn.jl:7 |
The other probably nicer way is to just write the |
@avik-pal from the error log above it looks like we don't override the in place mul method (and should). Specifically:
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The issue originates from the following function moving the data to a ConcreteRArray instead of a TracedRArray when run inside a compilation context. I could use
@reactant_override
to define custom dispatches when used insideReactant.compile
, but I am not sure that should be the recommended usageThe text was updated successfully, but these errors were encountered: