Skip to content

Commit

Permalink
Try #1348:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] committed Jun 26, 2023
2 parents 5d7be91 + 276c106 commit f0ec8c8
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 39 deletions.
12 changes: 12 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,18 @@ steps:
agents:
slurm_ntasks: 2

- label: ":computer: Thomas Algorithm test"
key: "cpu_thomas_algorithm_test"
command:
- "julia --color=yes --project=examples examples/hybrid/test_thomas.jl"

- label: ":flower_playing_cards: Thomas Algorithm test"
key: "gpu_thomas_algorithm_test"
command:
- "julia --color=yes --project=examples examples/hybrid/test_thomas.jl"
agents:
slurm_gpus: 1

- group: "Examples: Sphere"
steps:

Expand Down
2 changes: 1 addition & 1 deletion examples/hybrid/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ if is_distributed
logger_stream = ClimaComms.iamroot(comms_ctx) ? stderr : devnull
prev_logger = global_logger(ConsoleLogger(logger_stream, Logging.Info))
@info "Setting up distributed run on $nprocs \
processor$(nprocs == 1 ? "" : "s")"
processor$(nprocs == 1 ? "" : "s") on a $(comms_ctx.device) device"
else
using TerminalLoggers: TerminalLogger
prev_logger = global_logger(TerminalLogger())
Expand Down
91 changes: 81 additions & 10 deletions examples/hybrid/schur_complement_W.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LinearAlgebra
using CUDA

using ClimaCore: Spaces, Fields, Operators
using ClimaCore.Utilities: half
Expand Down Expand Up @@ -154,16 +155,7 @@ function linsolve!(::Type{Val{:init}}, f, u0; kwargs...)

@. xᶠ𝕄 = bᶠ𝕄 + dtγ * (apply(∂ᶠ𝕄ₜ∂ᶜρ, bᶜρ) + apply(∂ᶠ𝕄ₜ∂ᶜ𝔼, bᶜ𝔼))

# TODO: Do this with stencil_solve!.
Ni, Nj, _, _, Nh = size(Spaces.local_geometry_data(axes(xᶜρ)))
for h in 1:Nh, j in 1:Nj, i in 1:Ni
xᶠ𝕄_column_view = parent(Spaces.column(xᶠ𝕄, i, j, h))
S_column = Spaces.column(S, i, j, h)
@views S_column_array.dl .= parent(S_column.coefs.:1)[2:end]
S_column_array.d .= parent(S_column.coefs.:2)
@views S_column_array.du .= parent(S_column.coefs.:3)[1:(end - 1)]
ldiv!(lu!(S_column_array), xᶠ𝕄_column_view)
end
column_thomas_solve!(ClimaComms.device(axes(S)), S, xᶠ𝕄)

@. xᶜρ = -bᶜρ + dtγ * apply(∂ᶜρₜ∂ᶠ𝕄, xᶠ𝕄)
@. xᶜ𝔼 = -bᶜ𝔼 + dtγ * apply(∂ᶜ𝔼ₜ∂ᶠ𝕄, xᶠ𝕄)
Expand Down Expand Up @@ -206,3 +198,82 @@ function linsolve!(::Type{Val{:init}}, f, u0; kwargs...)
end
end
end

"""
thomas_algorithm!(A, b)
Thomas algorithm for solving a linear system A x = b,
where A is a tri-diagonal matrix.
A and b are overwritten.
Solution is written to b
"""
function thomas_algorithm!(A, b)
# A[_, 1] is lower diag, A[_, 2] is main diag, A[_, 3] is upper diag
nrows = size(A, 1)
# first row
A[1, 3] /= A[1, 2]
b[1] /= A[1, 2]
# interior rows
for row in 2:(nrows - 1)
fac = A[row, 2] - (A[row, 1] * A[row - 1, 3])
A[row, 3] /= fac
b[row] = (b[row] - A[row, 1] * b[row - 1]) / fac
end
# last row
fac = A[nrows, 2] - A[nrows - 1, 3] * A[nrows, 1]
b[nrows] = (b[nrows] - A[nrows, 1] * b[nrows - 1]) / fac
# back substitution
for row in (nrows - 1):-1:1
b[row] -= b[row + 1] * A[row, 3]
end
return nothing
end

function column_thomas_solve!(::ClimaComms.CPUSingleThreaded, A, b)
Ni, Nj, _, _, Nh = size(Spaces.local_geometry_data(axes(A)))
for h in 1:Nh, j in 1:Nj, i in 1:Ni
A_column = parent(Spaces.column(A, i, j, h))
b_column = parent(Spaces.column(b, i, j, h))
thomas_algorithm!(A_column, b_column)
end
return nothing
end

function column_thomas_solve!(::ClimaComms.CPUMultiThreaded, A, b)
Ni, Nj, _, _, Nh = size(Spaces.local_geometry_data(axes(A)))
@inbounds begin
Threads.@threads for h in 1:Nh
for j in 1:Nj, i in 1:Ni
A_column = parent(Spaces.column(A, i, j, h))
b_column = parent(Spaces.column(b, i, j, h))
thomas_algorithm!(A_column, b_column)
end
end
end
return nothing
end

function column_thomas_solve!(::ClimaComms.CUDADevice, A, b)
Ni, Nj, _, _, Nh = size(Spaces.local_geometry_data(axes(A)))
nitems = Ni * Nj * Nh
max_threads = 256
nthreads = min(nitems, max_threads)
nblocks = cld(nitems, nthreads)
@cuda threads = nthreads blocks = nblocks thomas_kernel(A, b)
return nothing
end


function thomas_kernel(A, b)
Ni, Nj, _, _, Nh = size(Spaces.local_geometry_data(axes(A)))
tid = threadIdx().x + (blockIdx().x - 1) * blockDim().x
if tid <= Ni * Nj * Nh
h = cld(tid, Ni * Nj)
j = cld(tid - (h - 1) * Ni * Nj, Ni)
i = tid - (h - 1) * Ni * Nj - (j - 1) * Ni
A_column = parent(Spaces.column(A, i, j, h))
b_column = parent(Spaces.column(b, i, j, h))
thomas_algorithm!(A_column, b_column)
end
return nothing
end
80 changes: 80 additions & 0 deletions examples/hybrid/test_thomas.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
using Test
using IntervalSets
import Random: seed!
import LinearAlgebra: Tridiagonal, norm
import ClimaComms
import ClimaCore:
Geometry, Domains, Meshes, Topologies, Spaces, Fields, Operators

include("schur_complement_W.jl")

function test_thomas_algorithm(::Type{FT}) where {FT}
comms_ctx = ClimaComms.SingletonCommsContext()
@info "Testing Thomas Algorithm on a $(comms_ctx.device) with \
$(ClimaComms.nprocs(comms_ctx)) processes and eltype $FT"

domain_z = Domains.IntervalDomain(
Geometry.ZPoint(FT(1)) .. Geometry.ZPoint(FT(2)),
periodic = true,
)
mesh_z = Meshes.IntervalMesh(domain_z; nelems = 10)
topology_z = Topologies.IntervalTopology(comms_ctx, mesh_z)

domain_xy = Domains.RectangleDomain(
Geometry.XPoint(FT(1)) .. Geometry.XPoint(FT(2)),
Geometry.YPoint(FT(1)) .. Geometry.YPoint(FT(2)),
x1periodic = true,
x2periodic = true,
)
mesh_xy = Meshes.RectilinearMesh(domain_xy, 10, 10)
topology_xy = Topologies.Topology2D(comms_ctx, mesh_xy)

quad = Spaces.Quadratures.GLL{4}()

space_vf = Spaces.CenterFiniteDifferenceSpace(topology_z)
space_ijfh = Spaces.SpectralElementSpace2D(topology_xy, quad)
space_vijfh = Spaces.ExtrudedFiniteDifferenceSpace(space_ijfh, space_vf)

coords = Fields.coordinate_field(space_vijfh)

seed!(1) # ensure reproducibility
A = map(coords) do _
diagonal_coef = FT(4) + rand(FT) / 10
off_diagonal_coef = FT(1) + rand(FT) / 10
Operators.StencilCoefs{-1, 1}((
off_diagonal_coef,
diagonal_coef,
off_diagonal_coef,
),)
end
b = map(_ -> rand(FT), coords)
b ./= norm(b)

A_copy = copy(A)
b_copy = copy(b)

column_thomas_solve!(comms_ctx.device, A, b)

Ni, Nj, _, _, Nh = size(Spaces.local_geometry_data(space_vijfh))
for h in 1:Nh, j in 1:Nj, i in 1:Ni
A_array_data = Array(parent(Spaces.column(A_copy, i, j, h)))
A_array = Tridiagonal(
@view(A_array_data[2:end, 1]),
@view(A_array_data[1:end, 2]),
@view(A_array_data[1:(end - 1), 3]),
)
b_array = Array(vec(parent(Spaces.column(b_copy, i, j, h))))
x_array_exact = A_array \ b_array
x_array_computed = Array(vec(parent(Spaces.column(b, i, j, h))))

@test norm(x_array_computed .- x_array_exact) / norm(x_array_exact) <
2 * eps(FT)
end
end

@testset "Thomas algorithm tests for Float32 and Float64" begin
FT = Float64
for FT in (Float32, Float64)
test_thomas_algorithm(FT)
end
end
50 changes: 22 additions & 28 deletions perf/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.8.5"
manifest_format = "2.0"
project_hash = "10fc1e8e4023a6e12399288cc9246d605c84fee9"
project_hash = "d58ea912fb26188a8a99f4b060e3ee9411719dd0"

[[deps.AMD]]
deps = ["Libdl", "LinearAlgebra", "SparseArrays", "Test"]
Expand Down Expand Up @@ -149,10 +149,10 @@ uuid = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9"
version = "0.1.30"

[[deps.CUDA]]
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"]
git-tree-sha1 = "442d989978ed3ff4e174c928ee879dc09d1ef693"
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"]
git-tree-sha1 = "35160ef0f03b14768abfd68b830f8e3940e8e0dc"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
version = "4.3.2"
version = "4.4.0"

[[deps.CUDA_Driver_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"]
Expand Down Expand Up @@ -206,7 +206,7 @@ version = "0.5.0"
deps = ["Adapt", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DiffEqBase", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"]
path = ".."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.10.41"
version = "0.10.42"

[[deps.ClimaCorePlots]]
deps = ["ClimaCore", "RecipesBase", "StaticArrays", "TriplotBase"]
Expand Down Expand Up @@ -436,9 +436,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[deps.Distributions]]
deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"]
git-tree-sha1 = "4ed4a6df2548a72f66e03f3a285cd1f3b573035d"
git-tree-sha1 = "db40d3aff76ea6a3619fdd15a8c78299221a2394"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.96"
version = "0.25.97"

[[deps.DocStringExtensions]]
deps = ["LibGit2"]
Expand Down Expand Up @@ -623,9 +623,9 @@ version = "3.3.8+0"

[[deps.GPUArrays]]
deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"]
git-tree-sha1 = "745847e65e72a475716952f0a8a7258d42338ce9"
git-tree-sha1 = "2e57b4a4f9cc15e85a24d603256fe08e527f48d1"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "8.8.0"
version = "8.8.1"

[[deps.GPUArraysCore]]
deps = ["Adapt"]
Expand All @@ -635,9 +635,9 @@ version = "0.1.5"

[[deps.GPUCompiler]]
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "cb090aea21c6ca78d59672a7e7d13bd56d09de64"
git-tree-sha1 = "69a9aa4346bca723e46769ff6b6277e597c969b1"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.20.3"
version = "0.21.2"

[[deps.GR]]
deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "UUIDs", "p7zip_jll"]
Expand Down Expand Up @@ -877,9 +877,9 @@ version = "3.0.0+1"

[[deps.LLVM]]
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "5007c1421563108110bbd57f63d8ad4565808818"
git-tree-sha1 = "7d5788011dd273788146d40eb5b1fbdc199d0296"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "5.2.0"
version = "6.0.1"

[[deps.LLVMExtra_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
Expand Down Expand Up @@ -1003,10 +1003,10 @@ uuid = "4b2f31a3-9ecc-558c-b454-b3730dcb73e9"
version = "2.35.0+0"

[[deps.Libtiff_jll]]
deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "XZ_jll", "Zlib_jll", "Zstd_jll"]
git-tree-sha1 = "2da088d113af58221c52828a80378e16be7d037a"
deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "Pkg", "Zlib_jll", "Zstd_jll"]
git-tree-sha1 = "3eb79b0ca5764d4799c06699573fd8f533259713"
uuid = "89763e89-9b03-5906-acba-b20f662cd828"
version = "4.5.1+1"
version = "4.4.0+0"

[[deps.Libuuid_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
Expand Down Expand Up @@ -1641,9 +1641,9 @@ version = "1.21.0"

[[deps.SpecialFunctions]]
deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
git-tree-sha1 = "ef28127915f4229c971eb43f3fc075dd3fe91880"
git-tree-sha1 = "7beb031cf8145577fbccacd94b8a8f4ce78428d3"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "2.2.0"
version = "2.3.0"

[[deps.Static]]
deps = ["IfElse"]
Expand Down Expand Up @@ -1729,9 +1729,9 @@ version = "1.10.1"

[[deps.TaylorSeries]]
deps = ["LinearAlgebra", "Markdown", "Requires", "SparseArrays"]
git-tree-sha1 = "c274151bde5a608bb329d76160a9344af707bfb4"
git-tree-sha1 = "50718b4fc1ce20cecf28d85215028c78b4d875c2"
uuid = "6aa5eb33-94cf-58f4-a9d0-e4b2c4fc25ea"
version = "0.15.1"
version = "0.15.2"

[[deps.TempestRemap_jll]]
deps = ["Artifacts", "HDF5_jll", "JLLWrappers", "Libdl", "NetCDF_jll", "OpenBLAS32_jll", "Pkg"]
Expand Down Expand Up @@ -1832,9 +1832,9 @@ version = "0.2.1"

[[deps.UnsafeAtomicsLLVM]]
deps = ["LLVM", "UnsafeAtomics"]
git-tree-sha1 = "ea37e6066bf194ab78f4e747f5245261f17a7175"
git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e"
uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
version = "0.1.2"
version = "0.1.3"

[[deps.Unzip]]
git-tree-sha1 = "ca0969166a028236229f63514992fc073799bb78"
Expand Down Expand Up @@ -1894,12 +1894,6 @@ git-tree-sha1 = "91844873c4085240b95e795f692c4cec4d805f8a"
uuid = "aed1982a-8fda-507f-9586-7b0439959a61"
version = "1.1.34+0"

[[deps.XZ_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "8abe223c2549ea70be752b20a53aa236a7868eb0"
uuid = "ffd25f8a-64ca-5728-b0f7-c24cf3aae800"
version = "5.4.3+0"

[[deps.Xorg_libX11_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libxcb_jll", "Xorg_xtrans_jll"]
git-tree-sha1 = "5be649d550f3f4b95308bf0183b82e2582876527"
Expand Down
1 change: 1 addition & 0 deletions perf/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884"
ClimaCorePlots = "cf7c7e5a-b407-4c48-9047-11a94a308626"
Expand Down

0 comments on commit f0ec8c8

Please sign in to comment.