Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Miles Cranmer <[email protected]>
  • Loading branch information
gca30 and MilesCranmer authored Jun 26, 2024
1 parent da1886f commit 8d2797e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/Evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function eval_tree_array(
error("Please load the LoopVectorization.jl package to use this feature.")
end
if (v_turbo isa Val{true} || v_bumper isa Val{true}) && !(T <: Number)
error("Bumper feature only works with numbers")
error("Bumper and LoopVectorization features are only compatible with numeric element types")
end
if v_bumper isa Val{true}
return bumper_eval_tree_array(tree, cX, operators, v_turbo)
Expand All @@ -97,7 +97,7 @@ function eval_tree_array(
operators::OperatorEnum;
kws...
) where {T}
return eval_tree_array(tree, reshape(cX, (size(cX)[1], 1))::AbstractMatrix{T}, operators; kws...)
return eval_tree_array(tree, reshape(cX, (size(cX, 1), 1)), operators; kws...)
end

function eval_tree_array(
Expand Down
7 changes: 3 additions & 4 deletions test/test_non_number_eval_tree_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ struct SVM{T}
scalar :: T
vector :: Vector{T}
matrix :: Matrix{T}
SVM{T}() where {T} = new(Int8(0), zero(T), T[], T[;;])
SVM{T}() where {T} = new(Int8(0), zero(T), T[], Array{T}(undef, 0, 0))
SVM{T}(scalar :: W) where {T, W <: Number} = new(Int8(0), Base.convert(T, scalar), T[], T[;;])
SVM{T}(vector :: Vector{W}) where {T, W <: Number} = new(Int8(1), zero(T), Vector{T}(vector) , T[;;])
SVM{T}(matrix :: Matrix{W}) where {T, W <: Number} = new(Int8(2), zero(T), T[], Matrix{T}(matrix))
Expand All @@ -29,8 +29,7 @@ end
function Base.:(==)(x::SVM{T}, y::SVM{T}) where T
if x.dims !== y.dims
return false
end
if x.dims == 0
elseif x.dims == 0
return x.scalar == y.scalar
elseif val.dims == 1
return x.vector == y.vector
Expand Down Expand Up @@ -61,7 +60,7 @@ Base.invokelatest(() -> begin
@test !hasmethod(a, Tuple{Node{SVM{Float32}}, Node{SVM{Float32}}})

tree = a(Node{SVM{Float64}}(; feature=1), SVM{Float64}(3.0))
results = tree([SVM{Float64}(1.0);; SVM{Float64}(2.0);; SVM{Float64}(3.0)])
results = tree([SVM{Float64}(1.0) SVM{Float64}(2.0) SVM{Float64}(3.0)])
@test results == [SVM{Float64}(4), SVM{Float64}(5), SVM{Float64}(6)]


Expand Down

0 comments on commit 8d2797e

Please sign in to comment.