Skip to content

Commit

Permalink
Use Broadcast.flatten on master
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jun 4, 2024
1 parent ca04c18 commit ec919be
Show file tree
Hide file tree
Showing 14 changed files with 117 additions and 22 deletions.
1 change: 1 addition & 0 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module ClimaCoreCUDAExt

import NVTX
import ClimaComms
import ClimaCore: broadcast_flatten
import ClimaCore: DataLayouts, Grids, Spaces, Fields
import ClimaCore: Geometry
import ClimaCore.Geometry: AxisTensor
Expand Down
12 changes: 8 additions & 4 deletions ext/cuda/data_layouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ end

function Base.copyto!(
dest::IJFH{S, Nij},
bc::Union{IJFH{S, Nij, A}, Base.Broadcast.Broadcasted{IJFHStyle{Nij, A}}},
bc::Union{IJFH{S, Nij, A}, Base.Broadcast.Broadcasted{IJFHStyle{Nij, A}}},
) where {S, Nij, A <: CUDA.CuArray}
bc = broadcast_flatten(bc′)
_, _, _, _, Nh = size(bc)
if Nh > 0
auto_launch!(
Expand Down Expand Up @@ -99,11 +100,12 @@ end

function Base.copyto!(
dest::VIJFH{S, Nv, Nij},
bc::Union{
bc::Union{
VIJFH{S, Nv, Nij, A},
Base.Broadcast.Broadcasted{VIJFHStyle{Nv, Nij, A}},
},
) where {S, Nv, Nij, A <: CUDA.CuArray}
bc = broadcast_flatten(bc′)
_, _, _, _, Nh = size(bc)
if Nv > 0 && Nh > 0
Nv_per_block = min(Nv, fld(256, Nij * Nij))
Expand Down Expand Up @@ -140,8 +142,9 @@ end

function Base.copyto!(
dest::VF{S, Nv},
bc::Union{VF{S, Nv, A}, Base.Broadcast.Broadcasted{VFStyle{Nv, A}}},
bc::Union{VF{S, Nv, A}, Base.Broadcast.Broadcasted{VFStyle{Nv, A}}},
) where {S, Nv, A <: CUDA.CuArray}
bc = broadcast_flatten(bc′)
_, _, _, _, Nh = size(dest)
if Nv > 0 && Nh > 0
auto_launch!(
Expand Down Expand Up @@ -170,8 +173,9 @@ end

function Base.copyto!(
dest::DataF{S},
bc::Union{DataF{S, A}, Base.Broadcast.Broadcasted{DataFStyle{A}}},
bc::Union{DataF{S, A}, Base.Broadcast.Broadcasted{DataFStyle{A}}},
) where {S, A <: CUDA.CuArray}
bc = broadcast_flatten(bc′)
auto_launch!(
knl_copyto!,
(dest, bc),
Expand Down
3 changes: 2 additions & 1 deletion ext/cuda/operators_finite_difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ AbstractStencilStyle(::ClimaComms.CUDADevice) = CUDAColumnStencilStyle

function Base.copyto!(
out::Field,
bc::Union{
bc::Union{
StencilBroadcasted{CUDAColumnStencilStyle},
Broadcasted{CUDAColumnStencilStyle},
},
)
bc = broadcast_flatten(bc′)
space = axes(out)
if space isa Spaces.ExtrudedFiniteDifferenceSpace
QS = Spaces.quadrature_style(space)
Expand Down
3 changes: 2 additions & 1 deletion ext/cuda/operators_spectral_element.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ end

function Base.copyto!(
out::Field,
sbc::Union{
sbc::Union{
SpectralBroadcasted{CUDASpectralStyle},
Broadcasted{CUDASpectralStyle},
},
)
sbc = broadcast_flatten(sbc′)
space = axes(out)
QS = Spaces.quadrature_style(space)
Nq = Quadratures.degrees_of_freedom(QS)
Expand Down
1 change: 1 addition & 0 deletions src/ClimaCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using PkgVersion
const VERSION = PkgVersion.@Version
import ClimaComms

include("upstream.jl")
include("interface.jl")
include("devices.jl")
include("Utilities/Utilities.jl")
Expand Down
1 change: 1 addition & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import ClimaComms
import MultiBroadcastFusion as MBF
import Adapt

import ..broadcast_flatten
import ..slab, ..slab_args, ..column, ..column_args, ..level
export slab, column, level, IJFH, IJF, IFH, IF, VF, VIJFH, VIFH, DataF

Expand Down
29 changes: 19 additions & 10 deletions src/DataLayouts/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,12 @@ end
# Performance optimization for the common identity scalar case: dest .= val
function Base.copyto!(
dest::AbstractData,
bc::Base.Broadcast.Broadcasted{Style},
bc::Base.Broadcast.Broadcasted{Style},
) where {
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
}
bc = broadcast_flatten(bc′)
bc = Base.Broadcast.instantiate(
Base.Broadcast.Broadcasted{Style}(bc.f, bc.args, ()),
)
Expand All @@ -481,16 +482,18 @@ end

function Base.copyto!(
dest::DataF{S},
bc::Union{DataF{S, A}, Base.Broadcast.Broadcasted{DataFStyle{A}}},
bc::Union{DataF{S, A}, Base.Broadcast.Broadcasted{DataFStyle{A}}},
) where {S, A}
bc = broadcast_flatten(bc′)
@inbounds dest[] = convert(S, bc[])
return dest
end

function Base.copyto!(
dest::IJFH{S, Nij},
bc::Union{IJFH{S, Nij}, Base.Broadcast.Broadcasted{<:IJFHStyle{Nij}}},
bc::Union{IJFH{S, Nij}, Base.Broadcast.Broadcasted{<:IJFHStyle{Nij}}},
) where {S, Nij}
bc = broadcast_flatten(bc′)
_, _, _, _, Nh = size(bc)
@inbounds for h in 1:Nh
slab_dest = slab(dest, h)
Expand All @@ -502,8 +505,9 @@ end

function Base.copyto!(
dest::IFH{S, Ni},
bc::Union{IFH{S, Ni}, Base.Broadcast.Broadcasted{<:IFHStyle{Ni}}},
bc::Union{IFH{S, Ni}, Base.Broadcast.Broadcasted{<:IFHStyle{Ni}}},
) where {S, Ni}
bc = broadcast_flatten(bc′)
_, _, _, _, Nh = size(bc)
@inbounds for h in 1:Nh
slab_dest = slab(dest, h)
Expand All @@ -516,8 +520,9 @@ end
# inline inner slab(::DataSlab2D) copy
function Base.copyto!(
dest::IJF{S, Nij},
bc::Union{IJF{S, Nij, A}, Base.Broadcast.Broadcasted{IJFStyle{Nij, A}}},
bc::Union{IJF{S, Nij, A}, Base.Broadcast.Broadcasted{IJFStyle{Nij, A}}},
) where {S, Nij, A}
bc = broadcast_flatten(bc′)
@inbounds for j in 1:Nij, i in 1:Nij
idx = CartesianIndex(i, j, 1, 1, 1)
dest[idx] = convert(S, bc[idx])
Expand All @@ -528,8 +533,9 @@ end
# inline inner slab(::DataSlab1D) copy
function Base.copyto!(
dest::IF{S, Ni},
bc::Base.Broadcast.Broadcasted{IFStyle{Ni, A}},
bc::Base.Broadcast.Broadcasted{IFStyle{Ni, A}},
) where {S, Ni, A}
bc = broadcast_flatten(bc′)
@inbounds for i in 1:Ni
idx = CartesianIndex(i, 1, 1, 1, 1)
dest[idx] = convert(S, bc[idx])
Expand All @@ -540,8 +546,9 @@ end
# inline inner column(::DataColumn) copy
function Base.copyto!(
dest::VF{S, Nv},
bc::Union{VF{S, Nv, A}, Base.Broadcast.Broadcasted{VFStyle{Nv, A}}},
bc::Union{VF{S, Nv, A}, Base.Broadcast.Broadcasted{VFStyle{Nv, A}}},
) where {S, Nv, A}
bc = broadcast_flatten(bc′)
@inbounds for v in 1:Nv
idx = CartesianIndex(1, 1, 1, v, 1)
dest[idx] = convert(S, bc[idx])
Expand Down Expand Up @@ -594,8 +601,9 @@ end

function Base.copyto!(
dest::VIFH{S, Nv, Ni},
bc::Base.Broadcast.Broadcasted{VIFHStyle{Nv, Ni, A}},
bc::Base.Broadcast.Broadcasted{VIFHStyle{Nv, Ni, A}},
) where {S, Nv, Ni, A}
bc = broadcast_flatten(bc′)
return _serial_copyto!(dest, bc)
end

Expand Down Expand Up @@ -644,8 +652,9 @@ end

function Base.copyto!(
dest::VIJFH{S, Nv, Nij},
bc::Base.Broadcast.Broadcasted{VIJFHStyle{Nv, Nij, A}},
bc::Base.Broadcast.Broadcasted{VIJFHStyle{Nv, Nij, A}},
) where {S, Nv, Nij, A}
bc = broadcast_flatten(bc′)
return _serial_copyto!(dest, bc)
end

Expand Down Expand Up @@ -674,7 +683,7 @@ function Base.copyto!(
else
bc
end
Pair(pair.first, bc′)
Pair(pair.first, broadcast_flatten(bc′))
end,
)
# check_fused_broadcast_axes(fmbc) # we should already have checked the axes
Expand Down
1 change: 1 addition & 0 deletions src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Fields

import ClimaComms
import MultiBroadcastFusion as MBF
import ..broadcast_flatten
import ..slab, ..slab_args, ..column, ..column_args, ..level
import ..DataLayouts:
DataLayouts,
Expand Down
6 changes: 4 additions & 2 deletions src/Fields/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,9 @@ end

@inline function Base.copyto!(
dest::Field,
bc::Base.Broadcast.Broadcasted{<:AbstractFieldStyle},
bc::Base.Broadcast.Broadcasted{<:AbstractFieldStyle},
)
bc = broadcast_flatten(bc′)
copyto!(field_values(dest), Base.Broadcast.instantiate(todata(bc)))
return dest
end
Expand All @@ -156,7 +157,8 @@ function Base.copyto!(
) where {N, T <: NTuple{N, Pair{<:Field, <:Any}}}
fmb_data = FusedMultiBroadcast(
map(fmbc.pairs) do pair
bc = Base.Broadcast.instantiate(todata(pair.second))
bc′ = Base.Broadcast.instantiate(todata(pair.second))
bc = broadcast_flatten(bc′)
Pair(field_values(pair.first), bc)
end,
)
Expand Down
6 changes: 4 additions & 2 deletions src/Fields/fieldvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,9 @@ end

@inline function Base.copyto!(
dest::FieldVector,
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
)
bc = broadcast_flatten(bc′)
map(propertynames(dest)) do symb
Base.@_inline_meta
p = parent(getfield(_values(dest), symb))
Expand All @@ -304,8 +305,9 @@ end

@inline function Base.copyto!(
dest::FieldVector,
bc::Base.Broadcast.Broadcasted{<:Base.Broadcast.AbstractArrayStyle{0}},
bc::Base.Broadcast.Broadcasted{<:Base.Broadcast.AbstractArrayStyle{0}},
)
bc = broadcast_flatten(bc′)
map(propertynames(dest)) do symb
Base.@_inline_meta
p = parent(getfield(_values(dest), symb))
Expand Down
1 change: 1 addition & 0 deletions src/Operators/Operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using StaticArrays

import Base.Broadcast: Broadcasted

import ..broadcast_flatten
import ..slab, ..slab_args, ..column, ..column_args
import ClimaComms
import ..DataLayouts: DataLayouts, Data2D, DataSlab2D
Expand Down
3 changes: 2 additions & 1 deletion src/Operators/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3354,11 +3354,12 @@ end

function Base.copyto!(
field_out::Field,
bc::Union{
bc::Union{
StencilBroadcasted{ColumnStencilStyle},
Broadcasted{ColumnStencilStyle},
},
)
bc = broadcast_flatten(bc′)
space = axes(bc)
local_geometry = Spaces.local_geometry_data(space)
(Ni, Nj, _, _, Nh) = size(local_geometry)
Expand Down
3 changes: 2 additions & 1 deletion src/Operators/spectralelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,12 @@ end
# Functions for SlabBlockSpectralStyle
function Base.copyto!(
out::Field,
sbc::Union{
sbc::Union{
SpectralBroadcasted{SlabBlockSpectralStyle},
Broadcasted{SlabBlockSpectralStyle},
},
)
sbc = broadcast_flatten(sbc′)
Fields.byslab(axes(out)) do slabidx
Base.@_inline_meta
@inbounds copyto_slab!(out, sbc, slabidx)
Expand Down
69 changes: 69 additions & 0 deletions src/upstream.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# https://github.com/JuliaArrays/StaticArrays.jl/pull/1186

if VERSION >= v"1.11.0-DEV.103"
const broadcast_flatten = Broadcast.flatten
else
using Base: tail
using Base.Broadcast: isflat, Broadcasted

maybeconstructor(f) = f
maybeconstructor(::Type{F}) where {F} =
(args...; kwargs...) -> F(args...; kwargs...)

broadcast_flatten(bc) = bc
function broadcast_flatten(bc::Broadcasted{Style}) where {Style}
isflat(bc) && return bc
args = cat_nested(bc)
len = Val{length(args)}()
makeargs = make_makeargs(bc.args, len, ntuple(_ -> true, len))
f = maybeconstructor(bc.f)
@inline newf(args...) = f(prepare_args(makeargs, args)...)
return Broadcasted{Style}(newf, args, bc.axes)
end

cat_nested(bc::Broadcasted) = cat_nested_args(bc.args)
cat_nested_args(::Tuple{}) = ()
cat_nested_args(t::Tuple) =
(cat_nested(t[1])..., cat_nested_args(tail(t))...)
cat_nested(@nospecialize(a)) = (a,)

function make_makeargs(args::Tuple, len, flags)
makeargs, r = _make_makeargs(args, len, flags)
r isa Tuple{} || error("Internal error. Please file a bug")
return makeargs
end

# We build `makeargs` by traversing the broadcast nodes recursively.
# note: `len` isa `Val` indicates the length of whole flattened argument list.
# `flags` is a tuple of `Bool` with the same length of the rest arguments.
@inline function _make_makeargs(args::Tuple, len::Val, flags::Tuple)
head, flags′ = _make_makeargs1(args[1], len, flags)
rest, flags″ = _make_makeargs(tail(args), len, flags′)
(head, rest...), flags″
end
_make_makeargs(::Tuple{}, ::Val, x::Tuple) = (), x

# For flat nodes:
# 1. we just consume one argument, and return the "pick" function
@inline function _make_makeargs1(
@nospecialize(a),
::Val{N},
flags::Tuple,
) where {N}
pickargs(::Val{N}) where {N} = (@nospecialize(x::Tuple)) -> x[N]
return pickargs(Val{N - length(flags) + 1}()), tail(flags)
end

# For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc)))
@inline function _make_makeargs1(bc::Broadcasted, len::Val, flags::Tuple)
makeargs, flags′ = _make_makeargs(bc.args, len, flags)
f = maybeconstructor(bc.f)
@inline makeargs1(@nospecialize(args::Tuple)) =
f(prepare_args(makeargs, args)...)
makeargs1, flags′
end

prepare_args(::Tuple{}, @nospecialize(::Tuple)) = ()
@inline prepare_args(makeargs::Tuple, @nospecialize(x::Tuple)) =
(makeargs[1](x), prepare_args(tail(makeargs), x)...)
end

0 comments on commit ec919be

Please sign in to comment.