Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ACEsuit committed May 31, 2024
1 parent 8cd0728 commit e58c617
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 79 deletions.
39 changes: 6 additions & 33 deletions src/WithAlloc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,58 +5,31 @@ export whatalloc, @withalloc

function whatalloc end

@inline function _bumper_alloc(allocinfo::Tuple{<: Type, Vararg{Int, N}}) where {N}
(Bumper.alloc!(Bumper.default_buffer(), allocinfo...), )
end

@inline function _bumper_alloc(allocinfo)
map( a -> _bumper_alloc(a)[1], allocinfo )
end

# macro withalloc(ex)
# fncall = esc(ex.args[1])
# args = esc.(ex.args[2:end])
# quote
# # not sure why this isn't working ...
# # whatalloc($fncall, $(args...))
# let
# allocinfo = whatalloc($fncall, $(args...), )
# storobj = _bumper_alloc(allocinfo)
# $(fncall)(storobj..., $(args...), )
# end
# end
# end

macro withalloc(ex)
esc_args = esc.(ex.args)
quote
withalloc($(esc_args...))
end
end


# For some reason that I don't understand the following implementation is allocating
# The @generated implementation below is to get around this.
# @inline function withalloc(fncall, args...)
# allocinfo = whatalloc(fncall, args..., )
# storobj = _bumper_alloc(allocinfo)
# fncall(storobj..., args..., )
# end

@inline function withalloc(fncall, args...)
allocinfo = whatalloc(fncall, args...)
_genwithalloc(allocinfo, fncall, args...)
end

@inline function _bumper_alloc(allocinfo::Tuple{<: Type, Vararg{Int, N}}) where {N}
Bumper.alloc!(Bumper.default_buffer(), allocinfo...)
end

@inline @generated function _genwithalloc(allocinfo::TT, fncall, args...) where {TT <: Tuple}
code = Expr[]
LEN = length(TT.types)
if TT.types[1] <: Tuple
for i in 1:LEN
push!(code, Meta.parse("tmp$i = _bumper_alloc(allocinfo[$i])[1]"))
push!(code, Meta.parse("tmp$i = _bumper_alloc(allocinfo[$i])"))
end
else
push!(code, Meta.parse("tmp1 = _bumper_alloc(allocinfo)[1]"))
push!(code, Meta.parse("tmp1 = _bumper_alloc(allocinfo)"))
LEN = 1
end
push!(code, Meta.parse("fncall($(join(["tmp$i, " for i in 1:LEN])) args...)"))
Expand Down
46 changes: 0 additions & 46 deletions test/test1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,49 +165,3 @@ nalloc3 = let
end

@test nalloc3 == 0




## Reproduce #1

using WithAlloc, LinearAlgebra, Bumper, BenchmarkTools

mymul2!(A1, A2, B, C, D) = mul!(A1, B, C), mul!(A2, B, D)

function WithAlloc.whatalloc(::typeof(mymul2!), B, C, D)
T1 = promote_type(eltype(B), eltype(C))
T2 = promote_type(eltype(B), eltype(D))
return ( (T1, size(B, 1), size(C, 2)),
(T2, size(B, 1), size(D, 2)) )
end

function alloctest1(B, C, D)
@no_escape begin
a1, a2 = WithAlloc.whatalloc(mymul2!, B, C, D)
A1 = @alloc(a1...)
A2 = @alloc(a2...)
mymul2!(A1, A2, B, C, D)
sum(A1) + sum(A2)
end
end

function alloctest2(B, C, D)
@no_escape begin
A1, A2 = @withalloc mymul2!(B, C, D)
sum(A1) + sum(A2)
end
end

function alloctest3(B, C, D)
@no_escape begin
A1, A2 = WithAlloc.withalloc(mymul2!, B, C, D)
sum(A1) + sum(A2)
end
end


B = randn(5,10); C = randn(10, 3); D = randn(10, 5)
@btime alloctest1($B, $C, $D)
@btime alloctest2($B, $C, $D)
@btime alloctest3($B, $C, $D)

0 comments on commit e58c617

Please sign in to comment.