Skip to content

Commit

Permalink
Merge pull request #1967 from CliMA/ck/fill_bm
Browse files Browse the repository at this point in the history
Refactor fill benchmark
  • Loading branch information
charleskawczynski authored Sep 3, 2024
2 parents d0680b8 + db622ba commit 40e6cf3
Showing 1 changed file with 36 additions and 19 deletions.
55 changes: 36 additions & 19 deletions test/DataLayouts/benchmark_fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,36 @@ julia --project
using Revise; include(joinpath("test", "DataLayouts", "benchmark_fill.jl"))
=#
using Test
using ClimaCore
using ClimaCore.DataLayouts
using BenchmarkTools
import ClimaComms
@static pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends

function benchmarkfill!(device, data, val, name)
println("Benchmarking ClimaCore fill! for $name DataLayout")
if ClimaComms.device() isa ClimaComms.CUDADevice
import CUDA
device_name = CUDA.name(CUDA.device()) # Move to ClimaComms
else
device_name = "CPU"
end

include(joinpath(pkgdir(ClimaCore), "benchmarks/scripts/benchmark_utils.jl"))

function benchmarkfill!(bm, device, data, val, name)
caller = string(nameof(typeof(data)))
@info "Benchmarking $caller..."
trial = @benchmark ClimaComms.@cuda_sync $device fill!($data, $val)
show(stdout, MIME("text/plain"), trial)
println()
println("Benchmarking array fill! for $name DataLayout")
trial =
@benchmark ClimaComms.@cuda_sync $device fill!($(parent(data)), $val)
show(stdout, MIME("text/plain"), trial)
println()
t_min = minimum(trial.times) * 1e-9 # to seconds
nreps = length(trial.times)
n_reads_writes = DataLayouts.ncomponents(data) * 2
push_info(
bm;
kernel_time_s = t_min,
nreps = nreps,
caller,
problem_size = size(data),
n_reads_writes,
)
end

@testset "fill! with Nf = 1" begin
Expand All @@ -30,17 +45,19 @@ end
Nij = 4
Nh = 30 * 30 * 6
Nk = 6
bm = Benchmark(; float_type = FT, device_name)
#! format: off
data = DataF{S}(device_zeros(FT,Nf)); benchmarkfill!(device, data, 3, "DataF" ); @test all(parent(data) .== 3)
data = IJFH{S, Nij, Nh}(device_zeros(FT,Nij,Nij,Nf,Nh)); benchmarkfill!(device, data, 3, "IJFH" ); @test all(parent(data) .== 3)
data = IFH{S, Nij, Nh}(device_zeros(FT,Nij,Nf,Nh)); benchmarkfill!(device, data, 3, "IFH" ); @test all(parent(data) .== 3)
data = IJF{S, Nij}(device_zeros(FT,Nij,Nij,Nf)); benchmarkfill!(device, data, 3, "IJF" ); @test all(parent(data) .== 3)
data = IF{S, Nij}(device_zeros(FT,Nij,Nf)); benchmarkfill!(device, data, 3, "IF" ); @test all(parent(data) .== 3)
data = VF{S, Nv}(device_zeros(FT,Nv,Nf)); benchmarkfill!(device, data, 3, "VF" ); @test all(parent(data) .== 3)
data = VIJFH{S,Nv,Nij,Nh}(device_zeros(FT,Nv,Nij,Nij,Nf,Nh));benchmarkfill!(device, data, 3, "VIJFH" ); @test all(parent(data) .== 3)
data = VIFH{S, Nv, Nij, Nh}(device_zeros(FT,Nv,Nij,Nf,Nh)); benchmarkfill!(device, data, 3, "VIFH" ); @test all(parent(data) .== 3)
data = DataF{S}(device_zeros(FT,Nf)); benchmarkfill!(bm, device, data, 3, "DataF" ); @test all(parent(data) .== 3)
data = IJFH{S, Nij, Nh}(device_zeros(FT,Nij,Nij,Nf,Nh)); benchmarkfill!(bm, device, data, 3, "IJFH" ); @test all(parent(data) .== 3)
data = IFH{S, Nij, Nh}(device_zeros(FT,Nij,Nf,Nh)); benchmarkfill!(bm, device, data, 3, "IFH" ); @test all(parent(data) .== 3)
data = IJF{S, Nij}(device_zeros(FT,Nij,Nij,Nf)); benchmarkfill!(bm, device, data, 3, "IJF" ); @test all(parent(data) .== 3)
data = IF{S, Nij}(device_zeros(FT,Nij,Nf)); benchmarkfill!(bm, device, data, 3, "IF" ); @test all(parent(data) .== 3)
data = VF{S, Nv}(device_zeros(FT,Nv,Nf)); benchmarkfill!(bm, device, data, 3, "VF" ); @test all(parent(data) .== 3)
data = VIJFH{S,Nv,Nij,Nh}(device_zeros(FT,Nv,Nij,Nij,Nf,Nh));benchmarkfill!(bm, device, data, 3, "VIJFH" ); @test all(parent(data) .== 3)
data = VIFH{S, Nv, Nij, Nh}(device_zeros(FT,Nv,Nij,Nf,Nh)); benchmarkfill!(bm, device, data, 3, "VIFH" ); @test all(parent(data) .== 3)
#! format: on

# data = IJKFVH{S}(device_zeros(FT,Nij,Nij,Nk,Nf,Nh)); benchmarkfill!(device, data, 3); @test all(parent(data) .== 3) # TODO: test
# data = IH1JH2{S}(device_zeros(FT,Nij,Nij,Nk,Nf,Nh)); benchmarkfill!(device, data, 3); @test all(parent(data) .== 3) # TODO: test
# data = IJKFVH{S}(device_zeros(FT,Nij,Nij,Nk,Nf,Nh)); benchmarkfill!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
# data = IH1JH2{S}(device_zeros(FT,Nij,Nij,Nk,Nf,Nh)); benchmarkfill!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
tabulate_benchmark(bm)
end

0 comments on commit 40e6cf3

Please sign in to comment.