From fadde5e39ef8ef952dd87e55741ce00a773104e6 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Mon, 26 Jun 2023 11:54:26 -0700 Subject: [PATCH 1/2] Add broken test for pow_n issue --- test/Fields/field.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/Fields/field.jl b/test/Fields/field.jl index ac1f730295..a5c303a5fb 100644 --- a/test/Fields/field.jl +++ b/test/Fields/field.jl @@ -94,6 +94,25 @@ end @test axes(point_field) isa Spaces.PointSpace end +# https://github.com/CliMA/ClimaCore.jl/issues/1126 +function pow_n(f) + @. f.x = f.x^2 + return nothing +end +@testset "Broadcasting with ^n" begin + FT = Float32 + for space in TU.all_spaces(FT) + f = fill((; x = FT(1)), space) + pow_n(f) # Compile first + p_allocated = @allocated pow_n(f) + if space isa PointSpace + @test p_allocated == 0 + else + @test_broken p_allocated == 0 + end + end +end + # Requires `--check-bounds=yes` @testset "Constructing & broadcasting over empty fields" begin FT = Float32 From e6d6800d82f059735833c3ebcdcd6e374809a208 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Mon, 26 Jun 2023 14:18:05 -0700 Subject: [PATCH 2/2] Fix pow2 allocations --- src/Fields/broadcast.jl | 7 +++++++ test/Fields/field.jl | 6 +----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/Fields/broadcast.jl b/src/Fields/broadcast.jl index 14cb646eb4..a63c174e2a 100644 --- a/src/Fields/broadcast.jl +++ b/src/Fields/broadcast.jl @@ -294,6 +294,13 @@ Base.Broadcast.broadcasted( ) where {T <: Geometry.AxisTensor} = Base.Broadcast.broadcasted(fs, (x...) -> T(x...), args...) +Base.Broadcast.broadcasted( + ::typeof(Base.literal_pow), + ::typeof(^), + ::Field, + ::Val{n}, +) where {n} = Base.Broadcast.broadcasted(x -> Base.literal_pow(^, x, Val(n)), f) + # Specialize handling of +, *, muladd, so that we can support broadcasting over NamedTuple element types # Required for ODE solvers diff --git a/test/Fields/field.jl b/test/Fields/field.jl index a5c303a5fb..673054f64a 100644 --- a/test/Fields/field.jl +++ b/test/Fields/field.jl @@ -105,11 +105,7 @@ end f = fill((; x = FT(1)), space) pow_n(f) # Compile first p_allocated = @allocated pow_n(f) - if space isa PointSpace - @test p_allocated == 0 - else - @test_broken p_allocated == 0 - end + @test p_allocated == 0 end end