diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 5b5df1a8..f139fc03 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -81,7 +81,7 @@ function eval_tree_array( error("Please load the LoopVectorization.jl package to use this feature.") end if (v_turbo isa Val{true} || v_bumper isa Val{true}) && !(T <: Number) - error("Bumper feature only works with numbers") + error("Bumper and LoopVectorization features are only compatible with numeric element types") end if v_bumper isa Val{true} return bumper_eval_tree_array(tree, cX, operators, v_turbo) @@ -97,7 +97,7 @@ function eval_tree_array( operators::OperatorEnum; kws... ) where {T} - return eval_tree_array(tree, reshape(cX, (size(cX)[1], 1))::AbstractMatrix{T}, operators; kws...) + return eval_tree_array(tree, reshape(cX, (size(cX, 1), 1)), operators; kws...) end function eval_tree_array( diff --git a/test/test_non_number_eval_tree_array.jl b/test/test_non_number_eval_tree_array.jl index f7ec071d..088dba02 100644 --- a/test/test_non_number_eval_tree_array.jl +++ b/test/test_non_number_eval_tree_array.jl @@ -8,7 +8,7 @@ struct SVM{T} scalar :: T vector :: Vector{T} matrix :: Matrix{T} - SVM{T}() where {T} = new(Int8(0), zero(T), T[], T[;;]) + SVM{T}() where {T} = new(Int8(0), zero(T), T[], Array{T}(undef, 0, 0)) SVM{T}(scalar :: W) where {T, W <: Number} = new(Int8(0), Base.convert(T, scalar), T[], T[;;]) SVM{T}(vector :: Vector{W}) where {T, W <: Number} = new(Int8(1), zero(T), Vector{T}(vector) , T[;;]) SVM{T}(matrix :: Matrix{W}) where {T, W <: Number} = new(Int8(2), zero(T), T[], Matrix{T}(matrix)) @@ -29,8 +29,7 @@ end function Base.:(==)(x::SVM{T}, y::SVM{T}) where T if x.dims !== y.dims return false - end - if x.dims == 0 + elseif x.dims == 0 return x.scalar == y.scalar elseif val.dims == 1 return x.vector == y.vector @@ -61,7 +60,7 @@ Base.invokelatest(() -> begin @test !hasmethod(a, Tuple{Node{SVM{Float32}}, Node{SVM{Float32}}}) tree = a(Node{SVM{Float64}}(; feature=1), SVM{Float64}(3.0)) - results = tree([SVM{Float64}(1.0);; SVM{Float64}(2.0);; SVM{Float64}(3.0)]) + results = tree([SVM{Float64}(1.0) SVM{Float64}(2.0) SVM{Float64}(3.0)]) @test results == [SVM{Float64}(4), SVM{Float64}(5), SVM{Float64}(6)]