Skip to content

Commit

Permalink
extend dss functions for FieldVectors
Browse files Browse the repository at this point in the history
  • Loading branch information
juliasloan25 committed Sep 24, 2024
1 parent 2c12dd3 commit 9075268
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 1 deletion.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ ClimaCore.jl Release Notes
main
-------

- Extended `create_dss_buffer` and `weighted_dss!` for `FieldVector`s, rather than
just `Field`s. PR [#2000](https://github.com/CliMA/ClimaCore.jl/pull/2000).

v0.14.16
-------

Expand Down
2 changes: 1 addition & 1 deletion src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ function interpcoord(elemrange, x::Real)
end

"""
Spaces.weighted_dss!(f::Field[, ghost_buffer = Spaces.create_dss_buffer(field)])
Spaces.weighted_dss!(f::Field, dss_buffer = Spaces.create_dss_buffer(field))
Apply weighted direct stiffness summation (DSS) to `f`. This operates in-place
(i.e. it modifies the `f`). `ghost_buffer` contains the necessary information
Expand Down
35 changes: 35 additions & 0 deletions src/Fields/fieldvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,41 @@ end
return dest
end

"""
Spaces.create_dss_buffer(fv::FieldVector)
Create a NamedTuple of buffers for communicating neighbour information of
each Field in `fv`. In this NamedTuple, the name of each field is mapped
to the buffer.
"""
function Spaces.create_dss_buffer(fv::FieldVector)
NamedTuple{propertynames(fv)}(
map(
key -> Spaces.create_dss_buffer(getproperty(fv, key)),
propertynames(fv),
),
)
end

"""
Spaces.weighted_dss!(fv::FieldVector, dss_buffer = Spaces.create_dss_buffer(fv))
Apply weighted direct stiffness summation (DSS) to each field in `fv`.
If a `dss_buffer` object is not provided, a buffer will be created for each
field in `fv`.
Note that using the `Pair` interface here parallelizes the `weighted_dss!` calls.
"""
function Spaces.weighted_dss!(
fv::FieldVector,
dss_buffer = Spaces.create_dss_buffer(fv),
)
pairs = map(propertynames(fv)) do key
Pair(getproperty(fv, key), getproperty(dss_buffer, key))
end
Spaces.weighted_dss!(pairs...)
end


# Recursively call transform_bc_args() on broadcast arguments in a way that is statically reducible by the optimizer
# see Base.Broadcast.preprocess_args
@inline transform_bc_args(args::Tuple, inds...) = (
Expand Down
45 changes: 45 additions & 0 deletions test/Fields/field_opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -393,4 +393,49 @@ using JET
@test_opt ifelsekernel!(S, ρ)
end

@testset "dss of FieldVectors" begin
function field_vec(center_space, face_space)
Y = Fields.FieldVector(
c = map(Fields.coordinate_field(center_space)) do coord
FT = Spaces.undertype(center_space)
(;
ρ = FT(coord.lat + coord.long),
uₕ = Geometry.Covariant12Vector(
FT(coord.lat),
FT(coord.long),
),
)
end,
f = map(Fields.coordinate_field(face_space)) do coord
FT = Spaces.undertype(face_space)
(; w = Geometry.Covariant3Vector(FT(coord.lat + coord.long)))
end,
)
return Y
end

fv = field_vec(toy_sphere(Float64)...)

c_copy = copy(getproperty(fv, :c))
f_copy = copy(getproperty(fv, :f))

# Test that dss_buffer is created and has the correct keys
dss_buffer = Spaces.create_dss_buffer(fv)
@test haskey(dss_buffer, :c)
@test haskey(dss_buffer, :f)

# Test weighted_dss! with and without preallocated buffer
Spaces.weighted_dss!(fv, dss_buffer)
@test getproperty(fv, :c) Spaces.weighted_dss!(c_copy)
@test getproperty(fv, :f) Spaces.weighted_dss!(f_copy)

fv = field_vec(toy_sphere(Float64)...)
c_copy = copy(getproperty(fv, :c))
f_copy = copy(getproperty(fv, :f))

Spaces.weighted_dss!(fv)
@test getproperty(fv, :c) Spaces.weighted_dss!(c_copy)
@test getproperty(fv, :f) Spaces.weighted_dss!(f_copy)
end

nothing

0 comments on commit 9075268

Please sign in to comment.