Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parameter to disable early exit of expression evaluation #91

Merged
merged 44 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
8df4bfe
add early_exit argument
nmheim Jun 30, 2024
8f1d2bd
make tests pass; add first test for early_exit
nmheim Jul 1, 2024
5cbe6a5
bumper & loopvec
nmheim Jul 3, 2024
82c4224
format
nmheim Jul 3, 2024
984a1ba
Merge branch 'SymbolicML:master' into nh/early-exit
nmheim Jul 3, 2024
1ee6884
introduce EvaluationOptions
nmheim Jul 8, 2024
86f8316
Merge branch 'master' into pr/nmheim/91
MilesCranmer Jul 19, 2024
6cab047
style: formatting
MilesCranmer Jul 19, 2024
6d46df9
style: more formatting
MilesCranmer Jul 19, 2024
b8f1087
style: clean up redundant options
MilesCranmer Jul 19, 2024
ee7d7c1
style: rename to `eval_options`
MilesCranmer Jul 19, 2024
c0b5a46
fix: merge edits to eval options
MilesCranmer Jul 19, 2024
17ae595
fix: fix generic eval errors
MilesCranmer Jul 19, 2024
2b98acf
refactor: test_evaluation.jl
MilesCranmer Jul 19, 2024
5399625
fix: specific branch calls
MilesCranmer Jul 19, 2024
204a9df
fix: `v_throw_errors` typo
MilesCranmer Jul 19, 2024
2fc5e87
fix: error catching for generic eval
MilesCranmer Jul 19, 2024
660d6f8
style: rename `EvaluationOptions` to `EvalOptions`
MilesCranmer Jul 19, 2024
dd24df6
test: fix initial errors test
MilesCranmer Jul 19, 2024
6302012
fix: type unstalbe tests
nmheim Jul 22, 2024
a73a04f
add doc strings
nmheim Jul 22, 2024
fe30e8b
update docs
nmheim Jul 22, 2024
944b2e8
format
nmheim Jul 22, 2024
905f5e1
approx equal
nmheim Jul 23, 2024
0a2bb96
fix enzyme test
nmheim Jul 24, 2024
6bd504b
Update docs/src/eval.md
nmheim Jul 25, 2024
54c6398
test: disable enzyme test
MilesCranmer Jul 24, 2024
958b9af
test: skip Enzyme test completely
MilesCranmer Jul 25, 2024
1365a55
test: fix Enzyme test
MilesCranmer Jul 26, 2024
cbcd221
style: fix formatting
MilesCranmer Jul 26, 2024
87c6225
ci: install fixed Enzyme
MilesCranmer Jul 26, 2024
86d2096
fix issue due to https://github.com/JuliaLang/Pkg.jl/issues/1585
MilesCranmer Jul 26, 2024
796cbae
fix custom enzyme install
MilesCranmer Jul 26, 2024
8a1ce63
Merge branch 'master' into nh/early-exit
MilesCranmer Jul 27, 2024
5a30d47
ci: remove Enzyme revision test
MilesCranmer Jul 27, 2024
64c797c
refactor: reduce code complexity of eval options
MilesCranmer Jul 27, 2024
a87016e
docs: render `EvalOptions` in docs
MilesCranmer Jul 27, 2024
ace5c19
test: more coverage of `EvalOptions` branches
MilesCranmer Jul 28, 2024
2c34b4e
test: prevent soft scope problem
MilesCranmer Jul 28, 2024
3113499
refactor: clean up evaluation
MilesCranmer Jul 28, 2024
586baf8
benchmarks: fix benchmark eval options
MilesCranmer Jul 28, 2024
3d7b529
fix: note instability in kw deprecation
MilesCranmer Jul 28, 2024
09b7a3d
feat: also include `early_exit` in scalar checks
MilesCranmer Jul 28, 2024
17a4a24
fix: incorporate `@return_on_nonfinite_val` in LoopVectorization exte…
MilesCranmer Jul 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,18 @@ function benchmark_evaluation()
extra_kws...
)
suite[T]["evaluation$(extra_key)"] = @benchmarkable(
[eval_tree_array(tree, X, $operators; turbo=$turbo, $extra_kws...) for tree in trees],
[eval_tree_array(tree, X, $operators; kws...) for tree in trees],
setup=(
X=randn(MersenneTwister(0), $T, 5, $n);
treesize=20;
ntrees=100;
kws=$(
if @isdefined(EvalOptions)
(; eval_options=EvalOptions(; turbo=turbo, extra_kws...))
else
(; turbo, extra_kws...)
end
);
trees=[gen_random_tree_fixed_size(treesize, $operators, 5, $T) for _ in 1:ntrees]
)
)
Expand Down
23 changes: 17 additions & 6 deletions docs/src/eval.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@ Given an expression tree specified with a `Node` type, you may evaluate the expr
over an array of data with the following command:

```@docs
eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum) where {T<:Number}
eval_tree_array(
tree::AbstractExpressionNode{T},
cX::AbstractMatrix{T},
operators::OperatorEnum;
eval_options::Union{EvalOptions,Nothing}=nothing,
) where {T}
```

Assuming you are only using a single `OperatorEnum`, you can also use
the following shorthand by using the expression as a function:
You can also use the following shorthand by using the expression as a function:

```
(tree::AbstractExpressionNode)(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Union{Bool,Val}=false, bumper::Union{Bool,Val}=Val(false))
(tree::AbstractExpressionNode)(X, operators::OperatorEnum; kws...)

Evaluate a binary tree (equation) over a given input data matrix. The
operators contain all of the operators used. This function fuses doublets
Expand All @@ -23,8 +27,7 @@ and triplets of operations for lower memory usage.
- `tree::AbstractExpressionNode`: The root node of the tree to evaluate.
- `cX::AbstractMatrix{T}`: The input data to evaluate the tree on.
- `operators::OperatorEnum`: The operators used in the tree.
- `turbo::Union{Bool,Val}`: Use LoopVectorization.jl for faster evaluation.
- `bumper::Union{Bool,Val}`: Use Bumper.jl for faster evaluation.
- `kws...`: Passed to [`eval_tree_array`](@ref).

# Returns
- `output::AbstractVector{T}`: the result, which is a 1D array.
Expand Down Expand Up @@ -53,6 +56,14 @@ It also re-defines `print`, `show`, and the various operators, to work with the
Thus, if you define an expression with one `OperatorEnum`, and then try to
evaluate it or print it with a different `OperatorEnum`, you will get undefined behavior!

For safer behavior, you should use [`Expression`](@ref) objects.

Evaluation options are specified using `EvalOptions`:

```@docs
EvalOptions
```

You can also work with arbitrary types, by defining a `GenericOperatorEnum` instead.
The notation is the same for `eval_tree_array`, though it will return `nothing`
when it can't find a method, and not do any NaN checks:
Expand Down
44 changes: 25 additions & 19 deletions ext/DynamicExpressionsBumperExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module DynamicExpressionsBumperExt

using Bumper: @no_escape, @alloc
using DynamicExpressions:
OperatorEnum, AbstractExpressionNode, tree_mapreduce, is_valid_array
OperatorEnum, AbstractExpressionNode, tree_mapreduce, is_valid_array, EvalOptions
using DynamicExpressions.UtilsModule: ResultOk, counttuple

import DynamicExpressions.ExtensionInterfaceModule:
Expand All @@ -12,8 +12,8 @@ function bumper_eval_tree_array(
tree::AbstractExpressionNode{T},
cX::AbstractMatrix{T},
operators::OperatorEnum,
::Val{turbo},
) where {T,turbo}
eval_options::EvalOptions{turbo,true,early_exit},
) where {T,turbo,early_exit}
result = similar(cX, axes(cX, 2))
n = size(cX, 2)
all_ok = Ref(false)
Expand All @@ -26,7 +26,7 @@ function bumper_eval_tree_array(
ok = if leaf_node.constant
v = leaf_node.val
ar .= v
isfinite(v)
early_exit ? isfinite(v) : true
else
ar .= view(cX, leaf_node.feature, :)
true
Expand All @@ -38,7 +38,7 @@ function bumper_eval_tree_array(
# In the evaluation kernel, we combine the branch nodes
# with the arrays created by the leaf nodes:
((args::Vararg{Any,M}) where {M}) ->
dispatch_kerns!(operators, args..., Val(turbo)),
dispatch_kerns!(operators, args..., eval_options),
tree;
break_sharing=Val(true),
)
Expand All @@ -49,55 +49,61 @@ function bumper_eval_tree_array(
return (result, all_ok[])
end

function dispatch_kerns!(operators, branch_node, cumulator, ::Val{turbo}) where {turbo}
function dispatch_kerns!(
operators, branch_node, cumulator, eval_options::EvalOptions{<:Any,true,early_exit}
) where {early_exit}
cumulator.ok || return cumulator

out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, Val(turbo))
return ResultOk(out, is_valid_array(out))
out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, eval_options)
return ResultOk(out, early_exit ? is_valid_array(out) : true)
end
function dispatch_kerns!(
operators, branch_node, cumulator1, cumulator2, ::Val{turbo}
) where {turbo}
operators,
branch_node,
cumulator1,
cumulator2,
eval_options::EvalOptions{<:Any,true,early_exit},
) where {early_exit}
cumulator1.ok || return cumulator1
cumulator2.ok || return cumulator2

out = dispatch_kern2!(
operators.binops, branch_node.op, cumulator1.x, cumulator2.x, Val(turbo)
operators.binops, branch_node.op, cumulator1.x, cumulator2.x, eval_options
)
return ResultOk(out, is_valid_array(out))
return ResultOk(out, early_exit ? is_valid_array(out) : true)
end

@generated function dispatch_kern1!(unaops, op_idx, cumulator, ::Val{turbo}) where {turbo}
@generated function dispatch_kern1!(unaops, op_idx, cumulator, eval_options::EvalOptions)
nuna = counttuple(unaops)
quote
Base.@nif(
$nuna,
i -> i == op_idx,
i -> let op = unaops[i]
return bumper_kern1!(op, cumulator, Val(turbo))
return bumper_kern1!(op, cumulator, eval_options)
end,
)
end
end
@generated function dispatch_kern2!(
binops, op_idx, cumulator1, cumulator2, ::Val{turbo}
) where {turbo}
binops, op_idx, cumulator1, cumulator2, eval_options::EvalOptions
)
nbin = counttuple(binops)
quote
Base.@nif(
$nbin,
i -> i == op_idx,
i -> let op = binops[i]
return bumper_kern2!(op, cumulator1, cumulator2, Val(turbo))
return bumper_kern2!(op, cumulator1, cumulator2, eval_options)
end,
)
end
end
function bumper_kern1!(op::F, cumulator, ::Val{false}) where {F}
function bumper_kern1!(op::F, cumulator, ::EvalOptions{false,true}) where {F}
@. cumulator = op(cumulator)
return cumulator
end
function bumper_kern2!(op::F, cumulator1, cumulator2, ::Val{false}) where {F}
function bumper_kern2!(op::F, cumulator1, cumulator2, ::EvalOptions{false,true}) where {F}
@. cumulator1 = op(cumulator1, cumulator2)
return cumulator1
end
Expand Down
35 changes: 25 additions & 10 deletions ext/DynamicExpressionsLoopVectorizationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module DynamicExpressionsLoopVectorizationExt
using LoopVectorization: @turbo
using DynamicExpressions: AbstractExpressionNode
using DynamicExpressions.UtilsModule: ResultOk, fill_similar
using DynamicExpressions.EvaluateModule: @return_on_check
using DynamicExpressions.EvaluateModule: @return_on_check, EvalOptions
import DynamicExpressions.EvaluateModule:
deg1_eval,
deg2_eval,
Expand All @@ -18,7 +18,10 @@ import DynamicExpressions.ExtensionInterfaceModule:
_is_loopvectorization_loaded(::Int) = true

function deg2_eval(
cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, ::Val{true}
cumulator_l::AbstractVector{T},
cumulator_r::AbstractVector{T},
op::F,
::EvalOptions{true},
)::ResultOk where {T<:Number,F}
@turbo for j in eachindex(cumulator_l)
x = op(cumulator_l[j], cumulator_r[j])
Expand All @@ -28,7 +31,7 @@ function deg2_eval(
end

function deg1_eval(
cumulator::AbstractVector{T}, op::F, ::Val{true}
cumulator::AbstractVector{T}, op::F, ::EvalOptions{true}
)::ResultOk where {T<:Number,F}
@turbo for j in eachindex(cumulator)
x = op(cumulator[j])
Expand All @@ -38,7 +41,11 @@ function deg1_eval(
end

function deg1_l2_ll0_lr0_eval(
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{true}
tree::AbstractExpressionNode{T},
cX::AbstractMatrix{T},
op::F,
op_l::F2,
::EvalOptions{true},
) where {T<:Number,F,F2}
if tree.l.l.constant && tree.l.r.constant
val_ll = tree.l.l.val
Expand Down Expand Up @@ -86,7 +93,11 @@ function deg1_l2_ll0_lr0_eval(
end

function deg1_l1_ll0_eval(
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{true}
tree::AbstractExpressionNode{T},
cX::AbstractMatrix{T},
op::F,
op_l::F2,
::EvalOptions{true},
) where {T<:Number,F,F2}
if tree.l.l.constant
val_ll = tree.l.l.val
Expand All @@ -109,7 +120,7 @@ function deg1_l1_ll0_eval(
end

function deg2_l0_r0_eval(
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::Val{true}
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::EvalOptions{true}
) where {T<:Number,F}
if tree.l.constant && tree.r.constant
val_l = tree.l.val
Expand Down Expand Up @@ -157,7 +168,7 @@ function deg2_l0_eval(
cumulator::AbstractVector{T},
cX::AbstractArray{T},
op::F,
::Val{true},
::EvalOptions{true},
) where {T<:Number,F}
if tree.l.constant
val = tree.l.val
Expand All @@ -182,7 +193,7 @@ function deg2_r0_eval(
cumulator::AbstractVector{T},
cX::AbstractArray{T},
op::F,
::Val{true},
::EvalOptions{true},
) where {T<:Number,F}
if tree.r.constant
val = tree.r.val
Expand All @@ -203,11 +214,15 @@ function deg2_r0_eval(
end

## Interface with Bumper.jl
function bumper_kern1!(op::F, cumulator, ::Val{true}) where {F}
function bumper_kern1!(
op::F, cumulator, ::EvalOptions{true,true,early_exit}
) where {F,early_exit}
@turbo @. cumulator = op(cumulator)
return cumulator
end
function bumper_kern2!(op::F, cumulator1, cumulator2, ::Val{true}) where {F}
function bumper_kern2!(
op::F, cumulator1, cumulator2, ::EvalOptions{true,true,early_exit}
) where {F,early_exit}
@turbo @. cumulator1 = op(cumulator1, cumulator2)
return cumulator1
end
Expand Down
3 changes: 2 additions & 1 deletion src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ import .NodeModule:
@reexport import .OperatorEnumModule: AbstractOperatorEnum
@reexport import .OperatorEnumConstructionModule:
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array
@reexport import .EvaluateModule:
eval_tree_array, differentiable_eval_tree_array, EvalOptions
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
@reexport import .ChainRulesModule: NodeTangent, extract_gradient
@reexport import .SimplifyModule: combine_operators, simplify_tree!
Expand Down
Loading
Loading