Skip to content

Commit

Permalink
Merge branch 'main' into ck/julia19_again
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jun 18, 2023
2 parents 4f5c5c5 + 3bf1ecb commit c511fd6
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 22 deletions.
4 changes: 2 additions & 2 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,8 @@ steps:
key: "gpu_implicit_stencil_float32"
command: "julia -O0 --color=yes --check-bounds=yes --project=test test/Operators/finitedifference/implicit_stencils.jl --float_type Float32"
agents:
slurm_time: 2:00:00
gpus: 1
slurm_mem: 20GB
slurm_gpus: 1

- group: "Unit: Hypsography"
steps:
Expand Down
79 changes: 70 additions & 9 deletions src/RecursiveApply/RecursiveApply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,52 @@ module RecursiveApply

export , ,

# These functions need to be generated for type stability (since T.parameters is
# a SimpleVector, the compiler cannot always infer its size and elements).
@generated first_param(::Type{T}) where {T} = :($(first(T.parameters)))
@generated tail_params(::Type{T}) where {T} =
:($(Tuple{Base.tail((T.parameters...,))...}))

# Applying `rmaptype` returns `Tuple{...}` for tuple
# types, which cannot follow the recursion pattern as
# it cannot be splatted, so we add a separate method,
# `rmaptype_Tuple`, for the part of the recursion.
rmaptype_Tuple(fn::F, ::Type{Tuple{}}) where {F} = ()
rmaptype_Tuple(fn::F, ::Type{T}) where {F, E, T <: Tuple{E}} =
(rmaptype(fn, first_param(T)),)
rmaptype_Tuple(fn::F, ::Type{T}) where {F, T <: Tuple} =
(rmaptype(fn, first_param(T)), rmaptype_Tuple(fn, tail_params(T))...)

rmaptype_Tuple(_, ::Type{Tuple{}}, ::Type{Tuple{}}) = ()
rmaptype_Tuple(_, ::Type{Tuple{}}, ::Type{T}) where {T <: Tuple} = ()
rmaptype_Tuple(_, ::Type{T}, ::Type{Tuple{}}) where {T <: Tuple} = ()
rmaptype_Tuple(
fn::F,
::Type{T1},
::Type{T2},
) where {F, T1 <: Tuple, T2 <: Tuple} = (
rmaptype(fn, first_param(T1), first_param(T2)),
rmaptype_Tuple(fn, tail_params(T1), tail_params(T2))...,
)

"""
rmap(fn, X...)
Recursively apply `fn` to each element of `X`
"""
rmap(fn::F, X) where {F} = fn(X)
rmap(fn::F, X::Tuple{}) where {F} = ()
rmap(fn::F, X::Tuple) where {F} =
(rmap(fn, first(X)), rmap(fn, Base.tail(X))...)
rmap(fn::F, X::NamedTuple{names}) where {F, names} =
NamedTuple{names}(rmap(fn, Tuple(X)))

rmap(fn::F, X, Y) where {F} = fn(X, Y)
rmap(fn::F, X::Tuple) where {F} = map(x -> rmap(fn, x), X)
rmap(fn, X::Tuple{}, Y::Tuple{}) = ()
rmap(fn::F, X::Tuple{}, Y::Tuple{}) where {F} = ()
rmap(fn::F, X::Tuple{}, Y::Tuple) where {F} = ()
rmap(fn::F, X::Tuple, Y::Tuple{}) where {F} = ()
rmap(fn::F, X::Tuple, Y::Tuple) where {F} =
(rmap(fn, first(X), first(Y)), rmap(fn, Base.tail(X), Base.tail(Y))...)
rmap(fn::F, X::NamedTuple{names}) where {F, names} =
NamedTuple{names}(rmap(fn, Tuple(X)))
rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} =
NamedTuple{names}(rmap(fn, Tuple(X), Tuple(Y)))

Expand All @@ -32,17 +65,45 @@ rmax(X, Y) = rmap(max, X, Y)

"""
rmaptype(fn, T)
rmaptype(fn, T1, T2)
The return type of `rmap(fn, X::T)`.
Recursively apply `fn` to each type parameter of the type `T`, or to each type
parameter of the types `T1` and `T2`, where `fn` returns a type.
"""
rmaptype(fn::F, ::Type{T}) where {F, T} = fn(T)
rmaptype(fn::F, ::Type{T}) where {F, T <: Tuple} =
Tuple{map(fn, tuple(T.parameters...))...}
Tuple{rmaptype_Tuple(fn, T)...}
rmaptype(fn::F, ::Type{T}) where {F, names, Tup, T <: NamedTuple{names, Tup}} =
NamedTuple{names, rmaptype(fn, Tup)}

rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1, T2} = fn(T1, T2)
rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} =
Tuple{rmaptype_Tuple(fn, T1, T2)...}
rmaptype(
fn::F,
::Type{T},
) where {F, T <: NamedTuple{names, tup}} where {names, tup} =
NamedTuple{names, rmaptype(fn, tup)}
::Type{T1},
::Type{T2},
) where {
F,
names,
Tup1,
Tup2,
T1 <: NamedTuple{names, Tup1},
T2 <: NamedTuple{names, Tup2},
} = NamedTuple{names, rmaptype(fn, Tup1, Tup2)}

"""
rzero(T)
Recursively compute the zero value of type `T`.
"""
rzero(::Type{T}) where {T} = zero(T)
rzero(::Type{Tuple{}}) = ()
rzero(::Type{T}) where {E, T <: Tuple{E}} = (rzero(E),)
rzero(::Type{T}) where {T <: Tuple} =
(rzero(first_param(T)), rzero(tail_params(T))...)
rzero(::Type{Tup}) where {names, T, Tup <: NamedTuple{names, T}} =
NamedTuple{names}(rzero(T))

"""
rmul(X, Y)
Expand Down
7 changes: 3 additions & 4 deletions src/Spaces/dss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ function dss_local_vertices!(
sum_data = mapreduce(
,
vertex;
init = RecursiveApply.rmap(zero, slab(perimeter_data, 1, 1)[1]),
init = RecursiveApply.rzero(eltype(slab(perimeter_data, 1, 1))),
) do (lidx, vert)
ip = Topologies.perimeter_vertex_node_index(vert)
perimeter_slab = slab(perimeter_data, level, lidx)
Expand Down Expand Up @@ -906,9 +906,8 @@ function dss_local_ghost!(
sum_data = mapreduce(
,
vertex;
init = RecursiveApply.rmap(
zero,
slab(perimeter_data, 1, 1)[1],
init = RecursiveApply.rzero(
eltype(slab(perimeter_data, 1, 1)),
),
) do (isghost, idx, vert)
ip = Topologies.perimeter_vertex_node_index(vert)
Expand Down
6 changes: 5 additions & 1 deletion test/Operators/finitedifference/implicit_stencils_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ function get_ref_value(op1, op2, a0, a1, a2)
return ref_value
end

function test_pointwise_stencils_compose(all_ops)
function test_pointwise_stencils_compose(all_ops; full = false)
(;
ops_F2C_S2S,
ops_C2F_S2S,
Expand All @@ -285,6 +285,10 @@ function test_pointwise_stencils_compose(all_ops)
(a_FS, a_FV, a_CV, ops_F2C_V2S, (ops_C2F_V2V..., ops_C2F_V2S...)),
(a_CS, a_CV, a_FV, ops_C2F_V2S, (ops_F2C_V2V..., ops_F2C_V2S...)),
)
if !full
op1s = op1s[1:2]
op2s = op2s[1:2]
end
for op1 in op1s
for op2 in op2s
GC.gc()
Expand Down
4 changes: 2 additions & 2 deletions test/Operators/spectralelement/benchmark_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,14 @@ using JET
p = @allocated kernel_complicated_field_dss!(kernel_args)
@test p == 0
p = @allocated kernel_complicated_field2_dss!(kernel_args)
@test_broken p == 0
@test p == 0
# Inference tests
JET.@test_opt kernel_scalar_dss!(kernel_args)
JET.@test_opt kernel_vector_dss!(kernel_args)
JET.@test_opt kernel_field_dss!(kernel_args)
JET.@test_opt kernel_ntuple_field_dss!(kernel_args)
JET.@test_opt kernel_ntuple_floats_dss!(kernel_args)
JET.@test_opt kernel_complicated_field_dss!(kernel_args)
# JET.@test_opt kernel_complicated_field2_dss!(kernel_args) # fails
JET.@test_opt kernel_complicated_field2_dss!(kernel_args)
end
end
46 changes: 46 additions & 0 deletions test/RecursiveApply/recursive_apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,49 @@ end
RecursiveApply.rmul(x, FT(2))
end
end

@testset "Highly nested types" begin
FT = Float64
nested_types = [
FT,
Tuple{FT, FT},
NamedTuple{(, ), Tuple{FT, FT}},
Tuple{
NamedTuple{(, ), Tuple{FT, FT}},
NamedTuple{(, ), Tuple{FT, FT}},
},
Tuple{FT, FT},
NamedTuple{
(, :uₕ, :ρe_tot, :ρq_tot, :sgs⁰, :sgsʲs),
Tuple{
FT,
Tuple{FT, FT},
FT,
FT,
NamedTuple{(:ρatke,), Tuple{FT}},
Tuple{NamedTuple{(:ρa, :ρae_tot, :ρaq_tot), Tuple{FT, FT, FT}}},
},
},
NamedTuple{
(:u₃, :sgsʲs),
Tuple{Tuple{FT}, Tuple{NamedTuple{(:u₃,), Tuple{Tuple{FT}}}}},
},
]
for nt in nested_types
rz = RecursiveApply.rmap(RecursiveApply.rzero, nt)
@test typeof(rz) == nt
@inferred RecursiveApply.rmap(RecursiveApply.rzero, nt)

rz = RecursiveApply.rmap((x, y) -> RecursiveApply.rzero(x), nt, nt)
@test typeof(rz) == nt
@inferred RecursiveApply.rmap((x, y) -> RecursiveApply.rzero(x), nt, nt)

rz = RecursiveApply.rmaptype(identity, nt)
@test rz == nt
@inferred RecursiveApply.rmaptype(zero, nt)

rz = RecursiveApply.rmaptype((x, y) -> identity(x), nt, nt)
@test rz == nt
@inferred RecursiveApply.rmaptype((x, y) -> zero(x), nt, nt)
end
end
8 changes: 4 additions & 4 deletions test/aqua.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ using Aqua
# If the number of ambiguities is less than the limit below,
# then please lower the limit based on the new number of ambiguities.
# We're trying to drive this number down to zero to reduce latency.
@test length(ambs) 15
# Uncomment for debugging:
# for method_ambiguity in ambs
# @show method_ambiguity
# end
for method_ambiguity in ambs
@show method_ambiguity
end
@test length(ambs) 16
end

@testset "Aqua tests (additional)" begin
Expand Down

0 comments on commit c511fd6

Please sign in to comment.