-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add new MatrixFields module, along with unit tests and performance tests
- Loading branch information
1 parent
eb6f155
commit f68577d
Showing
18 changed files
with
1,861 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
module MatrixFields | ||
|
||
import LinearAlgebra: UniformScaling, Adjoint | ||
import StaticArrays: SArray, SMatrix, SVector | ||
import BandedMatrices: BandedMatrix, band, _BandedMatrix | ||
import ..Utilities: PlusHalf, half | ||
import ..RecursiveApply: | ||
rmap, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv | ||
import ..Geometry | ||
import ..Spaces | ||
import ..Fields | ||
import ..Operators | ||
|
||
export ⋅ | ||
export DiagonalMatrixRow, | ||
BidiagonalMatrixRow, | ||
TridiagonalMatrixRow, | ||
QuaddiagonalMatrixRow, | ||
PentadiagonalMatrixRow | ||
|
||
include("band_matrix_row.jl") | ||
include("rmul_with_projection.jl") | ||
include("matrix_multiplication.jl") | ||
include("matrix_field_utils.jl") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
""" | ||
BandMatrixRow{ld}(entries...) | ||
Stores the nonzero entries in a row of a band matrix, starting with the lowest | ||
diagonal, which has index `ld`. Supported operations include accessing the entry | ||
on the diagonal with index `d` by calling `row[d]`, taking linear combinations | ||
with other band matrix rows (and with `LinearAlgebra.I`), and checking for | ||
equality with other band matrix rows (and with `LinearAlgebra.I`). There are | ||
several aliases defined for commonly-used subtypes of `BandMatrixRow` (with `T` | ||
denoting the type of the row's entries): | ||
- `DiagonalMatrixRow{T}` | ||
- `BidiagonalMatrixRow{T}` | ||
- `TridiagonalMatrixRow{T}` | ||
- `QuaddiagonalMatrixRow{T}` | ||
- `PentadiagonalMatrixRow{T}` | ||
""" | ||
struct BandMatrixRow{ld, bw, T} # bw is the bandwidth (the number of diagonals) | ||
entries::NTuple{bw, T} | ||
BandMatrixRow{ld, bw, T}(entries::NTuple{bw, Any}) where {ld, bw, T} = | ||
new{ld, bw, T}(rconvert(NTuple{bw, T}, entries)) | ||
# TODO: Remove this inner constructor once Julia's default convert function | ||
# is type-stable for nested Tuple/NamedTuple types. | ||
end | ||
BandMatrixRow{ld}(entries::Vararg{Any, bw}) where {ld, bw} = | ||
BandMatrixRow{ld, bw}(entries...) | ||
BandMatrixRow{ld, bw}(entries::Vararg{Any, bw}) where {ld, bw} = | ||
BandMatrixRow{ld, bw, rpromote_type(map(typeof, entries)...)}(entries) | ||
|
||
const DiagonalMatrixRow{T} = BandMatrixRow{0, 1, T} | ||
const BidiagonalMatrixRow{T} = BandMatrixRow{-1 + half, 2, T} | ||
const TridiagonalMatrixRow{T} = BandMatrixRow{-1, 3, T} | ||
const QuaddiagonalMatrixRow{T} = BandMatrixRow{-2 + half, 4, T} | ||
const PentadiagonalMatrixRow{T} = BandMatrixRow{-2, 5, T} | ||
|
||
""" | ||
outer_diagonals(::Type{<:BandMatrixRow}) | ||
Gets the indices of the lower and upper diagonals, `ld` and `ud`, of the given | ||
subtype of `BandMatrixRow`. | ||
""" | ||
outer_diagonals(::Type{<:BandMatrixRow{ld, bw}}) where {ld, bw} = | ||
(ld, ld + bw - 1) | ||
|
||
""" | ||
band_matrix_row_type(ld, ud, T) | ||
A shorthand for getting the subtype of `BandMatrixRow` that has entries of type | ||
`T` on the diagonals with indices in the range `ld:ud`. | ||
""" | ||
band_matrix_row_type(ld, ud, T) = BandMatrixRow{ld, ud - ld + 1, T} | ||
|
||
Base.eltype(::Type{BandMatrixRow{ld, bw, T}}) where {ld, bw, T} = T | ||
|
||
Base.zero(::Type{BandMatrixRow{ld, bw, T}}) where {ld, bw, T} = | ||
BandMatrixRow{ld}(ntuple(_ -> rzero(T), Val(bw))...) | ||
|
||
Base.map(f::F, rows::BandMatrixRow{ld}...) where {F, ld} = | ||
BandMatrixRow{ld}(map(f, map(row -> row.entries, rows)...)...) | ||
|
||
Base.@propagate_inbounds Base.getindex(row::BandMatrixRow{ld}, d) where {ld} = | ||
row.entries[d - ld + 1] | ||
|
||
function Base.promote_rule( | ||
::Type{BMR1}, | ||
::Type{BMR2}, | ||
) where {BMR1 <: BandMatrixRow, BMR2 <: BandMatrixRow} | ||
ld1, ud1 = outer_diagonals(BMR1) | ||
ld2, ud2 = outer_diagonals(BMR2) | ||
typeof(ld1) == typeof(ld2) || error( | ||
"Cannot promote the $(ld1 isa PlusHalf ? "non-" : "")square matrix \ | ||
row type $BMR1 and the $(ld2 isa PlusHalf ? "non-" : "")square matrix \ | ||
row type $BMR2 to a common type", | ||
) | ||
T = rpromote_type(eltype(BMR1), eltype(BMR2)) | ||
return band_matrix_row_type(min(ld1, ld2), max(ud1, ud2), T) | ||
end | ||
|
||
Base.promote_rule( | ||
::Type{BMR}, | ||
::Type{US}, | ||
) where {BMR <: BandMatrixRow, US <: UniformScaling} = | ||
promote_rule(BMR, DiagonalMatrixRow{eltype(US)}) | ||
|
||
function Base.convert( | ||
::Type{BMR}, | ||
row::BandMatrixRow, | ||
) where {BMR <: BandMatrixRow} | ||
old_ld, old_ud = outer_diagonals(typeof(row)) | ||
new_ld, new_ud = outer_diagonals(BMR) | ||
typeof(old_ld) == typeof(new_ld) || error( | ||
"Cannot convert a $(old_ld isa PlusHalf ? "non-" : "")square matrix \ | ||
row of type $(typeof(row)) to the \ | ||
$(new_ld isa PlusHalf ? "non-" : "")square matrix row type $BMR", | ||
) | ||
new_ld <= old_ld && new_ud >= old_ud || error( | ||
"Cannot convert a $(typeof(row)) to a $BMR, since that would require \ | ||
dropping potentially non-zero row entries", | ||
) | ||
first_zeros = ntuple(_ -> rzero(eltype(BMR)), Val(old_ld - new_ld)) | ||
last_zeros = ntuple(_ -> rzero(eltype(BMR)), Val(new_ud - old_ud)) | ||
return BMR((first_zeros..., row.entries..., last_zeros...)) | ||
end | ||
|
||
Base.convert(::Type{BMR}, row::UniformScaling) where {BMR <: BandMatrixRow} = | ||
convert(BMR, DiagonalMatrixRow(row.λ)) | ||
|
||
Base.:(==)(row1::BMR, row2::BMR) where {BMR <: BandMatrixRow} = | ||
row1.entries == row2.entries | ||
Base.:(==)(row1::BandMatrixRow, row2::BandMatrixRow) = | ||
==(promote(row1, row2)...) | ||
Base.:(==)(row1::BandMatrixRow, row2::UniformScaling) = | ||
==(promote(row1, row2)...) | ||
Base.:(==)(row1::UniformScaling, row2::BandMatrixRow) = | ||
==(promote(row1, row2)...) | ||
|
||
Base.:+(row::BandMatrixRow) = map(radd, row) | ||
Base.:+(row1::BandMatrixRow, row2::BandMatrixRow) = | ||
map(radd, promote(row1, row2)...) | ||
Base.:+(row1::BandMatrixRow, row2::UniformScaling) = | ||
map(radd, promote(row1, row2)...) | ||
Base.:+(row1::UniformScaling, row2::BandMatrixRow) = | ||
map(radd, promote(row1, row2)...) | ||
|
||
Base.:-(row::BandMatrixRow) = map(rsub, row) | ||
Base.:-(row1::BandMatrixRow, row2::BandMatrixRow) = | ||
map(rsub, promote(row1, row2)...) | ||
Base.:-(row1::BandMatrixRow, row2::UniformScaling) = | ||
map(rsub, promote(row1, row2)...) | ||
Base.:-(row1::UniformScaling, row2::BandMatrixRow) = | ||
map(rsub, promote(row1, row2)...) | ||
|
||
Base.:*(row::BandMatrixRow, value::Number) = | ||
map(entry -> rmul(entry, value), row) | ||
Base.:*(value::Number, row::BandMatrixRow) = | ||
map(entry -> rmul(value, entry), row) | ||
|
||
Base.:/(row::BandMatrixRow, value::Number) = | ||
map(entry -> rdiv(entry, value), row) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
function banded_matrix_info(field) | ||
space = axes(field) | ||
field_ld, field_ud = outer_diagonals(eltype(field)) | ||
|
||
# Find the diagonal index of the value that ends up in the bottom-right | ||
# corner of the matrix, as well as the amount by which the field's diagonal | ||
# indices get shifted when it is converted into a matrix. | ||
bottom_corner_matrix_d, matrix_d_minus_field_d = if field_ld isa PlusHalf | ||
if space.staggering isa Spaces.CellCenter | ||
1, half # field is a face-to-center matrix | ||
else | ||
-1, -half # field is a center-to-face matrix | ||
end | ||
else | ||
0, 0 # field is either a center-to-center or face-to-face matrix | ||
end | ||
|
||
n_rows = Spaces.nlevels(space) | ||
n_cols = n_rows + bottom_corner_matrix_d | ||
matrix_ld = field_ld + matrix_d_minus_field_d | ||
matrix_ud = field_ud + matrix_d_minus_field_d | ||
matrix_ld <= 0 && matrix_ud >= 0 || | ||
error("BandedMatrices.jl does not yet support matrices that have \ | ||
diagonals with indices in the range $matrix_ld:$matrix_ud") | ||
|
||
return n_rows, n_cols, matrix_ld, matrix_ud | ||
end | ||
|
||
""" | ||
column_field2array(field) | ||
Converts a field defined on a `FiniteDifferenceSpace` into a `Vector` or a | ||
`BandedMatrix`, depending on whether or not the elements of the field are | ||
`BandMatrixRow`s. This involves copying the data stored in the field. | ||
""" | ||
function column_field2array(field) | ||
space = axes(field) | ||
space isa Spaces.FiniteDifferenceSpace || | ||
error("column_field2array requires a field on a FiniteDifferenceSpace") | ||
if eltype(field) <: BandMatrixRow # field represents a matrix | ||
n_rows, n_cols, matrix_ld, matrix_ud = banded_matrix_info(field) | ||
matrix = BandedMatrix{eltype(eltype(field))}( | ||
undef, | ||
(n_rows, n_cols), | ||
(-matrix_ld, matrix_ud), | ||
) | ||
for (index_of_field_entry, matrix_d) in enumerate(matrix_ld:matrix_ud) | ||
# Find the rows for which field_diagonal[row] is inside the matrix. | ||
# Note: The matrix index (1, 1) corresponds to the diagonal index 0, | ||
# and the matrix index (n_rows, n_cols) corresponds to the diagonal | ||
# index n_cols - n_rows. | ||
first_row = matrix_d < 0 ? 1 - matrix_d : 1 | ||
last_row = matrix_d < n_cols - n_rows ? n_rows : n_cols - matrix_d | ||
|
||
# Copy the value in each row from field_diagonal to matrix_diagonal. | ||
field_diagonal = field.entries.:($index_of_field_entry) | ||
matrix_diagonal = view(matrix, band(matrix_d)) | ||
for (index_along_diagonal, row) in enumerate(first_row:last_row) | ||
matrix_diagonal[index_along_diagonal] = | ||
Fields.field_values(field_diagonal)[row] | ||
end | ||
end | ||
return matrix | ||
else # field represents a vector | ||
n_rows = Spaces.nlevels(space) | ||
return map(row -> Fields.field_values(field)[row], 1:n_rows) | ||
end | ||
end | ||
|
||
""" | ||
field2arrays(field) | ||
Converts a field defined on a `FiniteDifferenceSpace` or on an | ||
`ExtrudedFiniteDifferenceSpace` into a tuple of arrays, each of which | ||
corresponds to a column of the field. This is done by calling | ||
`column_field2array` on each of the field's columns. | ||
""" | ||
function field2arrays(field) | ||
space = axes(field) | ||
column_indices = if space isa Spaces.FiniteDifferenceSpace | ||
(((1, 1), 1),) | ||
elseif space isa Spaces.ExtrudedFiniteDifferenceSpace | ||
(Spaces.all_nodes(Spaces.horizontal_space(space))...,) | ||
else | ||
error("Invalid space type: $(typeof(space).name.wrapper)") | ||
end | ||
return map(column_indices) do ((i, j), h) | ||
column_field2array(Spaces.column(field, i, j, h)) | ||
end | ||
end | ||
|
||
""" | ||
column_field2array_view(field) | ||
Similar to `column_field2array(field)`, except that this version avoids copying | ||
the data stored in the field. | ||
""" | ||
function column_field2array_view(field) | ||
space = axes(field) | ||
space isa Spaces.FiniteDifferenceSpace || | ||
error("column_field2array_view requires a field on a \ | ||
FiniteDifferenceSpace") | ||
if eltype(field) <: BandMatrixRow # field represents a matrix | ||
n_rows, n_cols, matrix_ld, matrix_ud = banded_matrix_info(field) | ||
data_transpose = reinterpret(eltype(eltype(field)), parent(field)') | ||
matrix_transpose = | ||
_BandedMatrix(data_transpose, n_cols, matrix_ud, -matrix_ld) | ||
return permutedims(matrix_transpose) | ||
# TODO: Despite not copying any data, this function still allocates a | ||
# small amount of memory because of _BandedMatrix and permutedims. | ||
else # field represents a vector | ||
return vec(reinterpret(eltype(field), parent(field)')) | ||
end | ||
end | ||
|
||
function Base.show( | ||
io::IO, | ||
field::Fields.Field{<:Fields.AbstractData{<:BandMatrixRow}}, | ||
) | ||
print(io, eltype(field), "-valued Field") | ||
if eltype(eltype(field)) <: Number | ||
if axes(field) isa Spaces.FiniteDifferenceSpace | ||
println(io, " that corresponds to the matrix") | ||
else | ||
println(io, " whose first column corresponds to the matrix") | ||
end | ||
column_field = Fields.column(field, 1, 1, 1) | ||
Base.print_array(io, column_field2array_view(column_field)) | ||
else # A BandedMatrix with non-number entries currently errors when printed. | ||
print(io, ":") | ||
Fields._show_compact_field(io, field, " ", true) | ||
end | ||
end |
Oops, something went wrong.