From ed49dd843b5e20d3a504f78b7149d2b91e08c397 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 30 Aug 2024 22:32:43 -0400 Subject: [PATCH] Define CuArray on FieldArrays --- ext/cuda/data_layouts.jl | 14 ++++++++++++++ ext/cuda/data_layouts_copyto.jl | 4 ++-- src/DataLayouts/DataLayouts.jl | 11 ++++++++--- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/ext/cuda/data_layouts.jl b/ext/cuda/data_layouts.jl index 0c80540430..c63a6c0d4b 100644 --- a/ext/cuda/data_layouts.jl +++ b/ext/cuda/data_layouts.jl @@ -77,3 +77,17 @@ function Adapt.adapt_structure( Adapt.adapt(to, bc.axes), ) end + +import ClimaCore.DataLayouts as DL +import CUDA +function CUDA.CuArray(fa::DL.FieldArray{FD}) where {FD} + arrays = ntuple(Val(DL.ncomponents(fa))) do f + CUDA.CuArray(fa.arrays[f]) + end + return DL.FieldArray{FD}(arrays) +end + +DL.field_array( + array::CUDA.CuArray, + as::ArraySize +) = CUDA.CuArray(DL.field_array(Array(array), as)) diff --git a/ext/cuda/data_layouts_copyto.jl b/ext/cuda/data_layouts_copyto.jl index 4f88fd16ed..f855bdb999 100644 --- a/ext/cuda/data_layouts_copyto.jl +++ b/ext/cuda/data_layouts_copyto.jl @@ -70,11 +70,11 @@ function cuda_copyto!(dest::AbstractData, bc) auto_launch!( knl_copyto_linear!, (dest, bc′, us), - nitems; + n; auto = true, ) else - auto_launch!(knl_copyto_flat!, (dest, bc, us), nitems; auto = true) + auto_launch!(knl_copyto_flat!, (dest, bc, us), n; auto = true) end end return dest diff --git a/src/DataLayouts/DataLayouts.jl b/src/DataLayouts/DataLayouts.jl index 3e39d9e98e..e3539bd0fa 100644 --- a/src/DataLayouts/DataLayouts.jl +++ b/src/DataLayouts/DataLayouts.jl @@ -401,14 +401,19 @@ end Base.length(data::IJFH) = get_Nh(data) -Base.@propagate_inbounds slab(data::IJFH, h::Integer) = slab(data, 1, h) - @inline function slab(data::IJFH{S, Nij}, v::Integer, h::Integer) where {S, Nij} @boundscheck (v >= 1 && 1 <= h <= get_Nh(data)) || throw(BoundsError(data, (v, h))) - slab(data, h) + fa = field_array(data) + sub_arrays = ntuple(Val(ncomponents(fa))) do jf + view(fa.arrays[jf], :, :, h) + end + dataview = FieldArray{field_dim(IJF)}(sub_arrays) + IJF{S, Nij, typeof(dataview)}(dataview) end +Base.@propagate_inbounds slab(data::IJFH, h::Integer) = slab(data, 1, h) + @inline function column(data::IJFH{S, Nij}, i, j, h) where {S, Nij} @boundscheck (1 <= j <= Nij && 1 <= i <= Nij && 1 <= h <= get_Nh(data)) || throw(BoundsError(data, (i, j, h)))