Skip to content

Commit

Permalink
Try #1399:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Aug 9, 2023
2 parents bbe68ea + 790f990 commit fd63920
Show file tree
Hide file tree
Showing 22 changed files with 2,105 additions and 559 deletions.
15 changes: 15 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,10 @@ steps:
key: unit_field2arrays
command: "julia --color=yes --check-bounds=yes --project=test test/MatrixFields/field2arrays.jl"

- label: "Unit: matrix multiplication at boundaries"
key: unit_matrix_multiplication_at_boundaries
command: "julia --color=yes --check-bounds=yes --project=test test/MatrixFields/matrix_multiplication_at_boundaries.jl"

- label: "Unit: matrix field broadcasting (CPU)"
key: unit_matrix_field_broadcasting_cpu
command: "julia --color=yes --check-bounds=yes --project=test test/MatrixFields/matrix_field_broadcasting.jl"
Expand All @@ -524,6 +528,17 @@ steps:
slurm_gpus: 1
slurm_mem: 40GB

- label: "Unit: operator matrices (CPU)"
key: unit_operator_matrices_cpu
command: "julia --color=yes --check-bounds=yes --project=test test/MatrixFields/operator_matrices.jl"

- label: "Unit: operator matrices (GPU)"
key: unit_operator_matrices_gpu
command: "julia --color=yes --project=test test/MatrixFields/operator_matrices.jl"
agents:
slurm_gpus: 1
slurm_mem: 40GB

- group: "Unit: Hypsography"
steps:

Expand Down
Binary file added docs/.DS_Store
Binary file not shown.
9 changes: 9 additions & 0 deletions docs/src/matrix_fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ BandMatrixRow
MultiplyColumnwiseBandMatrixField
```

## Operator Matrices

```@docs
operator_matrix
```

## Internals

```@docs
Expand All @@ -30,6 +36,9 @@ rmul_with_projection
mul_return_type
rmul_return_type
matrix_shape
column_axes
AbstractLazyOperator
replace_lazy_operator
```

## Utilities
Expand Down
11 changes: 9 additions & 2 deletions src/MatrixFields/MatrixFields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,16 @@ export DiagonalMatrixRow,
QuaddiagonalMatrixRow,
PentadiagonalMatrixRow

# Types that are teated as single values when using matrix fields.
const SingleValue =
Union{Number, Geometry.AxisTensor, Geometry.AdjointAxisTensor}

include("band_matrix_row.jl")
include("rmul_with_projection.jl")
include("matrix_shape.jl")
include("matrix_multiplication.jl")
include("lazy_operators.jl")
include("operator_matrices.jl")
include("field2arrays.jl")

const ColumnwiseBandMatrixField{V, S} = Fields.Field{
Expand All @@ -68,10 +74,11 @@ const ColumnwiseBandMatrixField{V, S} = Fields.Field{
function Base.show(io::IO, field::ColumnwiseBandMatrixField)
print(io, eltype(field), "-valued Field")
if eltype(eltype(field)) <: Number
shape = typeof(matrix_shape(field)).name.name
if field isa Fields.FiniteDifferenceField
println(io, " that corresponds to the matrix")
println(io, " that corresponds to the $shape matrix")
else
println(io, " whose first column corresponds to the matrix")
println(io, " whose first column corresponds to the $shape matrix")
end
column_field = Fields.column(field, 1, 1, 1)
io = IOContext(io, :compact => true, :limit => true)
Expand Down
17 changes: 12 additions & 5 deletions src/MatrixFields/band_matrix_row.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ several aliases for commonly used subtypes of `BandMatrixRow`:
- `TridiagonalMatrixRow(entry_1, entry_2, entry_3)`
- `QuaddiagonalMatrixRow(entry_1, entry_2, entry_3, entry_4)`
- `PentadiagonalMatrixRow(entry_1, entry_2, entry_3, entry_4, entry_5)`
It is also possible to change the type of a band matrix row by padding it with
zeros, e.g., `QuaddiagonalMatrixRow(BidiagonalMatrixRow(entry_1, entry_2))`.
"""
struct BandMatrixRow{ld, bw, T} # bw is the bandwidth (the number of diagonals)
entries::NTuple{bw, T}
Expand All @@ -20,11 +22,16 @@ struct BandMatrixRow{ld, bw, T} # bw is the bandwidth (the number of diagonals)
# TODO: Remove this inner constructor once Julia's default convert function
# is type-stable for nested Tuple/NamedTuple types.
end
BandMatrixRow{ld}(entries::Vararg{Any, bw}) where {ld, bw} =
BandMatrixRow{ld, bw}(entries...)
BandMatrixRow{ld, bw}(entries::Vararg{Any, bw}) where {ld, bw} =
BandMatrixRow{ld}(entries...) where {ld} =
BandMatrixRow{ld, length(entries)}(entries...)
BandMatrixRow{ld, bw}(entries...) where {ld, bw} =
BandMatrixRow{ld, bw, rpromote_type(map(typeof, entries)...)}(entries)

BandMatrixRow{ld, bw}(row::BandMatrixRow) where {ld, bw} =
BandMatrixRow{ld, bw, eltype(row)}(row)
BandMatrixRow{ld, bw, T}(row::BandMatrixRow) where {ld, bw, T} =
convert(BandMatrixRow{ld, bw, T}, row)

const DiagonalMatrixRow{T} = BandMatrixRow{0, 1, T}
const BidiagonalMatrixRow{T} = BandMatrixRow{-1 + half, 2, T}
const TridiagonalMatrixRow{T} = BandMatrixRow{-1, 3, T}
Expand Down Expand Up @@ -128,9 +135,9 @@ Base.:-(row1::BandMatrixRow, row2::UniformScaling) =
Base.:-(row1::UniformScaling, row2::BandMatrixRow) =
map(rsub, promote(row1, row2)...)

Base.:*(row::BandMatrixRow, value::Number) =
Base.:*(row::BandMatrixRow, value::SingleValue) =
map(entry -> rmul(entry, value), row)
Base.:*(value::Number, row::BandMatrixRow) =
Base.:*(value::SingleValue, row::BandMatrixRow) =
map(entry -> rmul(value, entry), row)

Base.:/(row::BandMatrixRow, value::Number) =
Expand Down
6 changes: 4 additions & 2 deletions src/MatrixFields/field2arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ all_columns(space::Spaces.ExtrudedFiniteDifferenceSpace) =
# Spaces.all_nodes(Spaces.horizontal_space(axes(field)))

column_map(f::F, field) where {F} =
map((((i, j), h),) -> f(Spaces.column(field, i, j, h)), all_columns(field))
Iterators.map(all_columns(field)) do ((i, j), h)
f(Spaces.column(field, i, j, h))
end

"""
field2arrays(field)
Expand All @@ -104,7 +106,7 @@ Converts a field defined on a `FiniteDifferenceSpace` or on an
corresponds to a column of the field. This is done by calling
`column_field2array` on each of the field's columns.
"""
field2arrays(field) = column_map(column_field2array, field)
field2arrays(field) = collect(column_map(column_field2array, field))

"""
field2arrays_view(field)
Expand Down
121 changes: 121 additions & 0 deletions src/MatrixFields/lazy_operators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
AbstractLazyOperator
Supertype for "lazy operators", i.e., operators that can be called without any
arguments by users, as long as they appear in broadcast expressions that contain
at least one `Field`. If `lazy_op` is an `AbstractLazyOperator`, the expression
`lazy_op.()` will internally be translated to `non_lazy_op.(fields...)`, as long
as it appears in a broadcast expression with at least one `Field`. This
translation is done by the function `replace_lazy_operator(space, lazy_op)`,
which must be implemented by every subtype of `AbstractLazyOperator`.
"""
abstract type AbstractLazyOperator end

struct LazyOperatorStyle <: Base.Broadcast.BroadcastStyle end

Base.Broadcast.broadcasted(op::AbstractLazyOperator) =
Base.Broadcast.broadcasted(LazyOperatorStyle(), op)

# Broadcasting over an AbstractLazyOperator and either a Ref, a Tuple, a Field,
# an Operator, or another AbstractLazyOperator involves using LazyOperatorStyle.
Base.Broadcast.BroadcastStyle(
::LazyOperatorStyle,
::Union{
Base.Broadcast.AbstractArrayStyle{0},
Base.Broadcast.Style{Tuple},
Fields.AbstractFieldStyle,
LazyOperatorStyle,
},
) = LazyOperatorStyle()

struct LazyOperatorBroadcasted{F, A} <:
Operators.OperatorBroadcasted{LazyOperatorStyle}
f::F
args::A
end

# TODO: This definition of Base.Broadcast.broadcasted results in 2 additional
# method invalidations when using Julia 1.8.5. However, if we were to delete it,
# we would also need to replace the following specializations on
# LazyOperatorBroadcasted with specializations on Base.Broadcast.Broadcasted.
# Specifically, we would need to modify Base.Broadcast.materialize so that it
# specializes on Base.Broadcast.Broadcasted{LazyOperatorStyle}, and this would
# result in 11 invalidations instead of 2.
Base.Broadcast.broadcasted(::LazyOperatorStyle, f::F, args...) where {F} =
LazyOperatorBroadcasted(f, args)

function Base.Broadcast.materialize(bc::LazyOperatorBroadcasted)
space = largest_space(bc)
isnothing(space) && error("Cannot materialize broadcast expression with \
AbstractLazyOperator because it does not contain any Fields")
return Base.Broadcast.materialize(replace_lazy_operators(space, bc))
end

Base.Broadcast.materialize!(dest::Fields.Field, bc::LazyOperatorBroadcasted) =
Base.Broadcast.materialize!(dest, replace_lazy_operators(axes(dest), bc))

replace_lazy_operators(_, arg) = arg
replace_lazy_operators(space, bc::LazyOperatorBroadcasted) =
bc.f isa AbstractLazyOperator ? replace_lazy_operator(space, bc.f) :
Base.Broadcast.broadcasted(
bc.f,
replace_lazy_operators_args(space, bc.args...)...,
)

replace_lazy_operators_args(_) = ()
replace_lazy_operators_args(space, arg, args...) = (
replace_lazy_operators(space, arg),
replace_lazy_operators_args(space, args...)...,
)

"""
replace_lazy_operator(space, lazy_op)
Generates an instance of `Base.AbstractBroadcasted` that corresponds to the
expression `lazy_op.()`, where the broadcast in which this expression appears is
being evaluated on the given `space`. Note that the staggering (`CellCenter` or
`CellFace`) of this `space` depends on the specifics of the broadcast and is not
predetermined.
"""
replace_lazy_operator(_, ::AbstractLazyOperator) =
error("Every subtype of AbstractLazyOperator must implement a method for
replace_lazy_operator(space, lazy_op)")

largest_space(_) = nothing
largest_space(field::Fields.Field) = axes(field)
largest_space(bc::Base.AbstractBroadcasted) = largest_space_args(bc.args...)

largest_space_args() = nothing
largest_space_args(arg, args...) =
larger_space(largest_space(arg), largest_space_args(args...))

larger_space(::Nothing, ::Nothing) = nothing
larger_space(space1, ::Nothing) = space1
larger_space(::Nothing, space2) = space2
larger_space(space1::S, ::S) where {S} = space1 # Neither space is larger.
larger_space(
space1::Spaces.FiniteDifferenceSpace,
::Spaces.FiniteDifferenceSpace,
) = space1 # The staggering does not matter here, so neither space is larger.
larger_space(
space1::Spaces.ExtrudedFiniteDifferenceSpace,
::Spaces.ExtrudedFiniteDifferenceSpace,
) = space1 # The staggering does not matter here, so neither space is larger.
larger_space(
space1::Spaces.ExtrudedFiniteDifferenceSpace,
::Spaces.FiniteDifferenceSpace,
) = space1 # The types indicate that space2 is a subspace of space1.
larger_space(
::Spaces.FiniteDifferenceSpace,
space2::Spaces.ExtrudedFiniteDifferenceSpace,
) = space2 # The types indicate that space1 is a subspace of space2.
larger_space(
space1::Spaces.ExtrudedFiniteDifferenceSpace,
::Spaces.AbstractSpectralElementSpace,
) = space1 # The types indicate that space2 is a subspace of space1.
larger_space(
::Spaces.AbstractSpectralElementSpace,
space2::Spaces.ExtrudedFiniteDifferenceSpace,
) = space2 # The types indicate that space1 is a subspace of space2.
larger_space(::S1, ::S2) where {S1, S2} =
error("Mismatched spaces ($(S1.name.name) and $(S2.name.name))")
Loading

0 comments on commit fd63920

Please sign in to comment.