From 90b3999d9cbdbc82423bf4b82f2a29b15ea0b4ce Mon Sep 17 00:00:00 2001 From: Simon Byrne Date: Tue, 20 Jun 2023 10:03:03 -0700 Subject: [PATCH 1/2] define fill on Fields, remove TU.FieldFromNamedTuple --- src/Fields/Fields.jl | 18 ++++++++++++ src/Fields/broadcast.jl | 2 -- test/Fields/field.jl | 28 +++++++++---------- test/Fields/field_opt.jl | 12 ++++---- .../spectralelement/benchmark_utils.jl | 9 ++---- test/TestUtilities/TestUtilities.jl | 5 ---- 6 files changed, 40 insertions(+), 34 deletions(-) diff --git a/src/Fields/Fields.jl b/src/Fields/Fields.jl index f0fc54c2fe..025be6a020 100644 --- a/src/Fields/Fields.jl +++ b/src/Fields/Fields.jl @@ -230,6 +230,24 @@ function Base.copyto!(dest::Field{V, M}, src::Field{V, M}) where {V, M} return dest end +""" + fill!(field::Field, value) + +Fill `field` with `value`. +""" +function Base.fill!(field::Field, value) + fill!(field_values(field), value) + return field +end +""" + fill(value, space::AbstractSpace) + +Create a new `Field` on `space` and fill it with `value`. +""" +function Base.fill(value::FT, space::AbstractSpace) where {FT} + field = Field(FT, space) + return fill!(field, value) +end """ zeros(space::AbstractSpace) diff --git a/src/Fields/broadcast.jl b/src/Fields/broadcast.jl index fa1d1753ad..14cb646eb4 100644 --- a/src/Fields/broadcast.jl +++ b/src/Fields/broadcast.jl @@ -442,5 +442,3 @@ function Base.Broadcast.copyto!(field::Field, nt::NamedTuple) ), ) end - -Base.fill!(field::Fields.Field, val) = field .= val diff --git a/test/Fields/field.jl b/test/Fields/field.jl index a18ccb24dd..bbd8ae1636 100644 --- a/test/Fields/field.jl +++ b/test/Fields/field.jl @@ -98,7 +98,7 @@ end @testset "Constructing & broadcasting over empty fields" begin FT = Float32 for space in TU.all_spaces(FT) - f = TU.FieldFromNamedTuple(space, (;)) + f = fill((;), space) @. f += f end @@ -112,7 +112,7 @@ end @test_broken true end end - empty_field(space) = TU.FieldFromNamedTuple(space, (;)) + empty_field(space) = fill((;), space) # Broadcasting over the wrong size should error test_broken_throws(empty_field(TU.PointSpace(FT))) @@ -294,7 +294,7 @@ end FT = Float64 nt = (; x = FT(0), y = FT(0), tup = ntuple(i -> (; a = FT(1), b = FT(1)), 2)) - Y = TU.FieldFromNamedTuple(space, nt) + Y = fill(nt, space) prop_chains = Fields.property_chains(Y) @test prop_chains[1] == (:x,) @@ -321,7 +321,7 @@ end ClimaCore.Fields.truncate_printing_field_types() = true @testset "Truncated printing" begin nt = (; x = Float64(0), y = Float64(0)) - Y = TU.FieldFromNamedTuple(spectral_space_2D(), nt) + Y = fill(nt, spectral_space_2D()) @test sprint(show, typeof(Y); context = IOContext(stdout)) == "Field{(:x, :y)} (trunc disp)" end @@ -329,7 +329,7 @@ ClimaCore.Fields.truncate_printing_field_types() = false @testset "Standard printing" begin nt = (; x = Float64(0), y = Float64(0)) - Y = TU.FieldFromNamedTuple(spectral_space_2D(), nt) + Y = fill(nt, spectral_space_2D()) s = sprint(show, typeof(Y)) # just make sure this doesn't break end @@ -337,7 +337,7 @@ end space = spectral_space_2D() FT = Float64 nt = (; x = FT(0), y = FT(0)) - Y = TU.FieldFromNamedTuple(space, nt) + Y = fill(nt, space) foo(local_geom) = sin(local_geom.coordinates.x * local_geom.coordinates.y) + 3 Fields.set!(foo, Y.x) @@ -373,7 +373,7 @@ end FT = Float64 for space in TU.all_spaces(FT) TU.levelable(space) || continue - Y = TU.FieldFromNamedTuple(space, (; x = FT(2))) + Y = fill((; x = FT(2)), space) lg_space = Spaces.level(space, TU.fc_index(1, space)) lg_field_space = axes(Fields.level(Y, TU.fc_index(1, space))) @test all( @@ -388,14 +388,14 @@ end FT = Float64 for space in TU.all_spaces(FT) if space isa Spaces.SpectralElementSpace1D - Y = TU.FieldFromNamedTuple(space, (; x = FT(1))) + Y = fill((; x = FT(1)), space) point_space_from_field = axes(Fields.column(Y.x, 1, 1)) point_space = Spaces.column(space, 1, 1) @test Fields.ones(point_space) == Fields.ones(point_space_from_field) end if space isa Spaces.SpectralElementSpace2D - Y = TU.FieldFromNamedTuple(space, (; x = FT(1))) + Y = fill((; x = FT(1)), space) point_space_from_field = axes(Fields.column(Y.x, 1, 1, 1)) point_space = Spaces.column(space, 1, 1, 1) @test Fields.ones(point_space) == @@ -424,7 +424,7 @@ end for space in TU.all_spaces(FT) # Filter out spaces without z coordinates: TU.has_z_coordinates(space) || continue - Y = TU.FieldFromNamedTuple(space, (; x = FT(1))) + Y = fill((; x = FT(1)), space) ᶜz_surf = Spaces.level(Fields.coordinate_field(Y).z, TU.fc_index(1, space)) ᶜx_surf = copy(Spaces.level(Y.x, TU.fc_index(1, space))) @@ -434,7 +434,7 @@ end # Skip spaces incompatible with Fields.bycolumn: TU.bycolumnable(space) || continue - Yc = TU.FieldFromNamedTuple(space, (; x = FT(1))) + Yc = fill((; x = FT(1)), space) column_surface_bc!(Yc.x, ᶜz_surf, ᶜx_surf) @test Y.x == Yc.x nothing @@ -491,7 +491,7 @@ Base.broadcastable(x::InferenceFoo) = Ref(x) foo = InferenceFoo(2.0) for space in TU.all_spaces(FT) - Y = TU.FieldFromNamedTuple(space, (; a = FT(0), b = FT(1))) + Y = fill((; a = FT(0), b = FT(1)), space) @test_throws ErrorException("type InferenceFoo has no field bingo") FieldFromNamedTupleBroken( space, ics_foo, @@ -632,7 +632,7 @@ convergence_rate(err, Δh) = # Skip spaces incompatible with Fields.bycolumn: TU.bycolumnable(space) || continue - Y = TU.FieldFromNamedTuple(space, (; y = FT(1))) + Y = fill((; y = FT(1)), space) zcf = Fields.coordinate_field(Y.y).z Δz = Fields.Δz_field(axes(zcf)) Δz_col = Δz[Fields.ColumnIndex((1, 1), 1)] @@ -675,7 +675,7 @@ ClimaCore.enable_threading() = false # launching threads allocates for space in TU.all_spaces(FT) # Filter out spaces without z coordinates: TU.has_z_coordinates(space) || continue - Y = TU.FieldFromNamedTuple(space, (; y = FT(1))) + Y = fill((; y = FT(1)), space) zcf = Fields.coordinate_field(Y.y).z ∫y = Spaces.level(similar(Y.y), TU.fc_index(1, space)) ∫y .= 0 diff --git a/test/Fields/field_opt.jl b/test/Fields/field_opt.jl index e23cccea6a..46f80cc1d6 100644 --- a/test/Fields/field_opt.jl +++ b/test/Fields/field_opt.jl @@ -32,7 +32,7 @@ import .TestUtilities as TU end for space in TU.all_spaces(FT) TU.bycolumnable(space) || continue - Y = TU.FieldFromNamedTuple(space, (; x = FT(2))) + Y = fill((; x = FT(2)), space) # Plain broadcast Yx = Y.x @@ -56,11 +56,11 @@ end nothing end function callfill!(Y) - fill!(Y, ((; x = 2.0),)) + fill!(Y, (; x = 2.0)) nothing end for space in TU.all_spaces(FT) - Y = TU.FieldFromNamedTuple(space, (; x = FT(2))) + Y = fill((; x = FT(2)), space) allocs_test!(Y) p = @allocated allocs_test!(Y) @test p == 0 @@ -127,7 +127,7 @@ end @testset "Allocations StencilCoefs broadcasting" begin FT = Float64 for space in TU.all_spaces(FT) - Y = TU.FieldFromNamedTuple(space, (; x = sc(FT))) + Y = fill((; x = sc(FT)), space) allocs_test1!(Y) p = @allocated allocs_test1!(Y) @test p == 0 @@ -206,8 +206,8 @@ end FT = Float64 for space in TU.all_spaces(FT) Y = Fields.FieldVector(; - c = TU.FieldFromNamedTuple(space, (; x = FT(0))), - f = TU.FieldFromNamedTuple(space, (; x = FT(0))), + c = fill((; x = FT(0)), space), + f = fill((; x = FT(0)), space), ) Y .= 0 # compile first diff --git a/test/Operators/spectralelement/benchmark_utils.jl b/test/Operators/spectralelement/benchmark_utils.jl index 58fa5df090..0d68d24ab9 100644 --- a/test/Operators/spectralelement/benchmark_utils.jl +++ b/test/Operators/spectralelement/benchmark_utils.jl @@ -236,16 +236,11 @@ function setup_kernel_args(ARGS::Vector{String} = ARGS) end, ) - function FieldFromNamedTuple(space, nt::NamedTuple) - cmv(z) = nt - return cmv.(Fields.coordinate_field(space)) - end - ϕψ = combine.(ϕ, ψ) nt_ϕψ = combine_nt.(ϕ, ψ) nt_ϕψ_ft = combine_nt_ft.(ϕ) - f_comp = FieldFromNamedTuple(space, complicated_field(FT)) - f_comp2 = FieldFromNamedTuple(space, complicated_field2(FT)) + f_comp = fill(complicated_field(FT), space) + f_comp2 = fill(complicated_field2(FT), space) u = initial_velocity(space) du = initial_velocity(space) ϕ_buffer = Spaces.create_dss_buffer(ϕ) diff --git a/test/TestUtilities/TestUtilities.jl b/test/TestUtilities/TestUtilities.jl index c1533078c8..200dfc75bb 100644 --- a/test/TestUtilities/TestUtilities.jl +++ b/test/TestUtilities/TestUtilities.jl @@ -192,9 +192,4 @@ fc_index( has_z_coordinates(space) = :z in propertynames(Spaces.coordinates_data(space)) -function FieldFromNamedTuple(space, nt::NamedTuple) - cmv(z) = nt - return cmv.(Fields.coordinate_field(space)) -end - end From d4fe3c1ea86ad3ddfd4eeb2bc8e92e91e6493630 Mon Sep 17 00:00:00 2001 From: Simon Byrne Date: Thu, 22 Jun 2023 10:55:36 -0700 Subject: [PATCH 2/2] define fill for CuArray-backed DataLayouts --- src/DataLayouts/cuda.jl | 43 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/DataLayouts/cuda.jl b/src/DataLayouts/cuda.jl index bc57000a74..367208759e 100644 --- a/src/DataLayouts/cuda.jl +++ b/src/DataLayouts/cuda.jl @@ -71,6 +71,19 @@ function knl_copyto!(dest, src) return nothing end +function knl_fill!(dest, val) + i = CUDA.threadIdx().x + j = CUDA.threadIdx().y + + h = CUDA.blockIdx().x + v = CUDA.blockIdx().y + + I = CartesianIndex((i, j, 1, v, h)) + + @inbounds dest[I] = val + return nothing +end + function Base.copyto!( dest::IJFH{S, Nij}, bc::Union{IJFH{S, Nij, A}, Base.Broadcast.Broadcasted{IJFHStyle{Nij, A}}}, @@ -81,6 +94,18 @@ function Base.copyto!( end return dest end +function Base.fill!( + dest::IJFH{S, Nij, A}, + val, +) where {S, Nij, A <: CUDA.CuArray} + _, _, _, _, Nh = size(bc) + if Nh > 0 + CUDA.@cuda threads = (Nij, Nij) blocks = (Nh, 1) knl_fill!(dest, val) + end + return dest +end + + function Base.copyto!( dest::VIJFH{S, Nij}, @@ -92,6 +117,17 @@ function Base.copyto!( end return dest end +function Base.fill!( + dest::VIJFH{S, Nij, A}, + val, +) where {S, Nij, A <: CUDA.CuArray} + _, _, _, Nv, Nh = size(bc) + if Nv > 0 && Nh > 0 + CUDA.@cuda threads = (Nij, Nij) blocks = (Nh, Nv) knl_fill!(dest, val) + end + return dest +end + function Base.copyto!( dest::VF{S}, @@ -103,3 +139,10 @@ function Base.copyto!( end return dest end +function Base.fill!(dest::VF{S, A}, val) where {S, A <: CUDA.CuArray} + _, _, _, Nv, Nh = size(bc) + if Nv > 0 && Nh > 0 + CUDA.@cuda threads = (1, 1) blocks = (Nh, Nv) knl_fill!(dest, val) + end + return dest +end