diff --git a/README.md b/README.md index 5cc0dbc..a9789f2 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,6 @@ This package implements a very small extension to [Bumper.jl](https://github.com/MasonProtter/Bumper.jl). -### Introduction and Motivation - A common pattern in our own (the developers') codes is the following: ```julia @no_escape begin @@ -26,9 +24,9 @@ After writing the same pattern 10 times, we wondered whether there is an easy wa end ``` -### Documentation +### Preliminary Documentation -For now, there is just a simple example. More soon ... +For now, there is are just a few simple use-case examples. ```julia using WithAlloc, LinearAlgebra, Bumper diff --git a/src/WithAlloc.jl b/src/WithAlloc.jl index a9c3fc8..1e44fe0 100644 --- a/src/WithAlloc.jl +++ b/src/WithAlloc.jl @@ -13,6 +13,17 @@ function _bumper_alloc(allocinfo::NTuple{N, <: Tuple}) where {N} ntuple(i -> Bumper.alloc!(Bumper.default_buffer(), allocinfo[i]...), N) end +macro withalloc1(ex) + fncall = esc(ex.args[1]) + args = esc.(ex.args[2:end]) + quote + let + allocinfo = whatalloc($fncall, $(args...), ) + storobj = Bumper.alloc!(Bumper.default_buffer(), allocinfo... ) + $(fncall)(storobj, $(args...), ) + end + end +end macro withalloc(ex) fncall = esc(ex.args[1]) diff --git a/test/_readme.jl b/test/_readme.jl index 97d87f8..87a7674 100644 --- a/test/_readme.jl +++ b/test/_readme.jl @@ -41,3 +41,26 @@ end using BenchmarkTools @btime alloctest($B, $C) # 125.284 ns (0 allocations: 0 bytes) +# ------------------------------------------------------------------------ + +# Multiple arrays is handled via tuples: + +B = randn(5,10) +C = randn(10, 3) +D = randn(10, 5) +A1 = B * C +A2 = B * D + +mymul2!(A1, A2, B, C, D) = mul!(A1, B, C), mul!(A2, B, D) + +function WithAlloc.whatalloc(::typeof(mymul2!), B, C, D) + T1 = promote_type(eltype(B), eltype(C)) + T2 = promote_type(eltype(B), eltype(D)) + return ( (T1, size(B, 1), size(C, 2)), + (T2, size(B, 1), size(D, 2)) ) +end + +@no_escape begin + A1b, A2b = WithAlloc.@withalloc mymul2!(B, C, D) + @show A1 ≈ A1b, A2 ≈ A2b # true, true +end