From 00cada787403ec28c1cc5947943335f936344150 Mon Sep 17 00:00:00 2001 From: sriharshakandala Date: Mon, 26 Jun 2023 10:56:26 -0700 Subject: [PATCH] Add column_thomas_solve! to ClimaCore, along with unit tests --- .buildkite/pipeline.yml | 12 ++++ examples/hybrid/driver.jl | 2 +- examples/hybrid/schur_complement_W.jl | 11 +-- src/Operators/Operators.jl | 1 + src/Operators/thomas_algorithm.jl | 100 ++++++++++++++++++++++++++ test/Operators/thomas_algorithm.jl | 61 ++++++++++++++++ 6 files changed, 176 insertions(+), 11 deletions(-) create mode 100644 src/Operators/thomas_algorithm.jl create mode 100644 test/Operators/thomas_algorithm.jl diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 3defbf13c4..b060f6b934 100755 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -482,6 +482,18 @@ steps: slurm_mem: 20GB slurm_gpus: 1 + - label: "Unit: Thomas Algorithm" + key: "cpu_thomas_algorithm" + command: + - "julia --color=yes --check-bounds=yes --project=examples test/Operators/thomas_algorithm.jl" + + - label: "Unit: Thomas Algorithm" + key: "gpu_thomas_algorithm" + command: + - "julia --color=yes --check-bounds=yes --project=examples test/Operators/thomas_algorithm.jl" + agents: + slurm_gpus: 1 + - group: "Unit: Hypsography" steps: diff --git a/examples/hybrid/driver.jl b/examples/hybrid/driver.jl index 8c77d40406..140b5fad43 100644 --- a/examples/hybrid/driver.jl +++ b/examples/hybrid/driver.jl @@ -35,7 +35,7 @@ if is_distributed logger_stream = ClimaComms.iamroot(comms_ctx) ? stderr : devnull prev_logger = global_logger(ConsoleLogger(logger_stream, Logging.Info)) @info "Setting up distributed run on $nprocs \ - processor$(nprocs == 1 ? "" : "s")" + processor$(nprocs == 1 ? "" : "s") on a $(comms_ctx.device) device" else using TerminalLoggers: TerminalLogger prev_logger = global_logger(TerminalLogger()) diff --git a/examples/hybrid/schur_complement_W.jl b/examples/hybrid/schur_complement_W.jl index d4bf8e1367..0a0948488a 100644 --- a/examples/hybrid/schur_complement_W.jl +++ b/examples/hybrid/schur_complement_W.jl @@ -154,16 +154,7 @@ function linsolve!(::Type{Val{:init}}, f, u0; kwargs...) @. xᶠ𝕄 = bᶠ𝕄 + dtγ * (apply(∂ᶠ𝕄ₜ∂ᶜρ, bᶜρ) + apply(∂ᶠ𝕄ₜ∂ᶜ𝔼, bᶜ𝔼)) - # TODO: Do this with stencil_solve!. - Ni, Nj, _, _, Nh = size(Spaces.local_geometry_data(axes(xᶜρ))) - for h in 1:Nh, j in 1:Nj, i in 1:Ni - xᶠ𝕄_column_view = parent(Spaces.column(xᶠ𝕄, i, j, h)) - S_column = Spaces.column(S, i, j, h) - @views S_column_array.dl .= parent(S_column.coefs.:1)[2:end] - S_column_array.d .= parent(S_column.coefs.:2) - @views S_column_array.du .= parent(S_column.coefs.:3)[1:(end - 1)] - ldiv!(lu!(S_column_array), xᶠ𝕄_column_view) - end + Operators.column_thomas_solve!(S, xᶠ𝕄) @. xᶜρ = -bᶜρ + dtγ * apply(∂ᶜρₜ∂ᶠ𝕄, xᶠ𝕄) @. xᶜ𝔼 = -bᶜ𝔼 + dtγ * apply(∂ᶜ𝔼ₜ∂ᶠ𝕄, xᶠ𝕄) diff --git a/src/Operators/Operators.jl b/src/Operators/Operators.jl index 5ece12aef6..973d59b842 100644 --- a/src/Operators/Operators.jl +++ b/src/Operators/Operators.jl @@ -26,5 +26,6 @@ include("operator2stencil.jl") include("pointwisestencil.jl") include("remapping.jl") include("integrals.jl") +include("thomas_algorithm.jl") end # module diff --git a/src/Operators/thomas_algorithm.jl b/src/Operators/thomas_algorithm.jl new file mode 100644 index 0000000000..66a839d715 --- /dev/null +++ b/src/Operators/thomas_algorithm.jl @@ -0,0 +1,100 @@ +""" + column_thomas_solve!(A, b) + +Solves the linear system `A * x = b`, where `A` is a tri-diagonal matrix +(represented by a `Field` of tri-diagonal matrix rows), and where `b` is a +vector (represented by a `Field` of numbers). The data in `b` is overwritten +with the solution `x`, and the upper diagonal of `A` is also overwritten with +intermediate values used to compute `x`. The algorithm is described here: +https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm. +""" +column_thomas_solve!(A, b) = + column_thomas_solve!(ClimaComms.device(axes(A)), A, b) + +column_thomas_solve!(::ClimaComms.AbstractCPUDevice, A, b) = + thomas_algorithm!(A, b) + +function column_thomas_solve!(::ClimaComms.CUDADevice, A, b) + Ni, Nj, _, _, Nh = size(Fields.field_values(A)) + nthreads, nblocks = Spaces._configure_threadblock(Ni * Nj * Nh) + @cuda threads = nthreads blocks = nblocks thomas_algorithm_kernel!(A, b) +end + +function thomas_algorithm_kernel!( + A::Fields.ExtrudedFiniteDifferenceField, + b::Fields.ExtrudedFiniteDifferenceField, +) + idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x + Ni, Nj, _, _, Nh = size(Fields.field_values(A)) + if idx <= Ni * Nj * Nh + i, j, h = Spaces._get_idx((Ni, Nj, Nh), idx) + thomas_algorithm!(Spaces.column(A, i, j, h), Spaces.column(b, i, j, h)) + end + return nothing +end + +thomas_algorithm_kernel!(A::Fields.ColumnField, b::Fields.ColumnField) = + thomas_algorithm!(A, b) + +thomas_algorithm!( + A::Fields.ExtrudedFiniteDifferenceField, + b::Fields.ExtrudedFiniteDifferenceField, +) = Fields.bycolumn(colidx -> thomas_algorithm!(A[colidx], b[colidx]), axes(A)) + +function thomas_algorithm!(A::Fields.ColumnField, b::Fields.ColumnField) + nrows = Spaces.nlevels(axes(A)) + lower_diag = A.coefs.:1 + main_diag = A.coefs.:2 + upper_diag = A.coefs.:3 + + # first row + denominator = _getindex(main_diag, 1) + _setindex!(upper_diag, 1, _getindex(upper_diag, 1) / denominator) + _setindex!(b, 1, _getindex(b, 1) / denominator) + + # interior rows + for row in 2:(nrows - 1) + numerator = + _getindex(b, row) - + _getindex(lower_diag, row) * _getindex(b, row - 1) + denominator = + _getindex(main_diag, row) - + _getindex(lower_diag, row) * _getindex(upper_diag, row - 1) + _setindex!(upper_diag, row, _getindex(upper_diag, row) / denominator) + _setindex!(b, row, numerator / denominator) + end + + # last row + numerator = + _getindex(b, nrows) - + _getindex(lower_diag, nrows) * _getindex(b, nrows - 1) + denominator = + _getindex(main_diag, nrows) - + _getindex(lower_diag, nrows) * _getindex(upper_diag, nrows - 1) + _setindex!(b, nrows, numerator / denominator) + + # back substitution + for row in (nrows - 1):-1:1 + value = + _getindex(b, row) - + _getindex(upper_diag, row) * _getindex(b, row + 1) + _setindex!(b, row, value) + end +end + +# This is the same as @inbounds Fields.field_values(column_field)[index] +_getindex(column_field, index) = @inbounds Operators.getidx( + axes(column_field), + column_field, + Operators.Interior(), + index - 1 + Operators.left_idx(axes(column_field)), +) + +# This is the same as @inbounds Fields.field_values(column_field)[index] = value +_setindex!(column_field, index, value) = @inbounds Operators.setidx!( + axes(column_field), + column_field, + index - 1 + Operators.left_idx(axes(column_field)), + (1, 1, 1), + value, +) diff --git a/test/Operators/thomas_algorithm.jl b/test/Operators/thomas_algorithm.jl new file mode 100644 index 0000000000..7cb38a4922 --- /dev/null +++ b/test/Operators/thomas_algorithm.jl @@ -0,0 +1,61 @@ +using Test +import Random: seed! +import LinearAlgebra: Tridiagonal, norm +import ClimaCore +import ClimaCore: Geometry, Spaces, Fields, Operators + +include( + joinpath(pkgdir(ClimaCore), "test", "TestUtilities", "TestUtilities.jl"), +) +import .TestUtilities as TU + +function test_thomas_algorithm(space) + coords = Fields.coordinate_field(space) + + # Set the seed to ensure reproducibility. + seed!(1) + + # Set A to a random diagonally dominant tri-diagonal matrix. + A = map(coords) do coord + FT = Geometry.float_type(coord) + Operators.StencilCoefs{-1, 1}((rand(FT), 10 + rand(FT), rand(FT))) + end + + # Set b to a random vector. + b = map(coord -> rand(Geometry.float_type(coord)), coords) + + # Copy A and b, since they will be overwritten by column_thomas_solve!. + A_copy = copy(A) + b_copy = copy(b) + + Operators.column_thomas_solve!(A, b) + + # Verify that column_thomas_solve! correctly replaced b with A \ b. + Ni, Nj, _, _, Nh = size(Fields.field_values(A)) + for i in 1:Ni, j in 1:Nj, h in 1:Nh + A_column_data = Array(parent(Spaces.column(A_copy, i, j, h))) + A_column_array = Tridiagonal( + A_column_data[2:end, 1], + A_column_data[:, 2], + A_column_data[1:(end - 1), 3], + ) + b_column_array = Array(parent(Spaces.column(b_copy, i, j, h)))[:] + x_column_array = Array(parent(Spaces.column(b, i, j, h)))[:] + x_column_array_ref = A_column_array \ b_column_array + FT = Spaces.undertype(space) + @test all(@. abs(x_column_array - x_column_array_ref) < eps(FT)) + end +end + +@testset "Thomas Algorithm unit tests" begin + for FT in (Float32, Float64), + space in ( + TU.ColumnCenterFiniteDifferenceSpace(FT), + TU.ColumnFaceFiniteDifferenceSpace(FT), + TU.CenterExtrudedFiniteDifferenceSpace(FT), + TU.FaceExtrudedFiniteDifferenceSpace(FT), + ) + + test_thomas_algorithm(space) + end +end