Skip to content

Commit

Permalink
Merge #1334
Browse files Browse the repository at this point in the history
1334: Extend and improve RecursiveApply, fix DSS inference r=charleskawczynski a=charleskawczynski

This PR:
 - Extends and improves RecursiveApply (incorporates some changes/functionality in #1326) by working with Tuples/NamedTuples in a type-stable way
 - Adds RecursiveApply tests
 - Fix DSS inference and corresponding test

Co-authored-by: Charles Kawczynski <[email protected]>
  • Loading branch information
bors[bot] and charleskawczynski committed Jun 16, 2023
2 parents 046c63f + 5c8bd0e commit 1cfa698
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 15 deletions.
79 changes: 70 additions & 9 deletions src/RecursiveApply/RecursiveApply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,52 @@ module RecursiveApply

export , ,

# These functions need to be generated for type stability (since T.parameters is
# a SimpleVector, the compiler cannot always infer its size and elements).
@generated first_param(::Type{T}) where {T} = :($(first(T.parameters)))
@generated tail_params(::Type{T}) where {T} =
:($(Tuple{Base.tail((T.parameters...,))...}))

# Applying `rmaptype` returns `Tuple{...}` for tuple
# types, which cannot follow the recursion pattern as
# it cannot be splatted, so we add a separate method,
# `rmaptype_Tuple`, for the part of the recursion.
rmaptype_Tuple(fn::F, ::Type{Tuple{}}) where {F} = ()
rmaptype_Tuple(fn::F, ::Type{T}) where {F, E, T <: Tuple{E}} =
(rmaptype(fn, first_param(T)),)
rmaptype_Tuple(fn::F, ::Type{T}) where {F, T <: Tuple} =
(rmaptype(fn, first_param(T)), rmaptype_Tuple(fn, tail_params(T))...)

rmaptype_Tuple(_, ::Type{Tuple{}}, ::Type{Tuple{}}) = ()
rmaptype_Tuple(_, ::Type{Tuple{}}, ::Type{T}) where {T <: Tuple} = ()
rmaptype_Tuple(_, ::Type{T}, ::Type{Tuple{}}) where {T <: Tuple} = ()
rmaptype_Tuple(
fn::F,
::Type{T1},
::Type{T2},
) where {F, T1 <: Tuple, T2 <: Tuple} = (
rmaptype(fn, first_param(T1), first_param(T2)),
rmaptype_Tuple(fn, tail_params(T1), tail_params(T2))...,
)

"""
rmap(fn, X...)
Recursively apply `fn` to each element of `X`
"""
rmap(fn::F, X) where {F} = fn(X)
rmap(fn::F, X::Tuple{}) where {F} = ()
rmap(fn::F, X::Tuple) where {F} =
(rmap(fn, first(X)), rmap(fn, Base.tail(X))...)
rmap(fn::F, X::NamedTuple{names}) where {F, names} =
NamedTuple{names}(rmap(fn, Tuple(X)))

rmap(fn::F, X, Y) where {F} = fn(X, Y)
rmap(fn::F, X::Tuple) where {F} = map(x -> rmap(fn, x), X)
rmap(fn, X::Tuple{}, Y::Tuple{}) = ()
rmap(fn::F, X::Tuple{}, Y::Tuple{}) where {F} = ()
rmap(fn::F, X::Tuple{}, Y::Tuple) where {F} = ()
rmap(fn::F, X::Tuple, Y::Tuple{}) where {F} = ()
rmap(fn::F, X::Tuple, Y::Tuple) where {F} =
(rmap(fn, first(X), first(Y)), rmap(fn, Base.tail(X), Base.tail(Y))...)
rmap(fn::F, X::NamedTuple{names}) where {F, names} =
NamedTuple{names}(rmap(fn, Tuple(X)))
rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} =
NamedTuple{names}(rmap(fn, Tuple(X), Tuple(Y)))

Expand All @@ -32,17 +65,45 @@ rmax(X, Y) = rmap(max, X, Y)

"""
rmaptype(fn, T)
rmaptype(fn, T1, T2)
The return type of `rmap(fn, X::T)`.
Recursively apply `fn` to each type parameter of the type `T`, or to each type
parameter of the types `T1` and `T2`, where `fn` returns a type.
"""
rmaptype(fn::F, ::Type{T}) where {F, T} = fn(T)
rmaptype(fn::F, ::Type{T}) where {F, T <: Tuple} =
Tuple{map(fn, tuple(T.parameters...))...}
Tuple{rmaptype_Tuple(fn, T)...}
rmaptype(fn::F, ::Type{T}) where {F, names, Tup, T <: NamedTuple{names, Tup}} =
NamedTuple{names, rmaptype(fn, Tup)}

rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1, T2} = fn(T1, T2)
rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} =
Tuple{rmaptype_Tuple(fn, T1, T2)...}
rmaptype(
fn::F,
::Type{T},
) where {F, T <: NamedTuple{names, tup}} where {names, tup} =
NamedTuple{names, rmaptype(fn, tup)}
::Type{T1},
::Type{T2},
) where {
F,
names,
Tup1,
Tup2,
T1 <: NamedTuple{names, Tup1},
T2 <: NamedTuple{names, Tup2},
} = NamedTuple{names, rmaptype(fn, Tup1, Tup2)}

"""
rzero(T)
Recursively compute the zero value of type `T`.
"""
rzero(::Type{T}) where {T} = zero(T)
rzero(::Type{Tuple{}}) = ()
rzero(::Type{T}) where {E, T <: Tuple{E}} = (rzero(E),)
rzero(::Type{T}) where {T <: Tuple} =
(rzero(first_param(T)), rzero(tail_params(T))...)
rzero(::Type{Tup}) where {names, T, Tup <: NamedTuple{names, T}} =
NamedTuple{names}(rzero(T))

"""
rmul(X, Y)
Expand Down
7 changes: 3 additions & 4 deletions src/Spaces/dss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ function dss_local_vertices!(
sum_data = mapreduce(
,
vertex;
init = RecursiveApply.rmap(zero, slab(perimeter_data, 1, 1)[1]),
init = RecursiveApply.rzero(eltype(slab(perimeter_data, 1, 1))),
) do (lidx, vert)
ip = Topologies.perimeter_vertex_node_index(vert)
perimeter_slab = slab(perimeter_data, level, lidx)
Expand Down Expand Up @@ -906,9 +906,8 @@ function dss_local_ghost!(
sum_data = mapreduce(
,
vertex;
init = RecursiveApply.rmap(
zero,
slab(perimeter_data, 1, 1)[1],
init = RecursiveApply.rzero(
eltype(slab(perimeter_data, 1, 1)),
),
) do (isghost, idx, vert)
ip = Topologies.perimeter_vertex_node_index(vert)
Expand Down
4 changes: 2 additions & 2 deletions test/Operators/spectralelement/benchmark_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,14 @@ using JET
p = @allocated kernel_complicated_field_dss!(kernel_args)
@test p == 0
p = @allocated kernel_complicated_field2_dss!(kernel_args)
@test_broken p == 0
@test p == 0
# Inference tests
JET.@test_opt kernel_scalar_dss!(kernel_args)
JET.@test_opt kernel_vector_dss!(kernel_args)
JET.@test_opt kernel_field_dss!(kernel_args)
JET.@test_opt kernel_ntuple_field_dss!(kernel_args)
JET.@test_opt kernel_ntuple_floats_dss!(kernel_args)
JET.@test_opt kernel_complicated_field_dss!(kernel_args)
# JET.@test_opt kernel_complicated_field2_dss!(kernel_args) # fails
JET.@test_opt kernel_complicated_field2_dss!(kernel_args)
end
end
46 changes: 46 additions & 0 deletions test/RecursiveApply/recursive_apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,49 @@ end
RecursiveApply.rmul(x, FT(2))
end
end

@testset "Highly nested types" begin
FT = Float64
nested_types = [
FT,
Tuple{FT, FT},
NamedTuple{(, ), Tuple{FT, FT}},
Tuple{
NamedTuple{(, ), Tuple{FT, FT}},
NamedTuple{(, ), Tuple{FT, FT}},
},
Tuple{FT, FT},
NamedTuple{
(, :uₕ, :ρe_tot, :ρq_tot, :sgs⁰, :sgsʲs),
Tuple{
FT,
Tuple{FT, FT},
FT,
FT,
NamedTuple{(:ρatke,), Tuple{FT}},
Tuple{NamedTuple{(:ρa, :ρae_tot, :ρaq_tot), Tuple{FT, FT, FT}}},
},
},
NamedTuple{
(:u₃, :sgsʲs),
Tuple{Tuple{FT}, Tuple{NamedTuple{(:u₃,), Tuple{Tuple{FT}}}}},
},
]
for nt in nested_types
rz = RecursiveApply.rmap(RecursiveApply.rzero, nt)
@test typeof(rz) == nt
@inferred RecursiveApply.rmap(RecursiveApply.rzero, nt)

rz = RecursiveApply.rmap((x, y) -> RecursiveApply.rzero(x), nt, nt)
@test typeof(rz) == nt
@inferred RecursiveApply.rmap((x, y) -> RecursiveApply.rzero(x), nt, nt)

rz = RecursiveApply.rmaptype(identity, nt)
@test rz == nt
@inferred RecursiveApply.rmaptype(zero, nt)

rz = RecursiveApply.rmaptype((x, y) -> identity(x), nt, nt)
@test rz == nt
@inferred RecursiveApply.rmaptype((x, y) -> zero(x), nt, nt)
end
end

0 comments on commit 1cfa698

Please sign in to comment.