Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functional alias to Base.Cartesian.@nif #55093

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
50 changes: 50 additions & 0 deletions base/cartesian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,56 @@ macro nif(N, condition, operation...)
ex
end

"""
nif(condition, expression, [else_expression,] ::Val{N}) where {N}

Generate a sequence of `if ... elseif ... else ... end` statements.

# Arguments
- `condition`: A function that takes an integer between `1` and `N-1` and
returns a boolean condition.
- `expression`: A function that takes an integer between `1` and `N` (or,
only up to `N-1`, if `else_expression` is provided) and is called if
the condition is true.
- `else_expression`: (optional) A function that takes `N` as input
returns an expression to be evaluated if all conditions are false.
- `N`: The number of conditions to check, passed as a `Val{N}` instance.

This function is similar to the `@nif` macro but can be used in cases
where `N` is not known at parse time.

# Examples

For example, here we find the first index of a positive element in a
fixed-size tuple using `nif`:

```jldoctest
julia> x = (0, -1, 1, 0)
(0, -1, 1, 0)

julia> Base.Cartesian.nif(d -> x[d] > 0, d -> d, Val(4))
3
```
"""
@inline function nif(condition::F, expression::G, ::Val{N}) where {F,G,N}
nif(condition, expression, expression, Val(N))
end
@inline function nif(condition::F, expression::G, else_expression::H, ::Val{N}) where {F,G,H,N}
n = N::Int # Can improve inference; see #54544
(n >= 0) || throw(ArgumentError(LazyString("if statement length should be ≥ 0, got ", n)))
if @generated
return :(@nif $N d -> condition(d) d -> expression(d) d -> else_expression(d))
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
else
for d = 1:(n - 1)
if condition(d)
return expression(d)
end
end
return else_expression(n)
end
end
typeof(function nif end).name.max_methods = UInt8(2)

## Utilities

# Simplify expressions like :(d->3:size(A,d)-3) given an explicit value for d
Expand Down
53 changes: 53 additions & 0 deletions test/cartesian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,56 @@ end
end
@test t3 == (1, 2, 0)
end

@testset "nif" begin
let nif = Base.Cartesian.nif

x = (0, -1, 1, 0)
@test nif(d -> x[d] > 0, d -> d, Val(4)) == 3

@test nif(d -> d > 1, d -> "A", d -> "B", Val(1)) == "B"
@test nif(d -> d > 3, d -> "A", d -> "B", Val(3)) == "B"

# Test with N = 0
@test nif(d -> d > 0, d -> "", d -> "A", Val(0)) == "A"

# Specific branch true
@test nif(d -> d == 2, d -> d, d -> "else", Val(3)) == 2

# Test with condition only true for last branch
@test nif(d -> d == 5, d -> "A", d -> "B", Val(5)) == "B"

# Test with bad input:
@test_throws ArgumentError("if statement length should be ≥ 0, got -1") nif(identity, identity, Val(-1))

# Non-Int64 also throws
@test_throws TypeError nif(identity, identity, Val(1.5))

# Make sure all conditions are actually evaluated
result = let c = Ref(0)
nif(
d -> (c[] += 1; false),
d -> 1,
Val(4)
)
c[]
end
@test result == 3

# Test inference is good
t = ("i am not an int", ntuple(d -> d, Val(10))...)
function extract_from_tuple(t::Tuple, i)
nif(
d -> d == i,
d -> t[d + 1], # We skip the non-integer element
Val(length(t) - 1)
)
end
# Normally, had we used getindex here, inference would have
# not been able to infer that the return type never includes
# the first element. But since we used an `nif`, the compiler
# knows all possible branches and can infer the correct type.
@test @inferred(extract_from_tuple(t, 3)) == 3

end
end