Skip to content

Commit

Permalink
define fill for CuArray-backed DataLayouts
Browse files Browse the repository at this point in the history
  • Loading branch information
simonbyrne committed Jun 22, 2023
1 parent 90b3999 commit d4fe3c1
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions src/DataLayouts/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ function knl_copyto!(dest, src)
return nothing
end

function knl_fill!(dest, val)
i = CUDA.threadIdx().x
j = CUDA.threadIdx().y

h = CUDA.blockIdx().x
v = CUDA.blockIdx().y

I = CartesianIndex((i, j, 1, v, h))

@inbounds dest[I] = val
return nothing
end

function Base.copyto!(
dest::IJFH{S, Nij},
bc::Union{IJFH{S, Nij, A}, Base.Broadcast.Broadcasted{IJFHStyle{Nij, A}}},
Expand All @@ -81,6 +94,18 @@ function Base.copyto!(
end
return dest
end
function Base.fill!(
dest::IJFH{S, Nij, A},
val,
) where {S, Nij, A <: CUDA.CuArray}
_, _, _, _, Nh = size(bc)
if Nh > 0
CUDA.@cuda threads = (Nij, Nij) blocks = (Nh, 1) knl_fill!(dest, val)
end
return dest
end



function Base.copyto!(
dest::VIJFH{S, Nij},
Expand All @@ -92,6 +117,17 @@ function Base.copyto!(
end
return dest
end
function Base.fill!(
dest::VIJFH{S, Nij, A},
val,
) where {S, Nij, A <: CUDA.CuArray}
_, _, _, Nv, Nh = size(bc)
if Nv > 0 && Nh > 0
CUDA.@cuda threads = (Nij, Nij) blocks = (Nh, Nv) knl_fill!(dest, val)
end
return dest
end


function Base.copyto!(
dest::VF{S},
Expand All @@ -103,3 +139,10 @@ function Base.copyto!(
end
return dest
end
function Base.fill!(dest::VF{S, A}, val) where {S, A <: CUDA.CuArray}
_, _, _, Nv, Nh = size(bc)
if Nv > 0 && Nh > 0
CUDA.@cuda threads = (1, 1) blocks = (Nh, Nv) knl_fill!(dest, val)
end
return dest
end

0 comments on commit d4fe3c1

Please sign in to comment.