Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPU Thomas solver #1348

Merged
merged 1 commit into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,18 @@ steps:
slurm_mem: 20GB
slurm_gpus: 1

- label: "Unit: Thomas Algorithm"
key: "cpu_thomas_algorithm"
command:
- "julia --color=yes --check-bounds=yes --project=examples test/Operators/thomas_algorithm.jl"

- label: "Unit: Thomas Algorithm"
key: "gpu_thomas_algorithm"
command:
- "julia --color=yes --check-bounds=yes --project=examples test/Operators/thomas_algorithm.jl"
agents:
slurm_gpus: 1

- group: "Unit: Hypsography"
steps:

Expand Down
2 changes: 1 addition & 1 deletion examples/hybrid/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ if is_distributed
logger_stream = ClimaComms.iamroot(comms_ctx) ? stderr : devnull
prev_logger = global_logger(ConsoleLogger(logger_stream, Logging.Info))
@info "Setting up distributed run on $nprocs \
processor$(nprocs == 1 ? "" : "s")"
processor$(nprocs == 1 ? "" : "s") on a $(comms_ctx.device) device"
else
using TerminalLoggers: TerminalLogger
prev_logger = global_logger(TerminalLogger())
Expand Down
11 changes: 1 addition & 10 deletions examples/hybrid/schur_complement_W.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,7 @@ function linsolve!(::Type{Val{:init}}, f, u0; kwargs...)

@. xᶠ𝕄 = bᶠ𝕄 + dtγ * (apply(∂ᶠ𝕄ₜ∂ᶜρ, bᶜρ) + apply(∂ᶠ𝕄ₜ∂ᶜ𝔼, bᶜ𝔼))

# TODO: Do this with stencil_solve!.
Ni, Nj, _, _, Nh = size(Spaces.local_geometry_data(axes(xᶜρ)))
for h in 1:Nh, j in 1:Nj, i in 1:Ni
xᶠ𝕄_column_view = parent(Spaces.column(xᶠ𝕄, i, j, h))
S_column = Spaces.column(S, i, j, h)
@views S_column_array.dl .= parent(S_column.coefs.:1)[2:end]
S_column_array.d .= parent(S_column.coefs.:2)
@views S_column_array.du .= parent(S_column.coefs.:3)[1:(end - 1)]
ldiv!(lu!(S_column_array), xᶠ𝕄_column_view)
end
Operators.column_thomas_solve!(S, xᶠ𝕄)

@. xᶜρ = -bᶜρ + dtγ * apply(∂ᶜρₜ∂ᶠ𝕄, xᶠ𝕄)
@. xᶜ𝔼 = -bᶜ𝔼 + dtγ * apply(∂ᶜ𝔼ₜ∂ᶠ𝕄, xᶠ𝕄)
Expand Down
1 change: 1 addition & 0 deletions src/Operators/Operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@ include("operator2stencil.jl")
include("pointwisestencil.jl")
include("remapping.jl")
include("integrals.jl")
include("thomas_algorithm.jl")

end # module
105 changes: 105 additions & 0 deletions src/Operators/thomas_algorithm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
column_thomas_solve!(A, b)

Solves the linear system `A * x = b`, where `A` is a tri-diagonal matrix
(represented by a `Field` of tri-diagonal matrix rows), and where `b` is a
vector (represented by a `Field` of numbers). The data in `b` is overwritten
with the solution `x`, and the upper diagonal of `A` is also overwritten with
intermediate values used to compute `x`. The algorithm is described here:
https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm.
"""
column_thomas_solve!(A, b) =
column_thomas_solve!(ClimaComms.device(axes(A)), A, b)

column_thomas_solve!(::ClimaComms.AbstractCPUDevice, A, b) =
thomas_algorithm!(A, b)

function column_thomas_solve!(::ClimaComms.CUDADevice, A, b)
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
nthreads, nblocks = Spaces._configure_threadblock(Ni * Nj * Nh)
@cuda threads = nthreads blocks = nblocks thomas_algorithm_kernel!(A, b)
end

function thomas_algorithm_kernel!(
A::Fields.ExtrudedFiniteDifferenceField,
b::Fields.ExtrudedFiniteDifferenceField,
)
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
if idx <= Ni * Nj * Nh
i, j, h = Spaces._get_idx((Ni, Nj, Nh), idx)
thomas_algorithm!(Spaces.column(A, i, j, h), Spaces.column(b, i, j, h))
end
return nothing
end

thomas_algorithm_kernel!(
A::Fields.FiniteDifferenceField,
b::Fields.FiniteDifferenceField,
) = thomas_algorithm!(A, b)

thomas_algorithm!(
A::Fields.ExtrudedFiniteDifferenceField,
b::Fields.ExtrudedFiniteDifferenceField,
) = Fields.bycolumn(colidx -> thomas_algorithm!(A[colidx], b[colidx]), axes(A))

function thomas_algorithm!(
A::Fields.FiniteDifferenceField,
b::Fields.FiniteDifferenceField,
)
nrows = Spaces.nlevels(axes(A))
lower_diag = A.coefs.:1
main_diag = A.coefs.:2
upper_diag = A.coefs.:3

# first row
denominator = _getindex(main_diag, 1)
_setindex!(upper_diag, 1, _getindex(upper_diag, 1) / denominator)
_setindex!(b, 1, _getindex(b, 1) / denominator)

# interior rows
for row in 2:(nrows - 1)
numerator =
_getindex(b, row) -
_getindex(lower_diag, row) * _getindex(b, row - 1)
denominator =
_getindex(main_diag, row) -
_getindex(lower_diag, row) * _getindex(upper_diag, row - 1)
_setindex!(upper_diag, row, _getindex(upper_diag, row) / denominator)
_setindex!(b, row, numerator / denominator)
end

# last row
numerator =
_getindex(b, nrows) -
_getindex(lower_diag, nrows) * _getindex(b, nrows - 1)
denominator =
_getindex(main_diag, nrows) -
_getindex(lower_diag, nrows) * _getindex(upper_diag, nrows - 1)
_setindex!(b, nrows, numerator / denominator)

# back substitution
for row in (nrows - 1):-1:1
value =
_getindex(b, row) -
_getindex(upper_diag, row) * _getindex(b, row + 1)
_setindex!(b, row, value)
end
end

# This is the same as @inbounds Fields.field_values(column_field)[index]
_getindex(column_field, index) = @inbounds Operators.getidx(
axes(column_field),
column_field,
Operators.Interior(),
index - 1 + Operators.left_idx(axes(column_field)),
)
dennisYatunin marked this conversation as resolved.
Show resolved Hide resolved

# This is the same as @inbounds Fields.field_values(column_field)[index] = value
_setindex!(column_field, index, value) = @inbounds Operators.setidx!(
axes(column_field),
column_field,
index - 1 + Operators.left_idx(axes(column_field)),
(1, 1, 1),
value,
)
61 changes: 61 additions & 0 deletions test/Operators/thomas_algorithm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using Test
import Random: seed!
import LinearAlgebra: Tridiagonal, norm
import ClimaCore
import ClimaCore: Geometry, Spaces, Fields, Operators

include(
joinpath(pkgdir(ClimaCore), "test", "TestUtilities", "TestUtilities.jl"),
)
import .TestUtilities as TU

function test_thomas_algorithm(space)
coords = Fields.coordinate_field(space)

# Set the seed to ensure reproducibility.
seed!(1)

# Set A to a random diagonally dominant tri-diagonal matrix.
A = map(coords) do coord
FT = Geometry.float_type(coord)
Operators.StencilCoefs{-1, 1}((rand(FT), 10 + rand(FT), rand(FT)))
end

# Set b to a random vector.
b = map(coord -> rand(Geometry.float_type(coord)), coords)

# Copy A and b, since they will be overwritten by column_thomas_solve!.
A_copy = copy(A)
b_copy = copy(b)

Operators.column_thomas_solve!(A, b)

# Verify that column_thomas_solve! correctly replaced b with A \ b.
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
for i in 1:Ni, j in 1:Nj, h in 1:Nh
A_column_data = Array(parent(Spaces.column(A_copy, i, j, h)))
A_column_array = Tridiagonal(
A_column_data[2:end, 1],
A_column_data[:, 2],
A_column_data[1:(end - 1), 3],
)
b_column_array = Array(parent(Spaces.column(b_copy, i, j, h)))[:]
x_column_array = Array(parent(Spaces.column(b, i, j, h)))[:]
charleskawczynski marked this conversation as resolved.
Show resolved Hide resolved
x_column_array_ref = A_column_array \ b_column_array
FT = Spaces.undertype(space)
@test all(@. abs(x_column_array - x_column_array_ref) < eps(FT))
end
end

@testset "Thomas Algorithm unit tests" begin
for FT in (Float32, Float64),
space in (
TU.ColumnCenterFiniteDifferenceSpace(FT),
TU.ColumnFaceFiniteDifferenceSpace(FT),
TU.CenterExtrudedFiniteDifferenceSpace(FT),
TU.FaceExtrudedFiniteDifferenceSpace(FT),
)

test_thomas_algorithm(space)
end
end