diff --git a/Project.toml b/Project.toml index e4c98c7b93..8a1ccbbda1 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.10.40" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" @@ -31,6 +32,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [compat] Adapt = "3" +BandedMatrices = "0.17" BlockArrays = "0.16" CUDA = "3, 4.2.0" ClimaComms = "0.4.2" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 990879d0c0..00021ea2b8 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -98,6 +98,12 @@ git-tree-sha1 = "dbf84058d0a8cbbadee18d25cf606934b22d7c66" uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" version = "0.4.2" +[[deps.BandedMatrices]] +deps = ["ArrayLayouts", "FillArrays", "LinearAlgebra", "PrecompileTools", "SparseArrays"] +git-tree-sha1 = "9ad46355045491b12eab409dee73e9de46293aa2" +uuid = "aae01518-5342-5314-be14-df237901396f" +version = "0.17.28" + [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -216,7 +222,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.4.2" [[deps.ClimaCore]] -deps = ["Adapt", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DiffEqBase", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "Rotations", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DiffEqBase", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "Rotations", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"] path = ".." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" version = "0.10.39" diff --git a/src/ClimaCore.jl b/src/ClimaCore.jl index b37bd8f397..72ada48e34 100644 --- a/src/ClimaCore.jl +++ b/src/ClimaCore.jl @@ -14,6 +14,7 @@ include("Topologies/Topologies.jl") include("Spaces/Spaces.jl") include("Fields/Fields.jl") include("Operators/Operators.jl") +include("MatrixFields/MatrixFields.jl") include("Hypsography/Hypsography.jl") include("Limiters/Limiters.jl") include("InputOutput/InputOutput.jl") diff --git a/src/Fields/mapreduce.jl b/src/Fields/mapreduce.jl index 1b9b058399..4ccf03dd49 100644 --- a/src/Fields/mapreduce.jl +++ b/src/Fields/mapreduce.jl @@ -1,4 +1,4 @@ -Base.map(fn, field::Field) = Base.broadcast(fn, field) +Base.map(fn, fields::Field...) = Base.broadcast(fn, fields...) """ Fields.local_sum(v::Field) diff --git a/src/Geometry/axistensors.jl b/src/Geometry/axistensors.jl index a66b723708..0b4b70ed9e 100644 --- a/src/Geometry/axistensors.jl +++ b/src/Geometry/axistensors.jl @@ -269,10 +269,30 @@ Base.propertynames(x::AxisVector) = symbols(axes(x, 1)) end end +const AdjointAxisTensor{T, N, A, S} = Adjoint{T, AxisTensor{T, N, A, S}} + +Base.show(io::IO, a::AdjointAxisTensor{T, N, A, S}) where {T, N, A, S} = + print(io, "adjoint($(a'))") + +components(a::AdjointAxisTensor) = components(parent(a))' + +Base.zero(a::AdjointAxisTensor) = zero(typeof(a)) +Base.zero(::Type{AdjointAxisTensor{T, N, A, S}}) where {T, N, A, S} = + zero(AxisTensor{T, N, A, S})' + +@inline +(a::AdjointAxisTensor) = (+a')' +@inline -(a::AdjointAxisTensor) = (-a')' +@inline +(a::AdjointAxisTensor, b::AdjointAxisTensor) = (a' + b')' +@inline -(a::AdjointAxisTensor, b::AdjointAxisTensor) = (a' - b')' +@inline *(a::Number, b::AdjointAxisTensor) = (a * b')' +@inline *(a::AdjointAxisTensor, b::Number) = (a' * b)' +@inline /(a::AdjointAxisTensor, b::Number) = (a' / b)' +@inline \(a::Number, b::AdjointAxisTensor) = (a \ b')' + +@inline (==)(a::AdjointAxisTensor, b::AdjointAxisTensor) = a' == b' const AdjointAxisVector{T, A1, S} = Adjoint{T, AxisVector{T, A1, S}} -components(va::AdjointAxisVector) = components(parent(va))' Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int) = getindex(components(va), i) Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int, j::Int) = @@ -286,7 +306,6 @@ Axis2Tensor( ) = AxisTensor(axes, components) const AdjointAxis2Tensor{T, A, S} = Adjoint{T, Axis2Tensor{T, A, S}} -components(va::AdjointAxis2Tensor) = components(parent(va))' const Axis2TensorOrAdj{T, A, S} = Union{Axis2Tensor{T, A, S}, AdjointAxis2Tensor{T, A, S}} diff --git a/src/MatrixFields/MatrixFields.jl b/src/MatrixFields/MatrixFields.jl new file mode 100644 index 0000000000..6ed80c32e3 --- /dev/null +++ b/src/MatrixFields/MatrixFields.jl @@ -0,0 +1,26 @@ +module MatrixFields + +import LinearAlgebra: UniformScaling, Adjoint +import StaticArrays: SArray, SMatrix, SVector +import BandedMatrices: BandedMatrix, band, _BandedMatrix +import ..Utilities: PlusHalf, half +import ..RecursiveApply: + rmap, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv +import ..Geometry +import ..Spaces +import ..Fields +import ..Operators + +export ⋅ +export DiagonalMatrixRow, + BidiagonalMatrixRow, + TridiagonalMatrixRow, + QuaddiagonalMatrixRow, + PentadiagonalMatrixRow + +include("band_matrix_row.jl") +include("rmul_with_projection.jl") +include("matrix_multiplication.jl") +include("matrix_field_utils.jl") + +end diff --git a/src/MatrixFields/band_matrix_row.jl b/src/MatrixFields/band_matrix_row.jl new file mode 100644 index 0000000000..d637ef089c --- /dev/null +++ b/src/MatrixFields/band_matrix_row.jl @@ -0,0 +1,138 @@ +""" + BandMatrixRow{ld}(entries...) + +Stores the nonzero entries in a row of a band matrix, starting with the lowest +diagonal, which has index `ld`. Supported operations include accessing the entry +on the diagonal with index `d` by calling `row[d]`, taking linear combinations +with other band matrix rows (and with `LinearAlgebra.I`), and checking for +equality with other band matrix rows (and with `LinearAlgebra.I`). There are +several aliases defined for commonly-used subtypes of `BandMatrixRow` (with `T` +denoting the type of the row's entries): +- `DiagonalMatrixRow{T}` +- `BidiagonalMatrixRow{T}` +- `TridiagonalMatrixRow{T}` +- `QuaddiagonalMatrixRow{T}` +- `PentadiagonalMatrixRow{T}` +""" +struct BandMatrixRow{ld, bw, T} # bw is the bandwidth (the number of diagonals) + entries::NTuple{bw, T} + BandMatrixRow{ld, bw, T}(entries::NTuple{bw, Any}) where {ld, bw, T} = + new{ld, bw, T}(rconvert(NTuple{bw, T}, entries)) + # TODO: Remove this inner constructor once Julia's default convert function + # is type-stable for nested Tuple/NamedTuple types. +end +BandMatrixRow{ld}(entries::Vararg{Any, bw}) where {ld, bw} = + BandMatrixRow{ld, bw}(entries...) +BandMatrixRow{ld, bw}(entries::Vararg{Any, bw}) where {ld, bw} = + BandMatrixRow{ld, bw, rpromote_type(map(typeof, entries)...)}(entries) + +const DiagonalMatrixRow{T} = BandMatrixRow{0, 1, T} +const BidiagonalMatrixRow{T} = BandMatrixRow{-1 + half, 2, T} +const TridiagonalMatrixRow{T} = BandMatrixRow{-1, 3, T} +const QuaddiagonalMatrixRow{T} = BandMatrixRow{-2 + half, 4, T} +const PentadiagonalMatrixRow{T} = BandMatrixRow{-2, 5, T} + +""" + outer_diagonals(::Type{<:BandMatrixRow}) + +Gets the indices of the lower and upper diagonals, `ld` and `ud`, of the given +subtype of `BandMatrixRow`. +""" +outer_diagonals(::Type{<:BandMatrixRow{ld, bw}}) where {ld, bw} = + (ld, ld + bw - 1) + +""" + band_matrix_row_type(ld, ud, T) + +A shorthand for getting the subtype of `BandMatrixRow` that has entries of type +`T` on the diagonals with indices in the range `ld:ud`. +""" +band_matrix_row_type(ld, ud, T) = BandMatrixRow{ld, ud - ld + 1, T} + +Base.eltype(::Type{BandMatrixRow{ld, bw, T}}) where {ld, bw, T} = T + +Base.zero(::Type{BandMatrixRow{ld, bw, T}}) where {ld, bw, T} = + BandMatrixRow{ld}(ntuple(_ -> rzero(T), Val(bw))...) + +Base.map(f::F, rows::BandMatrixRow{ld}...) where {F, ld} = + BandMatrixRow{ld}(map(f, map(row -> row.entries, rows)...)...) + +Base.@propagate_inbounds Base.getindex(row::BandMatrixRow{ld}, d) where {ld} = + row.entries[d - ld + 1] + +function Base.promote_rule( + ::Type{BMR1}, + ::Type{BMR2}, +) where {BMR1 <: BandMatrixRow, BMR2 <: BandMatrixRow} + ld1, ud1 = outer_diagonals(BMR1) + ld2, ud2 = outer_diagonals(BMR2) + typeof(ld1) == typeof(ld2) || error( + "Cannot promote the $(ld1 isa PlusHalf ? "non-" : "")square matrix \ + row type $BMR1 and the $(ld2 isa PlusHalf ? "non-" : "")square matrix \ + row type $BMR2 to a common type", + ) + T = rpromote_type(eltype(BMR1), eltype(BMR2)) + return band_matrix_row_type(min(ld1, ld2), max(ud1, ud2), T) +end + +Base.promote_rule( + ::Type{BMR}, + ::Type{US}, +) where {BMR <: BandMatrixRow, US <: UniformScaling} = + promote_rule(BMR, DiagonalMatrixRow{eltype(US)}) + +function Base.convert( + ::Type{BMR}, + row::BandMatrixRow, +) where {BMR <: BandMatrixRow} + old_ld, old_ud = outer_diagonals(typeof(row)) + new_ld, new_ud = outer_diagonals(BMR) + typeof(old_ld) == typeof(new_ld) || error( + "Cannot convert a $(old_ld isa PlusHalf ? "non-" : "")square matrix \ + row of type $(typeof(row)) to the \ + $(new_ld isa PlusHalf ? "non-" : "")square matrix row type $BMR", + ) + new_ld <= old_ld && new_ud >= old_ud || error( + "Cannot convert a $(typeof(row)) to a $BMR, since that would require \ + dropping potentially non-zero row entries", + ) + first_zeros = ntuple(_ -> rzero(eltype(BMR)), Val(old_ld - new_ld)) + last_zeros = ntuple(_ -> rzero(eltype(BMR)), Val(new_ud - old_ud)) + return BMR((first_zeros..., row.entries..., last_zeros...)) +end + +Base.convert(::Type{BMR}, row::UniformScaling) where {BMR <: BandMatrixRow} = + convert(BMR, DiagonalMatrixRow(row.λ)) + +Base.:(==)(row1::BMR, row2::BMR) where {BMR <: BandMatrixRow} = + row1.entries == row2.entries +Base.:(==)(row1::BandMatrixRow, row2::BandMatrixRow) = + ==(promote(row1, row2)...) +Base.:(==)(row1::BandMatrixRow, row2::UniformScaling) = + ==(promote(row1, row2)...) +Base.:(==)(row1::UniformScaling, row2::BandMatrixRow) = + ==(promote(row1, row2)...) + +Base.:+(row::BandMatrixRow) = map(radd, row) +Base.:+(row1::BandMatrixRow, row2::BandMatrixRow) = + map(radd, promote(row1, row2)...) +Base.:+(row1::BandMatrixRow, row2::UniformScaling) = + map(radd, promote(row1, row2)...) +Base.:+(row1::UniformScaling, row2::BandMatrixRow) = + map(radd, promote(row1, row2)...) + +Base.:-(row::BandMatrixRow) = map(rsub, row) +Base.:-(row1::BandMatrixRow, row2::BandMatrixRow) = + map(rsub, promote(row1, row2)...) +Base.:-(row1::BandMatrixRow, row2::UniformScaling) = + map(rsub, promote(row1, row2)...) +Base.:-(row1::UniformScaling, row2::BandMatrixRow) = + map(rsub, promote(row1, row2)...) + +Base.:*(row::BandMatrixRow, value::Number) = + map(entry -> rmul(entry, value), row) +Base.:*(value::Number, row::BandMatrixRow) = + map(entry -> rmul(value, entry), row) + +Base.:/(row::BandMatrixRow, value::Number) = + map(entry -> rdiv(entry, value), row) diff --git a/src/MatrixFields/matrix_field_utils.jl b/src/MatrixFields/matrix_field_utils.jl new file mode 100644 index 0000000000..735e9b7999 --- /dev/null +++ b/src/MatrixFields/matrix_field_utils.jl @@ -0,0 +1,133 @@ +function banded_matrix_info(field) + space = axes(field) + field_ld, field_ud = outer_diagonals(eltype(field)) + + # Find the diagonal index of the value that ends up in the bottom-right + # corner of the matrix, as well as the amount by which the field's diagonal + # indices get shifted when it is converted into a matrix. + bottom_corner_matrix_d, matrix_d_minus_field_d = if field_ld isa PlusHalf + if space.staggering isa Spaces.CellCenter + 1, half # field is a face-to-center matrix + else + -1, -half # field is a center-to-face matrix + end + else + 0, 0 # field is either a center-to-center or face-to-face matrix + end + + n_rows = Spaces.nlevels(space) + n_cols = n_rows + bottom_corner_matrix_d + matrix_ld = field_ld + matrix_d_minus_field_d + matrix_ud = field_ud + matrix_d_minus_field_d + matrix_ld <= 0 && matrix_ud >= 0 || + error("BandedMatrices.jl does not yet support matrices that have \ + diagonals with indices in the range $matrix_ld:$matrix_ud") + + return n_rows, n_cols, matrix_ld, matrix_ud +end + +""" + column_field2array(field) + +Converts a field defined on a `FiniteDifferenceSpace` into a `Vector` or a +`BandedMatrix`, depending on whether or not the elements of the field are +`BandMatrixRow`s. This involves copying the data stored in the field. +""" +function column_field2array(field) + space = axes(field) + space isa Spaces.FiniteDifferenceSpace || + error("column_field2array requires a field on a FiniteDifferenceSpace") + if eltype(field) <: BandMatrixRow # field represents a matrix + n_rows, n_cols, matrix_ld, matrix_ud = banded_matrix_info(field) + matrix = BandedMatrix{eltype(eltype(field))}( + undef, + (n_rows, n_cols), + (-matrix_ld, matrix_ud), + ) + for (index_of_field_entry, matrix_d) in enumerate(matrix_ld:matrix_ud) + # Find the rows for which field_diagonal[row] is inside the matrix. + # Note: The matrix index (1, 1) corresponds to the diagonal index 0, + # and the matrix index (n_rows, n_cols) corresponds to the diagonal + # index n_cols - n_rows. + first_row = matrix_d < 0 ? 1 - matrix_d : 1 + last_row = matrix_d < n_cols - n_rows ? n_rows : n_cols - matrix_d + + # Copy the value in each row from field_diagonal to matrix_diagonal. + field_diagonal = field.entries.:($index_of_field_entry) + matrix_diagonal = view(matrix, band(matrix_d)) + for (index_along_diagonal, row) in enumerate(first_row:last_row) + matrix_diagonal[index_along_diagonal] = + Fields.field_values(field_diagonal)[row] + end + end + return matrix + else # field represents a vector + n_rows = Spaces.nlevels(space) + return map(row -> Fields.field_values(field)[row], 1:n_rows) + end +end + +""" + field2arrays(field) + +Converts a field defined on a `FiniteDifferenceSpace` or on an +`ExtrudedFiniteDifferenceSpace` into a tuple of arrays, each of which +corresponds to a column of the field. This is done by calling +`column_field2array` on each of the field's columns. +""" +function field2arrays(field) + space = axes(field) + column_indices = if space isa Spaces.FiniteDifferenceSpace + (((1, 1), 1),) + elseif space isa Spaces.ExtrudedFiniteDifferenceSpace + (Spaces.all_nodes(Spaces.horizontal_space(space))...,) + else + error("Invalid space type: $(typeof(space).name.wrapper)") + end + return map(column_indices) do ((i, j), h) + column_field2array(Spaces.column(field, i, j, h)) + end +end + +""" + column_field2array_view(field) + +Similar to `column_field2array(field)`, except that this version avoids copying +the data stored in the field. +""" +function column_field2array_view(field) + space = axes(field) + space isa Spaces.FiniteDifferenceSpace || + error("column_field2array_view requires a field on a \ + FiniteDifferenceSpace") + if eltype(field) <: BandMatrixRow # field represents a matrix + n_rows, n_cols, matrix_ld, matrix_ud = banded_matrix_info(field) + data_transpose = reinterpret(eltype(eltype(field)), parent(field)') + matrix_transpose = + _BandedMatrix(data_transpose, n_cols, matrix_ud, -matrix_ld) + return permutedims(matrix_transpose) + # TODO: Despite not copying any data, this function still allocates a + # small amount of memory because of _BandedMatrix and permutedims. + else # field represents a vector + return vec(reinterpret(eltype(field), parent(field)')) + end +end + +function Base.show( + io::IO, + field::Fields.Field{<:Fields.AbstractData{<:BandMatrixRow}}, +) + print(io, eltype(field), "-valued Field") + if eltype(eltype(field)) <: Number + if axes(field) isa Spaces.FiniteDifferenceSpace + println(io, " that corresponds to the matrix") + else + println(io, " whose first column corresponds to the matrix") + end + column_field = Fields.column(field, 1, 1, 1) + Base.print_array(io, column_field2array_view(column_field)) + else # A BandedMatrix with non-number entries currently errors when printed. + print(io, ":") + Fields._show_compact_field(io, field, " ", true) + end +end diff --git a/src/MatrixFields/matrix_multiplication.jl b/src/MatrixFields/matrix_multiplication.jl new file mode 100644 index 0000000000..612a196c37 --- /dev/null +++ b/src/MatrixFields/matrix_multiplication.jl @@ -0,0 +1,381 @@ +""" + MultiplyColumnwiseBandMatrixField + +An operator that multiplies a columnwise band matrix field (a field of +`BandMatrixRow`s) by a regular field or by another columnwise band matrix field, +i.e., matrix-vector or matrix-matrix multiplication. The `⋅` symbol is an alias +for `MultiplyColumnwiseBandMatrixField()`. +""" +struct MultiplyColumnwiseBandMatrixField <: Operators.FiniteDifferenceOperator end +const ⋅ = MultiplyColumnwiseBandMatrixField() + +#= +TODO: Rewrite the following derivation in LaTeX and move it into the ClimaCore +documentation. + +Notation: + +For any single-column field F, let F[idx] denote the value of F at level idx. +For any single-column BandMatrixRow field M, let + M[idx, idx′] = M[idx][idx′ - idx]. +If there are multiple columns, the following equations apply per column. + +Matrix-Vector Multiplication: + +Consider a BandMatrixRow field M and a scalar (non-BandMatrixRow) field V. +From the definition of matrix-vector multiplication, + (M ⋅ V)[idx] = ∑_{idx′} M[idx, idx′] * V[idx′]. +If V[idx] is only defined when left_idx ≤ idx ≤ right_idx, this becomes + (M ⋅ V)[idx] = ∑_{idx′ ∈ left_idx:right_idx} M[idx, idx′] * V[idx′]. +If M[idx, idx′] is only defined when idx + ld ≤ idx′ ≤ idx + ud, this becomes + (M ⋅ V)[idx] = + ∑_{idx′ ∈ max(left_idx, idx + ld):min(right_idx, idx + ud)} + M[idx, idx′] * V[idx′]. +Replacing the variable idx′ with the variable d = idx′ - idx gives us + (M ⋅ V)[idx] = + ∑_{d ∈ max(left_idx - idx, ld):min(right_idx - idx, ud)} + M[idx, idx + d] * V[idx + d]. +This can be rewritten using the standard indexing notation as + (M ⋅ V)[idx] = + ∑_{d ∈ max(left_idx - idx, ld):min(right_idx - idx, ud)} + M[idx][d] * V[idx + d]. +Finally, we can express this in terms of left/right boundaries and an interior: + (M ⋅ V)[idx] = + ∑_{ + d ∈ + if idx < left_idx - ld + (left_idx - idx):ud + elseif idx > right_idx - ud + ld:(right_idx - idx) + else + ld:ud + end + } M[idx][d] * V[idx + d]. + +Matrix-Matrix Multiplication: + +Consider a BandMatrixRow field M1 and another BandMatrixRow field M2. +From the definition of matrix-matrix multiplication, + (M1 ⋅ M2)[idx, idx′] = ∑_{idx′′} M1[idx, idx′′] * M2[idx′′, idx′]. +If M2[idx′′] is only defined when left_idx ≤ idx′′ ≤ right_idx, this becomes + (M1 ⋅ M2)[idx, idx′] = + ∑_{idx′′ ∈ left_idx:right_idx} M1[idx, idx′′] * M2[idx′′, idx′]. +If M1[idx, idx′′] is only defined when idx + ld1 ≤ idx′′ ≤ idx + ud1, this becomes + (M1 ⋅ M2)[idx, idx′] = + ∑_{idx′′ ∈ max(left_idx, idx + ld1):min(right_idx, idx + ud1)} + M1[idx, idx′′] * M2[idx′′, idx′]. +If M2[idx′′, idx′] is only defined when idx′′ + ld2 ≤ idx′ ≤ idx′′ + ud2, or, +equivalently, when idx′ - ud2 ≤ idx′′ ≤ idx′ - ld2, this becomes + (M1 ⋅ M2)[idx, idx′] = + ∑_{ + idx′′ ∈ + max(left_idx, idx + ld1, idx′ - ud2): + min(right_idx, idx + ud1, idx′ - ld2) + } M1[idx, idx′′] * M2[idx′′, idx′]. +Replacing the variable idx′ with the variable prod_d = idx′ - idx gives us + (M1 ⋅ M2)[idx, idx + prod_d] = + ∑_{ + idx′′ ∈ + max(left_idx, idx + ld1, idx + prod_d - ud2): + min(right_idx, idx + ud1, idx + prod_d - ld2) + } M1[idx, idx′′] * M2[idx′′, idx + prod_d]. +Replacing the variable idx′′ with the variable d = idx′′ - idx gives us + (M1 ⋅ M2)[idx, idx + prod_d] = + ∑_{ + d ∈ + max(left_idx - idx, ld1, prod_d - ud2): + min(right_idx - idx, ud1, prod_d - ld2) + } M1[idx, idx + d] * M2[idx + d, idx + prod_d]. +This can be rewritten using the standard indexing notation as + (M1 ⋅ M2)[idx][prod_d] = + ∑_{ + d ∈ + max(left_idx - idx, ld1, prod_d - ud2): + min(right_idx - idx, ud1, prod_d - ld2) + } M1[idx][d] * M2[idx + d][prod_d - d]. +Finally, we can express this in terms of left/right boundaries and an interior: + (M1 ⋅ M2)[idx][prod_d] = + ∑_{ + d ∈ + if idx < left_idx - ld1 + max(left_idx - idx, prod_d - ud2):min(ud1, prod_d - ld2) + elseif idx > right_idx - ud1 + max(ld1, prod_d - ud2):min(right_idx - idx, prod_d - ld2) + else + max(ld1, prod_d - ud2):min(ud1, prod_d - ld2) + end + } M1[idx][d] * M2[idx + d][prod_d - d]. + +We only need to define (M1 ⋅ M2)[idx][prod_d] when it has a nonzero value in the +interior, which will be the case when + max(ld1, prod_d - ud2) ≤ min(ud1, prod_d - ld2). +This can be rewritten as a system of four inequalities: + ld1 ≤ ud1, + ld1 ≤ prod_d - ld2, + prod_d - ud2 ≤ ud1, and + prod_d - ud2 ≤ prod_d - ld2. +By definition, ld1 ≤ ud1 and ld2 ≤ ud2, so the first and last inequality are +always true. Rearranging the remaining two inequalities gives us + ld1 + ld2 ≤ prod_d ≤ ud1 + ud2. +=# + +struct TopLeftMatrixCorner <: Operators.AbstractBoundaryCondition end +struct BottomRightMatrixCorner <: Operators.AbstractBoundaryCondition end + +Operators.has_boundary( + ::MultiplyColumnwiseBandMatrixField, + ::Operators.LeftBoundaryWindow{name}, +) where {name} = true +Operators.has_boundary( + ::MultiplyColumnwiseBandMatrixField, + ::Operators.RightBoundaryWindow{name}, +) where {name} = true + +Operators.get_boundary( + ::MultiplyColumnwiseBandMatrixField, + ::Operators.LeftBoundaryWindow{name}, +) where {name} = TopLeftMatrixCorner() +Operators.get_boundary( + ::MultiplyColumnwiseBandMatrixField, + ::Operators.RightBoundaryWindow{name}, +) where {name} = BottomRightMatrixCorner() + +Operators.stencil_interior_width( + ::MultiplyColumnwiseBandMatrixField, + matrix1, + arg, +) = ((0, 0), outer_diagonals(eltype(matrix1))) + +# Interior indices of a center-to-center or face-to-face matrix. +matrix_left_interior_index(space, lbw::Integer) = + Operators.left_idx(space) - lbw +matrix_right_interior_index(space, rbw::Integer) = + Operators.right_idx(space) - rbw + +# Interior indices of a face-to-center matrix. +matrix_left_interior_index( + space::Union{ + Spaces.CenterFiniteDifferenceSpace, + Spaces.CenterExtrudedFiniteDifferenceSpace, + }, + lbw::PlusHalf, +) = Operators.left_idx(space) - lbw - half +matrix_right_interior_index( + space::Union{ + Spaces.CenterFiniteDifferenceSpace, + Spaces.CenterExtrudedFiniteDifferenceSpace, + }, + rbw::PlusHalf, +) = Operators.right_idx(space) - rbw + half + +# Interior indices of a center-to-face matrix. +matrix_left_interior_index( + space::Union{ + Spaces.FaceFiniteDifferenceSpace, + Spaces.FaceExtrudedFiniteDifferenceSpace, + }, + lbw::PlusHalf, +) = Operators.left_idx(space) - lbw + half +matrix_right_interior_index( + space::Union{ + Spaces.FaceFiniteDifferenceSpace, + Spaces.FaceExtrudedFiniteDifferenceSpace, + }, + rbw::PlusHalf, +) = Operators.right_idx(space) - rbw - half + +Operators.left_interior_idx( + space::Spaces.AbstractSpace, + ::MultiplyColumnwiseBandMatrixField, + ::TopLeftMatrixCorner, + matrix1, + arg, +) = matrix_left_interior_index(space, outer_diagonals(eltype(matrix1))[1]) +Operators.right_interior_idx( + space::Spaces.AbstractSpace, + ::MultiplyColumnwiseBandMatrixField, + ::BottomRightMatrixCorner, + matrix1, + arg, +) = matrix_right_interior_index(space, outer_diagonals(eltype(matrix1))[2]) + +function rmul_type(::Type{T1}, ::Type{T2}, ::Type{LG}) where {T1, T2, LG} + type = Base._return_type(rmul_with_projection, Tuple{T1, T2, LG}) + type == Union{} && + error("Unable to infer result type: Calling rmul_with_projection with \ + arguments of types $T1 and $T2 will cause it to throw an error") + return type +end + +function Operators.return_eltype( + ::MultiplyColumnwiseBandMatrixField, + matrix1, + arg, +) + eltype(matrix1) <: BandMatrixRow || error( + "The first argument of ⋅ must have elements of type BandMatrixRow, but \ + the given argument has elements of type $(eltype(matrix1))", + ) + lg_type = eltype(Fields.local_geometry_field(axes(arg))) + if eltype(arg) <: BandMatrixRow # matrix-matrix multiplication + matrix2 = arg + ld1, ud1 = outer_diagonals(eltype(matrix1)) + ld2, ud2 = outer_diagonals(eltype(matrix2)) + prod_ld, prod_ud = ld1 + ld2, ud1 + ud2 + prod_value_type = + rmul_type(eltype(eltype(matrix1)), eltype(eltype(matrix2)), lg_type) + return band_matrix_row_type(prod_ld, prod_ud, prod_value_type) + else # matrix-vector multiplication + vector = arg + return rmul_type(eltype(eltype(matrix1)), eltype(vector), lg_type) + end +end + +Operators.return_space(::MultiplyColumnwiseBandMatrixField, space1, space2) = + space1 + +# TODO: Use @propagate_inbounds here, and remove @inbounds from this function. +# As of Julia 1.8, doing this increases compilation by more than an order of +# magnitude, and it also makes type inference fail for some complicated matrix +# field broadcast expressions. Unfortunately, not using @propagate_inbounds +# makes matrix field broadcast expressions take roughly 3 times longer to +# evaluate. However, since they are sufficiently fast as is, this is an +# acceptable performance loss. +function multiply_matrix_at_index( + loc, + space, + idx, + hidx, + matrix1, + arg, + boundary_ld1 = nothing, + boundary_ud1 = nothing, +) + matrix1_row = Operators.getidx(space, matrix1, loc, idx, hidx) + ld1, ud1 = outer_diagonals(eltype(matrix1)) + ld1_or_boundary_ld1 = isnothing(boundary_ld1) ? ld1 : boundary_ld1 + ud1_or_boundary_ud1 = isnothing(boundary_ud1) ? ud1 : boundary_ud1 + prod_type = Operators.return_eltype(⋅, matrix1, arg) + if eltype(arg) <: BandMatrixRow # matrix-matrix multiplication + matrix2 = arg + matrix2_rows = map((ld1:ud1...,)) do d + # TODO: Use @propagate_inbounds_meta instead of @inline_meta. + Base.@_inline_meta + if ( + (isnothing(boundary_ld1) || d >= boundary_ld1) && + (isnothing(boundary_ud1) || d <= boundary_ud1) + ) + @inbounds Operators.getidx(space, matrix2, loc, idx + d, hidx) + else + rzero(eltype(matrix2)) # This value is never used. + end + end # The rows are precomputed to avoid recomputing them multiple times. + matrix2_rows_wrapper = BandMatrixRow{ld1}(matrix2_rows...) + ld2, ud2 = outer_diagonals(eltype(matrix2)) + prod_ld, prod_ud = outer_diagonals(prod_type) + zero_value = rzero(eltype(prod_type)) + prod_entries = map((prod_ld:prod_ud...,)) do prod_d + # TODO: Use @propagate_inbounds_meta instead of @inline_meta. + Base.@_inline_meta + min_d = max(ld1_or_boundary_ld1, prod_d - ud2) + max_d = min(ud1_or_boundary_ud1, prod_d - ld2) + # Note: If min_d:max_d is an empty range, then the current entry + # lies outside of the product matrix, so it should never be used in + # any computations. By initializing prod_entry to zero_value, we are + # implicitly setting all such entries to 0. We could alternatively + # set all such entries to NaN (in order to more easily catch user + # errors that involve accidentally using these entires), but that + # would not generalize to non-floating-point types like Int or Bool. + prod_entry = zero_value + @inbounds for d in min_d:max_d + value1 = matrix1_row[d] + value2 = matrix2_rows_wrapper[d][prod_d - d] + value2_lg = Geometry.LocalGeometry(space, idx + d, hidx) + prod_entry = radd( + prod_entry, + rmul_with_projection(value1, value2, value2_lg), + ) + end # Using this for-loop is currently faster than using mapreduce. + prod_entry + end + return BandMatrixRow{prod_ld}(prod_entries...) + else # matrix-vector multiplication + vector = arg + prod_value = rzero(prod_type) + @inbounds for d in ld1_or_boundary_ld1:ud1_or_boundary_ud1 + value1 = matrix1_row[d] + value2 = Operators.getidx(space, vector, loc, idx + d, hidx) + value2_lg = Geometry.LocalGeometry(space, idx + d, hidx) + prod_value = radd( + prod_value, + rmul_with_projection(value1, value2, value2_lg), + ) + end # Using this for-loop is currently faster than using mapreduce. + return prod_value + end +end + +Base.@propagate_inbounds Operators.stencil_interior( + ::MultiplyColumnwiseBandMatrixField, + loc, + space, + idx, + hidx, + matrix1, + arg, +) = multiply_matrix_at_index(loc, space, idx, hidx, matrix1, arg) + +Base.@propagate_inbounds Operators.stencil_left_boundary( + ::MultiplyColumnwiseBandMatrixField, + ::TopLeftMatrixCorner, + loc, + space, + idx, + hidx, + matrix1, + arg, +) = multiply_matrix_at_index( + loc, + space, + idx, + hidx, + matrix1, + arg, + Operators.left_idx( + Operators.reconstruct_placeholder_space(axes(arg), space), + ) - idx, + nothing, +) + +Base.@propagate_inbounds Operators.stencil_right_boundary( + ::MultiplyColumnwiseBandMatrixField, + ::BottomRightMatrixCorner, + loc, + space, + idx, + hidx, + matrix1, + arg, +) = multiply_matrix_at_index( + loc, + space, + idx, + hidx, + matrix1, + arg, + nothing, + Operators.right_idx( + Operators.reconstruct_placeholder_space(axes(arg), space), + ) - idx, +) + +# For matrix field broadcast expressions involving 4 or more matrices, we +# sometimes hit a recursion limit and de-optimize. +# We know that the recursion will terminate due to the fact that broadcast +# expressions are not self-referential. +if hasfield(Method, :recursion_relation) + dont_limit = (args...) -> true + for m in methods(multiply_matrix_at_index) + m.recursion_relation = dont_limit + end +end diff --git a/src/MatrixFields/rmul_with_projection.jl b/src/MatrixFields/rmul_with_projection.jl new file mode 100644 index 0000000000..76f9846b68 --- /dev/null +++ b/src/MatrixFields/rmul_with_projection.jl @@ -0,0 +1,32 @@ +const SingleValue = + Union{Number, Geometry.AxisTensor, Geometry.AdjointAxisTensor} + +mul_with_projection(x, y, _) = x * y +mul_with_projection(x::Geometry.AdjointAxisVector, y::Geometry.AxisTensor, lg) = + x * Geometry.project(Geometry.dual(axes(x, 2)), y, lg) +mul_with_projection(::Geometry.AdjointAxisTensor, ::Geometry.AxisTensor, _) = + error("mul_with_projection is currently only implemented for covectors, \ + and higher-order cotensors are not supported") +# We should add methods for other cotensors (e.g., AdjointAxis2Tensor) when they +# are needed (e.g., when we need to support matrices that represent the +# divergence of higher-order tensors). + +""" + rmul_with_projection(x, y, lg) + +Similar to `rmul(x, y)`, but with automatic projection of `y` when `x` contains +a covector (i.e, an `AdjointAxisVector`). For example, if `x` is a covector +along the `Covariant3Axis` (e.g., `Covariant3Vector(1)'`), then `y` (or each +element of `y`) will be projected onto the `Contravariant3Axis`. In general, +each covector in `x` will cause `y` (or each corresponding element of `y`) to +be projected onto the dual axis of the covector. In the future, we may extend +this behavior to higher-order cotensors. +""" +rmul_with_projection(x, y, lg) = + rmap((x′, y′) -> mul_with_projection(x′, y′, lg), x, y) +rmul_with_projection(x::SingleValue, y, lg) = + rmap(y′ -> mul_with_projection(x, y′, lg), y) +rmul_with_projection(x, y::SingleValue, lg) = + rmap(x′ -> mul_with_projection(x′, y, lg), x) +rmul_with_projection(x::SingleValue, y::SingleValue, lg) = + mul_with_projection(x, y, lg) diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index 2309cfa0ce..dfe43ed074 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -92,6 +92,13 @@ rmaptype( T2 <: NamedTuple{names, Tup2}, } = NamedTuple{names, rmaptype(fn, Tup1, Tup2)} +""" + rpromote_type(Ts...) + +Recursively apply `promote_type` to the input types. +""" +rpromote_type(Ts...) = reduce((T1, T2) -> rmaptype(promote_type, T1, T2), Ts) + """ rzero(T) @@ -105,6 +112,17 @@ rzero(::Type{T}) where {T <: Tuple} = rzero(::Type{Tup}) where {names, T, Tup <: NamedTuple{names, T}} = NamedTuple{names}(rzero(T)) +""" + rconvert(T, X) + +Identical to `convert(T, X)`, but with improved type stability for nested types. +""" +rconvert(::Type{T}, X::T) where {T} = X +rconvert(::Type{T}, X) where {T} = + rmap((zero_value, x) -> convert(typeof(zero_value), x), rzero(T), X) +# TODO: Remove this function once Julia's default convert function is +# type-stable for nested Tuple/NamedTuple types. + """ rmul(X, Y) X ⊠ Y diff --git a/test/Geometry/axistensors.jl b/test/Geometry/axistensors.jl index de8c46c061..93bde79693 100644 --- a/test/Geometry/axistensors.jl +++ b/test/Geometry/axistensors.jl @@ -27,6 +27,12 @@ ClimaCore.Geometry.assert_exact_transform() = true @test M[:, 1] == Geometry.Cartesian12Vector(1.0, 0.5) @test M[1, :] == Geometry.Covariant12Vector(1.0, 0.0) + @test x + zero(x) == x + @test x' + zero(x') == x' + + @test -x + x * 2 - x / 2 == -x + 2 * x - 2 \ x == x / 2 + @test -x' + x' * 2 - x' / 2 == -x' + 2 * x' - 2 \ x' == (x / 2)' + @test x * y' == x ⊗ y == Geometry.AxisTensor( diff --git a/test/MatrixFields/band_matrix_row.jl b/test/MatrixFields/band_matrix_row.jl new file mode 100644 index 0000000000..0077948b4a --- /dev/null +++ b/test/MatrixFields/band_matrix_row.jl @@ -0,0 +1,66 @@ +using Test +using JET +using LinearAlgebra: I + +using ClimaCore.MatrixFields +import ClimaCore: Geometry + +macro test_all(expression) + return quote + local test_func() = $(esc(expression)) + @test test_func() # correctness + @test (@allocated test_func()) == 0 # allocations + @test_opt test_func() # type instabilities + end +end + +@testset "BandMatrixRow Unit Tests" begin + @test_all DiagonalMatrixRow(1) == + DiagonalMatrixRow(0.5) + DiagonalMatrixRow(1 // 2) == + DiagonalMatrixRow(1.5) - DiagonalMatrixRow(1 // 2) == + DiagonalMatrixRow(0.5) * 2 == + 0.5 * DiagonalMatrixRow(2) == + DiagonalMatrixRow(2) / 2 == + I + + @test_all DiagonalMatrixRow(1 // 2) + 0.5 * I === DiagonalMatrixRow(1.0) + @test_all BidiagonalMatrixRow(1 // 2, 0.5) === BidiagonalMatrixRow(1, 1) / 2 + + @test_all convert(TridiagonalMatrixRow{Int}, DiagonalMatrixRow(1)) === + convert(TridiagonalMatrixRow{Int}, I) === + TridiagonalMatrixRow(0, 1, 0) + + @test_all QuaddiagonalMatrixRow(0.5, 1, 1, 1 // 2) + + BidiagonalMatrixRow(-0.5, -1 // 2) == + QuaddiagonalMatrixRow(1, 1, 1, 1) / 2 + @test_all PentadiagonalMatrixRow(0, 0.5, 1, 1 // 2, 0) - + TridiagonalMatrixRow(1, 0, 1) / 2 - 0.5 * DiagonalMatrixRow(2) == + PentadiagonalMatrixRow(0, 0, 0, 0, 0) + + @test_all PentadiagonalMatrixRow(0, 0.5, 1, 1 // 2, 0) - + TridiagonalMatrixRow(1, 0, 1) / 2 - I == + zero(PentadiagonalMatrixRow{Int}) + + T(value) = (; a = (), b = value, c = (value, (; d = (value,)), (;))) + @test_all QuaddiagonalMatrixRow(T(0.5), T(1), T(1), T(1 // 2)) + + BidiagonalMatrixRow(T(-0.5), T(-1 // 2)) == + QuaddiagonalMatrixRow(T(1), T(1), T(1), T(1)) / 2 + @test_all PentadiagonalMatrixRow(T(0), T(0.5), T(1), T(1 // 2), T(0)) - + TridiagonalMatrixRow(T(1), T(0), T(1)) / 2 - + 0.5 * DiagonalMatrixRow(T(2)) == + PentadiagonalMatrixRow(T(0), T(0), T(0), T(0), T(0)) + + @test_throws "Cannot promote" BidiagonalMatrixRow(1, 1) + I + @test_throws "Cannot promote" BidiagonalMatrixRow(1, 1) + + DiagonalMatrixRow(1) + + @test_throws "Cannot convert" convert(BidiagonalMatrixRow{Int}, I) + @test_throws "Cannot convert" convert( + BidiagonalMatrixRow{Int}, + DiagonalMatrixRow(1), + ) + @test_throws "Cannot convert" convert( + TridiagonalMatrixRow{Int}, + PentadiagonalMatrixRow(0, 0, 1, 0, 0), + ) +end diff --git a/test/MatrixFields/field2arrays.jl b/test/MatrixFields/field2arrays.jl new file mode 100644 index 0000000000..9141ee071d --- /dev/null +++ b/test/MatrixFields/field2arrays.jl @@ -0,0 +1,82 @@ +using Test +using JET + +import ClimaCore: Geometry, Domains, Meshes, Spaces, Fields, MatrixFields + +@testset "field2arrays Unit Tests" begin + FT = Float64 + domain = Domains.IntervalDomain( + Geometry.ZPoint(FT(1)), + Geometry.ZPoint(FT(4)); + boundary_tags = (:bottom, :top), + ) + mesh = Meshes.IntervalMesh(domain, nelems = 3) + center_space = Spaces.CenterFiniteDifferenceSpace(mesh) + face_space = Spaces.FaceFiniteDifferenceSpace(center_space) + ᶜz = Fields.coordinate_field(center_space).z + ᶠz = Fields.coordinate_field(face_space).z + + ᶜᶜmat = map(z -> MatrixFields.TridiagonalMatrixRow(2 * z, 4 * z, 8 * z), ᶜz) + ᶜᶠmat = map(z -> MatrixFields.BidiagonalMatrixRow(2 * z, 4 * z), ᶜz) + ᶠᶠmat = map(z -> MatrixFields.TridiagonalMatrixRow(2 * z, 4 * z, 8 * z), ᶠz) + ᶠᶜmat = map(z -> MatrixFields.BidiagonalMatrixRow(2 * z, 4 * z), ᶠz) + + @test MatrixFields.column_field2array(ᶜz) == + MatrixFields.column_field2array_view(ᶜz) == + [1.5, 2.5, 3.5] + + @test MatrixFields.column_field2array(ᶠz) == + MatrixFields.column_field2array_view(ᶠz) == + [1, 2, 3, 4] + + @test MatrixFields.column_field2array(ᶜᶜmat) == + MatrixFields.column_field2array_view(ᶜᶜmat) == + [ + 6 12 0 + 5 10 20 + 0 7 14 + ] + + @test MatrixFields.column_field2array(ᶜᶠmat) == + MatrixFields.column_field2array_view(ᶜᶠmat) == + [ + 3 6 0 0 + 0 5 10 0 + 0 0 7 14 + ] + + @test MatrixFields.column_field2array(ᶠᶠmat) == + MatrixFields.column_field2array_view(ᶠᶠmat) == + [ + 4 8 0 0 + 4 8 16 0 + 0 6 12 24 + 0 0 8 16 + ] + + @test MatrixFields.column_field2array(ᶠᶜmat) == + MatrixFields.column_field2array_view(ᶠᶜmat) == + [ + 4 0 0 + 4 8 0 + 0 6 12 + 0 0 8 + ] + + ᶜᶜmat_array_not_view = MatrixFields.column_field2array(ᶜᶜmat) + ᶜᶜmat_array_view = MatrixFields.column_field2array_view(ᶜᶜmat) + ᶜᶜmat .*= 2 + @test ᶜᶜmat_array_not_view == MatrixFields.column_field2array(ᶜᶜmat) ./ 2 + @test ᶜᶜmat_array_view == MatrixFields.column_field2array(ᶜᶜmat) + + @test MatrixFields.field2arrays(ᶜᶜmat) == + (MatrixFields.column_field2array(ᶜᶜmat),) + + # Check for type instabilities. + @test_opt MatrixFields.column_field2array(ᶜᶜmat) + @test_opt MatrixFields.column_field2array_view(ᶜᶜmat) + @test_opt MatrixFields.field2arrays(ᶜᶜmat) + + # Because this test is broken, printing matrix fields allocates some memory. + @test_broken MatrixFields.column_field2array_view(ᶜᶜmat) +end diff --git a/test/MatrixFields/matrix_field_broadcasting.jl b/test/MatrixFields/matrix_field_broadcasting.jl new file mode 100644 index 0000000000..a6c4882c6c --- /dev/null +++ b/test/MatrixFields/matrix_field_broadcasting.jl @@ -0,0 +1,818 @@ +using Test +using JET +using Random: seed! +using LinearAlgebra: I, mul! +using BandedMatrices: band + +using ClimaCore.MatrixFields +import ClimaCore: + Geometry, Domains, Meshes, Topologies, Hypsography, Spaces, Fields +import ClimaComms + +# Using @benchmark from BenchmarkTools is extremely slow; it appears to keep +# triggering recompilations and allocating a lot of memory in the process. +# This macro returns the minimum time (in seconds) required to run the +# expression after it has been compiled. +macro benchmark(expression) + return quote + $(esc(expression)) # Compile the expression first. Use esc for hygiene. + best_time = Inf + start_time = time_ns() + while time_ns() - start_time < 1e8 # Benchmark for 0.1 s (1e8 ns). + best_time = min(best_time, @elapsed $(esc(expression))) + end + best_time + end +end + +# This has to be its own function in order to correctly measure its allocations. +call_array_func( + ref_set_result!::F, + ref_result_arrays, + inputs_arrays, + temp_values_arrays, +) where {F} = foreach( + ref_set_result!, + ref_result_arrays, + inputs_arrays..., + temp_values_arrays..., +) + +function test_matrix_broadcast_against_array_reference(; + test_name, + inputs, + get_result::F1, + set_result!::F2, + temp_values = (), + ref_set_result!::F3, + print_summary = true, + max_error_limit = 3, +) where {F1, F2, F3} + @testset "$test_name" begin + result = get_result(inputs...) + + # Fill all output fields with NaNs for testing correctness. + result .*= NaN + for temp_value in temp_values + temp_value .*= NaN + end + + ref_result_arrays = MatrixFields.field2arrays(result) + inputs_arrays = map(MatrixFields.field2arrays, inputs) + temp_values_arrays = map(MatrixFields.field2arrays, temp_values) + + best_time = @benchmark set_result!(result, inputs...) + best_ref_time = @benchmark call_array_func( + ref_set_result!, + ref_result_arrays, + inputs_arrays, + temp_values_arrays, + ) + + # Compute the maximum error as an integer multiple of machine epsilon. + result_arrays = MatrixFields.field2arrays(result) + max_error = + maximum(zip(result_arrays, ref_result_arrays)) do (array, ref_array) + maximum(zip(array, ref_array)) do (value, ref_value) + Int(abs(value - ref_value) / eps(ref_value)) + end + end + + if print_summary + @info "$test_name:\n\tBest Time = $best_time s\n\tBest Reference \ + Time = $best_ref_time s\n\tMaximum Error = $max_error eps" + end + + # Test that set_result! is performant compared to ref_set_result!. The + # factor of 12 is needed in order for this test to pass on CI, though it + # is not needed when running this test locally. + @test best_time < 12 * best_ref_time + + # Test set_result! for correctness, allocations, and type instabilities. + @test max_error <= max_error_limit + @test (@allocated set_result!(result, inputs...)) == 0 + @test_opt set_result!(result, inputs...) + + # Test ref_set_result! for allocations and type instabilities. This is + # helpful for ensuring that the performance comparison is fair. + @test (@allocated call_array_func( + ref_set_result!, + ref_result_arrays, + inputs_arrays, + temp_values_arrays, + )) == 0 + @test_opt call_array_func( + ref_set_result!, + ref_result_arrays, + inputs_arrays, + temp_values_arrays, + ) + + # Test get_result (the allocating version of set_result!) for type + # instabilities. + @test_opt get_result(inputs...) + end +end + +function test_matrix_broadcast_against_reference(; + test_name, + inputs, + get_result::F1, + set_result!::F2, + ref_inputs, + ref_set_result!::F3, + print_summary = true, +) where {F1, F2, F3} + @testset "$test_name" begin + result = get_result(inputs...) + + # Fill the output field with NaNs for testing correctness. + result .*= NaN + + ref_result = copy(result) + + best_time = @benchmark set_result!(result, inputs...) + best_ref_time = @benchmark ref_set_result!(ref_result, ref_inputs...) + + if print_summary + @info "$test_name:\n\tBest Time = $best_time s\n\tBest Reference \ + Time = $best_ref_time s" + end + + # Test that set_result! is performant compared to ref_set_result!. + @test best_time < best_ref_time + + # Test set_result! for correctness, allocations, and type instabilities. + @test result == ref_result + @test (@allocated set_result!(result, inputs...)) == 0 + @test_opt set_result!(result, inputs...) + + # Test ref_set_result! for allocations and type instabilities. This is + # helpful for ensuring that the performance comparison is fair. + @test (@allocated ref_set_result!(ref_result, ref_inputs...)) == 0 + @test_opt ref_set_result!(ref_result, ref_inputs...) + + # Test get_result (the allocating version of set_result!) for type + # instabilities. + @test_opt get_result(inputs...) + end +end + +function random_test_fields(::Type{FT}) where {FT} + velem = 20 # This should be big enough to test high-bandwidth matrices. + helem = npoly = 1 # These should be small enough for the tests to be fast. + + hdomain = Domains.SphereDomain(FT(10)) + hmesh = Meshes.EquiangularCubedSphere(hdomain, helem) + htopology = Topologies.Topology2D(ClimaComms.SingletonCommsContext(), hmesh) + quad = Spaces.Quadratures.GLL{npoly + 1}() + hspace = Spaces.SpectralElementSpace2D(htopology, quad) + vdomain = Domains.IntervalDomain( + Geometry.ZPoint(FT(0)), + Geometry.ZPoint(FT(10)); + boundary_tags = (:bottom, :top), + ) + vmesh = Meshes.IntervalMesh(vdomain, nelems = velem) + vspace = Spaces.CenterFiniteDifferenceSpace(vmesh) + sfc_coord = Fields.coordinate_field(hspace) + hypsography = Hypsography.LinearAdaption( + @. cosd(sfc_coord.lat) + cosd(sfc_coord.long) + 1 + ) + center_space = + Spaces.ExtrudedFiniteDifferenceSpace(hspace, vspace, hypsography) + face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(center_space) + ᶜcoord = Fields.coordinate_field(center_space) + ᶠcoord = Fields.coordinate_field(face_space) + + seed!(1) # ensures reproducibility + ᶜᶜmat = map(_ -> DiagonalMatrixRow(rand(FT, 1)...), ᶜcoord) + ᶜᶠmat = map(_ -> BidiagonalMatrixRow(rand(FT, 2)...), ᶜcoord) + ᶠᶠmat = map(_ -> TridiagonalMatrixRow(rand(FT, 3)...), ᶠcoord) + ᶠᶜmat = map(_ -> QuaddiagonalMatrixRow(rand(FT, 4)...), ᶠcoord) + ᶜvec = map(_ -> rand(FT), ᶜcoord) + ᶠvec = map(_ -> rand(FT), ᶠcoord) + + return ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec, ᶠvec +end + +@testset "Scalar Matrix Field Broadcasting" begin + ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec, ᶠvec = random_test_fields(Float64) + + test_matrix_broadcast_against_array_reference(; + test_name = "diagonal matrix times vector", + inputs = (ᶜᶜmat, ᶜvec), + get_result = (ᶜᶜmat, ᶜvec) -> (@. ᶜᶜmat ⋅ ᶜvec), + set_result! = (result, ᶜᶜmat, ᶜvec) -> (@. result = ᶜᶜmat ⋅ ᶜvec), + ref_set_result! = (_result, _ᶜᶜmat, _ᶜvec) -> + mul!(_result, _ᶜᶜmat, _ᶜvec), + ) + + test_matrix_broadcast_against_array_reference(; + test_name = "tri-diagonal matrix times vector", + inputs = (ᶠᶠmat, ᶠvec), + get_result = (ᶠᶠmat, ᶠvec) -> (@. ᶠᶠmat ⋅ ᶠvec), + set_result! = (result, ᶠᶠmat, ᶠvec) -> (@. result = ᶠᶠmat ⋅ ᶠvec), + ref_set_result! = (_result, _ᶠᶠmat, _ᶠvec) -> + mul!(_result, _ᶠᶠmat, _ᶠvec), + ) + + test_matrix_broadcast_against_array_reference(; + test_name = "quad-diagonal matrix times vector", + inputs = (ᶠᶜmat, ᶜvec), + get_result = (ᶠᶜmat, ᶜvec) -> (@. ᶠᶜmat ⋅ ᶜvec), + set_result! = (result, ᶠᶜmat, ᶜvec) -> (@. result = ᶠᶜmat ⋅ ᶜvec), + ref_set_result! = (_result, _ᶠᶜmat, _ᶜvec) -> + mul!(_result, _ᶠᶜmat, _ᶜvec), + ) + + test_matrix_broadcast_against_array_reference(; + test_name = "diagonal matrix times bi-diagonal matrix", + inputs = (ᶜᶜmat, ᶜᶠmat), + get_result = (ᶜᶜmat, ᶜᶠmat) -> (@. ᶜᶜmat ⋅ ᶜᶠmat), + set_result! = (result, ᶜᶜmat, ᶜᶠmat) -> (@. result = ᶜᶜmat ⋅ ᶜᶠmat), + ref_set_result! = (_result, _ᶜᶜmat, _ᶜᶠmat) -> + mul!(_result, _ᶜᶜmat, _ᶜᶠmat), + ) + + test_matrix_broadcast_against_array_reference(; + test_name = "tri-diagonal matrix times tri-diagonal matrix", + inputs = (ᶠᶠmat,), + get_result = (ᶠᶠmat,) -> (@. ᶠᶠmat ⋅ ᶠᶠmat), + set_result! = (result, ᶠᶠmat) -> (@. result = ᶠᶠmat ⋅ ᶠᶠmat), + ref_set_result! = (_result, _ᶠᶠmat) -> mul!(_result, _ᶠᶠmat, _ᶠᶠmat), + ) + + test_matrix_broadcast_against_array_reference(; + test_name = "quad-diagonal matrix times diagonal matrix", + inputs = (ᶠᶜmat, ᶜᶜmat), + get_result = (ᶠᶜmat, ᶜᶜmat) -> (@. ᶠᶜmat ⋅ ᶜᶜmat), + set_result! = (result, ᶠᶜmat, ᶜᶜmat) -> (@. result = ᶠᶜmat ⋅ ᶜᶜmat), + ref_set_result! = (_result, _ᶠᶜmat, _ᶜᶜmat) -> + mul!(_result, _ᶠᶜmat, _ᶜᶜmat), + ) + + test_matrix_broadcast_against_array_reference(; + test_name = "diagonal matrix times bi-diagonal matrix times \ + tri-diagonal matrix times quad-diagonal matrix", + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat), + get_result = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> + (@. ᶜᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶠmat ⋅ ᶠᶜmat), + set_result! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> + (@. result = ᶜᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶠmat ⋅ ᶠᶜmat), + temp_values = ((@. ᶜᶜmat ⋅ ᶜᶠmat), (@. ᶜᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶠmat)), + ref_set_result! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _temp1, + _temp2, + ) -> begin + mul!(_temp1, _ᶜᶜmat, _ᶜᶠmat) + mul!(_temp2, _temp1, _ᶠᶠmat) + mul!(_result, _temp2, _ᶠᶜmat) + end, + ) + + test_matrix_broadcast_against_array_reference(; + test_name = "diagonal matrix times bi-diagonal matrix times \ + tri-diagonal matrix times quad-diagonal matrix, but with \ + forced right-associativity", + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat), + get_result = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> + (@. ᶜᶜmat ⋅ (ᶜᶠmat ⋅ (ᶠᶠmat ⋅ ᶠᶜmat))), + set_result! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> + (@. result = ᶜᶜmat ⋅ (ᶜᶠmat ⋅ (ᶠᶠmat ⋅ ᶠᶜmat))), + temp_values = ((@. ᶠᶠmat ⋅ ᶠᶜmat), (@. ᶜᶠmat ⋅ (ᶠᶠmat ⋅ ᶠᶜmat))), + ref_set_result! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _temp1, + _temp2, + ) -> begin + mul!(_temp1, _ᶠᶠmat, _ᶠᶜmat) + mul!(_temp2, _ᶜᶠmat, _temp1) + mul!(_result, _ᶜᶜmat, _temp2) + end, + ) + + test_matrix_broadcast_against_array_reference(; + test_name = "diagonal matrix times bi-diagonal matrix times \ + tri-diagonal matrix times quad-diagonal matrix times \ + vector", + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec), + get_result = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec) -> + (@. ᶜᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶠmat ⋅ ᶠᶜmat ⋅ ᶜvec), + set_result! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec) -> + (@. result = ᶜᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶠmat ⋅ ᶠᶜmat ⋅ ᶜvec), + temp_values = ( + (@. ᶜᶜmat ⋅ ᶜᶠmat), + (@. ᶜᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶠmat), + (@. ᶜᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶠmat ⋅ ᶠᶜmat), + ), + ref_set_result! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _ᶜvec, + _temp1, + _temp2, + _temp3, + ) -> begin + mul!(_temp1, _ᶜᶜmat, _ᶜᶠmat) + mul!(_temp2, _temp1, _ᶠᶠmat) + mul!(_temp3, _temp2, _ᶠᶜmat) + mul!(_result, _temp3, _ᶜvec) + end, + ) + + test_matrix_broadcast_against_array_reference(; + test_name = "diagonal matrix times bi-diagonal matrix times \ + tri-diagonal matrix times quad-diagonal matrix times \ + vector, but with forced right-associativity", + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec), + get_result = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec) -> + (@. ᶜᶜmat ⋅ (ᶜᶠmat ⋅ (ᶠᶠmat ⋅ (ᶠᶜmat ⋅ ᶜvec)))), + set_result! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec) -> + (@. result = ᶜᶜmat ⋅ (ᶜᶠmat ⋅ (ᶠᶠmat ⋅ (ᶠᶜmat ⋅ ᶜvec)))), + temp_values = ( + (@. ᶠᶜmat ⋅ ᶜvec), + (@. ᶠᶠmat ⋅ (ᶠᶜmat ⋅ ᶜvec)), + (@. ᶜᶠmat ⋅ (ᶠᶠmat ⋅ (ᶠᶜmat ⋅ ᶜvec))), + ), + ref_set_result! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _ᶜvec, + _temp1, + _temp2, + _temp3, + ) -> begin + mul!(_temp1, _ᶠᶜmat, _ᶜvec) + mul!(_temp2, _ᶠᶠmat, _temp1) + mul!(_temp3, _ᶜᶠmat, _temp2) + mul!(_result, _ᶜᶜmat, _temp3) + end, + ) + + test_matrix_broadcast_against_array_reference(; + test_name = "linear combination of matrix products and LinearAlgebra.I", + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat), + get_result = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)), + set_result! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> + (@. result = 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)), + temp_values = ( + (@. 2 * ᶠᶜmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat), + (@. ᶠᶠmat ⋅ ᶠᶠmat), + ), + ref_set_result! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _temp1, + _temp2, + _temp3, + _temp4, + ) -> begin + @. _temp1 = 0 + 2 * _ᶠᶜmat # This allocates without the `0 + `. + mul!(_temp2, _temp1, _ᶜᶜmat) + mul!(_temp3, _temp2, _ᶜᶠmat) + mul!(_temp4, _ᶠᶠmat, _ᶠᶠmat) + copyto!(_result, 4I) # We can't directly use I in array broadcasts. + @. _result = _temp3 + _temp4 / 3 - _result + end, + ) + + test_matrix_broadcast_against_array_reference(; + test_name = "another linear combination of matrix products and \ + LinearAlgebra.I", + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat), + get_result = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> + (@. ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,)), + set_result! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> (@. result = + ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,)), + temp_values = ( + (@. ᶠᶜmat ⋅ ᶜᶜmat), + (@. ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat), + (@. ᶠᶠmat / 3), + (@. (ᶠᶠmat / 3) ⋅ ᶠᶠmat), + ), + ref_set_result! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _temp1, + _temp2, + _temp3, + _temp4, + ) -> begin + mul!(_temp1, _ᶠᶜmat, _ᶜᶜmat) + mul!(_temp2, _temp1, _ᶜᶠmat) + @. _temp3 = 0 + _ᶠᶠmat / 3 # This allocates without the `0 + `. + mul!(_temp4, _temp3, _ᶠᶠmat) + copyto!(_result, 4I) # We can't directly use I in array broadcasts. + @. _result = _temp2 * 2 - _temp4 + _result + end, + max_error_limit = 512, # TODO: Why is the error so large on CI? + ) + + test_matrix_broadcast_against_array_reference(; + test_name = "matrix times linear combination", + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat), + get_result = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> (@. ᶜᶠmat ⋅ + (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,))), + set_result! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> (@. result = + ᶜᶠmat ⋅ (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,))), + temp_values = ( + (@. 2 * ᶠᶜmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat), + (@. ᶠᶠmat ⋅ ᶠᶠmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)), + ), + ref_set_result! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _temp1, + _temp2, + _temp3, + _temp4, + _temp5, + ) -> begin + @. _temp1 = 0 + 2 * _ᶠᶜmat # This allocates without the `0 + `. + mul!(_temp2, _temp1, _ᶜᶜmat) + mul!(_temp3, _temp2, _ᶜᶠmat) + mul!(_temp4, _ᶠᶠmat, _ᶠᶠmat) + copyto!(_temp5, 4I) # We can't directly use I in array broadcasts. + @. _temp5 = _temp3 + _temp4 / 3 - _temp5 + mul!(_result, _ᶜᶠmat, _temp5) + end, + max_error_limit = 276, # TODO: Why is the error so large on CI? + ) + + test_matrix_broadcast_against_array_reference(; + test_name = "linear combination times another linear combination", + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat), + get_result = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> + (@. (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)) ⋅ + (ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,))), + set_result! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> (@. result = + (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)) ⋅ + (ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,))), + temp_values = ( + (@. 2 * ᶠᶜmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat), + (@. ᶠᶠmat ⋅ ᶠᶠmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)), + (@. ᶠᶜmat ⋅ ᶜᶜmat), + (@. ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat), + (@. ᶠᶠmat / 3), + (@. (ᶠᶠmat / 3) ⋅ ᶠᶠmat), + (@. ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,)), + ), + ref_set_result! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _temp1, + _temp2, + _temp3, + _temp4, + _temp5, + _temp6, + _temp7, + _temp8, + _temp9, + _temp10, + ) -> begin + @. _temp1 = 0 + 2 * _ᶠᶜmat # This allocates without the `0 + `. + mul!(_temp2, _temp1, _ᶜᶜmat) + mul!(_temp3, _temp2, _ᶜᶠmat) + mul!(_temp4, _ᶠᶠmat, _ᶠᶠmat) + copyto!(_temp5, 4I) # We can't directly use I in array broadcasts. + @. _temp5 = _temp3 + _temp4 / 3 - _temp5 + mul!(_temp6, _ᶠᶜmat, _ᶜᶜmat) + mul!(_temp7, _temp6, _ᶜᶠmat) + @. _temp8 = 0 + _ᶠᶠmat / 3 # This allocates without the `0 + `. + mul!(_temp9, _temp8, _ᶠᶠmat) + copyto!(_temp10, 4I) # We can't directly use I in array broadcasts. + @. _temp10 = _temp7 * 2 - _temp9 + _temp10 + mul!(_result, _temp5, _temp10) + end, + max_error_limit = 44, # TODO: Why is the error so large on CI? + ) + + test_matrix_broadcast_against_array_reference(; + test_name = "matrix times matrix times linear combination times matrix \ + times another linear combination times matrix", + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat), + get_result = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> (@. ᶠᶜmat ⋅ ᶜᶠmat ⋅ + (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)) ⋅ + ᶠᶠmat ⋅ + (ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,)) ⋅ + ᶠᶠmat), + set_result! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> (@. result = + ᶠᶜmat ⋅ ᶜᶠmat ⋅ + (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)) ⋅ + ᶠᶠmat ⋅ + (ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,)) ⋅ + ᶠᶠmat), + temp_values = ( + (@. ᶠᶜmat ⋅ ᶜᶠmat), + (@. 2 * ᶠᶜmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat), + (@. ᶠᶠmat ⋅ ᶠᶠmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)), + (@. ᶠᶜmat ⋅ ᶜᶠmat ⋅ + (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,))), + (@. ᶠᶜmat ⋅ ᶜᶠmat ⋅ + (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)) ⋅ + ᶠᶠmat), + (@. ᶠᶜmat ⋅ ᶜᶜmat), + (@. ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat), + (@. ᶠᶠmat / 3), + (@. (ᶠᶠmat / 3) ⋅ ᶠᶠmat), + (@. ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,)), + (@. ᶠᶜmat ⋅ ᶜᶠmat ⋅ + (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)) ⋅ + ᶠᶠmat ⋅ + (ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,))), + ), + ref_set_result! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _temp1, + _temp2, + _temp3, + _temp4, + _temp5, + _temp6, + _temp7, + _temp8, + _temp9, + _temp10, + _temp11, + _temp12, + _temp13, + _temp14, + ) -> begin + mul!(_temp1, _ᶠᶜmat, _ᶜᶠmat) + @. _temp2 = 0 + 2 * _ᶠᶜmat # This allocates without the `0 + `. + mul!(_temp3, _temp2, _ᶜᶜmat) + mul!(_temp4, _temp3, _ᶜᶠmat) + mul!(_temp5, _ᶠᶠmat, _ᶠᶠmat) + copyto!(_temp6, 4I) # We can't directly use I in array broadcasts. + @. _temp6 = _temp4 + _temp5 / 3 - _temp6 + mul!(_temp7, _temp1, _temp6) + mul!(_temp8, _temp7, _ᶠᶠmat) + mul!(_temp9, _ᶠᶜmat, _ᶜᶜmat) + mul!(_temp10, _temp9, _ᶜᶠmat) + @. _temp11 = 0 + _ᶠᶠmat / 3 # This allocates without the `0 + `. + mul!(_temp12, _temp11, _ᶠᶠmat) + copyto!(_temp13, 4I) # We can't directly use I in array broadcasts. + @. _temp13 = _temp10 * 2 - _temp12 + _temp13 + mul!(_temp14, _temp8, _temp13) + mul!(_result, _temp14, _ᶠᶠmat) + end, + max_error_limit = 3038, # TODO: Why is the error so large on CI? + ) + + test_matrix_broadcast_against_array_reference(; + test_name = "matrix constructions and multiplications", + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec, ᶠvec), + get_result = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec, ᶠvec) -> + (@. BidiagonalMatrixRow(ᶜᶠmat ⋅ ᶠvec, ᶜᶜmat ⋅ ᶜvec) ⋅ + TridiagonalMatrixRow(ᶠvec, ᶠᶜmat ⋅ ᶜvec, 1) ⋅ ᶠᶠmat ⋅ + DiagonalMatrixRow(DiagonalMatrixRow(ᶠvec) ⋅ ᶠvec)), + set_result! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec, ᶠvec) -> + (@. result = + BidiagonalMatrixRow(ᶜᶠmat ⋅ ᶠvec, ᶜᶜmat ⋅ ᶜvec) ⋅ + TridiagonalMatrixRow(ᶠvec, ᶠᶜmat ⋅ ᶜvec, 1) ⋅ ᶠᶠmat ⋅ + DiagonalMatrixRow(DiagonalMatrixRow(ᶠvec) ⋅ ᶠvec)), + temp_values = ( + (@. BidiagonalMatrixRow(ᶜᶠmat ⋅ ᶠvec, ᶜᶜmat ⋅ ᶜvec)), + (@. TridiagonalMatrixRow(ᶠvec, ᶠᶜmat ⋅ ᶜvec, 1)), + (@. BidiagonalMatrixRow(ᶜᶠmat ⋅ ᶠvec, ᶜᶜmat ⋅ ᶜvec) ⋅ + TridiagonalMatrixRow(ᶠvec, ᶠᶜmat ⋅ ᶜvec, 1)), + (@. BidiagonalMatrixRow(ᶜᶠmat ⋅ ᶠvec, ᶜᶜmat ⋅ ᶜvec) ⋅ + TridiagonalMatrixRow(ᶠvec, ᶠᶜmat ⋅ ᶜvec, 1) ⋅ ᶠᶠmat), + (@. DiagonalMatrixRow(ᶠvec)), + (@. DiagonalMatrixRow(DiagonalMatrixRow(ᶠvec) ⋅ ᶠvec)), + ), + ref_set_result! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _ᶜvec, + _ᶠvec, + _temp1, + _temp2, + _temp3, + _temp4, + _temp5, + _temp6, + ) -> begin + mul!(view(_temp1, band(0)), _ᶜᶠmat, _ᶠvec) + mul!(view(_temp1, band(1)), _ᶜᶜmat, _ᶜvec) + copyto!(view(_temp2, band(-1)), 1, _ᶠvec, 2) + mul!(view(_temp2, band(0)), _ᶠᶜmat, _ᶜvec) + fill!(view(_temp2, band(1)), 1) + mul!(_temp3, _temp1, _temp2) + mul!(_temp4, _temp3, _ᶠᶠmat) + copyto!(view(_temp5, band(0)), 1, _ᶠvec, 1) + mul!(view(_temp6, band(0)), _temp5, _ᶠvec) + mul!(_result, _temp4, _temp6) + end, + ) +end + +@testset "Non-scalar Matrix Field Broadcasting" begin + ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec, ᶠvec = random_test_fields(Float64) + + ᶜlg = Fields.local_geometry_field(ᶜvec) + ᶠlg = Fields.local_geometry_field(ᶠvec) + + ᶜᶠmat2 = map(row -> map(sin, row), ᶜᶠmat) + ᶜᶠmat3 = map(row -> map(cos, row), ᶜᶠmat) + ᶠᶜmat2 = map(row -> map(sin, row), ᶠᶜmat) + ᶠᶜmat3 = map(row -> map(cos, row), ᶠᶜmat) + + ᶜᶠmat_AC1 = map(row -> map(adjoint ∘ Geometry.Covariant1Vector, row), ᶜᶠmat) + ᶜᶠmat_C12 = map( + (row1, row2) -> map(Geometry.Covariant12Vector, row1, row2), + ᶜᶠmat2, + ᶜᶠmat3, + ) + ᶠᶜmat_AC1 = map(row -> map(adjoint ∘ Geometry.Covariant1Vector, row), ᶠᶜmat) + ᶠᶜmat_C12 = map( + (row1, row2) -> map(Geometry.Covariant12Vector, row1, row2), + ᶠᶜmat2, + ᶠᶜmat3, + ) + + test_matrix_broadcast_against_reference(; + test_name = "matrix of covectors times matrix of vectors", + inputs = (ᶜᶠmat_AC1, ᶠᶜmat_C12), + get_result = (ᶜᶠmat_AC1, ᶠᶜmat_C12) -> (@. ᶜᶠmat_AC1 ⋅ ᶠᶜmat_C12), + set_result! = (result, ᶜᶠmat_AC1, ᶠᶜmat_C12) -> + (@. result = ᶜᶠmat_AC1 ⋅ ᶠᶜmat_C12), + ref_inputs = (ᶜᶠmat, ᶠᶜmat2, ᶠᶜmat3, ᶠlg), + ref_set_result! = (result, ᶜᶠmat, ᶠᶜmat2, ᶠᶜmat3, ᶠlg) -> (@. result = + ᶜᶠmat ⋅ ( + DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:1) ⋅ ᶠᶜmat2 + + DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:2) ⋅ ᶠᶜmat3 + )), + ) + + test_matrix_broadcast_against_reference(; + test_name = "matrix of covectors times matrix of vectors times matrix \ + of numbers times matrix of covectors times matrix of \ + vectors", + inputs = (ᶜᶠmat_AC1, ᶠᶜmat_C12, ᶜᶠmat, ᶠᶜmat_AC1, ᶜᶠmat_C12), + get_result = (ᶜᶠmat_AC1, ᶠᶜmat_C12, ᶜᶠmat, ᶠᶜmat_AC1, ᶜᶠmat_C12) -> + (@. ᶜᶠmat_AC1 ⋅ ᶠᶜmat_C12 ⋅ ᶜᶠmat ⋅ ᶠᶜmat_AC1 ⋅ ᶜᶠmat_C12), + set_result! = ( + result, + ᶜᶠmat_AC1, + ᶠᶜmat_C12, + ᶜᶠmat, + ᶠᶜmat_AC1, + ᶜᶠmat_C12, + ) -> + (@. result = ᶜᶠmat_AC1 ⋅ ᶠᶜmat_C12 ⋅ ᶜᶠmat ⋅ ᶠᶜmat_AC1 ⋅ ᶜᶠmat_C12), + ref_inputs = (ᶜᶠmat, ᶜᶠmat2, ᶜᶠmat3, ᶠᶜmat, ᶠᶜmat2, ᶠᶜmat3, ᶜlg, ᶠlg), + ref_set_result! = ( + result, + ᶜᶠmat, + ᶜᶠmat2, + ᶜᶠmat3, + ᶠᶜmat, + ᶠᶜmat2, + ᶠᶜmat3, + ᶜlg, + ᶠlg, + ) -> (@. result = + ᶜᶠmat ⋅ ( + DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:1) ⋅ ᶠᶜmat2 + + DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:2) ⋅ ᶠᶜmat3 + ) ⋅ ᶜᶠmat ⋅ ᶠᶜmat ⋅ ( + DiagonalMatrixRow(ᶜlg.gⁱʲ.components.data.:1) ⋅ ᶜᶠmat2 + + DiagonalMatrixRow(ᶜlg.gⁱʲ.components.data.:2) ⋅ ᶜᶠmat3 + )), + ) + + ᶜᶠmat_AC1_num = + map((row1, row2) -> map(tuple, row1, row2), ᶜᶠmat_AC1, ᶜᶠmat) + ᶜᶠmat_num_C12 = + map((row1, row2) -> map(tuple, row1, row2), ᶜᶠmat, ᶜᶠmat_C12) + ᶠᶜmat_C12_AC1 = + map((row1, row2) -> map(tuple, row1, row2), ᶠᶜmat_C12, ᶠᶜmat_AC1) + + test_matrix_broadcast_against_reference(; + test_name = "matrix of covectors and numbers times matrix of vectors \ + and covectors times matrix of numbers and vectors times \ + vector of numbers", + inputs = (ᶜᶠmat_AC1_num, ᶠᶜmat_C12_AC1, ᶜᶠmat_num_C12, ᶠvec), + get_result = (ᶜᶠmat_AC1_num, ᶠᶜmat_C12_AC1, ᶜᶠmat_num_C12, ᶠvec) -> + (@. ᶜᶠmat_AC1_num ⋅ ᶠᶜmat_C12_AC1 ⋅ ᶜᶠmat_num_C12 ⋅ ᶠvec), + set_result! = ( + result, + ᶜᶠmat_AC1_num, + ᶠᶜmat_C12_AC1, + ᶜᶠmat_num_C12, + ᶠvec, + ) -> (@. result = ᶜᶠmat_AC1_num ⋅ ᶠᶜmat_C12_AC1 ⋅ ᶜᶠmat_num_C12 ⋅ ᶠvec), + ref_inputs = ( + ᶜᶠmat, + ᶜᶠmat2, + ᶜᶠmat3, + ᶠᶜmat, + ᶠᶜmat2, + ᶠᶜmat3, + ᶠvec, + ᶜlg, + ᶠlg, + ), + ref_set_result! = ( + result, + ᶜᶠmat, + ᶜᶠmat2, + ᶜᶠmat3, + ᶠᶜmat, + ᶠᶜmat2, + ᶠᶜmat3, + ᶠvec, + ᶜlg, + ᶠlg, + ) -> (@. result = tuple( + ᶜᶠmat ⋅ ( + DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:1) ⋅ ᶠᶜmat2 + + DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:2) ⋅ ᶠᶜmat3 + ) ⋅ ᶜᶠmat ⋅ ᶠvec, + ᶜᶠmat ⋅ ᶠᶜmat ⋅ ( + DiagonalMatrixRow(ᶜlg.gⁱʲ.components.data.:1) ⋅ ᶜᶠmat2 + + DiagonalMatrixRow(ᶜlg.gⁱʲ.components.data.:2) ⋅ ᶜᶠmat3 + ) ⋅ ᶠvec, + )), + ) + + T(value1, value2, value3) = + (; a = (), b = value1, c = (value2, (; d = (value3,)), (;))) + ᶜᶠmat_T = map((rows...) -> map(T, rows...), ᶜᶠmat, ᶜᶠmat2, ᶜᶠmat3) + ᶠᶜmat_T = map((rows...) -> map(T, rows...), ᶠᶜmat, ᶠᶜmat2, ᶠᶜmat3) + ᶜvec_T = @. T(ᶜvec, ᶜvec, ᶜvec) + + test_matrix_broadcast_against_reference(; + test_name = "matrix of nested values times matrix of nested values \ + times matrix of numbers times matrix of numbers times \ + times vector of nested values", + inputs = (ᶜᶠmat_T, ᶠᶜmat, ᶜᶠmat, ᶠᶜmat_T, ᶜvec_T), + get_result = (ᶜᶠmat_T, ᶠᶜmat, ᶜᶠmat, ᶠᶜmat_T, ᶜvec_T) -> + (@. ᶜᶠmat_T ⋅ ᶠᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶜmat_T ⋅ ᶜvec_T), + set_result! = (result, ᶜᶠmat_T, ᶠᶜmat, ᶜᶠmat, ᶠᶜmat_T, ᶜvec_T) -> + (@. result = ᶜᶠmat_T ⋅ ᶠᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶜmat_T ⋅ ᶜvec_T), + ref_inputs = (ᶜᶠmat, ᶜᶠmat2, ᶜᶠmat3, ᶠᶜmat, ᶠᶜmat2, ᶠᶜmat3, ᶜvec), + ref_set_result! = ( + result, + ᶜᶠmat, + ᶜᶠmat2, + ᶜᶠmat3, + ᶠᶜmat, + ᶠᶜmat2, + ᶠᶜmat3, + ᶜvec, + ) -> (@. result = T( + ᶜᶠmat ⋅ ᶠᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶜmat ⋅ ᶜvec, + ᶜᶠmat2 ⋅ ᶠᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶜmat2 ⋅ ᶜvec, + ᶜᶠmat3 ⋅ ᶠᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶜmat3 ⋅ ᶜvec, + )), + ) +end diff --git a/test/MatrixFields/rmul_with_projection.jl b/test/MatrixFields/rmul_with_projection.jl new file mode 100644 index 0000000000..d7f4855865 --- /dev/null +++ b/test/MatrixFields/rmul_with_projection.jl @@ -0,0 +1,123 @@ +using Test +using JET +using Random: seed! +using StaticArrays: @SMatrix + +import ClimaCore: Geometry +import ClimaCore.MatrixFields: rmul_with_projection + +function test_rmul_with_projection(x, y, lg, expected_result) + result = rmul_with_projection(x, y, lg) + + # Compute the maximum error as an integer multiple of machine epsilon. + FT = Geometry.undertype(typeof(lg)) + object2tuple(obj) = + reinterpret(NTuple{sizeof(obj) ÷ sizeof(FT), FT}, [obj])[1] + max_error = maximum( + ((value, expected_value),) -> + Int(abs(value - expected_value) / eps(expected_value)), + zip(object2tuple(result), object2tuple(expected_result)), + ) + + @test max_error <= 1 # correctness + @test (@allocated rmul_with_projection(x, y, lg)) == 0 # allocations + @test_opt rmul_with_projection(x, y, lg) # type instabilities +end + +@testset "rmul_with_projection Unit Tests" begin + seed!(1) # ensures reproducibility + + FT = Float64 + coord = Geometry.LatLongZPoint(rand(FT), rand(FT), rand(FT)) + ∂x∂ξ = Geometry.AxisTensor( + (Geometry.LocalAxis{(1, 2, 3)}(), Geometry.CovariantAxis{(1, 2, 3)}()), + (@SMatrix rand(FT, 3, 3)), + ) + lg = Geometry.LocalGeometry(coord, rand(FT), rand(FT), ∂x∂ξ) + + number = rand(FT) + covector = Geometry.Covariant12Vector(rand(FT), rand(FT))' + vector = Geometry.Covariant123Vector(rand(FT), rand(FT), rand(FT)) + tensor = vector * vector' + projected_vector = + Geometry.project(Geometry.Contravariant12Axis(), vector, lg) + projected_tensor = + Geometry.project(Geometry.Contravariant12Axis(), tensor, lg) + + # Test all required combinations of single values. + test_rmul_with_projection(number, number, lg, number * number) + test_rmul_with_projection(number, covector, lg, number * covector) + test_rmul_with_projection(number, vector, lg, number * vector) + test_rmul_with_projection(number, tensor, lg, number * tensor) + test_rmul_with_projection(covector, number, lg, covector * number) + test_rmul_with_projection(vector, number, lg, vector * number) + test_rmul_with_projection(tensor, number, lg, tensor * number) + test_rmul_with_projection(covector, vector, lg, covector * projected_vector) + test_rmul_with_projection(covector, tensor, lg, covector * projected_tensor) + + # Test some combinations of complicated nested values. + T(value1, value2, value3) = + (; a = (), b = value1, c = (value2, (; d = (value3,)), (;))) + test_rmul_with_projection( + number, + T(covector, vector, tensor), + lg, + T(number * covector, number * vector, number * tensor), + ) + test_rmul_with_projection( + T(covector, vector, tensor), + number, + lg, + T(covector * number, vector * number, tensor * number), + ) + test_rmul_with_projection( + vector, + T(number, number, number), + lg, + T(vector * number, vector * number, vector * number), + ) + test_rmul_with_projection( + T(number, number, number), + covector, + lg, + T(number * covector, number * covector, number * covector), + ) + test_rmul_with_projection( + T(number, vector, number), + T(covector, number, tensor), + lg, + T(number * covector, vector * number, number * tensor), + ) + test_rmul_with_projection( + T(covector, number, tensor), + T(number, vector, number), + lg, + T(covector * number, number * vector, tensor * number), + ) + test_rmul_with_projection( + covector, + T(vector, number, tensor), + lg, + T( + covector * projected_vector, + covector * number, + covector * projected_tensor, + ), + ) + test_rmul_with_projection( + T(covector, number, covector), + vector, + lg, + T( + covector * projected_vector, + number * vector, + covector * projected_vector, + ), + ) + test_rmul_with_projection( + T(covector, number, covector), + T(number, vector, tensor), + lg, + T(covector * number, number * vector, covector * projected_tensor), + ) +end diff --git a/test/Project.toml b/test/Project.toml index 530d76c116..b1de2515a9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" AssociatedLegendrePolynomials = "2119f1ac-fb78-50f5-8cc0-dda848ebdb19" +BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/test/runtests.jl b/test/runtests.jl index f04ac240fd..27f5123f8a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -75,6 +75,11 @@ if !Sys.iswindows() @safetestset "Hybrid - dss opt" begin @time include("Operators/hybrid/dss_opt.jl") end @safetestset "Hybrid - opt" begin @time include("Operators/hybrid/opt.jl") end + @safetestset "MatrixFields - BandMatrixRow" begin @time include("MatrixFields/band_matrix_row.jl") end + @safetestset "MatrixFields - rmul_with_projection" begin @time include("MatrixFields/rmul_with_projection.jl") end + @safetestset "MatrixFields - field2arrays" begin @time include("MatrixFields/field2arrays.jl") end + @safetestset "MatrixFields - matrix field broadcasting" begin @time include("MatrixFields/matrix_field_broadcasting.jl") end + @safetestset "Hypsography - 2d" begin @time include("Hypsography/2d.jl") end @safetestset "Hypsography - 3d sphere" begin @time include("Hypsography/3dsphere.jl") end