Skip to content

Commit

Permalink
feat: generalize OperatorEnum to additional degree
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jul 20, 2024
1 parent 068b8eb commit d57645c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 21 deletions.
6 changes: 4 additions & 2 deletions src/Evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ function eval_tree_array(
return eval_tree_array(tree, cX, operators; turbo, bumper)
end

get_nuna(::Type{<:OperatorEnum{B,U}}) where {B,U} = counttuple(U)
get_nbin(::Type{<:OperatorEnum{B}}) where {B} = counttuple(B)
#! format: off
get_nuna(::Type{<:OperatorEnum{<:Tuple{Tuple{Vararg{Any,U}},Vararg}}}) where {U} = U
get_nbin(::Type{<:OperatorEnum{<:Tuple{Tuple{Vararg},Tuple{Vararg{Any,B}},Vararg}}}) where {B} = B
#! format: on

function _eval_tree_array(
tree::AbstractExpressionNode{T},
Expand Down
32 changes: 19 additions & 13 deletions src/OperatorEnum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,35 @@ module OperatorEnumModule
abstract type AbstractOperatorEnum end

"""
OperatorEnum
OperatorEnum{T}
Defines an enum over operators, along with their derivatives.
# Fields
- `binops`: A tuple of binary operators. Scalar input type.
- `unaops`: A tuple of unary operators. Scalar input type.
`.ops` is a tuple of operators, where `.ops[1]` is the unary
operators and `.ops[2]` is the binary operators, and so on.
"""
struct OperatorEnum{B,U} <: AbstractOperatorEnum
binops::B
unaops::U
struct OperatorEnum{T<:Tuple} <: AbstractOperatorEnum
ops::T
end

"""
GenericOperatorEnum
Defines an enum over operators, along with their derivatives.
# Fields
- `binops`: A tuple of binary operators.
- `unaops`: A tuple of unary operators.
This is equivalent to [`OperatorEnum`](@ref), but dispatches
to generic evaluation for non-numeric types.
"""
struct GenericOperatorEnum{B,U} <: AbstractOperatorEnum
binops::B
unaops::U
struct GenericOperatorEnum{T<:Tuple} <: AbstractOperatorEnum
ops::T
end

function Base.getproperty(op::AbstractOperatorEnum, name::Symbol)
if name == :unaops
return getfield(op, :ops)[1]
elseif name == :binops
return getfield(op, :ops)[2]
else
return getfield(op, name)
end
end

Base.copy(op::AbstractOperatorEnum) = op
Expand Down
14 changes: 8 additions & 6 deletions src/OperatorEnumConstruction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ redefine operators for `AbstractExpressionNode` types, as well as `show`, `print
@warn "Using `BroadcastFunction` in an `OperatorEnum` is not yet stable"
end

operators = OperatorEnum(Tuple(binary_operators), Tuple(unary_operators))
operators = OperatorEnum((Tuple(unary_operators), Tuple(binary_operators)))

if define_helper_functions
@extend_operators_base operators empty_old_operators = empty_old_operators
Expand Down Expand Up @@ -498,7 +498,7 @@ and `(::AbstractExpressionNode)(X)`.
)
@assert length(binary_operators) > 0 || length(unary_operators) > 0

operators = GenericOperatorEnum(Tuple(binary_operators), Tuple(unary_operators))
operators = GenericOperatorEnum((Tuple(unary_operators), Tuple(binary_operators)))

if define_helper_functions
@extend_operators_base operators empty_old_operators = empty_old_operators
Expand All @@ -513,12 +513,14 @@ end
function _overload_common_operators()
# Overload the operators in batches (so that we don't hit the warning
# about too many operators)
operators = OperatorEnum(
(+, -, *, /, ^, max, min, mod),
operators = OperatorEnum((
(sin, cos, tan, exp, log, log1p, log2, log10, sqrt, cbrt, abs, sinh),
)
(+, -, *, /, ^, max, min, mod),
))
@extend_operators(operators, empty_old_operators = false, internal = true)
operators = OperatorEnum((), (cosh, tanh, atan, asinh, acosh, round, sign, floor, ceil))
operators = OperatorEnum((
(cosh, tanh, atan, asinh, acosh, round, sign, floor, ceil), ()
))
@extend_operators(operators, empty_old_operators = true, internal = true)

empty!(LATEST_UNARY_OPERATOR_MAPPING)
Expand Down

0 comments on commit d57645c

Please sign in to comment.