Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Don't use stacks for simple control flow #78

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft

Commits on Feb 24, 2019

  1. WIP: Don't use stacks for simple control flow

    Right now Zygote inserts stacks whenever it needs to use an ssa value
    not defined in the first basic block. This is of course unnecessary.
    The condition for needing stacks is that the basic block that defines
    it is self-reachable (i.e. in a loop). Otherwise, we can simply insert
    phi nodes to thread the desired SSA value through to the exit block
    (we don't need to do anything in the adjoint, since the reversal of
    the CFG ensures dominance). Removing stacks allows for both more
    efficient code generation and enables higher order auto-diff (since
    we use control flow in Zygote, but can't handle differentiating code
    that contains stacks). The headline example is something like the following:
    
    ```
    function foo(b, x)
        if b
            sin(x)
        else
            cos(x)
        end
    end
    ```
    
    Then looking at `@code_typed derivative(x->foo(true, x), 1.0)`, we get:
    
    Before:
    ```
    CodeInfo(
    1 ── %1  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Int8,1}, svec(Any, Int64), :(:ccall), 2, Array{Int8,1}, 0, 0))::Array{Int8,1}
    │    %2  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1}
    │    %3  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1}
    │    %4  = Base.sin::typeof(sin)
    │          invoke %4(_3::Float64)::Float64
    │    %6  = %new(##334#335{Float64}, x)::##334#335{Float64}
    │    %7  = %new(##758#back#336{##334#335{Float64}}, %6)::##758#back#336{##334#335{Float64}}
    [snip]
    23 ─ %52 = invoke %47(1::Int8)::Tuple{Nothing,Nothing,Any}
    │    %53 = Base.getfield(%52, 3, true)::Any
    └───       goto #24
    24 ─       return %53
    ) => Any
    ```
    
    After:
    ```
    CodeInfo(
    1 ─ %1 = Base.sin::typeof(sin)
    │        invoke %1(_3::Float64)::Float64
    │   %3 = Core.Intrinsics.not_int(true)::Bool
    └──      goto #3 if not %3
    2 ─      invoke Zygote.notnothing(nothing::Nothing)::Union{}
    └──      $(Expr(:unreachable))::Union{}
    3 ┄ %7 = invoke Zygote.cos(_3::Float64)::Float64
    │   %8 = Base.mul_float(1.0, %7)::Float64
    └──      goto #4
    4 ─      goto #5
    5 ─      goto #6
    6 ─      goto #7
    7 ─      return %8
    ) => Float64
    ```
    
    Which is essentially perfect (there's a bit of junk left over, but LLVM
    can take care of that. The only thing that doesn't get removed is the
    useless invocation of `sin`, but that's a separate and known issue).
    Keno committed Feb 24, 2019
    Configuration menu
    Copy the full SHA
    b909b9b View commit details
    Browse the repository at this point in the history