Skip to content

Commit

Permalink
fix alloc for all cases
Browse files Browse the repository at this point in the history
  • Loading branch information
ACEsuit committed May 31, 2024
1 parent 9908689 commit 8cd0728
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 16 deletions.
57 changes: 44 additions & 13 deletions src/WithAlloc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,56 @@ end
map( a -> _bumper_alloc(a)[1], allocinfo )
end

# macro withalloc(ex)
# fncall = esc(ex.args[1])
# args = esc.(ex.args[2:end])
# quote
# # not sure why this isn't working ...
# # whatalloc($fncall, $(args...))
# let
# allocinfo = whatalloc($fncall, $(args...), )
# storobj = _bumper_alloc(allocinfo)
# $(fncall)(storobj..., $(args...), )
# end
# end
# end

macro withalloc(ex)
fncall = esc(ex.args[1])
args = esc.(ex.args[2:end])
esc_args = esc.(ex.args)
quote
# not sure why this isn't working ...
# whatalloc($fncall, $(args...))
let
allocinfo = whatalloc($fncall, $(args...), )
storobj = _bumper_alloc(allocinfo)
$(fncall)(storobj..., $(args...), )
end
withalloc($(esc_args...))
end
end


@inline function withalloc(fncall, args...)
allocinfo = whatalloc(fncall, args..., )
storobj = _bumper_alloc(allocinfo)
fncall(storobj..., args..., )
# For some reason that I don't understand the following implementation is allocating
# The @generated implementation below is to get around this.
# @inline function withalloc(fncall, args...)
# allocinfo = whatalloc(fncall, args..., )
# storobj = _bumper_alloc(allocinfo)
# fncall(storobj..., args..., )
# end

@inline function withalloc(fncall, args...)
allocinfo = whatalloc(fncall, args...)
_genwithalloc(allocinfo, fncall, args...)
end

@inline @generated function _genwithalloc(allocinfo::TT, fncall, args...) where {TT <: Tuple}
code = Expr[]
LEN = length(TT.types)
if TT.types[1] <: Tuple
for i in 1:LEN
push!(code, Meta.parse("tmp$i = _bumper_alloc(allocinfo[$i])[1]"))
end
else
push!(code, Meta.parse("tmp1 = _bumper_alloc(allocinfo)[1]"))
LEN = 1
end
push!(code, Meta.parse("fncall($(join(["tmp$i, " for i in 1:LEN])) args...)"))
quote
$(code...)
end
end


Expand Down
50 changes: 47 additions & 3 deletions test/test1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,13 @@ nalloc2 = let
B = randn(5,10); C = randn(10, 3); D = randn(10, 5)
@allocated alloctest2(B, C, D)
end
@show nalloc2 # 64
@test nalloc2 == 0

nalloc2_nm = let
B = randn(5,10); C = randn(10, 3); D = randn(10, 5)
@allocated alloctest2_nm(B, C, D)
end
@show nalloc2_nm # 64
@test nalloc2_nm == 0



Expand Down Expand Up @@ -164,6 +164,50 @@ nalloc3 = let
@allocated alloctest2(B, C, D)
end

@show nalloc3 # 64
@test nalloc3 == 0




## Reproduce #1

using WithAlloc, LinearAlgebra, Bumper, BenchmarkTools

mymul2!(A1, A2, B, C, D) = mul!(A1, B, C), mul!(A2, B, D)

function WithAlloc.whatalloc(::typeof(mymul2!), B, C, D)
T1 = promote_type(eltype(B), eltype(C))
T2 = promote_type(eltype(B), eltype(D))
return ( (T1, size(B, 1), size(C, 2)),
(T2, size(B, 1), size(D, 2)) )
end

function alloctest1(B, C, D)
@no_escape begin
a1, a2 = WithAlloc.whatalloc(mymul2!, B, C, D)
A1 = @alloc(a1...)
A2 = @alloc(a2...)
mymul2!(A1, A2, B, C, D)
sum(A1) + sum(A2)
end
end

function alloctest2(B, C, D)
@no_escape begin
A1, A2 = @withalloc mymul2!(B, C, D)
sum(A1) + sum(A2)
end
end

function alloctest3(B, C, D)
@no_escape begin
A1, A2 = WithAlloc.withalloc(mymul2!, B, C, D)
sum(A1) + sum(A2)
end
end


B = randn(5,10); C = randn(10, 3); D = randn(10, 5)
@btime alloctest1($B, $C, $D)
@btime alloctest2($B, $C, $D)
@btime alloctest3($B, $C, $D)

0 comments on commit 8cd0728

Please sign in to comment.