Skip to content

Commit

Permalink
Try #1380:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] committed Jul 19, 2023
2 parents 0e7a5f9 + d8aa6bc commit d70640b
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 28 deletions.
77 changes: 63 additions & 14 deletions src/Operators/integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,80 @@
The definite vertical column integral, `col∫field`, of field `field`.
"""
function column_integral_definite!(col∫field::Fields.Field, field::Fields.Field)
column_integral_definite!(col∫field::Fields.Field, field::Fields.Field) =
column_integral_definite!(ClimaComms.device(axes(field)), col∫field, field)

function column_integral_definite!(
::ClimaComms.CUDADevice,
col∫field::Fields.Field,
field::Fields.Field,
)
Ni, Nj, _, _, Nh = size(Fields.field_values(col∫field))
nthreads, nblocks = Spaces._configure_threadblock(Ni * Nj * Nh)
@cuda threads = nthreads blocks = nblocks column_integral_definite_kernel!(
col∫field,
field,
)
end

function column_integral_definite_kernel!(
col∫field::Fields.SpectralElementField,
field::Fields.ExtrudedFiniteDifferenceField,
)
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(field))
if idx <= Ni * Nj * Nh
i, j, h = Spaces._get_idx((Ni, Nj, Nh), idx)
colfield = Spaces.column(field, i, j, h)
_column_integral_definite!(Spaces.column(col∫field, i, j, h), colfield)
end
return nothing
end

column_integral_definite_kernel!(
col∫field::Fields.PointField,
field::Fields.FiniteDifferenceField,
) = _column_integral_definite!(col∫field, field)

function column_integral_definite!(
::ClimaComms.AbstractCPUDevice,
col∫field::Fields.SpectralElementField,
field::Fields.ExtrudedFiniteDifferenceField,
)
Fields.bycolumn(axes(field)) do colidx
column_integral_definite!(col∫field[colidx], field[colidx])
_column_integral_definite!(col∫field[colidx], field[colidx])
nothing
end
return nothing
end

function column_integral_definite!(
column_integral_definite!(
::ClimaComms.AbstractCPUDevice,
col∫field::Fields.PointField,
field::Fields.FiniteDifferenceField,
) = _column_integral_definite!(col∫field, field)

function _column_integral_definite!(
col∫field::Fields.PointField,
field::Fields.ColumnField,
)
@inbounds col∫field[] = column_integral_definite(field)
return nothing
end
space = axes(field)
Δz_field = Fields.Δz_field(space)
Nv = Spaces.nlevels(space)

function column_integral_definite(field::Fields.ColumnField)
field_data = Fields.field_values(field)
Δz_data = Spaces.Δz_data(axes(field))
Nv = Spaces.nlevels(axes(field))
∫field = zero(eltype(field))
@inbounds for j in 1:Nv
∫field += field_data[j] * Δz_data[j]
col∫field[] = 0
@inbounds for idx in 1:Nv
col∫field[] +=
reduction_getindex(field, idx) * reduction_getindex(Δz_field, idx)
end
return ∫field
return nothing
end

reduction_getindex(column_field, index) = @inbounds getidx(
axes(column_field),
column_field,
Interior(),
index - 1 + left_idx(axes(column_field)),
)

# TODO: add support for indefinite integrals
6 changes: 2 additions & 4 deletions src/Spaces/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,12 @@ Base.@propagate_inbounds function level(
v::PlusHalf,
)
@inbounds local_geometry = level(local_geometry_data(space), v.i + 1)
context = ClimaComms.context(space)
PointSpace(context, local_geometry)
PointSpace(local_geometry)
end
Base.@propagate_inbounds function level(
space::CenterFiniteDifferenceSpace,
v::Int,
)
local_geometry = level(local_geometry_data(space), v)
context = ClimaComms.context(space)
PointSpace(context, local_geometry)
PointSpace(local_geometry)
end
11 changes: 6 additions & 5 deletions src/Spaces/pointspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@ local_geometry_data(space::AbstractPointSpace) = space.local_geometry
A zero-dimensional space.
"""
struct PointSpace{C <: ClimaComms.AbstractCommsContext, LG} <:
AbstractPointSpace
context::C
struct PointSpace{LG <: DataLayouts.Data0D} <: AbstractPointSpace
local_geometry::LG
end

ClimaComms.device(space::PointSpace) = ClimaComms.device(space.context)
ClimaComms.context(space::PointSpace) = space.context
ClimaComms.context(space::PointSpace) =
ClimaComms.SingletonCommsContext(ClimaComms.CPUSingleThreaded())

#=
PointSpace(x::Geometry.LocalGeometry) =
PointSpace(ClimaComms.CPUSingleThreaded(), x)
PointSpace(x::Geometry.AbstractPoint) =
PointSpace(ClimaComms.CPUSingleThreaded(), x)
=#

function PointSpace(device::ClimaComms.AbstractDevice, x)
context = ClimaComms.SingletonCommsContext(device)
Expand All @@ -34,7 +35,7 @@ function PointSpace(
ArrayType = ClimaComms.array_type(ClimaComms.device(context))
local_geometry_data = DataLayouts.DataF{LG}(Array{FT})
local_geometry_data[] = local_geometry
return PointSpace(context, Adapt.adapt(ArrayType, local_geometry_data))
return PointSpace(Adapt.adapt(ArrayType, local_geometry_data))
end

function PointSpace(
Expand Down
6 changes: 2 additions & 4 deletions src/Spaces/spectralelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -644,16 +644,14 @@ Base.@propagate_inbounds slab(space::AbstractSpectralElementSpace, h) =

Base.@propagate_inbounds function column(space::SpectralElementSpace1D, i, h)
local_geometry = column(local_geometry_data(space), i, h)
context = ClimaComms.context(space)
PointSpace(context, local_geometry)
PointSpace(local_geometry)
end
Base.@propagate_inbounds column(space::SpectralElementSpace1D, i, j, h) =
column(space, i, h)

Base.@propagate_inbounds function column(space::SpectralElementSpace2D, i, j, h)
local_geometry = column(local_geometry_data(space), i, j, h)
context = ClimaComms.context(space)
PointSpace(context, local_geometry)
PointSpace(local_geometry)
end

# XXX: this cannot take `space` as it must be constructed beforehand so
Expand Down
2 changes: 1 addition & 1 deletion test/Fields/field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ convergence_rate(err, Δh) =
col_copy = similar(y[Fields.ColumnIndex((1, 1), 1)])
return Fields.Field(Fields.field_values(col_copy), axes(col_copy))
end
device = ClimaComms.CPUSingleThreaded()
device = ClimaComms.device()
context = ClimaComms.SingletonCommsContext(device)
for zelem in (2^2, 2^3, 2^4, 2^5)
for space in TU.all_spaces(FT; zelem, context)
Expand Down

0 comments on commit d70640b

Please sign in to comment.