Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for linear indexing for pointwise kernels #1922

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions ext/cuda/data_layouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,16 @@ function Adapt.adapt_structure(
end,
)
end

import Adapt
import CUDA
function Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
bc::DataLayouts.NonExtrudedBroadcasted{Style},
) where {Style}
DataLayouts.NonExtrudedBroadcasted{Style}(
adapt_f(to, bc.f),
Adapt.adapt(to, bc.args),
Adapt.adapt(to, bc.axes),
)
end
23 changes: 23 additions & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,27 @@ empty_kernel_stats() = empty_kernel_stats(ClimaComms.device())
@inline get_Nij(::IJF{S, Nij}) where {S, Nij} = Nij
@inline get_Nij(::IF{S, Nij}) where {S, Nij} = Nij

# Returns the size of the backing array.
@inline array_size(::IJKFVH{S, Nij, Nk, Nv, Nh}) where {S, Nij, Nk, Nv, Nh} = (Nij, Nij, Nk, 1, Nv, Nh)
@inline array_size(::IJFH{S, Nij, Nh}) where {S, Nij, Nh} = (Nij, Nij, 1, Nh)
@inline array_size(::IFH{S, Ni, Nh}) where {S, Ni, Nh} = (Ni, 1, Nh)
@inline array_size(::DataF{S}) where {S} = (1,)
@inline array_size(::IJF{S, Nij}) where {S, Nij} = (Nij, Nij, 1)
@inline array_size(::IF{S, Ni}) where {S, Ni} = (Ni, 1)
@inline array_size(::VF{S, Nv}) where {S, Nv} = (Nv, 1)
@inline array_size(::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} = (Nv, Nij, Nij, 1, Nh)
@inline array_size(::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} = (Nv, Ni, 1, Nh)

@inline farray_size(data::IJKFVH{S, Nij, Nk, Nv, Nh}) where {S, Nij, Nk, Nv, Nh} = (Nij, Nij, Nk, ncomponents(data), Nv, Nh)
@inline farray_size(data::IJFH{S, Nij, Nh}) where {S, Nij, Nh} = (Nij, Nij, ncomponents(data), Nh)
@inline farray_size(data::IFH{S, Ni, Nh}) where {S, Ni, Nh} = (Ni, ncomponents(data), Nh)
@inline farray_size(data::DataF{S}) where {S} = (ncomponents(data),)
@inline farray_size(data::IJF{S, Nij}) where {S, Nij} = (Nij, Nij, ncomponents(data))
@inline farray_size(data::IF{S, Ni}) where {S, Ni} = (Ni, ncomponents(data))
@inline farray_size(data::VF{S, Nv}) where {S, Nv} = (Nv, ncomponents(data))
@inline farray_size(data::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} = (Nv, Nij, Nij, ncomponents(data), Nh)
@inline farray_size(data::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} = (Nv, Ni, ncomponents(data), Nh)

"""
field_dim(data::AbstractData)
field_dim(::Type{<:AbstractData})
Expand Down Expand Up @@ -1216,9 +1237,11 @@ _device_dispatch(x::AbstractData) = _device_dispatch(parent(x))
_device_dispatch(x::SArray) = ToCPU()
_device_dispatch(x::MArray) = ToCPU()

include("non_extruded_broadcasted.jl")
include("copyto.jl")
include("fused_copyto.jl")
include("fill.jl")
include("mapreduce.jl")
include("has_uniform_datalayouts.jl")

end # module
1 change: 1 addition & 0 deletions src/DataLayouts/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ DataSlab2DStyle(::Type{VIJFHStyle{Nv, Nij, Nh, A}}) where {Nv, Nij, Nh, A} =
#####

#! format: off
const BroadcastedUnionData = Union{Base.Broadcast.Broadcasted{<:DataStyle}, AbstractData}
const BroadcastedUnionIJFH{S, Nij, Nh, A} = Union{Base.Broadcast.Broadcasted{IJFHStyle{Nij, Nh, A}}, IJFH{S, Nij, Nh, A}}
const BroadcastedUnionIFH{S, Ni, Nh, A} = Union{Base.Broadcast.Broadcasted{IFHStyle{Ni, Nh, A}}, IFH{S, Ni, Nh, A}}
const BroadcastedUnionIJF{S, Nij, A} = Union{Base.Broadcast.Broadcasted{IJFStyle{Nij, A}}, IJF{S, Nij, A}}
Expand Down
18 changes: 15 additions & 3 deletions src/DataLayouts/copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,22 @@
##### Dispatching and edge cases
#####

Base.copyto!(
dest::AbstractData,
function Base.copyto!(
dest::AbstractData{S},
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
) = Base.copyto!(dest, bc, device_dispatch(dest))
) where {S}
dev = device_dispatch(dest)
if dev isa ToCPU && has_uniform_datalayouts(bc) && !(dest isa DataF)
# Specialize on linear indexing case:
bc′ = Base.Broadcast.instantiate(to_non_extruded_broadcasted(bc))
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
dest[I] = convert(S, bc′[I])
end
else
Base.copyto!(dest, bc, device_dispatch(dest))
end
return dest
end

# Specialize on non-Broadcasted objects
function Base.copyto!(dest::D, src::D) where {D <: AbstractData}
Expand Down
61 changes: 7 additions & 54 deletions src/DataLayouts/fill.jl
Original file line number Diff line number Diff line change
@@ -1,60 +1,13 @@
function Base.fill!(data::IJFH, val, ::ToCPU)
(_, _, _, _, Nh) = size(data)
@inbounds for h in 1:Nh
fill!(slab(data, h), val)
function Base.fill!(dest::AbstractData, val, ::ToCPU)
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
dest[I] = val
end
return data
return dest
end

function Base.fill!(data::IFH, val, ::ToCPU)
(_, _, _, _, Nh) = size(data)
@inbounds for h in 1:Nh
fill!(slab(data, h), val)
end
return data
end

function Base.fill!(data::DataF, val, ::ToCPU)
@inbounds data[] = val
return data
end

function Base.fill!(data::IJF{S, Nij}, val, ::ToCPU) where {S, Nij}
@inbounds for j in 1:Nij, i in 1:Nij
data[CartesianIndex(i, j, 1, 1, 1)] = val
end
return data
end

function Base.fill!(data::IF{S, Ni}, val, ::ToCPU) where {S, Ni}
@inbounds for i in 1:Ni
data[CartesianIndex(i, 1, 1, 1, 1)] = val
end
return data
end

function Base.fill!(data::VF, val, ::ToCPU)
Nv = nlevels(data)
@inbounds for v in 1:Nv
data[CartesianIndex(1, 1, 1, v, 1)] = val
end
return data
end

function Base.fill!(data::VIJFH, val, ::ToCPU)
(Ni, Nj, _, Nv, Nh) = size(data)
@inbounds for h in 1:Nh, v in 1:Nv
fill!(slab(data, v, h), val)
end
return data
end

function Base.fill!(data::VIFH, val, ::ToCPU)
(Ni, _, _, Nv, Nh) = size(data)
@inbounds for h in 1:Nh, v in 1:Nv
fill!(slab(data, v, h), val)
end
return data
function Base.fill!(dest::DataF, val, ::ToCPU)
@inbounds dest[] = val
return dest
end

Base.fill!(dest::AbstractData, val) =
Expand Down
60 changes: 60 additions & 0 deletions src/DataLayouts/has_uniform_datalayouts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
@inline function first_datalayout_in_bc(args::Tuple, rargs...)
x1 = first_datalayout_in_bc(args[1], rargs...)
x1 isa AbstractData && return x1
return first_datalayout_in_bc(Base.tail(args), rargs...)
end

@inline first_datalayout_in_bc(args::Tuple{Any}, rargs...) =
first_datalayout_in_bc(args[1], rargs...)
@inline first_datalayout_in_bc(args::Tuple{}, rargs...) = nothing
@inline first_datalayout_in_bc(x) = nothing
@inline first_datalayout_in_bc(x::AbstractData) = x

@inline first_datalayout_in_bc(bc::Base.Broadcast.Broadcasted) =
first_datalayout_in_bc(bc.args)

@inline _has_uniform_datalayouts_args(truesofar, start, args::Tuple, rargs...) =
truesofar &&
_has_uniform_datalayouts(truesofar, start, args[1], rargs...) &&
_has_uniform_datalayouts_args(truesofar, start, Base.tail(args), rargs...)

@inline _has_uniform_datalayouts_args(
truesofar,
start,
args::Tuple{Any},
rargs...,
) = truesofar && _has_uniform_datalayouts(truesofar, start, args[1], rargs...)
@inline _has_uniform_datalayouts_args(truesofar, _, args::Tuple{}, rargs...) =
truesofar

@inline function _has_uniform_datalayouts(
truesofar,
start,
bc::Base.Broadcast.Broadcasted,
)
return truesofar && _has_uniform_datalayouts_args(truesofar, start, bc.args)
end
for DL in (:IJKFVH, :IJFH, :IFH, :DataF, :IJF, :IF, :VF, :VIJFH, :VIFH)
@eval begin
@inline _has_uniform_datalayouts(truesofar, ::$(DL), ::$(DL)) = true
end
end
@inline _has_uniform_datalayouts(truesofar, _, x::AbstractData) = false
@inline _has_uniform_datalayouts(truesofar, _, x) = truesofar

"""
has_uniform_datalayouts
Find the first datalayout in the broadcast expression (BCE),
and compares against every other datalayout in the BCE. Returns
- `true` if the broadcasted object has only a single kind of datalayout (e.g. VF,VF, VIJFH,VIJFH)
- `false` if the broadcasted object has multiple kinds of datalayouts (e.g. VIJFH, VIFH)
Note: a broadcasted object can have different _types_,
e.g., `VIFJH{Float64}` and `VIFJH{Tuple{Float64,Float64}}`
but not different kinds, e.g., `VIFJH{Float64}` and `VF{Float64}`.
"""
function has_uniform_datalayouts end

@inline has_uniform_datalayouts(bc::Base.Broadcast.Broadcasted) =
_has_uniform_datalayouts_args(true, first_datalayout_in_bc(bc), bc.args)

@inline has_uniform_datalayouts(bc::AbstractData) = true
160 changes: 160 additions & 0 deletions src/DataLayouts/non_extruded_broadcasted.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#! format: off
# ============================================================ Adapted from Base.Broadcast (julia version 1.10.4)
import Base.Broadcast: BroadcastStyle
struct NonExtrudedBroadcasted{
Style <: Union{Nothing, BroadcastStyle},
Axes,
F,
Args <: Tuple,
} <: Base.AbstractBroadcasted
style::Style
f::F
args::Args
axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `NonExtrudedBroadcasted`)

NonExtrudedBroadcasted(style::Union{Nothing, BroadcastStyle}, f::Tuple, args::Tuple) =
error() # disambiguation: tuple is not callable
function NonExtrudedBroadcasted(
style::Union{Nothing, BroadcastStyle},
f::F,
args::Tuple,
axes = nothing,
) where {F}
# using Core.Typeof rather than F preserves inferrability when f is a type
return new{typeof(style), typeof(axes), Core.Typeof(f), typeof(args)}(
style,
f,
args,
axes,
)
end
function NonExtrudedBroadcasted(f::F, args::Tuple, axes = nothing) where {F}
NonExtrudedBroadcasted(combine_styles(args...)::BroadcastStyle, f, args, axes)
end
function NonExtrudedBroadcasted{Style}(f::F, args, axes = nothing) where {Style, F}
return new{Style, typeof(axes), Core.Typeof(f), typeof(args)}(
Style()::Style,
f,
args,
axes,
)
end
function NonExtrudedBroadcasted{Style, Axes, F, Args}(
f,
args,
axes,
) where {Style, Axes, F, Args}
return new{Style, Axes, F, Args}(Style()::Style, f, args, axes)
end
end

@inline to_non_extruded_broadcasted(bc::Base.Broadcast.Broadcasted) =
NonExtrudedBroadcasted(bc.style, bc.f, to_non_extruded_broadcasted(bc.args), bc.axes)
@inline to_non_extruded_broadcasted(x) = x
NonExtrudedBroadcasted(bc::Base.Broadcast.Broadcasted) = to_non_extruded_broadcasted(bc)

@inline to_non_extruded_broadcasted(args::Tuple) = (
to_non_extruded_broadcasted(args[1]),
to_non_extruded_broadcasted(Base.tail(args))...,
)
@inline to_non_extruded_broadcasted(args::Tuple{Any}) =
(to_non_extruded_broadcasted(args[1]),)
@inline to_non_extruded_broadcasted(args::Tuple{}) = ()

@inline _checkbounds(bc, _, I) = nothing # TODO: fix this case
@inline _checkbounds(bc, ::Tuple, I) = Base.checkbounds(bc, I)
@inline function Base.getindex(
bc::NonExtrudedBroadcasted,
I::Union{Integer, CartesianIndex},
)
@boundscheck _checkbounds(bc, axes(bc), I) # is this really the only issue?
@inbounds _broadcast_getindex(bc, I)
end

# --- here, we define our own bounds checks
@inline function Base.checkbounds(bc::NonExtrudedBroadcasted, I::Integer)
# Base.checkbounds_indices(Bool, axes(bc), (I,)) || Base.throw_boundserror(bc, (I,)) # from Base
Base.checkbounds_indices(Bool, (Base.OneTo(n_dofs(bc)),), (I,)) || Base.throw_boundserror(bc, (I,))
end

import StaticArrays
to_tuple(t::Tuple) = t
to_tuple(t::NTuple{N, <: Base.OneTo}) where {N} = map(x->x.stop, t)
to_tuple(t::NTuple{N, <: StaticArrays.SOneTo}) where {N} = map(x->x.stop, t)
n_dofs(bc::NonExtrudedBroadcasted) = prod(to_tuple(axes(bc)))
# ---

Base.@propagate_inbounds _broadcast_getindex(
A::Union{Ref, AbstractArray{<:Any, 0}, Number},
I::Integer,
) = A[] # Scalar-likes can just ignore all indices
Base.@propagate_inbounds _broadcast_getindex(
::Ref{Type{T}},
I::Integer,
) where {T} = T
# Tuples are statically known to be singleton or vector-like
Base.@propagate_inbounds _broadcast_getindex(A::Tuple{Any}, I::Integer) = A[1]
Base.@propagate_inbounds _broadcast_getindex(A::Tuple, I::Integer) = A[I[1]]
# Everything else falls back to dynamically dropping broadcasted indices based upon its axes
# Base.@propagate_inbounds _broadcast_getindex(A, I) = A[newindex(A, I)]
Base.@propagate_inbounds _broadcast_getindex(A, I::Integer) = A[I]
Base.@propagate_inbounds function _broadcast_getindex(
bc::NonExtrudedBroadcasted{<:Any, <:Any, <:Any, <:Any},
I::Integer,
)
args = _getindex(bc.args, I)
return _broadcast_getindex_evalf(bc.f, args...)
end
@inline _broadcast_getindex_evalf(f::Tf, args::Vararg{Any, N}) where {Tf, N} =
f(args...) # not propagate_inbounds
Base.@propagate_inbounds _getindex(args::Tuple, I) =
(_broadcast_getindex(args[1], I), _getindex(Base.tail(args), I)...)
Base.@propagate_inbounds _getindex(args::Tuple{Any}, I) =
(_broadcast_getindex(args[1], I),)
Base.@propagate_inbounds _getindex(args::Tuple{}, I) = ()

@inline Base.axes(bc::NonExtrudedBroadcasted) = _axes(bc, bc.axes)
_axes(::NonExtrudedBroadcasted, axes::Tuple) = axes
@inline _axes(bc::NonExtrudedBroadcasted, ::Nothing) = Base.Broadcast.combine_axes(bc.args...)
_axes(bc::NonExtrudedBroadcasted{<:Base.Broadcast.AbstractArrayStyle{0}}, ::Nothing) = ()
@inline Base.axes(bc::NonExtrudedBroadcasted{<:Any, <:NTuple{N}}, d::Integer) where {N} =
d <= N ? axes(bc)[d] : OneTo(1)
Base.IndexStyle(::Type{<:NonExtrudedBroadcasted{<:Any, <:Tuple{Any}}}) = IndexLinear()
@inline _axes(::NonExtrudedBroadcasted, axes) = axes
@inline Base.eltype(bc::NonExtrudedBroadcasted) = Base.Broadcast.combine_axes(bc.args...)


# ============================================================

#! format: on
# Datalayouts
@propagate_inbounds function linear_getindex(
data::AbstractData{S},
I::Integer,
) where {S}
s_array = farray_size(data)
ss = StaticSize(s_array, field_dim(data))
@inbounds get_struct_linear(parent(data), S, Val(field_dim(data)), I, ss)
end
@propagate_inbounds function linear_setindex!(
data::AbstractData{S},
val,
I::Integer,
) where {S}
s_array = farray_size(data)
ss = StaticSize(s_array, field_dim(data))
@inbounds set_struct_linear!(
parent(data),
convert(S, val),
Val(field_dim(data)),
I,
ss,
)
end

for DL in (:IJKFVH, :IJFH, :IFH, :IJF, :IF, :VF, :VIJFH, :VIFH) # Skip DataF, since we want that to MethodError.
@eval @propagate_inbounds Base.getindex(data::$(DL), I::Integer) =
linear_getindex(data, I)
@eval @propagate_inbounds Base.setindex!(data::$(DL), val, I::Integer) =
linear_setindex!(data, val, I)
end
Loading
Loading