From d4fe3c1ea86ad3ddfd4eeb2bc8e92e91e6493630 Mon Sep 17 00:00:00 2001 From: Simon Byrne Date: Thu, 22 Jun 2023 10:55:36 -0700 Subject: [PATCH] define fill for CuArray-backed DataLayouts --- src/DataLayouts/cuda.jl | 43 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/DataLayouts/cuda.jl b/src/DataLayouts/cuda.jl index bc57000a74..367208759e 100644 --- a/src/DataLayouts/cuda.jl +++ b/src/DataLayouts/cuda.jl @@ -71,6 +71,19 @@ function knl_copyto!(dest, src) return nothing end +function knl_fill!(dest, val) + i = CUDA.threadIdx().x + j = CUDA.threadIdx().y + + h = CUDA.blockIdx().x + v = CUDA.blockIdx().y + + I = CartesianIndex((i, j, 1, v, h)) + + @inbounds dest[I] = val + return nothing +end + function Base.copyto!( dest::IJFH{S, Nij}, bc::Union{IJFH{S, Nij, A}, Base.Broadcast.Broadcasted{IJFHStyle{Nij, A}}}, @@ -81,6 +94,18 @@ function Base.copyto!( end return dest end +function Base.fill!( + dest::IJFH{S, Nij, A}, + val, +) where {S, Nij, A <: CUDA.CuArray} + _, _, _, _, Nh = size(bc) + if Nh > 0 + CUDA.@cuda threads = (Nij, Nij) blocks = (Nh, 1) knl_fill!(dest, val) + end + return dest +end + + function Base.copyto!( dest::VIJFH{S, Nij}, @@ -92,6 +117,17 @@ function Base.copyto!( end return dest end +function Base.fill!( + dest::VIJFH{S, Nij, A}, + val, +) where {S, Nij, A <: CUDA.CuArray} + _, _, _, Nv, Nh = size(bc) + if Nv > 0 && Nh > 0 + CUDA.@cuda threads = (Nij, Nij) blocks = (Nh, Nv) knl_fill!(dest, val) + end + return dest +end + function Base.copyto!( dest::VF{S}, @@ -103,3 +139,10 @@ function Base.copyto!( end return dest end +function Base.fill!(dest::VF{S, A}, val) where {S, A <: CUDA.CuArray} + _, _, _, Nv, Nh = size(bc) + if Nv > 0 && Nh > 0 + CUDA.@cuda threads = (1, 1) blocks = (Nh, Nv) knl_fill!(dest, val) + end + return dest +end