-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add new MatrixFields module, along with unit tests and performance tests
- Loading branch information
1 parent
8e77c96
commit f15ba20
Showing
14 changed files
with
1,739 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
module MatrixFields | ||
|
||
import LinearAlgebra: UniformScaling, Adjoint | ||
import StaticArrays: SArray, SMatrix, SVector | ||
import BandedMatrices: BandedMatrix, band | ||
import ..Utilities: PlusHalf, half | ||
import ..RecursiveApply: rmap, rmaptype, rzero, radd, rsub, rmul, rdiv | ||
import ..Geometry | ||
import ..Spaces | ||
import ..Fields | ||
import ..Operators: | ||
FiniteDifferenceOperator, | ||
BoundaryCondition, | ||
LeftBoundaryWindow, | ||
RightBoundaryWindow, | ||
has_boundary, | ||
get_boundary, | ||
stencil_interior_width, | ||
boundary_width, | ||
return_eltype, | ||
return_space, | ||
stencil_interior, | ||
stencil_left_boundary, | ||
stencil_right_boundary, | ||
left_idx, | ||
right_idx, | ||
getidx, | ||
getidx_args | ||
|
||
include("band_matrix_row.jl") | ||
include("rmul_with_projection.jl") | ||
include("matrix_multiplication.jl") | ||
include("matrix_field_utils.jl") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
""" | ||
BandMatrixRow{ld}(entries...) | ||
Stores the nonzero entries in a row of a band matrix, starting with the lowest | ||
diagonal, which has index `ld`. Supported operations include accessing the entry | ||
on the diagonal with index `d` by calling `row[d]`, taking linear combinations | ||
with other band matrix rows, and checking for equality with other band matrix | ||
rows. There are several aliases defined for commonly-used subtypes of | ||
`BandMatrixRow` (with `T` denoting the type of the row's entries): | ||
- `DiagonalMatrixRow{T}` | ||
- `BidiagonalMatrixRow{T}` | ||
- `TridiagonalMatrixRow{T}` | ||
- `QuaddiagonalMatrixRow{T}` | ||
- `PentadiagonalMatrixRow{T}` | ||
""" | ||
struct BandMatrixRow{ld, bw, T} # bw is the bandwidth (the number of diagonals) | ||
entries::NTuple{bw, T} | ||
end | ||
BandMatrixRow{ld}(entries::Vararg{Any, bw}) where {ld, bw} = | ||
BandMatrixRow{ld, bw}(entries...) | ||
function BandMatrixRow{ld, bw}(entries::Vararg{Any, bw}) where {ld, bw} | ||
promoted_entries = promote(entries...) | ||
return BandMatrixRow{ld, bw, eltype(promoted_entries)}(promoted_entries) | ||
end | ||
|
||
const DiagonalMatrixRow{T} = BandMatrixRow{0, 1, T} | ||
const BidiagonalMatrixRow{T} = BandMatrixRow{-1 + half, 2, T} | ||
const TridiagonalMatrixRow{T} = BandMatrixRow{-1, 3, T} | ||
const QuaddiagonalMatrixRow{T} = BandMatrixRow{-2 + half, 4, T} | ||
const PentadiagonalMatrixRow{T} = BandMatrixRow{-2, 5, T} | ||
|
||
function Base.show( | ||
io::IO, | ||
::Type{BMR}, | ||
) where {ld, bw, T, BMR <: BandMatrixRow{ld, bw, T}} | ||
string = if BMR <: DiagonalMatrixRow | ||
"DiagonalMatrixRow{$T}" | ||
elseif BMR <: BidiagonalMatrixRow | ||
"BidiagonalMatrixRow{$T}" | ||
elseif BMR <: TridiagonalMatrixRow | ||
"TridiagonalMatrixRow{$T}" | ||
elseif BMR <: QuaddiagonalMatrixRow | ||
"QuaddiagonalMatrixRow{$T}" | ||
elseif BMR <: PentadiagonalMatrixRow | ||
"PentadiagonalMatrixRow{$T}" | ||
else | ||
"BandMatrixRow{$ld, $bw, $T}" | ||
end | ||
print(io, string) | ||
end | ||
|
||
""" | ||
outer_diagonals(::Type{<:BandMatrixRow}) | ||
Gets the indices of the lower and upper diagonals, `ld` and `ud`, of the given | ||
subtype of `BandMatrixRow`. | ||
""" | ||
outer_diagonals(::Type{<:BandMatrixRow{ld, bw}}) where {ld, bw} = | ||
(ld, ld + bw - 1) | ||
|
||
""" | ||
band_matrix_row_type(ld, ud, T) | ||
A shorthand for getting the subtype of `BandMatrixRow` that has entries of type | ||
`T` on the diagonals with indices in the range `ld:ud`. | ||
""" | ||
band_matrix_row_type(ld, ud, T) = BandMatrixRow{ld, ud - ld + 1, T} | ||
|
||
Base.eltype(::Type{BandMatrixRow{ld, bw, T}}) where {ld, bw, T} = T | ||
|
||
Base.zero(::Type{BandMatrixRow{ld, bw, T}}) where {ld, bw, T} = | ||
BandMatrixRow{ld}(ntuple(_ -> rzero(T), Val(bw))...) | ||
|
||
Base.@propagate_inbounds Base.getindex(row::BandMatrixRow{ld}, d) where {ld} = | ||
row.entries[d - ld + 1] | ||
|
||
function Base.promote_rule( | ||
::Type{BMR1}, | ||
::Type{BMR2}, | ||
) where {BMR1 <: BandMatrixRow, BMR2 <: BandMatrixRow} | ||
ld1, ud1 = outer_diagonals(BMR1) | ||
ld2, ud2 = outer_diagonals(BMR2) | ||
typeof(ld1) == typeof(ld2) || error( | ||
"Cannot promote the $(ld1 isa PlusHalf ? "non-" : "")square matrix \ | ||
row type $BMR1 and the $(ld2 isa PlusHalf ? "non-" : "")square \ | ||
matrix row type $BMR2 to a common type", | ||
) | ||
T = promote_type(eltype(BMR1), eltype(BMR2)) | ||
return band_matrix_row_type(min(ld1, ld2), max(ud1, ud2), T) | ||
end | ||
|
||
Base.promote_rule( | ||
::Type{BMR}, | ||
::Type{US}, | ||
) where {BMR <: BandMatrixRow, US <: UniformScaling} = | ||
promote_rule(BMR, DiagonalMatrixRow{eltype(US)}) | ||
|
||
function Base.convert( | ||
::Type{BMR}, | ||
row::BandMatrixRow, | ||
) where {BMR <: BandMatrixRow} | ||
old_ld, old_ud = outer_diagonals(typeof(row)) | ||
new_ld, new_ud = outer_diagonals(BMR) | ||
typeof(old_ld) == typeof(new_ld) || | ||
error("Cannot convert a $(old_ld isa PlusHalf ? "non-" : "")square \ | ||
matrix row of type $(typeof(row)) to the \ | ||
$(new_ld isa PlusHalf ? "non-" : "")square matrix row type $BMR") | ||
new_ld <= old_ld && new_ud >= old_ud || | ||
error("Cannot convert a $(typeof(row)) to a $BMR, since that would \ | ||
require dropping potentially non-zero row entries") | ||
first_zeros = ntuple(_ -> rzero(eltype(BMR)), Val(old_ld - new_ld)) | ||
entries = map(entry -> convert(eltype(BMR), entry), row.entries) | ||
last_zeros = ntuple(_ -> rzero(eltype(BMR)), Val(new_ud - old_ud)) | ||
return BandMatrixRow{new_ld}(first_zeros..., entries..., last_zeros...) | ||
end | ||
|
||
Base.convert(::Type{BMR}, row::UniformScaling) where {BMR <: BandMatrixRow} = | ||
convert(BMR, DiagonalMatrixRow(row.λ)) | ||
|
||
Base.:(==)(row1::BMR, row2::BMR) where {BMR <: BandMatrixRow} = | ||
row1.entries == row2.entries | ||
Base.:(==)(row1::BandMatrixRow, row2::BandMatrixRow) = | ||
==(promote(row1, row2)...) | ||
Base.:(==)(row1::BandMatrixRow, row2::UniformScaling) = | ||
==(promote(row1, row2)...) | ||
Base.:(==)(row1::UniformScaling, row2::BandMatrixRow) = | ||
==(promote(row1, row2)...) | ||
|
||
Base.map(f::F, rows::BandMatrixRow{ld}...) where {F, ld} = | ||
BandMatrixRow{ld}(map(f, map(row -> row.entries, rows)...)...) | ||
|
||
# Define all necessary operations for computing linear combinations. Use | ||
# methods from RecursiveApply to handle nested types. | ||
Base.:+(row::BandMatrixRow) = map(radd, row) | ||
Base.:+(row1::BandMatrixRow, row2::BandMatrixRow) = | ||
map(radd, promote(row1, row2)...) | ||
Base.:+(row1::BandMatrixRow, row2::UniformScaling) = | ||
map(radd, promote(row1, row2)...) | ||
Base.:+(row1::UniformScaling, row2::BandMatrixRow) = | ||
map(radd, promote(row1, row2)...) | ||
Base.:-(row::BandMatrixRow) = map(rsub, row) | ||
Base.:-(row1::BandMatrixRow, row2::BandMatrixRow) = | ||
map(rsub, promote(row1, row2)...) | ||
Base.:-(row1::BandMatrixRow, row2::UniformScaling) = | ||
map(rsub, promote(row1, row2)...) | ||
Base.:-(row1::UniformScaling, row2::BandMatrixRow) = | ||
map(rsub, promote(row1, row2)...) | ||
Base.:*(row::BandMatrixRow, value::Number) = | ||
map(entry -> rmul(entry, value), row) | ||
Base.:*(value::Number, row::BandMatrixRow) = | ||
map(entry -> rmul(value, entry), row) | ||
Base.:/(row::BandMatrixRow, value::Number) = | ||
map(entry -> rdiv(entry, value), row) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
""" | ||
column_field2array(field) | ||
Converts a field defined on a `FiniteDifferenceSpace` into a `Vector` or a | ||
`BandedMatrix`, depending on whether or not the elements of the field are | ||
`BandMatrixRow`s. This involves copying the data stored in the field. | ||
""" | ||
function column_field2array(field) | ||
space = axes(field) | ||
space isa Spaces.FiniteDifferenceSpace || | ||
error("column_field2array requires a field on a FiniteDifferenceSpace") | ||
n_rows = Spaces.nlevels(space) | ||
if eltype(field) <: BandMatrixRow # field represents a matrix | ||
field_ld, field_ud = outer_diagonals(eltype(field)) | ||
|
||
# Find the amount by which the field's diagonal indices get shifted when | ||
# it is converted into a matrix, as well as the diagonal index of the | ||
# value that ends up in the bottom-right corner of the matrix. | ||
matrix_d_minus_field_d, bottom_corner_matrix_d = | ||
if field_ld isa PlusHalf | ||
if axes(field).staggering isa Spaces.CellCenter | ||
half, 1 # field is a face-to-center matrix | ||
else | ||
-half, -1 # field is a center-to-face matrix | ||
end | ||
else | ||
0, 0 # field is either a center-to-center or face-to-face matrix | ||
end | ||
|
||
matrix_ld = field_ld + matrix_d_minus_field_d | ||
matrix_ud = field_ud + matrix_d_minus_field_d | ||
matrix_ld <= 0 && matrix_ud >= 0 || | ||
error("BandedMatrices.jl does not yet support matrices that have \ | ||
diagonals with indices in the range $matrix_ld:$matrix_ud") | ||
|
||
n_cols = n_rows + bottom_corner_matrix_d | ||
matrix = BandedMatrix{eltype(eltype(field))}( | ||
undef, | ||
(n_rows, n_cols), | ||
(-matrix_ld, matrix_ud), | ||
) | ||
for (index_of_field_entry, matrix_d) in enumerate(matrix_ld:matrix_ud) | ||
# Find the rows for which field_diagonal[row] is inside the matrix. | ||
# Note: The matrix index (1, 1) corresponds to the diagonal index 0, | ||
# and the matrix index (n_rows, n_cols) corresponds to the diagonal | ||
# index bottom_corner_matrix_d. | ||
first_row = matrix_d < 0 ? 1 - matrix_d : 1 | ||
last_row = | ||
matrix_d < bottom_corner_matrix_d ? n_rows : n_cols - matrix_d | ||
|
||
# Copy the value in each row from field_diagonal to matrix_diagonal. | ||
field_diagonal = field.entries.:($index_of_field_entry) | ||
matrix_diagonal = view(matrix, band(matrix_d)) | ||
for (index_along_diagonal, row) in enumerate(first_row:last_row) | ||
matrix_diagonal[index_along_diagonal] = | ||
Fields.field_values(field_diagonal)[row] | ||
end | ||
end | ||
return matrix | ||
else # field represents a vector | ||
return map(i -> Fields.field_values(field)[i], 1:n_rows) | ||
end | ||
end | ||
|
||
""" | ||
field2arrays(field) | ||
Converts a field defined on a `FiniteDifferenceSpace` or on an | ||
`ExtrudedFiniteDifferenceSpace` into a tuple of arrays, each of which | ||
corresponds to a column of the field. This is done by calling | ||
`column_field2array` on each of the field's columns. | ||
""" | ||
function field2arrays(field) | ||
space = axes(field) | ||
column_indices = if space isa Spaces.FiniteDifferenceSpace | ||
(((1, 1), 1),) | ||
elseif space isa Spaces.ExtrudedFiniteDifferenceSpace | ||
(Spaces.all_nodes(Spaces.horizontal_space(space))...,) | ||
else | ||
error("Invalid space type: $(typeof(space).name.wrapper)") | ||
end | ||
return map(column_indices) do ((i, j), h) | ||
return column_field2array(Spaces.column(field, i, j, h)) | ||
end | ||
end |
Oops, something went wrong.