From e2035af251d484cb60ea4627cef7f102778a8a26 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 16 Jun 2023 10:59:20 -0700 Subject: [PATCH 1/9] wip --- src/RecursiveApply/RecursiveApply.jl | 96 +++++++++++++++++++++++--- test/RecursiveApply/recursive_apply.jl | 34 +++++++++ 2 files changed, 122 insertions(+), 8 deletions(-) diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index dfe1907ac8..17a9a1932c 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -9,6 +9,52 @@ module RecursiveApply export ⊞, ⊠, ⊟ +# These functions need to be generated for type stability (since T.parameters is +# a SimpleVector, the compiler cannot always infer its size and elements). +first_param(::Type{T}) where {T} = + if @generated + :($(first(T.parameters))) + else + first(T.parameters) + end +tail_params(::Type{T}) where {T} = + if @generated + :($(Tuple{Base.tail((T.parameters...,))...})) + else + Tuple{Base.tail((T.parameters...,))...} + end + +# This is a type-stable version of map(x -> rmap(fn, x), X) or +# map((x, y) -> rmap(fn, x, y), X, Y). +rmap_tuple(fn::F, X) where {F} = + isempty(X) ? () : (rmap(fn, first(X)), rmap_tuple(fn, Base.tail(X))...) +rmap_tuple(fn::F, X, Y) where {F} = + isempty(X) || isempty(Y) ? () : + ( + rmap(fn, first(X), first(Y)), + rmap_tuple(fn, Base.tail(X), Base.tail(Y))..., + ) + +# This is a type-stable version of map(T′ -> rfunc(fn, T′), T.parameters) or +# map((T1′, T2′) -> rfunc(fn, T1′, T2′), T1.parameters, T2.parameters), where +# rfunc can be either rmaptype or rmap_type2value. +rmap_Tuple(rfunc::R, fn::F, ::Type{Tuple{}}) where {R, F, T} = () +rmap_Tuple(rfunc::R, fn::F, ::Type{T}) where {R, F, T <: Tuple} = + (rfunc(fn, first_param(T)), rmap_Tuple(rfunc, fn, tail_params(T))...) + +rmap_Tuple(_, _, ::Type{Tuple{}}, ::Type{T}) where {T <: Tuple} = () +rmap_Tuple(_, _, ::Type{T}, ::Type{Tuple{}}) where {T <: Tuple} = () + +rmap_Tuple( + rfunc::R, + fn::F, + ::Type{T1}, + ::Type{T2}, +) where {R, F, T1 <: Tuple, T2 <: Tuple} = ( + rfunc(fn, first_param(T1), first_param(T2)), + rmap_Tuple(rfunc, fn, tail_params(T1), tail_params(T2))..., +) + """ rmap(fn, X...) @@ -16,10 +62,8 @@ Recursively apply `fn` to each element of `X` """ rmap(fn::F, X) where {F} = fn(X) rmap(fn::F, X, Y) where {F} = fn(X, Y) -rmap(fn::F, X::Tuple) where {F} = map(x -> rmap(fn, x), X) -rmap(fn, X::Tuple{}, Y::Tuple{}) = () -rmap(fn::F, X::Tuple, Y::Tuple) where {F} = - (rmap(fn, first(X), first(Y)), rmap(fn, Base.tail(X), Base.tail(Y))...) +rmap(fn::F, X::Tuple) where {F} = rmap_tuple(fn, X) +rmap(fn::F, X::Tuple, Y::Tuple) where {F} = rmap_tuple(fn, X, Y) rmap(fn::F, X::NamedTuple{names}) where {F, names} = NamedTuple{names}(rmap(fn, Tuple(X))) rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} = @@ -32,17 +76,53 @@ rmax(X, Y) = rmap(max, X, Y) """ rmaptype(fn, T) + rmaptype(fn, T1, T2) -The return type of `rmap(fn, X::T)`. +Recursively apply `fn` to each type parameter of the type `T`, or to each type +parameter of the types `T1` and `T2`, where `fn` returns a type. """ rmaptype(fn::F, ::Type{T}) where {F, T} = fn(T) +rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1, T2} = fn(T1, T2) rmaptype(fn::F, ::Type{T}) where {F, T <: Tuple} = - Tuple{map(fn, tuple(T.parameters...))...} + Tuple{rmap_Tuple(rmaptype, fn, T)...} +rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} = + Tuple{rmap_Tuple(rmaptype, fn, T1, T2)...} +rmaptype(fn::F, ::Type{T}) where {F, names, Tup, T <: NamedTuple{names, Tup}} = + NamedTuple{names, rmaptype(fn, Tup)} rmaptype( + fn::F, + ::Type{T1}, + ::Type{T2}, +) where { + F, + names, + Tup1, + Tup2, + T1 <: NamedTuple{names, Tup1}, + T2 <: NamedTuple{names, Tup2}, +} = NamedTuple{names, rmaptype(fn, Tup1, Tup2)} + +""" + rmap_type2value(fn, T) + +Recursively apply `fn` to each type parameter of the type `T`, where `fn` +returns a value instead of a type. +""" +rmap_type2value(fn::F, ::Type{T}) where {F, T} = fn(T) +rmap_type2value(fn::F, ::Type{T}) where {F, T <: Tuple} = + rmap_Tuple(rmap_type2value, fn, T) +rmap_type2value( fn::F, ::Type{T}, -) where {F, T <: NamedTuple{names, tup}} where {names, tup} = - NamedTuple{names, rmaptype(fn, tup)} +) where {F, names, Tup, T <: NamedTuple{names, Tup}} = + NamedTuple{names}(rmap_type2value(fn, Tup)) + +""" + rzero(T) + +Recursively compute the zero value of type `T`. +""" +rzero(::Type{T}) where {T} = rmap_type2value(zero, T) """ rmul(X, Y) diff --git a/test/RecursiveApply/recursive_apply.jl b/test/RecursiveApply/recursive_apply.jl index 72c7d0d61d..0e98f30475 100644 --- a/test/RecursiveApply/recursive_apply.jl +++ b/test/RecursiveApply/recursive_apply.jl @@ -38,3 +38,37 @@ end RecursiveApply.rmul(x, FT(2)) end end + +@testset "Highly nested types" begin + FT = Float64 + nested_types = [ + FT, + Tuple{FT, FT}, + NamedTuple{(:ϕ, :ψ), Tuple{FT, FT}}, + Tuple{ + NamedTuple{(:ϕ, :ψ), Tuple{FT, FT}}, + NamedTuple{(:ϕ, :ψ), Tuple{FT, FT}}, + }, + Tuple{FT, FT}, + NamedTuple{ + (:ρ, :uₕ, :ρe_tot, :ρq_tot, :sgs⁰, :sgsʲs), + Tuple{ + FT, + Tuple{FT, FT}, + FT, + FT, + NamedTuple{(:ρatke,), Tuple{FT}}, + Tuple{NamedTuple{(:ρa, :ρae_tot, :ρaq_tot), Tuple{FT, FT, FT}}}, + }, + }, + NamedTuple{ + (:u₃, :sgsʲs), + Tuple{Tuple{FT}, Tuple{NamedTuple{(:u₃,), Tuple{Tuple{FT}}}}}, + }, + ] + for nt in nested_types + rz = RecursiveApply.rmap(RecursiveApply.rzero, nt) + @test typeof(rz) == nt + @inferred RecursiveApply.rmap(RecursiveApply.rzero, nt) + end +end From a539f1ef4bcd9feb6d880267f0a6aafcb48a8802 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 16 Jun 2023 11:24:24 -0700 Subject: [PATCH 2/9] Remove layer of indirection --- src/RecursiveApply/RecursiveApply.jl | 72 ++++++++++------------------ 1 file changed, 24 insertions(+), 48 deletions(-) diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index 17a9a1932c..95f23e54bf 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -11,18 +11,9 @@ export ⊞, ⊠, ⊟ # These functions need to be generated for type stability (since T.parameters is # a SimpleVector, the compiler cannot always infer its size and elements). -first_param(::Type{T}) where {T} = - if @generated - :($(first(T.parameters))) - else - first(T.parameters) - end -tail_params(::Type{T}) where {T} = - if @generated - :($(Tuple{Base.tail((T.parameters...,))...})) - else - Tuple{Base.tail((T.parameters...,))...} - end +@generated first_param(::Type{T}) where {T} = :($(first(T.parameters))) +@generated tail_params(::Type{T}) where {T} = + :($(Tuple{Base.tail((T.parameters...,))...})) # This is a type-stable version of map(x -> rmap(fn, x), X) or # map((x, y) -> rmap(fn, x, y), X, Y). @@ -35,25 +26,20 @@ rmap_tuple(fn::F, X, Y) where {F} = rmap_tuple(fn, Base.tail(X), Base.tail(Y))..., ) -# This is a type-stable version of map(T′ -> rfunc(fn, T′), T.parameters) or -# map((T1′, T2′) -> rfunc(fn, T1′, T2′), T1.parameters, T2.parameters), where -# rfunc can be either rmaptype or rmap_type2value. -rmap_Tuple(rfunc::R, fn::F, ::Type{Tuple{}}) where {R, F, T} = () -rmap_Tuple(rfunc::R, fn::F, ::Type{T}) where {R, F, T <: Tuple} = - (rfunc(fn, first_param(T)), rmap_Tuple(rfunc, fn, tail_params(T))...) +# This is a type-stable version of map(T′ -> rmaptype(fn, T′), T.parameters) or +# map((T1′, T2′) -> rmaptype(fn, T1′, T2′), T1.parameters, T2.parameters), where +rmap_Tuple(fn::F, ::Type{Tuple{}}) where {F} = () +rmap_Tuple(fn::F, ::Type{T}) where {F, T <: Tuple} = + (rmaptype(fn, first_param(T)), rmap_Tuple(rmaptype, fn, tail_params(T))...) -rmap_Tuple(_, _, ::Type{Tuple{}}, ::Type{T}) where {T <: Tuple} = () -rmap_Tuple(_, _, ::Type{T}, ::Type{Tuple{}}) where {T <: Tuple} = () +rmap_Tuple(_, ::Type{Tuple{}}, ::Type{T}) where {T <: Tuple} = () +rmap_Tuple(_, ::Type{T}, ::Type{Tuple{}}) where {T <: Tuple} = () -rmap_Tuple( - rfunc::R, - fn::F, - ::Type{T1}, - ::Type{T2}, -) where {R, F, T1 <: Tuple, T2 <: Tuple} = ( - rfunc(fn, first_param(T1), first_param(T2)), - rmap_Tuple(rfunc, fn, tail_params(T1), tail_params(T2))..., -) +rmap_Tuple(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} = + ( + rmaptype(fn, first_param(T1), first_param(T2)), + rmap_Tuple(rmaptype, fn, tail_params(T1), tail_params(T2))..., + ) """ rmap(fn, X...) @@ -83,10 +69,9 @@ parameter of the types `T1` and `T2`, where `fn` returns a type. """ rmaptype(fn::F, ::Type{T}) where {F, T} = fn(T) rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1, T2} = fn(T1, T2) -rmaptype(fn::F, ::Type{T}) where {F, T <: Tuple} = - Tuple{rmap_Tuple(rmaptype, fn, T)...} +rmaptype(fn::F, ::Type{T}) where {F, T <: Tuple} = Tuple{rmap_Tuple(fn, T)...} rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} = - Tuple{rmap_Tuple(rmaptype, fn, T1, T2)...} + Tuple{rmap_Tuple(fn, T1, T2)...} rmaptype(fn::F, ::Type{T}) where {F, names, Tup, T <: NamedTuple{names, Tup}} = NamedTuple{names, rmaptype(fn, Tup)} rmaptype( @@ -102,27 +87,18 @@ rmaptype( T2 <: NamedTuple{names, Tup2}, } = NamedTuple{names, rmaptype(fn, Tup1, Tup2)} -""" - rmap_type2value(fn, T) - -Recursively apply `fn` to each type parameter of the type `T`, where `fn` -returns a value instead of a type. -""" -rmap_type2value(fn::F, ::Type{T}) where {F, T} = fn(T) -rmap_type2value(fn::F, ::Type{T}) where {F, T <: Tuple} = - rmap_Tuple(rmap_type2value, fn, T) -rmap_type2value( - fn::F, - ::Type{T}, -) where {F, names, Tup, T <: NamedTuple{names, Tup}} = - NamedTuple{names}(rmap_type2value(fn, Tup)) - """ rzero(T) Recursively compute the zero value of type `T`. """ -rzero(::Type{T}) where {T} = rmap_type2value(zero, T) +rzero(::Type{T}) where {T} = zero(T) +rzero(::Type{Tuple{}}) = () +rzero(::Type{T}) where {E, T <: Tuple{E}} = (rzero(E),) +rzero(::Type{T}) where {T <: Tuple} = + (rzero(first_param(T)), rzero(tail_params(T))...) +rzero(::Type{Tup}) where {names, T, Tup <: NamedTuple{names, T}} = + NamedTuple{names}(rzero(T)) """ rmul(X, Y) From 25a7de07c95b9e4e29d9a5c0d13f0cebc70bcb81 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 16 Jun 2023 11:55:13 -0700 Subject: [PATCH 3/9] Separate 1 and 2 argument functions --- src/RecursiveApply/RecursiveApply.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index 95f23e54bf..b234384ca6 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -47,11 +47,12 @@ rmap_Tuple(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} = Recursively apply `fn` to each element of `X` """ rmap(fn::F, X) where {F} = fn(X) -rmap(fn::F, X, Y) where {F} = fn(X, Y) rmap(fn::F, X::Tuple) where {F} = rmap_tuple(fn, X) -rmap(fn::F, X::Tuple, Y::Tuple) where {F} = rmap_tuple(fn, X, Y) rmap(fn::F, X::NamedTuple{names}) where {F, names} = NamedTuple{names}(rmap(fn, Tuple(X))) + +rmap(fn::F, X, Y) where {F} = fn(X, Y) +rmap(fn::F, X::Tuple, Y::Tuple) where {F} = rmap_tuple(fn, X, Y) rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} = NamedTuple{names}(rmap(fn, Tuple(X), Tuple(Y))) @@ -68,12 +69,13 @@ Recursively apply `fn` to each type parameter of the type `T`, or to each type parameter of the types `T1` and `T2`, where `fn` returns a type. """ rmaptype(fn::F, ::Type{T}) where {F, T} = fn(T) -rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1, T2} = fn(T1, T2) rmaptype(fn::F, ::Type{T}) where {F, T <: Tuple} = Tuple{rmap_Tuple(fn, T)...} -rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} = - Tuple{rmap_Tuple(fn, T1, T2)...} rmaptype(fn::F, ::Type{T}) where {F, names, Tup, T <: NamedTuple{names, Tup}} = NamedTuple{names, rmaptype(fn, Tup)} + +rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1, T2} = fn(T1, T2) +rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} = + Tuple{rmap_Tuple(fn, T1, T2)...} rmaptype( fn::F, ::Type{T1}, From a77ff83f2d038ea73bfa27fdd7520021e735c458 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 16 Jun 2023 12:23:53 -0700 Subject: [PATCH 4/9] Follow existing recursion pattern with rmap --- src/RecursiveApply/RecursiveApply.jl | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index b234384ca6..aa318032f1 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -15,17 +15,6 @@ export ⊞, ⊠, ⊟ @generated tail_params(::Type{T}) where {T} = :($(Tuple{Base.tail((T.parameters...,))...})) -# This is a type-stable version of map(x -> rmap(fn, x), X) or -# map((x, y) -> rmap(fn, x, y), X, Y). -rmap_tuple(fn::F, X) where {F} = - isempty(X) ? () : (rmap(fn, first(X)), rmap_tuple(fn, Base.tail(X))...) -rmap_tuple(fn::F, X, Y) where {F} = - isempty(X) || isempty(Y) ? () : - ( - rmap(fn, first(X), first(Y)), - rmap_tuple(fn, Base.tail(X), Base.tail(Y))..., - ) - # This is a type-stable version of map(T′ -> rmaptype(fn, T′), T.parameters) or # map((T1′, T2′) -> rmaptype(fn, T1′, T2′), T1.parameters, T2.parameters), where rmap_Tuple(fn::F, ::Type{Tuple{}}) where {F} = () @@ -47,12 +36,20 @@ rmap_Tuple(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} = Recursively apply `fn` to each element of `X` """ rmap(fn::F, X) where {F} = fn(X) -rmap(fn::F, X::Tuple) where {F} = rmap_tuple(fn, X) +rmap(fn::F, X::Tuple{}) where {F} = () +rmap(fn::F, X::Tuple) where {F} = + (rmap(fn, first(X)), rmap(fn, Base.tail(X))...) rmap(fn::F, X::NamedTuple{names}) where {F, names} = NamedTuple{names}(rmap(fn, Tuple(X))) rmap(fn::F, X, Y) where {F} = fn(X, Y) -rmap(fn::F, X::Tuple, Y::Tuple) where {F} = rmap_tuple(fn, X, Y) +rmap(fn::F, X::Tuple{}, Y::Tuple{}) where {F} = () +rmap(fn::F, X::Tuple{}, Y::Tuple) where {F} = () +rmap(fn::F, X::Tuple, Y::Tuple{}) where {F} = () +rmap(fn::F, X::Tuple, Y::Tuple) where {F} = ( + rmap(fn, first(X), first(Y)), + rmap(fn, Base.tail(X), Base.tail(Y))..., + ) rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} = NamedTuple{names}(rmap(fn, Tuple(X), Tuple(Y))) From 824c240ac112d097240517bd8bd5f06b85f844bb Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 16 Jun 2023 12:25:55 -0700 Subject: [PATCH 5/9] Rename rmap_Tuple to rmaptype_Tuple --- src/RecursiveApply/RecursiveApply.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index aa318032f1..5619150d9c 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -17,17 +17,17 @@ export ⊞, ⊠, ⊟ # This is a type-stable version of map(T′ -> rmaptype(fn, T′), T.parameters) or # map((T1′, T2′) -> rmaptype(fn, T1′, T2′), T1.parameters, T2.parameters), where -rmap_Tuple(fn::F, ::Type{Tuple{}}) where {F} = () -rmap_Tuple(fn::F, ::Type{T}) where {F, T <: Tuple} = - (rmaptype(fn, first_param(T)), rmap_Tuple(rmaptype, fn, tail_params(T))...) +rmaptype_Tuple(fn::F, ::Type{Tuple{}}) where {F} = () +rmaptype_Tuple(fn::F, ::Type{T}) where {F, T <: Tuple} = + (rmaptype(fn, first_param(T)), rmaptype_Tuple(rmaptype, fn, tail_params(T))...) -rmap_Tuple(_, ::Type{Tuple{}}, ::Type{T}) where {T <: Tuple} = () -rmap_Tuple(_, ::Type{T}, ::Type{Tuple{}}) where {T <: Tuple} = () +rmaptype_Tuple(_, ::Type{Tuple{}}, ::Type{T}) where {T <: Tuple} = () +rmaptype_Tuple(_, ::Type{T}, ::Type{Tuple{}}) where {T <: Tuple} = () -rmap_Tuple(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} = +rmaptype_Tuple(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} = ( rmaptype(fn, first_param(T1), first_param(T2)), - rmap_Tuple(rmaptype, fn, tail_params(T1), tail_params(T2))..., + rmaptype_Tuple(rmaptype, fn, tail_params(T1), tail_params(T2))..., ) """ @@ -66,13 +66,13 @@ Recursively apply `fn` to each type parameter of the type `T`, or to each type parameter of the types `T1` and `T2`, where `fn` returns a type. """ rmaptype(fn::F, ::Type{T}) where {F, T} = fn(T) -rmaptype(fn::F, ::Type{T}) where {F, T <: Tuple} = Tuple{rmap_Tuple(fn, T)...} +rmaptype(fn::F, ::Type{T}) where {F, T <: Tuple} = Tuple{rmaptype_Tuple(fn, T)...} rmaptype(fn::F, ::Type{T}) where {F, names, Tup, T <: NamedTuple{names, Tup}} = NamedTuple{names, rmaptype(fn, Tup)} rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1, T2} = fn(T1, T2) rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} = - Tuple{rmap_Tuple(fn, T1, T2)...} + Tuple{rmaptype_Tuple(fn, T1, T2)...} rmaptype( fn::F, ::Type{T1}, From 3294bd41da0a298fa128c8027813e79bb17c3de7 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 16 Jun 2023 13:38:11 -0700 Subject: [PATCH 6/9] Add and fix test --- src/RecursiveApply/RecursiveApply.jl | 11 +++++++---- test/RecursiveApply/recursive_apply.jl | 5 +++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index 5619150d9c..b7b4d11511 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -15,15 +15,18 @@ export ⊞, ⊠, ⊟ @generated tail_params(::Type{T}) where {T} = :($(Tuple{Base.tail((T.parameters...,))...})) -# This is a type-stable version of map(T′ -> rmaptype(fn, T′), T.parameters) or -# map((T1′, T2′) -> rmaptype(fn, T1′, T2′), T1.parameters, T2.parameters), where +# Applying `rmaptype` returns `Tuple{...}` for tuple +# types, which cannot follow the recursion pattern as +# it cannot be splatted, so we add a separate method, +# `rmaptype_Tuple`, for the part of the recursion. rmaptype_Tuple(fn::F, ::Type{Tuple{}}) where {F} = () +rmaptype_Tuple(fn::F, ::Type{T}) where {F, E, T <: Tuple{E}} = + (rmaptype(fn, first_param(T)), ) rmaptype_Tuple(fn::F, ::Type{T}) where {F, T <: Tuple} = - (rmaptype(fn, first_param(T)), rmaptype_Tuple(rmaptype, fn, tail_params(T))...) + (rmaptype(fn, first_param(T)), rmaptype_Tuple(fn, tail_params(T))...) rmaptype_Tuple(_, ::Type{Tuple{}}, ::Type{T}) where {T <: Tuple} = () rmaptype_Tuple(_, ::Type{T}, ::Type{Tuple{}}) where {T <: Tuple} = () - rmaptype_Tuple(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} = ( rmaptype(fn, first_param(T1), first_param(T2)), diff --git a/test/RecursiveApply/recursive_apply.jl b/test/RecursiveApply/recursive_apply.jl index 0e98f30475..cf9966505c 100644 --- a/test/RecursiveApply/recursive_apply.jl +++ b/test/RecursiveApply/recursive_apply.jl @@ -71,4 +71,9 @@ end @test typeof(rz) == nt @inferred RecursiveApply.rmap(RecursiveApply.rzero, nt) end + for nt in nested_types + rz = RecursiveApply.rmaptype(identity, nt) + @test rz == nt + @inferred RecursiveApply.rmaptype(zero, nt) + end end From af8b6c7898ea1c15466f3d270a3ca36fe7ef5278 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 16 Jun 2023 13:39:06 -0700 Subject: [PATCH 7/9] Apply formatting --- src/RecursiveApply/RecursiveApply.jl | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index b7b4d11511..e42aec00ca 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -21,17 +21,20 @@ export ⊞, ⊠, ⊟ # `rmaptype_Tuple`, for the part of the recursion. rmaptype_Tuple(fn::F, ::Type{Tuple{}}) where {F} = () rmaptype_Tuple(fn::F, ::Type{T}) where {F, E, T <: Tuple{E}} = - (rmaptype(fn, first_param(T)), ) + (rmaptype(fn, first_param(T)),) rmaptype_Tuple(fn::F, ::Type{T}) where {F, T <: Tuple} = (rmaptype(fn, first_param(T)), rmaptype_Tuple(fn, tail_params(T))...) rmaptype_Tuple(_, ::Type{Tuple{}}, ::Type{T}) where {T <: Tuple} = () rmaptype_Tuple(_, ::Type{T}, ::Type{Tuple{}}) where {T <: Tuple} = () -rmaptype_Tuple(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} = - ( - rmaptype(fn, first_param(T1), first_param(T2)), - rmaptype_Tuple(rmaptype, fn, tail_params(T1), tail_params(T2))..., - ) +rmaptype_Tuple( + fn::F, + ::Type{T1}, + ::Type{T2}, +) where {F, T1 <: Tuple, T2 <: Tuple} = ( + rmaptype(fn, first_param(T1), first_param(T2)), + rmaptype_Tuple(rmaptype, fn, tail_params(T1), tail_params(T2))..., +) """ rmap(fn, X...) @@ -49,10 +52,8 @@ rmap(fn::F, X, Y) where {F} = fn(X, Y) rmap(fn::F, X::Tuple{}, Y::Tuple{}) where {F} = () rmap(fn::F, X::Tuple{}, Y::Tuple) where {F} = () rmap(fn::F, X::Tuple, Y::Tuple{}) where {F} = () -rmap(fn::F, X::Tuple, Y::Tuple) where {F} = ( - rmap(fn, first(X), first(Y)), - rmap(fn, Base.tail(X), Base.tail(Y))..., - ) +rmap(fn::F, X::Tuple, Y::Tuple) where {F} = + (rmap(fn, first(X), first(Y)), rmap(fn, Base.tail(X), Base.tail(Y))...) rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} = NamedTuple{names}(rmap(fn, Tuple(X), Tuple(Y))) @@ -69,7 +70,8 @@ Recursively apply `fn` to each type parameter of the type `T`, or to each type parameter of the types `T1` and `T2`, where `fn` returns a type. """ rmaptype(fn::F, ::Type{T}) where {F, T} = fn(T) -rmaptype(fn::F, ::Type{T}) where {F, T <: Tuple} = Tuple{rmaptype_Tuple(fn, T)...} +rmaptype(fn::F, ::Type{T}) where {F, T <: Tuple} = + Tuple{rmaptype_Tuple(fn, T)...} rmaptype(fn::F, ::Type{T}) where {F, names, Tup, T <: NamedTuple{names, Tup}} = NamedTuple{names, rmaptype(fn, Tup)} From ba5c5567d3be389bed1969adbe708258225cb12f Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 16 Jun 2023 13:40:37 -0700 Subject: [PATCH 8/9] Fix dss inference --- src/Spaces/dss.jl | 7 +++---- test/Operators/spectralelement/benchmark_ops.jl | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/Spaces/dss.jl b/src/Spaces/dss.jl index abb955dabe..b5df5db645 100644 --- a/src/Spaces/dss.jl +++ b/src/Spaces/dss.jl @@ -836,7 +836,7 @@ function dss_local_vertices!( sum_data = mapreduce( ⊞, vertex; - init = RecursiveApply.rmap(zero, slab(perimeter_data, 1, 1)[1]), + init = RecursiveApply.rzero(eltype(slab(perimeter_data, 1, 1))), ) do (lidx, vert) ip = Topologies.perimeter_vertex_node_index(vert) perimeter_slab = slab(perimeter_data, level, lidx) @@ -906,9 +906,8 @@ function dss_local_ghost!( sum_data = mapreduce( ⊞, vertex; - init = RecursiveApply.rmap( - zero, - slab(perimeter_data, 1, 1)[1], + init = RecursiveApply.rzero( + eltype(slab(perimeter_data, 1, 1)), ), ) do (isghost, idx, vert) ip = Topologies.perimeter_vertex_node_index(vert) diff --git a/test/Operators/spectralelement/benchmark_ops.jl b/test/Operators/spectralelement/benchmark_ops.jl index 2dd94ab424..f11469dd22 100644 --- a/test/Operators/spectralelement/benchmark_ops.jl +++ b/test/Operators/spectralelement/benchmark_ops.jl @@ -138,7 +138,7 @@ using JET p = @allocated kernel_complicated_field_dss!(kernel_args) @test p == 0 p = @allocated kernel_complicated_field2_dss!(kernel_args) - @test_broken p == 0 + @test p == 0 # Inference tests JET.@test_opt kernel_scalar_dss!(kernel_args) JET.@test_opt kernel_vector_dss!(kernel_args) @@ -146,6 +146,6 @@ using JET JET.@test_opt kernel_ntuple_field_dss!(kernel_args) JET.@test_opt kernel_ntuple_floats_dss!(kernel_args) JET.@test_opt kernel_complicated_field_dss!(kernel_args) - # JET.@test_opt kernel_complicated_field2_dss!(kernel_args) # fails + JET.@test_opt kernel_complicated_field2_dss!(kernel_args) end end From 5c8bd0e472322335cd6be354f200a8206f30d948 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 16 Jun 2023 14:11:08 -0700 Subject: [PATCH 9/9] Fixes --- src/RecursiveApply/RecursiveApply.jl | 3 ++- test/RecursiveApply/recursive_apply.jl | 11 +++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index e42aec00ca..2309cfa0ce 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -25,6 +25,7 @@ rmaptype_Tuple(fn::F, ::Type{T}) where {F, E, T <: Tuple{E}} = rmaptype_Tuple(fn::F, ::Type{T}) where {F, T <: Tuple} = (rmaptype(fn, first_param(T)), rmaptype_Tuple(fn, tail_params(T))...) +rmaptype_Tuple(_, ::Type{Tuple{}}, ::Type{Tuple{}}) = () rmaptype_Tuple(_, ::Type{Tuple{}}, ::Type{T}) where {T <: Tuple} = () rmaptype_Tuple(_, ::Type{T}, ::Type{Tuple{}}) where {T <: Tuple} = () rmaptype_Tuple( @@ -33,7 +34,7 @@ rmaptype_Tuple( ::Type{T2}, ) where {F, T1 <: Tuple, T2 <: Tuple} = ( rmaptype(fn, first_param(T1), first_param(T2)), - rmaptype_Tuple(rmaptype, fn, tail_params(T1), tail_params(T2))..., + rmaptype_Tuple(fn, tail_params(T1), tail_params(T2))..., ) """ diff --git a/test/RecursiveApply/recursive_apply.jl b/test/RecursiveApply/recursive_apply.jl index cf9966505c..e77c824c73 100644 --- a/test/RecursiveApply/recursive_apply.jl +++ b/test/RecursiveApply/recursive_apply.jl @@ -70,10 +70,17 @@ end rz = RecursiveApply.rmap(RecursiveApply.rzero, nt) @test typeof(rz) == nt @inferred RecursiveApply.rmap(RecursiveApply.rzero, nt) - end - for nt in nested_types + + rz = RecursiveApply.rmap((x, y) -> RecursiveApply.rzero(x), nt, nt) + @test typeof(rz) == nt + @inferred RecursiveApply.rmap((x, y) -> RecursiveApply.rzero(x), nt, nt) + rz = RecursiveApply.rmaptype(identity, nt) @test rz == nt @inferred RecursiveApply.rmaptype(zero, nt) + + rz = RecursiveApply.rmaptype((x, y) -> identity(x), nt, nt) + @test rz == nt + @inferred RecursiveApply.rmaptype((x, y) -> zero(x), nt, nt) end end