Skip to content

Commit

Permalink
Define CuArray on FieldArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Aug 31, 2024
1 parent 5df8fbb commit ed49dd8
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
14 changes: 14 additions & 0 deletions ext/cuda/data_layouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,17 @@ function Adapt.adapt_structure(
Adapt.adapt(to, bc.axes),
)
end

import ClimaCore.DataLayouts as DL
import CUDA
function CUDA.CuArray(fa::DL.FieldArray{FD}) where {FD}
arrays = ntuple(Val(DL.ncomponents(fa))) do f
CUDA.CuArray(fa.arrays[f])
end
return DL.FieldArray{FD}(arrays)
end

DL.field_array(
array::CUDA.CuArray,
as::ArraySize
) = CUDA.CuArray(DL.field_array(Array(array), as))
4 changes: 2 additions & 2 deletions ext/cuda/data_layouts_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ function cuda_copyto!(dest::AbstractData, bc)
auto_launch!(
knl_copyto_linear!,
(dest, bc′, us),
nitems;
n;
auto = true,
)
else
auto_launch!(knl_copyto_flat!, (dest, bc, us), nitems; auto = true)
auto_launch!(knl_copyto_flat!, (dest, bc, us), n; auto = true)
end
end
return dest
Expand Down
11 changes: 8 additions & 3 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,14 +401,19 @@ end

Base.length(data::IJFH) = get_Nh(data)

Base.@propagate_inbounds slab(data::IJFH, h::Integer) = slab(data, 1, h)

@inline function slab(data::IJFH{S, Nij}, v::Integer, h::Integer) where {S, Nij}
@boundscheck (v >= 1 && 1 <= h <= get_Nh(data)) ||
throw(BoundsError(data, (v, h)))
slab(data, h)
fa = field_array(data)
sub_arrays = ntuple(Val(ncomponents(fa))) do jf
view(fa.arrays[jf], :, :, h)
end
dataview = FieldArray{field_dim(IJF)}(sub_arrays)
IJF{S, Nij, typeof(dataview)}(dataview)
end

Base.@propagate_inbounds slab(data::IJFH, h::Integer) = slab(data, 1, h)

@inline function column(data::IJFH{S, Nij}, i, j, h) where {S, Nij}
@boundscheck (1 <= j <= Nij && 1 <= i <= Nij && 1 <= h <= get_Nh(data)) ||
throw(BoundsError(data, (i, j, h)))
Expand Down

0 comments on commit ed49dd8

Please sign in to comment.