Skip to content

Commit

Permalink
Update byslab function for CPU threading
Browse files Browse the repository at this point in the history
  • Loading branch information
sriharshakandala committed Jun 22, 2023
1 parent 8bc482c commit c498015
Showing 1 changed file with 67 additions and 3 deletions.
70 changes: 67 additions & 3 deletions src/Fields/indices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,40 @@ Base.@propagate_inbounds function slab(
slab(field, slabidx.v + half, slabidx.h)
end

function byslab(fn, space::Spaces.AbstractSpectralElementSpace)
function byslab(fn, space::Spaces.AbstractSpace)
byslab(fn, ClimaComms.device(space), space)
end

function byslab(
fn,
::ClimaComms.CPUSingleThreaded,
space::Spaces.AbstractSpectralElementSpace,
)
Nh = Topologies.nlocalelems(space.topology)::Int
@inbounds for h in 1:Nh
fn(SlabIndex(nothing, h))
end
end
function byslab(fn, space::Spaces.CenterExtrudedFiniteDifferenceSpace)

function byslab(
fn,
::ClimaComms.CPUMultiThreaded,
space::Spaces.AbstractSpectralElementSpace,
)
Nh = Topologies.nlocalelems(space.topology)::Int
@inbounds begin
Threads.@threads for h in 1:Nh
fn(SlabIndex(nothing, h))
end
end
end


function byslab(
fn,
::ClimaComms.CPUSingleThreaded,
space::Spaces.CenterExtrudedFiniteDifferenceSpace,
)
Nh = Topologies.nlocalelems(Spaces.topology(space))
Nv = Spaces.nlevels(space)
@inbounds begin
Expand All @@ -204,7 +231,28 @@ function byslab(fn, space::Spaces.CenterExtrudedFiniteDifferenceSpace)
end
end
end
function byslab(fn, space::Spaces.FaceExtrudedFiniteDifferenceSpace)

function byslab(
fn,
::ClimaComms.CPUMultiThreaded,
space::Spaces.CenterExtrudedFiniteDifferenceSpace,
)
Nh = Topologies.nlocalelems(Spaces.topology(space))
Nv = Spaces.nlevels(space)
@inbounds begin
Threads.@threads for h in 1:Nh
for v in 1:Nv
fn(SlabIndex(v, h))
end
end
end
end

function byslab(
fn,
::ClimaComms.CPUSingleThreaded,
space::Spaces.FaceExtrudedFiniteDifferenceSpace,
)
Nh = Topologies.nlocalelems(Spaces.topology(space))
Nv = Spaces.nlevels(space)
@inbounds begin
Expand All @@ -215,3 +263,19 @@ function byslab(fn, space::Spaces.FaceExtrudedFiniteDifferenceSpace)
end
end
end

function byslab(
fn,
::ClimaComms.CPUMultiThreaded,
space::Spaces.FaceExtrudedFiniteDifferenceSpace,
)
Nh = Topologies.nlocalelems(Spaces.topology(space))
Nv = Spaces.nlevels(space)
@inbounds begin
Threads.@threads for h in 1:Nh
for v in 1:Nv
fn(SlabIndex(v - half, h))
end
end
end
end

0 comments on commit c498015

Please sign in to comment.