From 8cd0728a37ccde2f843da5b7f7e7f8c0456306cd Mon Sep 17 00:00:00 2001 From: ACEsuit Date: Fri, 31 May 2024 16:45:18 -0700 Subject: [PATCH] fix alloc for all cases --- src/WithAlloc.jl | 57 +++++++++++++++++++++++++++++++++++++----------- test/test1.jl | 50 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 91 insertions(+), 16 deletions(-) diff --git a/src/WithAlloc.jl b/src/WithAlloc.jl index 84439d2..c97f4c4 100644 --- a/src/WithAlloc.jl +++ b/src/WithAlloc.jl @@ -13,25 +13,56 @@ end map( a -> _bumper_alloc(a)[1], allocinfo ) end +# macro withalloc(ex) +# fncall = esc(ex.args[1]) +# args = esc.(ex.args[2:end]) +# quote +# # not sure why this isn't working ... +# # whatalloc($fncall, $(args...)) +# let +# allocinfo = whatalloc($fncall, $(args...), ) +# storobj = _bumper_alloc(allocinfo) +# $(fncall)(storobj..., $(args...), ) +# end +# end +# end + macro withalloc(ex) - fncall = esc(ex.args[1]) - args = esc.(ex.args[2:end]) + esc_args = esc.(ex.args) quote - # not sure why this isn't working ... - # whatalloc($fncall, $(args...)) - let - allocinfo = whatalloc($fncall, $(args...), ) - storobj = _bumper_alloc(allocinfo) - $(fncall)(storobj..., $(args...), ) - end + withalloc($(esc_args...)) end end -@inline function withalloc(fncall, args...) - allocinfo = whatalloc(fncall, args..., ) - storobj = _bumper_alloc(allocinfo) - fncall(storobj..., args..., ) +# For some reason that I don't understand the following implementation is allocating +# The @generated implementation below is to get around this. +# @inline function withalloc(fncall, args...) +# allocinfo = whatalloc(fncall, args..., ) +# storobj = _bumper_alloc(allocinfo) +# fncall(storobj..., args..., ) +# end + +@inline function withalloc(fncall, args...) + allocinfo = whatalloc(fncall, args...) + _genwithalloc(allocinfo, fncall, args...) +end + +@inline @generated function _genwithalloc(allocinfo::TT, fncall, args...) where {TT <: Tuple} + code = Expr[] + LEN = length(TT.types) + if TT.types[1] <: Tuple + for i in 1:LEN + push!(code, Meta.parse("tmp$i = _bumper_alloc(allocinfo[$i])[1]")) + end + else + push!(code, Meta.parse("tmp1 = _bumper_alloc(allocinfo)[1]")) + LEN = 1 + end + push!(code, Meta.parse("fncall($(join(["tmp$i, " for i in 1:LEN])) args...)")) + quote + $(code...) + end end diff --git a/test/test1.jl b/test/test1.jl index 0b4456d..1b7dd44 100644 --- a/test/test1.jl +++ b/test/test1.jl @@ -109,13 +109,13 @@ nalloc2 = let B = randn(5,10); C = randn(10, 3); D = randn(10, 5) @allocated alloctest2(B, C, D) end -@show nalloc2 # 64 +@test nalloc2 == 0 nalloc2_nm = let B = randn(5,10); C = randn(10, 3); D = randn(10, 5) @allocated alloctest2_nm(B, C, D) end -@show nalloc2_nm # 64 +@test nalloc2_nm == 0 @@ -164,6 +164,50 @@ nalloc3 = let @allocated alloctest2(B, C, D) end -@show nalloc3 # 64 +@test nalloc3 == 0 + + +## Reproduce #1 + +using WithAlloc, LinearAlgebra, Bumper, BenchmarkTools + +mymul2!(A1, A2, B, C, D) = mul!(A1, B, C), mul!(A2, B, D) + +function WithAlloc.whatalloc(::typeof(mymul2!), B, C, D) + T1 = promote_type(eltype(B), eltype(C)) + T2 = promote_type(eltype(B), eltype(D)) + return ( (T1, size(B, 1), size(C, 2)), + (T2, size(B, 1), size(D, 2)) ) +end + +function alloctest1(B, C, D) + @no_escape begin + a1, a2 = WithAlloc.whatalloc(mymul2!, B, C, D) + A1 = @alloc(a1...) + A2 = @alloc(a2...) + mymul2!(A1, A2, B, C, D) + sum(A1) + sum(A2) + end +end + +function alloctest2(B, C, D) + @no_escape begin + A1, A2 = @withalloc mymul2!(B, C, D) + sum(A1) + sum(A2) + end +end + +function alloctest3(B, C, D) + @no_escape begin + A1, A2 = WithAlloc.withalloc(mymul2!, B, C, D) + sum(A1) + sum(A2) + end +end + + +B = randn(5,10); C = randn(10, 3); D = randn(10, 5) +@btime alloctest1($B, $C, $D) +@btime alloctest2($B, $C, $D) +@btime alloctest3($B, $C, $D)