diff --git a/src/compiler/emit.jl b/src/compiler/emit.jl index 1c82a44f1..e1563ecfb 100644 --- a/src/compiler/emit.jl +++ b/src/compiler/emit.jl @@ -34,27 +34,100 @@ xtuple(xs...) = xcall(:tuple, xs...) concrete(T::DataType) = T concrete(::Type{Type{T}}) where T = typeof(T) -concrete(T) = Any +concrete(@nospecialize _) = Any runonce(b) = b.id in (1, length(b.ir.blocks)) +# TODO use a more efficient algorithm such as Johnson (1975) +# https://epubs.siam.org/doi/abs/10.1137/0204007 +self_reaching(cfg, bid, visited = BitSet()) = reaches(cfg, bid, bid, visited) +function reaches(cfg, from, to, visited) + for succ in cfg[from] + if succ === to + return true + elseif succ ∉ visited + push!(visited, succ) + if reaches(cfg, succ, to, visited) + return true + end + end + end + return false +end + function forward_stacks!(adj, F) - stks, recs = [], [] + stks, recs = Tuple{Int, Alpha, Bool}[], Variable[] pr = adj.primal - for b in blocks(pr), α in alphauses(block(adj.adjoint, b.id)) - if runonce(b) - push!(recs, Variable(α)) - else - stk = pushfirst!(pr, xstack(Any)) - push!(recs, stk) - push!(b, xcall(Zygote, :_push!, stk, Variable(α))) + blks = blocks(pr) + last_block = length(blks) + cfg = IRTools.CFG(pr) + cfgᵀ = cfg' + doms = IRTools.dominators(cfg) + + reaching_visited = BitSet() + in_loop = map(1:last_block) do b + empty!(reaching_visited) + self_reaching(cfg, b, reaching_visited) + end + alphavars = Dict{Alpha, Variable}() + alpha_blocks = [α => b.id for b in blks for α in alphauses(block(adj.adjoint, b.id))] + for b in Iterators.reverse(blks) + filter!(alpha_blocks) do (α, bid) + if b.id in doms[bid] + # If a block dominates this block, α is guaranteed to be present here + αvar = Variable(α) + for br in branches(b) + map!(a -> a === α ? αvar : a, br.args, br.args) + end + push!(recs, b.id === last_block ? αvar : alphavars[α]) + push!(stks, (bid, α, false)) + elseif in_loop[bid] + # This block is in a loop, so we're forced to insert stacks + # Note: all alphas in loops will have stacks after the first iteration + stk = pushfirst!(pr, xstack(Any)) + push!(recs, stk) + push!(block(pr, bid), xcall(Zygote, :_push!, stk, Variable(α))) + push!(stks, (bid, α, true)) + else + # Fallback case, propagate alpha back through the CFG + argvar = nothing + if b.id > 1 + # Need to make sure all predecessors have a branch to add arguments to + IRTools.explicitbranch!(b) + argvar = argument!(b, insert=false) + end + if b.id === last_block + # This alpha has been threaded all the way through to the exit block + alphavars[α] = argvar + end + for br in branches(b) + map!(a -> a === α ? argvar : a, br.args, br.args) + end + for pred in cfgᵀ[b.id] + pred >= b.id && continue # TODO is this needed? + pred_branches = branches(block(pr, pred)) + idx = findfirst(br -> br.block === b.id, pred_branches) + if idx === nothing + throw(error("Predecessor $pred of block $(b.id) has no branch to $(b.id)")) + end + branch_here = pred_branches[idx] + push!(branch_here.args, α) + end + # We're not done with this alpha yet, revisit in predecessors + return true + end + return false + end + # Prune any alphas that don't exist on this path through the CFG + for br in branches(b) + map!(a -> a isa Alpha ? nothing : a, br.args, br.args) end - push!(stks, (b.id, alpha(α))) end - args = arguments(pr)[3:end] + @assert isempty(alpha_blocks) + rec = push!(pr, xtuple(recs...)) + # Pullback{F,Any} reduces specialisation P = length(pr.blocks) == 1 ? Pullback{F} : Pullback{F,Any} - # P = Pullback{F,Any} # reduce specialisation rec = push!(pr, Expr(:call, P, rec)) ret = xtuple(pr.blocks[end].branches[end].args[1], rec) ret = push!(pr, ret) @@ -62,22 +135,29 @@ function forward_stacks!(adj, F) return pr, stks end +# Helps constrain pullback function type in the backwards pass +# If we had the type, we could make this a PiNode +notnothing(::Nothing) = error() +notnothing(x) = x + function reverse_stacks!(adj, stks) ir = adj.adjoint - entry = blocks(ir)[end] + blcks = blocks(ir) + entry = blcks[end] self = argument!(entry, at = 1) - t = pushfirst!(blocks(ir)[end], xcall(:getfield, self, QuoteNode(:t))) - repl = Dict() - runonce(b) = b.id in (1, length(ir.blocks)) - for b in blocks(ir) - for (i, (b′, α)) in enumerate(stks) + t = pushfirst!(entry, xcall(:getfield, self, QuoteNode(:t))) + repl = Dict{Alpha,Variable}() + for b in blcks + for (i, (b′, α, use_stack)) in enumerate(stks) b.id == b′ || continue - if runonce(b) - val = insertafter!(ir, t, xcall(:getindex, t, i)) - else - stk = push!(entry, xcall(:getindex, t, i)) - stk = push!(entry, xcall(Zygote, :Stack, stk)) + # i.e. recs[i] from forward_stacks! + val = insertafter!(ir, t, xcall(:getindex, t, i)) + if use_stack + stk = push!(entry, xcall(Zygote, :Stack, val)) val = pushfirst!(b, xcall(:pop!, stk)) + elseif !runonce(b) + # The first and last blocks always run, so this check is redundant there + val = pushfirst!(b, xcall(Zygote, :notnothing, val)) end repl[α] = val end @@ -87,6 +167,7 @@ end function stacks!(adj, T) forw, stks = forward_stacks!(adj, T) + IRTools.domorder!(forw) back = reverse_stacks!(adj, stks) permute!(back, length(back.blocks):-1:1) IRTools.domorder!(back) @@ -97,6 +178,7 @@ varargs(m::Method, n) = m.isva ? n - m.nargs + 1 : nothing function _generate_pullback_via_decomposition(T) (m = meta(T)) === nothing && return + # Core.println("decomp: ", T) va = varargs(m.method, length(T.parameters)) forw, back = stacks!(Adjoint(IR(m), varargs = va, normalise = false), T) m, forw, back diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index bf3692a30..5b24d99c0 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -33,10 +33,10 @@ end meta, forw, _ = g argnames!(meta, Symbol("#self#"), :ctx, :f, :args) forw = varargs!(meta, forw, 3) - # IRTools.verify(forw) + # verify(forw) forw = slots!(pis!(inlineable!(forw))) # be ready to swap to using chainrule if one is declared - cr_edge != nothing && edge!(meta, cr_edge) + cr_edge !== nothing && edge!(meta, cr_edge) return update!(meta.code, forw) end @@ -53,7 +53,7 @@ end end meta, _, back = g argnames!(meta, Symbol("#self#"), :Δ) - # IRTools.verify(back) + # verify(back) back = slots!(inlineable!(back)) return update!(meta.code, back) end diff --git a/src/lib/lib.jl b/src/lib/lib.jl index f154ecd2a..7f717591b 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -69,6 +69,9 @@ function accum_global(cx::Context, ref, x̄) return end +# Needed for nested AD +@nograd accum_global + unwrap(x) = x @adjoint unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),) diff --git a/test/compiler.jl b/test/compiler.jl index bc37d271e..4b342a36e 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -1,5 +1,6 @@ using Zygote, Test using Zygote: pullback, @adjoint +using IRTools macro test_inferred(ex) :(let res = nothing @@ -30,11 +31,11 @@ y, back = pullback(badly, 2) @test_throws Exception back(1) bt = try back(1) catch e stacktrace(catch_backtrace()) end -@test trace_contains(bt, nothing, "compiler.jl", 20) +@test trace_contains(bt, nothing, "compiler.jl", 21) if VERSION >= v"1.6-" - @test_broken trace_contains(bt, :badly, "compiler.jl", 24) + @test_broken trace_contains(bt, :badly, "compiler.jl", 25) else - @test trace_contains(bt, :badly, "compiler.jl", 24) + @test trace_contains(bt, :badly, "compiler.jl", 25) end # Type inference checks @@ -58,10 +59,9 @@ y, back = @test_inferred pullback(f, 5) y, back = @test_inferred pullback(Core._apply, +, (1, 2, 3)) @test_inferred back(1) -# TODO fix bcast inference -# bcast(x) = x .* 5 -# y, back = @test_inferred pullback(bcast, [1,2,3]) -# @test_inferred back([1,1,1]) +bcast(x) = x .* 5 +y, back = @test_inferred pullback(bcast, [1,2,3]) +@test_inferred back([1,1,1]) foo = let a = 4 x -> x*a @@ -91,6 +91,43 @@ struct Funky y end +@testset "stack elision" begin + function stackfree(T) + _, forw = Zygote._generate_pullback_via_decomposition(T) + for b in IRTools.blocks(forw) + bb = IRTools.BasicBlock(b) + for stmt in bb.stmts + expr = stmt.expr + expr.head == :call && expr.args[1:2] == [Zygote, :_push!] && return false + end + end + return true + end + + function knockoff_pow(x, n) + n == 0 && return 1 + n == 1 && return x + n == 2 && return x * x + n == 3 && return x * x * x + return x ^ n + end + + function roundabout_trig(x, fancy_sin, fancy_cos, fancy_tan) + if fancy_tan + s = fancy_sin ? inv(csc(x)) : sin(x) + c = fancy_cos ? inv(sec(x)) : cos(x) + s += 0 + c *= 1 + return s / c + else + return tan(x) + end + end + + @test stackfree(Tuple{typeof(knockoff_pow), Int, Int}) + @test stackfree(Tuple{typeof(roundabout_trig), Float64, Bool, Bool, Bool}) +end + @testset "issue #851" begin f = Funky(1, 1); function Base.getproperty(f::Funky, i::Symbol) diff --git a/test/runtests.jl b/test/runtests.jl index 17ebb3997..8434d8011 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using Zygote, Test using Zygote: gradient, ZygoteRuleConfig using CUDA using CUDA: has_cuda +using LinearAlgebra @testset "all" begin # Overall testset ensures it keeps running after failure diff --git a/test/utils.jl b/test/utils.jl index b6d6ed018..3c2a98e4c 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,4 +1,3 @@ -using LinearAlgebra using ForwardDiff using Zygote: hessian_dual, hessian_reverse