Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid closure in tree_mapreduce #103

Merged
merged 10 commits into from
Oct 12, 2024
4 changes: 1 addition & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ jobs:
- macOS-latest
include:
- os: ubuntu-latest
julia-version: '1.7'
- os: ubuntu-latest
julia-version: '1.6'
julia-version: '1.10'

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/benchmark_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: "1.9"
version: "1"
- uses: julia-actions/cache@v1
- name: Extract Package Name from Project.toml
id: extract-package-name
Expand Down
8 changes: 2 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
name = "DynamicExpressions"
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
authors = ["MilesCranmer <[email protected]>"]
version = "1.0.1"
version = "1.1.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
Interfaces = "85a1e053-f937-4924-92a5-1367d23b7b87"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -32,18 +30,16 @@ DynamicExpressionsZygoteExt = "Zygote"
[compat]
Bumper = "0.6"
ChainRulesCore = "1"
Compat = "3.37, 4"
DispatchDoctor = "0.4"
Interfaces = "0.3"
LoopVectorization = "0.12"
MacroTools = "0.4, 0.5"
Optim = "0.19, 1"
PackageExtensionCompat = "1"
PrecompileTools = "1"
Reexport = "1"
SymbolicUtils = "0.19, ^1.0.5, 2, 3"
Zygote = "0.6"
julia = "1.6"
julia = "1.10"

[extras]
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
Expand Down
1 change: 0 additions & 1 deletion ext/DynamicExpressionsOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ using DynamicExpressions:
get_scalar_constants,
set_scalar_constants!,
get_number_type
using Compat: @inline

import Optim: Optim, OptimizationResults, NLSolversBase

Expand Down
5 changes: 0 additions & 5 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ using DispatchDoctor: @stable, @unstable
include("StructuredExpression.jl")
end

import PackageExtensionCompat: @require_extensions
import Reexport: @reexport
macro ignore(args...) end

Expand Down Expand Up @@ -104,10 +103,6 @@ end
import .InterfacesModule:
ExpressionInterface, NodeInterface, all_ei_methods_except, all_ni_methods_except

function __init__()
@require_extensions
end

include("deprecated.jl")

import TOML: parsefile
Expand Down
2 changes: 1 addition & 1 deletion src/Node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module NodeModule
using DispatchDoctor: @unstable

import ..OperatorEnumModule: AbstractOperatorEnum
import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined
import ..UtilsModule: deprecate_varmap, Undefined

const DEFAULT_NODE_TYPE = Float32

Expand Down
1 change: 0 additions & 1 deletion src/NodeUtils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module NodeUtilsModule

import Compat: Returns
import ..NodeModule:
AbstractNode,
AbstractExpressionNode,
Expand Down
1 change: 0 additions & 1 deletion src/Random.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module RandomModule

using Compat: Returns, @inline
using Random: AbstractRNG
using ..NodeModule: AbstractNode, tree_mapreduce, filter_map
using ..ExpressionModule: AbstractExpression, get_tree
Expand Down
97 changes: 0 additions & 97 deletions src/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,103 +13,6 @@ macro return_on_false2(flag, retval, retval2)
)
end

"""
@memoize_on tree [postprocess] function my_function_on_tree(tree::AbstractExpressionNode)
...
end

This macro takes a function definition and creates a second version of the
function with an additional `id_map` argument. When passed this argument (an
IdDict()), it will use use the `id_map` to avoid recomputing the same value
for the same node in a tree. Use this to automatically create functions that
work with trees that have shared child nodes.

Can optionally take a `postprocess` function, which will be applied to the
result of the function before returning it, taking the result as the
first argument and a boolean for whether the result was memoized as the
second argument. This is useful for functions that need to count the number
of unique nodes in a tree, for example.
"""
macro memoize_on(tree, args...)
if length(args) ∉ (1, 2)
error("Expected 2 or 3 arguments to @memoize_on")
end
postprocess = length(args) == 1 ? :((r, _) -> r) : args[1]
def = length(args) == 1 ? args[1] : args[2]
idmap_def = _memoize_on(tree, postprocess, def)

return quote
$(esc(def)) # The normal function
$(esc(idmap_def)) # The function with an id_map argument
end
end
function _memoize_on(tree::Symbol, postprocess, def)
sdef = splitdef(def)

# Add an id_map argument
push!(sdef[:args], :(id_map::AbstractDict))

f_name = sdef[:name]

# Forward id_map argument to all calls of the same function
# within the function body:
sdef[:body] = postwalk(sdef[:body]) do ex
if @capture(ex, f_(args__))
if f == f_name
return Expr(:call, f, args..., :id_map)
end
end
return ex
end

# Wrap the function body in a get!(id_map, tree) do ... end block:
@gensym key is_memoized result body
sdef[:body] = quote
$key = objectid($tree)
$is_memoized = haskey(id_map, $key)
function $body()
return $(sdef[:body])
end
$result = if $is_memoized
@inbounds(id_map[$key])
else
id_map[$key] = $body()
end
return $postprocess($result, $is_memoized)
end

return combinedef(sdef)
end

"""
@with_memoize(call, id_map)

This simple macro simply puts the `id_map`
into the call, to be consistent with the `@memoize_on` macro.

```
@with_memoize(_copy_node(tree), IdDict{Any,Any}())
````

is converted to

```
_copy_node(tree, IdDict{Any,Any}())
```

"""
macro with_memoize(def, id_map)
idmap_def = _add_idmap_to_call(def, id_map)
return quote
$(esc(idmap_def))
end
end

function _add_idmap_to_call(def::Expr, id_map::Union{Symbol,Expr})
@assert def.head == :call
return Expr(:call, def.args[1], def.args[2:end]..., id_map)
end

@inline function fill_similar(value::T, array, args...) where {T}
out_array = similar(array, args...)
fill!(out_array, value)
Expand Down
64 changes: 44 additions & 20 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ import Base:
sum

using DispatchDoctor: @unstable
using Compat: @inline, Returns
using ..UtilsModule: @memoize_on, @with_memoize, Undefined
using ..UtilsModule: Undefined

"""
tree_mapreduce(
Expand Down Expand Up @@ -94,41 +93,66 @@ function tree_mapreduce(
f_on_shared::H=(result, is_shared) -> result,
break_sharing::Val{BS}=Val(false),
) where {F1<:Function,F2<:Function,G<:Function,H<:Function,RT,BS}

# Trick taken from here:
# https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5
# to speed up recursive closure
@memoize_on t f_on_shared function inner(inner, t)
if t.degree == 0
return @inline(f_leaf(t))
elseif t.degree == 1
return @inline(op(@inline(f_branch(t)), inner(inner, t.l)))
else
return @inline(op(@inline(f_branch(t)), inner(inner, t.l), inner(inner, t.r)))
end
end

sharing = preserve_sharing(typeof(tree)) && !BS

RT == Undefined &&
sharing &&
throw(ArgumentError("Need to specify `result_type` if nodes are shared.."))

if sharing && RT != Undefined
d = allocate_id_map(tree, RT)
return @with_memoize inner(inner, tree) d
id_map = allocate_id_map(tree, RT)
reducer = TreeMapreducer(Val(2), id_map, f_leaf, f_branch, op, f_on_shared)
return call_mapreducer(reducer, tree)
else
reducer = TreeMapreducer(Val(2), nothing, f_leaf, f_branch, op, f_on_shared)
return call_mapreducer(reducer, tree)
end
end

struct TreeMapreducer{
D,ID<:Union{Nothing,Dict},F1<:Function,F2<:Function,G<:Function,H<:Function
}
max_degree::Val{D}
id_map::ID
f_leaf::F1
f_branch::F2
op::G
f_on_shared::H
end

function call_mapreducer(mapreducer::TreeMapreducer{2,ID}, tree::AbstractNode) where {ID}
key = ID <: Dict ? objectid(tree) : nothing
if ID <: Dict && haskey(mapreducer.id_map, key)
result = @inbounds(mapreducer.id_map[key])
return mapreducer.f_on_shared(result, true)
else
return inner(inner, tree)
result = if tree.degree == 0
mapreducer.f_leaf(tree)
elseif tree.degree == 1
mapreducer.op(mapreducer.f_branch(tree), call_mapreducer(mapreducer, tree.l))
else
mapreducer.op(
mapreducer.f_branch(tree),
call_mapreducer(mapreducer, tree.l),
call_mapreducer(mapreducer, tree.r),
)
end
if ID <: Dict
mapreducer.id_map[key] = result
return mapreducer.f_on_shared(result, false)
else
return result
end
end
end

function allocate_id_map(tree::AbstractNode, ::Type{RT}) where {RT}
d = Dict{UInt,RT}()
# Preallocate maximum storage (counting with duplicates is fast)
N = length(tree; break_sharing=Val(true))
sizehint!(d, N)
return d
end

# TODO: Raise Julia issue for this.
# Surprisingly Dict{UInt,RT} is faster than IdDict{Node{T},RT} here!
# I think it's because `setindex!` is declared with `@nospecialize` in IdDict.
Expand Down
6 changes: 5 additions & 1 deletion test/test_derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ end
end

@testset "Test many operators" begin
using DispatchDoctor

# Since we use `@nif` in evaluating expressions,
# we can see if there are any issues with LARGE numbers of operators.
num_ops = 100
Expand Down Expand Up @@ -198,7 +200,9 @@ end
tree = gen_random_tree_fixed_size(20, only_basic_ops_operator, n_features, Float64)
X = randn(Float64, n_features, 10)
basic_eval = tree'(X, only_basic_ops_operator)
many_ops_eval = tree'(X, many_ops_operators)
many_ops_eval = allow_unstable() do
tree'(X, many_ops_operators)
end
@test (all(isnan, basic_eval) && all(isnan, many_ops_eval)) ||
basic_eval ≈ many_ops_eval
end
Expand Down
Loading
Loading