Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Aug 23, 2024
1 parent 236262b commit 24ef689
Show file tree
Hide file tree
Showing 31 changed files with 592 additions and 330 deletions.
1 change: 1 addition & 0 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import ClimaCore.RecursiveApply:
, , , radd, rmul, rsub, rdiv, rmap, rzero, rmin, rmax
import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh
import ClimaCore.DataLayouts: universal_size, UniversalSize
import ClimaCore.DataLayouts: ArraySize

include(joinpath("cuda", "cuda_utils.jl"))
include(joinpath("cuda", "data_layouts.jl"))
Expand Down
59 changes: 33 additions & 26 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,18 @@ function Base.show(io::IO, data::AbstractData)
(rows, cols) = displaysize(io)
println(io, summary(data))
print(io, " "^indent_width)
# @show similar(parent_array_type(data))
# fa = map(x -> vec(x), field_arrays(data))
print(
IOContext(
io,
:compact => true,
:limit => true,
:displaysize => (rows, cols - indent_width),
),
map(x -> vec(x), field_arrays(data)),
# collect(field_array(data)),
parent(data),
# map(x -> vec(x), field_arrays(data)),
)
return io
end
Expand Down Expand Up @@ -619,10 +623,7 @@ function IJF{S, Nij}(::Type{MArray}, ::Type{T}) where {S, Nij, T}
array = FieldArray{field_dim(IJF)}(ntuple(f->MArray{Tuple{Nij, Nij}, T, 2, Nij * Nij}(undef), Nf))
IJF{S, Nij}(array)
end
function SArray(ijf::IJF{S, Nij, FieldArray{FD, N, T}}) where {S, Nij, FD, N, T <: MArray}
IJF{S, Nij}(SArray(field_array(ijf)))
end
function SArray(ijf::IJF{S, Nij, <:MArray}) where {S, Nij}
function SArray(ijf::IJF{S, Nij, <:FieldArray}) where {S, Nij}
IJF{S, Nij}(SArray(field_array(ijf)))
end

Expand Down Expand Up @@ -681,15 +682,15 @@ end
function IF{S, Ni}(::Type{MArray}, ::Type{T}) where {S, Ni, T}
Nf = typesize(T, S)
# array = MArray{Tuple{Ni, Nf}, T, 2, Ni * Nf}(undef)
array = FieldArray{field_dim(IF)}(ntuple(f->MArray{Tuple{Ni}, T, 1, Ni}(undef), Nf))
IF{S, Ni}(array)
fa = FieldArray{field_dim(IF)}(ntuple(f->MArray{Tuple{Ni}, T, 1, Ni}(undef), Nf))
IF{S, Ni}(fa)
end
function SArray(data::IF{S, Ni, <:FieldArray{<:Any, <:Any, T}}) where {S, Ni, T <: MArray}
IF{S, Ni}(SArray(field_array(data)))
end
function SArray(data::IF{S, Ni, <:MArray}) where {S, Ni}
function SArray(data::IF{S, Ni, <:FieldArray}) where {S, Ni}
IF{S, Ni}(SArray(field_array(data)))
end
# function SArray(data::IF{S, Ni, <:MArray}) where {S, Ni}
# IF{S, Ni}(SArray(field_array(data)))
# end

@inline function column(data::IF{S, Ni}, i) where {S, Ni}
@boundscheck (1 <= i <= Ni) || throw(BoundsError(data, (i,)))
Expand Down Expand Up @@ -816,14 +817,16 @@ Base.length(data::VIJFH) = get_Nv(data) * get_Nh(data)
@boundscheck (1 <= v <= Nv && 1 <= h <= Nh) ||
throw(BoundsError(data, (v, h)))
Nf = ncomponents(data)
dataview = @inbounds view(
array,
v,
Base.Slice(Base.OneTo(Nij)),
Base.Slice(Base.OneTo(Nij)),
Base.Slice(Base.OneTo(Nf)),
h,
)
sub_arrays = @inbounds ntuple(Nf) do f
view(
array.arrays[f],
v,
Base.Slice(Base.OneTo(Nij)),
Base.Slice(Base.OneTo(Nij)),
h,
)
end
dataview = FieldArray{field_dim(IJF)}(sub_arrays)
IJF{S, Nij}(dataview)
end

Expand Down Expand Up @@ -1113,11 +1116,15 @@ type parameters.
@inline field_dim(::Type{<:VIJFH}) = 4
@inline field_dim(::Type{<:VIFH}) = 3

@inline to_data_specific_field_array(::IJFH, I::CartesianIndex{5}) = CartesianIndex(I.I[1], I.I[2], I.I[5])
@inline to_data_specific_field_array(::IFH, I::CartesianIndex{5}) = CartesianIndex(I.I[1], I.I[5])
@inline to_data_specific_field_array(::VIJFH, I::CartesianIndex{5}) = CartesianIndex(I.I[4], I.I[1], I.I[2], I.I[5])
@inline to_data_specific_field_array(::VIFH, I::CartesianIndex{5}) = CartesianIndex(I.I[4], I.I[1], I.I[5])
@inline to_data_specific_field_array(::DataSlab1D, I::CartesianIndex{5}) = CartesianIndex(I.I[1], I.I[1], I.I[5])
@inline to_data_specific_field_array(data::AbstractData, I::CartesianIndex) =
CartesianIndex(_to_data_specific_field_array(data, I.I))
@inline _to_data_specific_field_array(::VF, I::Tuple) = (I[4],)
@inline _to_data_specific_field_array(::IF, I::Tuple) = (I[1],)
@inline _to_data_specific_field_array(::IJF, I::Tuple) = (I[1], I[2])
@inline _to_data_specific_field_array(::IJFH, I::Tuple) = (I[1], I[2], I[5])
@inline _to_data_specific_field_array(::IFH, I::Tuple) = (I[1], I[5])
@inline _to_data_specific_field_array(::VIJFH, I::Tuple) = (I[4], I[1], I[2], I[5])
@inline _to_data_specific_field_array(::VIFH, I::Tuple) = (I[4], I[1], I[5])

@inline to_data_specific(data::AbstractData, I::CartesianIndex) =
CartesianIndex(_to_data_specific(data, I.I))
Expand Down Expand Up @@ -1349,7 +1356,7 @@ field_array(data::AbstractData{S}) where {S} = parent(data)
parent(data),
eltype(data),
Val(field_dim(data)),
to_data_specific(data, I),
to_data_specific_field_array(data, I),
)
end

Expand All @@ -1363,7 +1370,7 @@ end
parent(data),
convert(eltype(data), val),
Val(field_dim(data)),
to_data_specific(data, I),
to_data_specific_field_array(data, I),
)
end

Expand Down
86 changes: 57 additions & 29 deletions src/DataLayouts/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import MultiBroadcastFusion as MBF
import MultiBroadcastFusion: fused_direct
import ..RecursiveApply

# Make a MultiBroadcastFusion type, `FusedMultiBroadcast`, and macro, `@fused`:
# via https://github.com/CliMA/MultiBroadcastFusion.jl
Expand All @@ -11,6 +12,25 @@ MBF.@make_fused fused_direct FusedMultiBroadcast fused_direct

abstract type DataStyle <: Base.BroadcastStyle end

"""
parent_array_type
Returns a UnionAll array type given the inputs.
For example: `Array`, `CuArray` etc.
# Note
The returned type must be a UnionAll array type
because we need to be able to promote broadcast
expressions with fields containing different number
of variables. The number of fields returns depends
on the function being broadcasted over, and we do
not have this number here.
# TODO: make this note more precise
"""
function parent_array_type end

abstract type Data0DStyle <: DataStyle end
struct DataFStyle{A} <: Data0DStyle end
DataStyle(::Type{DataF{S, A}}) where {S, A} = DataFStyle{parent_array_type(A)}()
Expand Down Expand Up @@ -291,45 +311,59 @@ function Base.similar(
bc::BroadcastedUnionDataF{<:Any, A},
::Type{Eltype},
) where {A, Eltype}
PA = parent_array_type(A)
array = similar(PA, (typesize(eltype(A), Eltype)))
return DataF{Eltype}(array)
Nf = typesize(eltype(A), Eltype)
_size = ()
as = ArraySize{field_dim(DataF), Nf, _size}()
fa = similar(rebuild_field_array_type(A, as), _size)
return DataF{Eltype}(fa)
end

function Base.similar(
bc::BroadcastedUnionIJFH{<:Any, Nij, Nh, A},
::Type{Eltype},
) where {Nij, Nh, A, Eltype}
PA = parent_array_type(A)
array = similar(PA, (Nij, Nij, typesize(eltype(A), Eltype), Nh))
return IJFH{Eltype, Nij, Nh}(array)
Nf = typesize(eltype(A), Eltype)
_size = (Nij, Nij, Nh)
as = ArraySize{field_dim(IJFH), Nf, _size}()
fa = similar(rebuild_field_array_type(A, as), _size)
return IJFH{Eltype, Nij, Nh}(fa)
end

function Base.similar(
bc::BroadcastedUnionIFH{<:Any, Ni, Nh, A},
::Type{Eltype},
) where {Ni, Nh, A, Eltype}
PA = parent_array_type(A)
array = similar(PA, (Ni, typesize(eltype(A), Eltype), Nh))
return IFH{Eltype, Ni, Nh}(array)
Nf = typesize(eltype(A), Eltype)
_size = (Ni, Nh)
as = ArraySize{field_dim(IFH), Nf, _size}()
fa = similar(rebuild_field_array_type(A, as), _size)
return IFH{Eltype, Ni, Nh}(fa)
end

function Base.similar(
::BroadcastedUnionIJF{<:Any, Nij, A},
::Type{Eltype},
) where {Nij, A, Eltype}
Nf = typesize(eltype(A), Eltype)
array = MArray{Tuple{Nij, Nij, Nf}, eltype(A), 3, Nij * Nij * Nf}(undef)
return IJF{Eltype, Nij}(array)
# array = MArray{Tuple{Nij, Nij, Nf}, eltype(A), 3, Nij * Nij * Nf}(undef)
MAT = MArray{Tuple{Nij, Nij}, eltype(A), 2, Nij * Nij}
_size = (Nij, Nij)
as = ArraySize{field_dim(IJF), Nf, ()}()
fa = similar(rebuild_field_array_type(A, as, MAT), _size)
return IJF{Eltype, Nij}(fa)
end

function Base.similar(
::BroadcastedUnionIF{<:Any, Ni, A},
::Type{Eltype},
) where {Ni, A, Eltype}
Nf = typesize(eltype(A), Eltype)
array = MArray{Tuple{Ni, Nf}, eltype(A), 2, Ni * Nf}(undef)
return IF{Eltype, Ni}(array)
# array = MArray{Tuple{Ni, Nf}, eltype(A), 2, Ni * Nf}(undef)
MAT = MArray{Tuple{Ni}, eltype(A), 2, Ni}
_size = (Ni, )
as = ArraySize{field_dim(IF), Nf, ()}() # size is unused
fa = similar(rebuild_field_array_type(A, as, MAT), _size)
return IF{Eltype, Ni}(fa)
end

Base.similar(
Expand All @@ -342,12 +376,10 @@ function Base.similar(
::Type{Eltype},
::Val{newNv},
) where {Nv, A, Eltype, newNv}
PA = parent_array_type(A)
# @show PA
Nf = typesize(eltype(A), Eltype)
# @show (newNv, Nf)
# array = similar(PA, (newNv, Nf))
fa = FieldArray{field_dim(VF)}(ntuple(i -> similar(PA, newNv), Nf))
_size = (newNv, )
as = ArraySize{field_dim(VF), Nf, _size}()
fa = similar(rebuild_field_array_type(A, as), _size)
return VF{Eltype, newNv, typeof(fa)}(fa)
end

Expand All @@ -361,9 +393,11 @@ function Base.similar(
::Type{Eltype},
::Val{newNv},
) where {Nv, Ni, Nh, A, Eltype, newNv}
PA = parent_array_type(A)
array = similar(PA, (newNv, Ni, typesize(eltype(A), Eltype), Nh))
return VIFH{Eltype, newNv, Ni, Nh}(array)
Nf = typesize(eltype(A), Eltype)
_size = (newNv, Ni, Nh)
as = ArraySize{field_dim(VIFH), Nf, _size}()
fa = similar(rebuild_field_array_type(A, as), _size)
return VIFH{Eltype, newNv, Ni, Nh}(fa)
end

Base.similar(
Expand All @@ -378,16 +412,10 @@ function Base.similar(
) where {Nv, Nij, Nh, A, Eltype, newNv}
T = eltype(A)
Nf = typesize(eltype(A), Eltype)
# fat = rebuild_type(A, Val(field_dim(VIJFH)), Val(Nf), Val(4))
_size = (newNv, Nij, Nij, Nh)
as = ArraySize{field_dim(VIJFH), Nf, _size}()
# fat = if A isa AbstractArray
# field_array_type(A, as)
# else
# end
array = similar(rebuild_field_array_type(A, as), _size)
vd = VIJFH{Eltype, newNv, Nij, Nh}(array)
return vd
fa = similar(rebuild_field_array_type(A, as), _size)
return VIJFH{Eltype, newNv, Nij, Nh}(fa)
end

# ============= FusedMultiBroadcast
Expand Down
7 changes: 5 additions & 2 deletions src/DataLayouts/copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
##### Dispatching and edge cases
#####

Base.copyto!(
function Base.copyto!(
dest::AbstractData,
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
) = Base.copyto!(dest, bc, device_dispatch(dest))
)
ncomponents(dest) > 0 || return dest
Base.copyto!(dest, bc, device_dispatch(dest))
end

# Specialize on non-Broadcasted objects
function Base.copyto!(dest::D, src::D) where {D <: AbstractData}
Expand Down
Loading

0 comments on commit 24ef689

Please sign in to comment.