diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 88999e62..c1c764d0 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -113,8 +113,10 @@ function eval_tree_array( return eval_tree_array(tree, cX, operators; turbo, bumper) end -get_nuna(::Type{<:OperatorEnum{B,U}}) where {B,U} = counttuple(U) -get_nbin(::Type{<:OperatorEnum{B}}) where {B} = counttuple(B) +#! format: off +get_nuna(::Type{<:OperatorEnum{<:Tuple{Tuple{Vararg{Any,U}},Vararg}}}) where {U} = U +get_nbin(::Type{<:OperatorEnum{<:Tuple{Tuple{Vararg},Tuple{Vararg{Any,B}},Vararg}}}) where {B} = B +#! format: on function _eval_tree_array( tree::AbstractExpressionNode{T}, diff --git a/src/OperatorEnum.jl b/src/OperatorEnum.jl index 6da6779c..5806bb99 100644 --- a/src/OperatorEnum.jl +++ b/src/OperatorEnum.jl @@ -3,29 +3,35 @@ module OperatorEnumModule abstract type AbstractOperatorEnum end """ - OperatorEnum + OperatorEnum{T} Defines an enum over operators, along with their derivatives. -# Fields -- `binops`: A tuple of binary operators. Scalar input type. -- `unaops`: A tuple of unary operators. Scalar input type. +`.ops` is a tuple of operators, where `.ops[1]` is the unary +operators and `.ops[2]` is the binary operators, and so on. """ -struct OperatorEnum{B,U} <: AbstractOperatorEnum - binops::B - unaops::U +struct OperatorEnum{T<:Tuple} <: AbstractOperatorEnum + ops::T end """ GenericOperatorEnum Defines an enum over operators, along with their derivatives. -# Fields -- `binops`: A tuple of binary operators. -- `unaops`: A tuple of unary operators. +This is equivalent to [`OperatorEnum`](@ref), but dispatches +to generic evaluation for non-numeric types. """ -struct GenericOperatorEnum{B,U} <: AbstractOperatorEnum - binops::B - unaops::U +struct GenericOperatorEnum{T<:Tuple} <: AbstractOperatorEnum + ops::T +end + +function Base.getproperty(op::AbstractOperatorEnum, name::Symbol) + if name == :unaops + return getfield(op, :ops)[1] + elseif name == :binops + return getfield(op, :ops)[2] + else + return getfield(op, name) + end end Base.copy(op::AbstractOperatorEnum) = op diff --git a/src/OperatorEnumConstruction.jl b/src/OperatorEnumConstruction.jl index c5c557af..68d24e7a 100644 --- a/src/OperatorEnumConstruction.jl +++ b/src/OperatorEnumConstruction.jl @@ -461,7 +461,7 @@ redefine operators for `AbstractExpressionNode` types, as well as `show`, `print @warn "Using `BroadcastFunction` in an `OperatorEnum` is not yet stable" end - operators = OperatorEnum(Tuple(binary_operators), Tuple(unary_operators)) + operators = OperatorEnum((Tuple(unary_operators), Tuple(binary_operators))) if define_helper_functions @extend_operators_base operators empty_old_operators = empty_old_operators @@ -498,7 +498,7 @@ and `(::AbstractExpressionNode)(X)`. ) @assert length(binary_operators) > 0 || length(unary_operators) > 0 - operators = GenericOperatorEnum(Tuple(binary_operators), Tuple(unary_operators)) + operators = GenericOperatorEnum((Tuple(unary_operators), Tuple(binary_operators))) if define_helper_functions @extend_operators_base operators empty_old_operators = empty_old_operators @@ -513,12 +513,14 @@ end function _overload_common_operators() # Overload the operators in batches (so that we don't hit the warning # about too many operators) - operators = OperatorEnum( - (+, -, *, /, ^, max, min, mod), + operators = OperatorEnum(( (sin, cos, tan, exp, log, log1p, log2, log10, sqrt, cbrt, abs, sinh), - ) + (+, -, *, /, ^, max, min, mod), + )) @extend_operators(operators, empty_old_operators = false, internal = true) - operators = OperatorEnum((), (cosh, tanh, atan, asinh, acosh, round, sign, floor, ceil)) + operators = OperatorEnum(( + (cosh, tanh, atan, asinh, acosh, round, sign, floor, ceil), () + )) @extend_operators(operators, empty_old_operators = true, internal = true) empty!(LATEST_UNARY_OPERATOR_MAPPING)