From 143f9297dc312b72fd60588d9e3ca9ca4beb4680 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sun, 3 Apr 2022 14:23:56 -0700 Subject: [PATCH] Elide stack generation outside of non-looping control flow Co-authored-by: Keno Fischer --- src/compiler/emit.jl | 127 ++++++++++++++++++++++++++++++------- src/compiler/interface2.jl | 6 +- src/lib/lib.jl | 6 ++ test/compiler.jl | 62 ++++++++++++++---- test/features.jl | 2 +- test/runtests.jl | 1 + test/utils.jl | 1 - 7 files changed, 164 insertions(+), 41 deletions(-) diff --git a/src/compiler/emit.jl b/src/compiler/emit.jl index 1c82a44f1..f79f1ae33 100644 --- a/src/compiler/emit.jl +++ b/src/compiler/emit.jl @@ -34,27 +34,100 @@ xtuple(xs...) = xcall(:tuple, xs...) concrete(T::DataType) = T concrete(::Type{Type{T}}) where T = typeof(T) -concrete(T) = Any +concrete(@nospecialize _) = Any runonce(b) = b.id in (1, length(b.ir.blocks)) +# TODO use a more efficient algorithm such as Johnson (1975) +# https://epubs.siam.org/doi/abs/10.1137/0204007 +self_reaching(cfg, bid, visited = BitSet()) = reaches(cfg, bid, bid, visited) +function reaches(cfg, from, to, visited) + for succ in cfg[from] + if succ === to + return true + elseif succ ∉ visited + push!(visited, succ) + if reaches(cfg, succ, to, visited) + return true + end + end + end + return false +end + function forward_stacks!(adj, F) - stks, recs = [], [] + stks, recs = Tuple{Int, Alpha, Bool}[], Variable[] pr = adj.primal - for b in blocks(pr), α in alphauses(block(adj.adjoint, b.id)) - if runonce(b) - push!(recs, Variable(α)) - else - stk = pushfirst!(pr, xstack(Any)) - push!(recs, stk) - push!(b, xcall(Zygote, :_push!, stk, Variable(α))) + blks = blocks(pr) + last_block = length(blks) + cfg = IRTools.CFG(pr) + cfgᵀ = cfg' + doms = IRTools.dominators(cfg) + + reaching_visited = BitSet() + in_loop = map(1:last_block) do b + empty!(reaching_visited) + self_reaching(cfg, b, reaching_visited) + end + alphavars = Dict{Alpha, Variable}() + alpha_blocks = [α => b.id for b in blks for α in alphauses(block(adj.adjoint, b.id))] + for b in Iterators.reverse(blks) + filter!(alpha_blocks) do (α, bid) + if b.id in doms[bid] + # If a block dominates this block, α is guaranteed to be present here + αvar = Variable(α) + for br in branches(b) + map!(a -> a === α ? αvar : a, br.args, br.args) + end + push!(recs, b.id === last_block ? αvar : alphavars[α]) + push!(stks, (bid, α, false)) + elseif in_loop[bid] + # This block is in a loop, so we're forced to insert stacks + # Note: all alphas in loops will have stacks after the first iteration + stk = pushfirst!(pr, xstack(Any)) + push!(recs, stk) + push!(block(pr, bid), xcall(Zygote, :_push!, stk, Variable(α))) + push!(stks, (bid, α, true)) + else + # Fallback case, propagate alpha back through the CFG + argvar = nothing + if b.id > 1 + # Need to make sure all predecessors have a branch to add arguments to + IRTools.explicitbranch!(b) + argvar = argument!(b, insert=false) + end + if b.id === last_block + # This alpha has been threaded all the way through to the exit block + alphavars[α] = argvar + end + for br in branches(b) + map!(a -> a === α ? argvar : a, br.args, br.args) + end + for pred in cfgᵀ[b.id] + pred >= b.id && continue # TODO is this needed? + pred_branches = branches(block(pr, pred)) + idx = findfirst(br -> br.block === b.id, pred_branches) + if idx === nothing + throw(error("Predecessor $pred of block $(b.id) has no branch to $(b.id)")) + end + branch_here = pred_branches[idx] + push!(branch_here.args, α) + end + # We're not done with this alpha yet, revisit in predecessors + return true + end + return false + end + # Prune any alphas that don't exist on this path through the CFG + for br in branches(b) + map!(a -> a isa Alpha ? nothing : a, br.args, br.args) end - push!(stks, (b.id, alpha(α))) end - args = arguments(pr)[3:end] + @assert isempty(alpha_blocks) + rec = push!(pr, xtuple(recs...)) + # Pullback{F,Any} reduces specialisation P = length(pr.blocks) == 1 ? Pullback{F} : Pullback{F,Any} - # P = Pullback{F,Any} # reduce specialisation rec = push!(pr, Expr(:call, P, rec)) ret = xtuple(pr.blocks[end].branches[end].args[1], rec) ret = push!(pr, ret) @@ -62,22 +135,29 @@ function forward_stacks!(adj, F) return pr, stks end +# Helps constrain pullback function type in the backwards pass +# If we had the type, we could make this a PiNode +notnothing(::Nothing) = error() +notnothing(x) = x + function reverse_stacks!(adj, stks) ir = adj.adjoint - entry = blocks(ir)[end] + blcks = blocks(ir) + entry = blcks[end] self = argument!(entry, at = 1) - t = pushfirst!(blocks(ir)[end], xcall(:getfield, self, QuoteNode(:t))) - repl = Dict() - runonce(b) = b.id in (1, length(ir.blocks)) - for b in blocks(ir) - for (i, (b′, α)) in enumerate(stks) + t = pushfirst!(entry, xcall(:getfield, self, QuoteNode(:t))) + repl = Dict{Alpha,Variable}() + for b in blcks + for (i, (b′, α, use_stack)) in enumerate(stks) b.id == b′ || continue - if runonce(b) - val = insertafter!(ir, t, xcall(:getindex, t, i)) - else - stk = push!(entry, xcall(:getindex, t, i)) - stk = push!(entry, xcall(Zygote, :Stack, stk)) + # i.e. recs[i] from forward_stacks! + val = insertafter!(ir, t, xcall(:getindex, t, i)) + if use_stack + stk = push!(entry, xcall(Zygote, :Stack, val)) val = pushfirst!(b, xcall(:pop!, stk)) + elseif !runonce(b) + # The first and last blocks always run, so this check is redundant there + val = pushfirst!(b, xcall(Zygote, :notnothing, val)) end repl[α] = val end @@ -87,6 +167,7 @@ end function stacks!(adj, T) forw, stks = forward_stacks!(adj, T) + IRTools.domorder!(forw) back = reverse_stacks!(adj, stks) permute!(back, length(back.blocks):-1:1) IRTools.domorder!(back) diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index bf3692a30..5b24d99c0 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -33,10 +33,10 @@ end meta, forw, _ = g argnames!(meta, Symbol("#self#"), :ctx, :f, :args) forw = varargs!(meta, forw, 3) - # IRTools.verify(forw) + # verify(forw) forw = slots!(pis!(inlineable!(forw))) # be ready to swap to using chainrule if one is declared - cr_edge != nothing && edge!(meta, cr_edge) + cr_edge !== nothing && edge!(meta, cr_edge) return update!(meta.code, forw) end @@ -53,7 +53,7 @@ end end meta, _, back = g argnames!(meta, Symbol("#self#"), :Δ) - # IRTools.verify(back) + # verify(back) back = slots!(inlineable!(back)) return update!(meta.code, back) end diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 52a734809..d61d7dabe 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -68,6 +68,12 @@ function accum_global(cx::Context, ref, x̄) return end +# Needed for nested AD +function _pullback(::typeof(accum_global), cx::Context, ref, x̄) + accum_global_pullback(_) = nothing + return accum_global(cx, ref, x̄), accum_global_pullback +end + unwrap(x) = x @adjoint unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),) diff --git a/test/compiler.jl b/test/compiler.jl index c5ddf1f38..dc1271ae3 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -1,4 +1,4 @@ -using Zygote, Test +using Zygote, IRTools, Test using Zygote: pullback, @adjoint, Context macro test_inferred(ex) @@ -18,24 +18,22 @@ end bad(x) = x @adjoint bad(x) = x, Δ -> error("bad") +bad_adjoint_line = @__LINE__ - 1 # source location of above function badly(x) x = x + 1 x = bad(x) return x end +bad_pullback_line = @__LINE__ - 3 # should match source location of Pullback y, back = pullback(badly, 2) @test y == 3 @test_throws Exception back(1) -bt = try back(1) catch e stacktrace(catch_backtrace()) end -@test trace_contains(bt, nothing, "compiler.jl", 20) -if VERSION >= v"1.6-" - @test_broken trace_contains(bt, :badly, "compiler.jl", 24) -else - @test trace_contains(bt, :badly, "compiler.jl", 24) -end +bt = try back(1) catch e stacktrace(catch_backtrace()) end +@test trace_contains(bt, nothing, "compiler.jl", bad_adjoint_line) +@test trace_contains(bt, nothing, "compiler.jl", bad_pullback_line) # Type inference checks @@ -58,10 +56,9 @@ y, back = @test_inferred pullback(f, 5) y, back = @test_inferred pullback(Core._apply, +, (1, 2, 3)) @test_inferred back(1) -# TODO fix bcast inference -# bcast(x) = x .* 5 -# y, back = @test_inferred pullback(bcast, [1,2,3]) -# @test_inferred back([1,1,1]) +bcast(x) = x .* 5 +y, back = @test_inferred pullback(bcast, [1,2,3]) +@test_inferred back([1,1,1]) foo = let a = 4 x -> x*a @@ -91,6 +88,45 @@ struct Funky y end +@testset "stack elision" begin + function isstackfree(T) + _, forw, back = Zygote._generate_pullback_via_decomposition(T) + for (_, stmt) in forw + expr = stmt.expr + expr.head == :call && first(expr.args) == GlobalRef(Zygote, :_push!) && return false + end + for (_, stmt) in back + expr = stmt.expr + expr.head == :call && first(expr.args) == GlobalRef(Zygote, :Stack) && return false + end + return true + end + + function knockoff_pow(x, n) + n == 0 && return 1 + n == 1 && return x + n == 2 && return x * x + n == 3 && return x * x * x + return x ^ n + end + + function roundabout_trig(x, fancy_sin, fancy_cos, fancy_tan) + if fancy_tan + s = fancy_sin ? inv(csc(x)) : sin(x) + c = fancy_cos ? inv(sec(x)) : cos(x) + s += 0 + c *= 1 + return s / c + else + return tan(x) + end + end + + @test !isstackfree(Tuple{typeof(pow), Int, Int}) + @test isstackfree(Tuple{typeof(knockoff_pow), Int, Int}) + @test isstackfree(Tuple{typeof(roundabout_trig), Float64, Bool, Bool, Bool}) +end + @testset "issue #851" begin f = Funky(1, 1); function Base.getproperty(f::Funky, i::Symbol) @@ -128,7 +164,7 @@ end d_two = Zygote.pullback(two_svds, X)[2](Δoutput) d_one = Zygote.pullback(one_svd, X)[2](Δoutput) @test d_one == d_two -end +end # this test fails if adjoint for literal_getproperty is added # https://github.com/FluxML/Zygote.jl/issues/922#issuecomment-804128905 diff --git a/test/features.jl b/test/features.jl index cdd513fdb..9c71a5273 100644 --- a/test/features.jl +++ b/test/features.jl @@ -396,7 +396,7 @@ end == (2,) global_param = 3 @testset "Global Params" begin - cx = Zygote.Context() + cx = Zygote.Context{true}() # only makes sense with implicit params y, back = Zygote._pullback(cx, x -> x*global_param, 2) @test y == 6 @test back(1) == (nothing, 3) diff --git a/test/runtests.jl b/test/runtests.jl index 565ad182f..7a326c53f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using Zygote, Test using Zygote: gradient, ZygoteRuleConfig using CUDA using CUDA: has_cuda +using LinearAlgebra @testset "all" begin # Overall testset ensures it keeps running after failure diff --git a/test/utils.jl b/test/utils.jl index 40b2e85b7..cb11437cf 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,4 +1,3 @@ -using LinearAlgebra using ForwardDiff using Zygote: hessian_dual, hessian_reverse