Skip to content

Commit

Permalink
Merge pull request #2072 from CliMA/ck/fv_unit
Browse files Browse the repository at this point in the history
Add fieldvector unit tests, update `rcompare`
  • Loading branch information
charleskawczynski authored Nov 5, 2024
2 parents 9a05ed2 + 46edf40 commit 36837c2
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 13 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ ClimaCore.jl Release Notes
main
-------

- A `strict = true` keyword was added to `rcompare`, which checks that the types match. If `strict = false`, then `rcompare` will return `true` for `FieldVector`s and `NamedTuple`s with the same properties but permuted order. For example:
- `rcompare((;a=1,b=2), (;b=2,a=1); strict = true)` will return `false` and
- `rcompare((;a=1,b=2), (;b=2,a=1); strict = false)` will return `true`
- We've added new datalayouts: `VIJHF`,`IJHF`,`IHF`,`VIHF`, to explore their performance compared to our existing datalayouts: `VIJFH`,`IJFH`,`IFH`,`VIFH`. PR [#2055](https://github.com/CliMA/ClimaCore.jl/pull/2053), PR [#2052](https://github.com/CliMA/ClimaCore.jl/pull/2055).
- We've refactored some modules to use less internals. PR [#2053](https://github.com/CliMA/ClimaCore.jl/pull/2053), PR [#2052](https://github.com/CliMA/ClimaCore.jl/pull/2052), [#2051](https://github.com/CliMA/ClimaCore.jl/pull/2051), [#2049](https://github.com/CliMA/ClimaCore.jl/pull/2049).
- Some work was done in attempt to reduce specializations and compile time. PR [#2042](https://github.com/CliMA/ClimaCore.jl/pull/2042), [#2041](https://github.com/CliMA/ClimaCore.jl/pull/2041)
Expand Down
44 changes: 31 additions & 13 deletions src/Fields/fieldvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,36 +469,54 @@ end


# Recursively compare contents of similar fieldvectors
_rcompare(pass, x::T, y::T) where {T <: Field} =
pass && _rcompare(pass, field_values(x), field_values(y))
_rcompare(pass, x::T, y::T) where {T <: DataLayouts.AbstractData} =
_rcompare(pass, x::T, y::T; strict) where {T <: Field} =
pass && _rcompare(pass, field_values(x), field_values(y); strict)
_rcompare(pass, x::T, y::T; strict) where {T <: DataLayouts.AbstractData} =
pass && (parent(x) == parent(y))
_rcompare(pass, x::T, y::T) where {T} = pass && (x == y)
_rcompare(pass, x::T, y::T; strict) where {T} = pass && (x == y)

function _rcompare(pass, x::T, y::T) where {T <: Union{FieldVector, NamedTuple}}
_rcompare(pass, x::NamedTuple, y::NamedTuple; strict) =
_rcompare_nt(pass, x, y; strict)
_rcompare(pass, x::FieldVector, y::FieldVector; strict) =
_rcompare_nt(pass, x, y; strict)

function _rcompare_nt(pass, x, y; strict)
length(propertynames(x)) length(propertynames(y)) && return false
if strict
typeof(x) == typeof(y) || return false
end
for pn in propertynames(x)
pass &= _rcompare(pass, getproperty(x, pn), getproperty(y, pn))
pass &= _rcompare(pass, getproperty(x, pn), getproperty(y, pn); strict)
end
return pass
end

"""
rcompare(x::T, y::T) where {T <: Union{FieldVector, NamedTuple}}
rcompare(x::T, y::T; strict = true) where {T <: Union{FieldVector, NamedTuple}}
Recursively compare given fieldvectors via `==`.
Returns `true` if `x == y` recursively.
FieldVectors with different types are considered different.
"""
rcompare(x::T, y::T) where {T <: Union{FieldVector, NamedTuple}} =
_rcompare(true, x, y)
rcompare(
x::T,
y::T;
strict = true,
) where {T <: Union{FieldVector, NamedTuple}} = _rcompare(true, x, y; strict)

rcompare(x::T, y::T) where {T <: FieldVector} = _rcompare(true, x, y)
rcompare(x::T, y::T; strict = true) where {T <: FieldVector} =
_rcompare(true, x, y; strict)

rcompare(x::T, y::T) where {T <: NamedTuple} = _rcompare(true, x, y)
rcompare(x::T, y::T; strict = true) where {T <: NamedTuple} =
_rcompare(true, x, y; strict)

# FieldVectors with different types are always different
rcompare(x::FieldVector, y::FieldVector) = false
rcompare(x::FieldVector, y::FieldVector; strict::Bool = true) =
strict ? false : _rcompare(true, x, y; strict)

rcompare(x::NamedTuple, y::NamedTuple; strict::Bool = true) =
strict ? false : _rcompare(true, x, y; strict)

# Define == to call rcompare for two fieldvectors
Base.:(==)(x::FieldVector, y::FieldVector) = rcompare(x, y)
Base.:(==)(x::FieldVector, y::FieldVector) = rcompare(x, y; strict = true)
42 changes: 42 additions & 0 deletions test/Fields/unit_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ ClimaComms.@import_required_backends
using OrderedCollections
using StaticArrays, IntervalSets
import ClimaCore
import ClimaCore.InputOutput
import ClimaCore.Utilities: PlusHalf
import ClimaCore.DataLayouts
import ClimaCore.DataLayouts: IJFH
Expand Down Expand Up @@ -330,6 +331,47 @@ end
@test occursin("==================== Difference found:", s)
end

@testset "Nested FieldVector broadcasting with permuted order" begin
FT = Float32
context = ClimaComms.context()
vertdomain = Domains.IntervalDomain(
Geometry.ZPoint{FT}(-3.5),
Geometry.ZPoint{FT}(0);
boundary_names = (:bottom, :top),
)
vertmesh = Meshes.IntervalMesh(vertdomain; nelems = 10)
device = ClimaComms.device()
vert_center_space = Spaces.CenterFiniteDifferenceSpace(device, vertmesh)
horzdomain = Domains.SphereDomain(FT(100))
horzmesh = Meshes.EquiangularCubedSphere(horzdomain, 1)
horztopology = Topologies.Topology2D(context, horzmesh)
quad = Spaces.Quadratures.GLL{2}()
space = Spaces.SpectralElementSpace2D(horztopology, quad)

vars1 = (; # order is different!
bucket = (; # nesting is needed!
T = Fields.Field(FT, space),
W = Fields.Field(FT, space),
)
)
vars2 = (; # order is different!
bucket = (; # nesting is needed!
W = Fields.Field(FT, space),
T = Fields.Field(FT, space),
)
)
Y1 = Fields.FieldVector(; vars1...)
Y1.bucket.T .= 280.0
Y1.bucket.W .= 0.05

Y2 = Fields.FieldVector(; vars2...)
Y2.bucket.T .= 280.0
Y2.bucket.W .= 0.05

Y1 .= Y2 # FieldVector broadcasting
@test Fields.rcompare(Y1, Y2; strict = false)
end

# https://github.com/CliMA/ClimaCore.jl/issues/1465
@testset "Diagonal FieldVector broadcast expressions" begin
FT = Float64
Expand Down

0 comments on commit 36837c2

Please sign in to comment.