Skip to content

Commit

Permalink
Try #1326:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] committed Jun 21, 2023
2 parents eb6f155 + 6265bdf commit 2973c23
Show file tree
Hide file tree
Showing 18 changed files with 1,861 additions and 4 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.10.40"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
Expand All @@ -31,6 +32,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
Adapt = "3"
BandedMatrices = "0.17"
BlockArrays = "0.16"
CUDA = "3, 4.2.0"
ClimaComms = "0.4.2"
Expand Down
8 changes: 7 additions & 1 deletion docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ git-tree-sha1 = "dbf84058d0a8cbbadee18d25cf606934b22d7c66"
uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
version = "0.4.2"

[[deps.BandedMatrices]]
deps = ["ArrayLayouts", "FillArrays", "LinearAlgebra", "PrecompileTools", "SparseArrays"]
git-tree-sha1 = "9ad46355045491b12eab409dee73e9de46293aa2"
uuid = "aae01518-5342-5314-be14-df237901396f"
version = "0.17.28"

[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

Expand Down Expand Up @@ -216,7 +222,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.4.2"

[[deps.ClimaCore]]
deps = ["Adapt", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DiffEqBase", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "Rotations", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DiffEqBase", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "Rotations", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"]
path = ".."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.10.39"
Expand Down
1 change: 1 addition & 0 deletions src/ClimaCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ include("Topologies/Topologies.jl")
include("Spaces/Spaces.jl")
include("Fields/Fields.jl")
include("Operators/Operators.jl")
include("MatrixFields/MatrixFields.jl")
include("Hypsography/Hypsography.jl")
include("Limiters/Limiters.jl")
include("InputOutput/InputOutput.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/Fields/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Base.map(fn, field::Field) = Base.broadcast(fn, field)
Base.map(fn, fields::Field...) = Base.broadcast(fn, fields...)

"""
Fields.local_sum(v::Field)
Expand Down
23 changes: 21 additions & 2 deletions src/Geometry/axistensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,30 @@ Base.propertynames(x::AxisVector) = symbols(axes(x, 1))
end
end

const AdjointAxisTensor{T, N, A, S} = Adjoint{T, AxisTensor{T, N, A, S}}

Base.show(io::IO, a::AdjointAxisTensor{T, N, A, S}) where {T, N, A, S} =
print(io, "adjoint($(a'))")

components(a::AdjointAxisTensor) = components(parent(a))'

Base.zero(a::AdjointAxisTensor) = zero(typeof(a))
Base.zero(::Type{AdjointAxisTensor{T, N, A, S}}) where {T, N, A, S} =
zero(AxisTensor{T, N, A, S})'

@inline +(a::AdjointAxisTensor) = (+a')'
@inline -(a::AdjointAxisTensor) = (-a')'
@inline +(a::AdjointAxisTensor, b::AdjointAxisTensor) = (a' + b')'
@inline -(a::AdjointAxisTensor, b::AdjointAxisTensor) = (a' - b')'
@inline *(a::Number, b::AdjointAxisTensor) = (a * b')'
@inline *(a::AdjointAxisTensor, b::Number) = (a' * b)'
@inline /(a::AdjointAxisTensor, b::Number) = (a' / b)'
@inline \(a::Number, b::AdjointAxisTensor) = (a \ b')'

@inline (==)(a::AdjointAxisTensor, b::AdjointAxisTensor) = a' == b'

const AdjointAxisVector{T, A1, S} = Adjoint{T, AxisVector{T, A1, S}}

components(va::AdjointAxisVector) = components(parent(va))'
Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int) =
getindex(components(va), i)
Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int, j::Int) =
Expand All @@ -286,7 +306,6 @@ Axis2Tensor(
) = AxisTensor(axes, components)

const AdjointAxis2Tensor{T, A, S} = Adjoint{T, Axis2Tensor{T, A, S}}
components(va::AdjointAxis2Tensor) = components(parent(va))'

const Axis2TensorOrAdj{T, A, S} =
Union{Axis2Tensor{T, A, S}, AdjointAxis2Tensor{T, A, S}}
Expand Down
26 changes: 26 additions & 0 deletions src/MatrixFields/MatrixFields.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module MatrixFields

import LinearAlgebra: UniformScaling, Adjoint
import StaticArrays: SArray, SMatrix, SVector
import BandedMatrices: BandedMatrix, band, _BandedMatrix
import ..Utilities: PlusHalf, half
import ..RecursiveApply:
rmap, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv
import ..Geometry
import ..Spaces
import ..Fields
import ..Operators

export
export DiagonalMatrixRow,
BidiagonalMatrixRow,
TridiagonalMatrixRow,
QuaddiagonalMatrixRow,
PentadiagonalMatrixRow

include("band_matrix_row.jl")
include("rmul_with_projection.jl")
include("matrix_multiplication.jl")
include("matrix_field_utils.jl")

end
138 changes: 138 additions & 0 deletions src/MatrixFields/band_matrix_row.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
BandMatrixRow{ld}(entries...)
Stores the nonzero entries in a row of a band matrix, starting with the lowest
diagonal, which has index `ld`. Supported operations include accessing the entry
on the diagonal with index `d` by calling `row[d]`, taking linear combinations
with other band matrix rows (and with `LinearAlgebra.I`), and checking for
equality with other band matrix rows (and with `LinearAlgebra.I`). There are
several aliases defined for commonly-used subtypes of `BandMatrixRow` (with `T`
denoting the type of the row's entries):
- `DiagonalMatrixRow{T}`
- `BidiagonalMatrixRow{T}`
- `TridiagonalMatrixRow{T}`
- `QuaddiagonalMatrixRow{T}`
- `PentadiagonalMatrixRow{T}`
"""
struct BandMatrixRow{ld, bw, T} # bw is the bandwidth (the number of diagonals)
entries::NTuple{bw, T}
BandMatrixRow{ld, bw, T}(entries::NTuple{bw, Any}) where {ld, bw, T} =
new{ld, bw, T}(rconvert(NTuple{bw, T}, entries))
# TODO: Remove this inner constructor once Julia's default convert function
# is type-stable for nested Tuple/NamedTuple types.
end
BandMatrixRow{ld}(entries::Vararg{Any, bw}) where {ld, bw} =
BandMatrixRow{ld, bw}(entries...)
BandMatrixRow{ld, bw}(entries::Vararg{Any, bw}) where {ld, bw} =
BandMatrixRow{ld, bw, rpromote_type(map(typeof, entries)...)}(entries)

const DiagonalMatrixRow{T} = BandMatrixRow{0, 1, T}
const BidiagonalMatrixRow{T} = BandMatrixRow{-1 + half, 2, T}
const TridiagonalMatrixRow{T} = BandMatrixRow{-1, 3, T}
const QuaddiagonalMatrixRow{T} = BandMatrixRow{-2 + half, 4, T}
const PentadiagonalMatrixRow{T} = BandMatrixRow{-2, 5, T}

"""
outer_diagonals(::Type{<:BandMatrixRow})
Gets the indices of the lower and upper diagonals, `ld` and `ud`, of the given
subtype of `BandMatrixRow`.
"""
outer_diagonals(::Type{<:BandMatrixRow{ld, bw}}) where {ld, bw} =
(ld, ld + bw - 1)

"""
band_matrix_row_type(ld, ud, T)
A shorthand for getting the subtype of `BandMatrixRow` that has entries of type
`T` on the diagonals with indices in the range `ld:ud`.
"""
band_matrix_row_type(ld, ud, T) = BandMatrixRow{ld, ud - ld + 1, T}

Base.eltype(::Type{BandMatrixRow{ld, bw, T}}) where {ld, bw, T} = T

Base.zero(::Type{BandMatrixRow{ld, bw, T}}) where {ld, bw, T} =
BandMatrixRow{ld}(ntuple(_ -> rzero(T), Val(bw))...)

Base.map(f::F, rows::BandMatrixRow{ld}...) where {F, ld} =
BandMatrixRow{ld}(map(f, map(row -> row.entries, rows)...)...)

Base.@propagate_inbounds Base.getindex(row::BandMatrixRow{ld}, d) where {ld} =
row.entries[d - ld + 1]

function Base.promote_rule(
::Type{BMR1},
::Type{BMR2},
) where {BMR1 <: BandMatrixRow, BMR2 <: BandMatrixRow}
ld1, ud1 = outer_diagonals(BMR1)
ld2, ud2 = outer_diagonals(BMR2)
typeof(ld1) == typeof(ld2) || error(
"Cannot promote the $(ld1 isa PlusHalf ? "non-" : "")square matrix \
row type $BMR1 and the $(ld2 isa PlusHalf ? "non-" : "")square matrix \
row type $BMR2 to a common type",
)
T = rpromote_type(eltype(BMR1), eltype(BMR2))
return band_matrix_row_type(min(ld1, ld2), max(ud1, ud2), T)
end

Base.promote_rule(
::Type{BMR},
::Type{US},
) where {BMR <: BandMatrixRow, US <: UniformScaling} =
promote_rule(BMR, DiagonalMatrixRow{eltype(US)})

function Base.convert(
::Type{BMR},
row::BandMatrixRow,
) where {BMR <: BandMatrixRow}
old_ld, old_ud = outer_diagonals(typeof(row))
new_ld, new_ud = outer_diagonals(BMR)
typeof(old_ld) == typeof(new_ld) || error(
"Cannot convert a $(old_ld isa PlusHalf ? "non-" : "")square matrix \
row of type $(typeof(row)) to the \
$(new_ld isa PlusHalf ? "non-" : "")square matrix row type $BMR",
)
new_ld <= old_ld && new_ud >= old_ud || error(
"Cannot convert a $(typeof(row)) to a $BMR, since that would require \
dropping potentially non-zero row entries",
)
first_zeros = ntuple(_ -> rzero(eltype(BMR)), Val(old_ld - new_ld))
last_zeros = ntuple(_ -> rzero(eltype(BMR)), Val(new_ud - old_ud))
return BMR((first_zeros..., row.entries..., last_zeros...))
end

Base.convert(::Type{BMR}, row::UniformScaling) where {BMR <: BandMatrixRow} =
convert(BMR, DiagonalMatrixRow(row.λ))

Base.:(==)(row1::BMR, row2::BMR) where {BMR <: BandMatrixRow} =
row1.entries == row2.entries
Base.:(==)(row1::BandMatrixRow, row2::BandMatrixRow) =
==(promote(row1, row2)...)
Base.:(==)(row1::BandMatrixRow, row2::UniformScaling) =
==(promote(row1, row2)...)
Base.:(==)(row1::UniformScaling, row2::BandMatrixRow) =
==(promote(row1, row2)...)

Base.:+(row::BandMatrixRow) = map(radd, row)
Base.:+(row1::BandMatrixRow, row2::BandMatrixRow) =
map(radd, promote(row1, row2)...)
Base.:+(row1::BandMatrixRow, row2::UniformScaling) =
map(radd, promote(row1, row2)...)
Base.:+(row1::UniformScaling, row2::BandMatrixRow) =
map(radd, promote(row1, row2)...)

Base.:-(row::BandMatrixRow) = map(rsub, row)
Base.:-(row1::BandMatrixRow, row2::BandMatrixRow) =
map(rsub, promote(row1, row2)...)
Base.:-(row1::BandMatrixRow, row2::UniformScaling) =
map(rsub, promote(row1, row2)...)
Base.:-(row1::UniformScaling, row2::BandMatrixRow) =
map(rsub, promote(row1, row2)...)

Base.:*(row::BandMatrixRow, value::Number) =
map(entry -> rmul(entry, value), row)
Base.:*(value::Number, row::BandMatrixRow) =
map(entry -> rmul(value, entry), row)

Base.:/(row::BandMatrixRow, value::Number) =
map(entry -> rdiv(entry, value), row)
133 changes: 133 additions & 0 deletions src/MatrixFields/matrix_field_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
function banded_matrix_info(field)
space = axes(field)
field_ld, field_ud = outer_diagonals(eltype(field))

# Find the diagonal index of the value that ends up in the bottom-right
# corner of the matrix, as well as the amount by which the field's diagonal
# indices get shifted when it is converted into a matrix.
bottom_corner_matrix_d, matrix_d_minus_field_d = if field_ld isa PlusHalf
if space.staggering isa Spaces.CellCenter
1, half # field is a face-to-center matrix
else
-1, -half # field is a center-to-face matrix
end
else
0, 0 # field is either a center-to-center or face-to-face matrix
end

n_rows = Spaces.nlevels(space)
n_cols = n_rows + bottom_corner_matrix_d
matrix_ld = field_ld + matrix_d_minus_field_d
matrix_ud = field_ud + matrix_d_minus_field_d
matrix_ld <= 0 && matrix_ud >= 0 ||
error("BandedMatrices.jl does not yet support matrices that have \
diagonals with indices in the range $matrix_ld:$matrix_ud")

return n_rows, n_cols, matrix_ld, matrix_ud
end

"""
column_field2array(field)
Converts a field defined on a `FiniteDifferenceSpace` into a `Vector` or a
`BandedMatrix`, depending on whether or not the elements of the field are
`BandMatrixRow`s. This involves copying the data stored in the field.
"""
function column_field2array(field)
space = axes(field)
space isa Spaces.FiniteDifferenceSpace ||
error("column_field2array requires a field on a FiniteDifferenceSpace")
if eltype(field) <: BandMatrixRow # field represents a matrix
n_rows, n_cols, matrix_ld, matrix_ud = banded_matrix_info(field)
matrix = BandedMatrix{eltype(eltype(field))}(
undef,
(n_rows, n_cols),
(-matrix_ld, matrix_ud),
)
for (index_of_field_entry, matrix_d) in enumerate(matrix_ld:matrix_ud)
# Find the rows for which field_diagonal[row] is inside the matrix.
# Note: The matrix index (1, 1) corresponds to the diagonal index 0,
# and the matrix index (n_rows, n_cols) corresponds to the diagonal
# index n_cols - n_rows.
first_row = matrix_d < 0 ? 1 - matrix_d : 1
last_row = matrix_d < n_cols - n_rows ? n_rows : n_cols - matrix_d

# Copy the value in each row from field_diagonal to matrix_diagonal.
field_diagonal = field.entries.:($index_of_field_entry)
matrix_diagonal = view(matrix, band(matrix_d))
for (index_along_diagonal, row) in enumerate(first_row:last_row)
matrix_diagonal[index_along_diagonal] =
Fields.field_values(field_diagonal)[row]
end
end
return matrix
else # field represents a vector
n_rows = Spaces.nlevels(space)
return map(row -> Fields.field_values(field)[row], 1:n_rows)
end
end

"""
field2arrays(field)
Converts a field defined on a `FiniteDifferenceSpace` or on an
`ExtrudedFiniteDifferenceSpace` into a tuple of arrays, each of which
corresponds to a column of the field. This is done by calling
`column_field2array` on each of the field's columns.
"""
function field2arrays(field)
space = axes(field)
column_indices = if space isa Spaces.FiniteDifferenceSpace
(((1, 1), 1),)
elseif space isa Spaces.ExtrudedFiniteDifferenceSpace
(Spaces.all_nodes(Spaces.horizontal_space(space))...,)
else
error("Invalid space type: $(typeof(space).name.wrapper)")
end
return map(column_indices) do ((i, j), h)
column_field2array(Spaces.column(field, i, j, h))
end
end

"""
column_field2array_view(field)
Similar to `column_field2array(field)`, except that this version avoids copying
the data stored in the field.
"""
function column_field2array_view(field)
space = axes(field)
space isa Spaces.FiniteDifferenceSpace ||
error("column_field2array_view requires a field on a \
FiniteDifferenceSpace")
if eltype(field) <: BandMatrixRow # field represents a matrix
n_rows, n_cols, matrix_ld, matrix_ud = banded_matrix_info(field)
data_transpose = reinterpret(eltype(eltype(field)), parent(field)')
matrix_transpose =
_BandedMatrix(data_transpose, n_cols, matrix_ud, -matrix_ld)
return permutedims(matrix_transpose)
# TODO: Despite not copying any data, this function still allocates a
# small amount of memory because of _BandedMatrix and permutedims.
else # field represents a vector
return vec(reinterpret(eltype(field), parent(field)'))
end
end

function Base.show(
io::IO,
field::Fields.Field{<:Fields.AbstractData{<:BandMatrixRow}},
)
print(io, eltype(field), "-valued Field")
if eltype(eltype(field)) <: Number
if axes(field) isa Spaces.FiniteDifferenceSpace
println(io, " that corresponds to the matrix")
else
println(io, " whose first column corresponds to the matrix")
end
column_field = Fields.column(field, 1, 1, 1)
Base.print_array(io, column_field2array_view(column_field))
else # A BandedMatrix with non-number entries currently errors when printed.
print(io, ":")
Fields._show_compact_field(io, field, " ", true)
end
end
Loading

0 comments on commit 2973c23

Please sign in to comment.