diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index dfe1907ac8..e42aec00ca 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -9,19 +9,51 @@ module RecursiveApply export ⊞, ⊠, ⊟ +# These functions need to be generated for type stability (since T.parameters is +# a SimpleVector, the compiler cannot always infer its size and elements). +@generated first_param(::Type{T}) where {T} = :($(first(T.parameters))) +@generated tail_params(::Type{T}) where {T} = + :($(Tuple{Base.tail((T.parameters...,))...})) + +# Applying `rmaptype` returns `Tuple{...}` for tuple +# types, which cannot follow the recursion pattern as +# it cannot be splatted, so we add a separate method, +# `rmaptype_Tuple`, for the part of the recursion. +rmaptype_Tuple(fn::F, ::Type{Tuple{}}) where {F} = () +rmaptype_Tuple(fn::F, ::Type{T}) where {F, E, T <: Tuple{E}} = + (rmaptype(fn, first_param(T)),) +rmaptype_Tuple(fn::F, ::Type{T}) where {F, T <: Tuple} = + (rmaptype(fn, first_param(T)), rmaptype_Tuple(fn, tail_params(T))...) + +rmaptype_Tuple(_, ::Type{Tuple{}}, ::Type{T}) where {T <: Tuple} = () +rmaptype_Tuple(_, ::Type{T}, ::Type{Tuple{}}) where {T <: Tuple} = () +rmaptype_Tuple( + fn::F, + ::Type{T1}, + ::Type{T2}, +) where {F, T1 <: Tuple, T2 <: Tuple} = ( + rmaptype(fn, first_param(T1), first_param(T2)), + rmaptype_Tuple(rmaptype, fn, tail_params(T1), tail_params(T2))..., +) + """ rmap(fn, X...) Recursively apply `fn` to each element of `X` """ rmap(fn::F, X) where {F} = fn(X) +rmap(fn::F, X::Tuple{}) where {F} = () +rmap(fn::F, X::Tuple) where {F} = + (rmap(fn, first(X)), rmap(fn, Base.tail(X))...) +rmap(fn::F, X::NamedTuple{names}) where {F, names} = + NamedTuple{names}(rmap(fn, Tuple(X))) + rmap(fn::F, X, Y) where {F} = fn(X, Y) -rmap(fn::F, X::Tuple) where {F} = map(x -> rmap(fn, x), X) -rmap(fn, X::Tuple{}, Y::Tuple{}) = () +rmap(fn::F, X::Tuple{}, Y::Tuple{}) where {F} = () +rmap(fn::F, X::Tuple{}, Y::Tuple) where {F} = () +rmap(fn::F, X::Tuple, Y::Tuple{}) where {F} = () rmap(fn::F, X::Tuple, Y::Tuple) where {F} = (rmap(fn, first(X), first(Y)), rmap(fn, Base.tail(X), Base.tail(Y))...) -rmap(fn::F, X::NamedTuple{names}) where {F, names} = - NamedTuple{names}(rmap(fn, Tuple(X))) rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} = NamedTuple{names}(rmap(fn, Tuple(X), Tuple(Y))) @@ -32,17 +64,45 @@ rmax(X, Y) = rmap(max, X, Y) """ rmaptype(fn, T) + rmaptype(fn, T1, T2) -The return type of `rmap(fn, X::T)`. +Recursively apply `fn` to each type parameter of the type `T`, or to each type +parameter of the types `T1` and `T2`, where `fn` returns a type. """ rmaptype(fn::F, ::Type{T}) where {F, T} = fn(T) rmaptype(fn::F, ::Type{T}) where {F, T <: Tuple} = - Tuple{map(fn, tuple(T.parameters...))...} + Tuple{rmaptype_Tuple(fn, T)...} +rmaptype(fn::F, ::Type{T}) where {F, names, Tup, T <: NamedTuple{names, Tup}} = + NamedTuple{names, rmaptype(fn, Tup)} + +rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1, T2} = fn(T1, T2) +rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} = + Tuple{rmaptype_Tuple(fn, T1, T2)...} rmaptype( fn::F, - ::Type{T}, -) where {F, T <: NamedTuple{names, tup}} where {names, tup} = - NamedTuple{names, rmaptype(fn, tup)} + ::Type{T1}, + ::Type{T2}, +) where { + F, + names, + Tup1, + Tup2, + T1 <: NamedTuple{names, Tup1}, + T2 <: NamedTuple{names, Tup2}, +} = NamedTuple{names, rmaptype(fn, Tup1, Tup2)} + +""" + rzero(T) + +Recursively compute the zero value of type `T`. +""" +rzero(::Type{T}) where {T} = zero(T) +rzero(::Type{Tuple{}}) = () +rzero(::Type{T}) where {E, T <: Tuple{E}} = (rzero(E),) +rzero(::Type{T}) where {T <: Tuple} = + (rzero(first_param(T)), rzero(tail_params(T))...) +rzero(::Type{Tup}) where {names, T, Tup <: NamedTuple{names, T}} = + NamedTuple{names}(rzero(T)) """ rmul(X, Y) diff --git a/src/Spaces/dss.jl b/src/Spaces/dss.jl index abb955dabe..b5df5db645 100644 --- a/src/Spaces/dss.jl +++ b/src/Spaces/dss.jl @@ -836,7 +836,7 @@ function dss_local_vertices!( sum_data = mapreduce( ⊞, vertex; - init = RecursiveApply.rmap(zero, slab(perimeter_data, 1, 1)[1]), + init = RecursiveApply.rzero(eltype(slab(perimeter_data, 1, 1))), ) do (lidx, vert) ip = Topologies.perimeter_vertex_node_index(vert) perimeter_slab = slab(perimeter_data, level, lidx) @@ -906,9 +906,8 @@ function dss_local_ghost!( sum_data = mapreduce( ⊞, vertex; - init = RecursiveApply.rmap( - zero, - slab(perimeter_data, 1, 1)[1], + init = RecursiveApply.rzero( + eltype(slab(perimeter_data, 1, 1)), ), ) do (isghost, idx, vert) ip = Topologies.perimeter_vertex_node_index(vert) diff --git a/test/Operators/spectralelement/benchmark_ops.jl b/test/Operators/spectralelement/benchmark_ops.jl index 2dd94ab424..f11469dd22 100644 --- a/test/Operators/spectralelement/benchmark_ops.jl +++ b/test/Operators/spectralelement/benchmark_ops.jl @@ -138,7 +138,7 @@ using JET p = @allocated kernel_complicated_field_dss!(kernel_args) @test p == 0 p = @allocated kernel_complicated_field2_dss!(kernel_args) - @test_broken p == 0 + @test p == 0 # Inference tests JET.@test_opt kernel_scalar_dss!(kernel_args) JET.@test_opt kernel_vector_dss!(kernel_args) @@ -146,6 +146,6 @@ using JET JET.@test_opt kernel_ntuple_field_dss!(kernel_args) JET.@test_opt kernel_ntuple_floats_dss!(kernel_args) JET.@test_opt kernel_complicated_field_dss!(kernel_args) - # JET.@test_opt kernel_complicated_field2_dss!(kernel_args) # fails + JET.@test_opt kernel_complicated_field2_dss!(kernel_args) end end diff --git a/test/RecursiveApply/recursive_apply.jl b/test/RecursiveApply/recursive_apply.jl index 72c7d0d61d..cf9966505c 100644 --- a/test/RecursiveApply/recursive_apply.jl +++ b/test/RecursiveApply/recursive_apply.jl @@ -38,3 +38,42 @@ end RecursiveApply.rmul(x, FT(2)) end end + +@testset "Highly nested types" begin + FT = Float64 + nested_types = [ + FT, + Tuple{FT, FT}, + NamedTuple{(:ϕ, :ψ), Tuple{FT, FT}}, + Tuple{ + NamedTuple{(:ϕ, :ψ), Tuple{FT, FT}}, + NamedTuple{(:ϕ, :ψ), Tuple{FT, FT}}, + }, + Tuple{FT, FT}, + NamedTuple{ + (:ρ, :uₕ, :ρe_tot, :ρq_tot, :sgs⁰, :sgsʲs), + Tuple{ + FT, + Tuple{FT, FT}, + FT, + FT, + NamedTuple{(:ρatke,), Tuple{FT}}, + Tuple{NamedTuple{(:ρa, :ρae_tot, :ρaq_tot), Tuple{FT, FT, FT}}}, + }, + }, + NamedTuple{ + (:u₃, :sgsʲs), + Tuple{Tuple{FT}, Tuple{NamedTuple{(:u₃,), Tuple{Tuple{FT}}}}}, + }, + ] + for nt in nested_types + rz = RecursiveApply.rmap(RecursiveApply.rzero, nt) + @test typeof(rz) == nt + @inferred RecursiveApply.rmap(RecursiveApply.rzero, nt) + end + for nt in nested_types + rz = RecursiveApply.rmaptype(identity, nt) + @test rz == nt + @inferred RecursiveApply.rmaptype(zero, nt) + end +end