Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Don't use stacks for simple control flow #78

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft

Conversation

Keno
Copy link
Contributor

@Keno Keno commented Feb 24, 2019

Right now Zygote inserts stacks whenever it needs to use an ssa value
not defined in the first basic block. This is of course unnecessary.
The condition for needing stacks is that the basic block that defines
it is self-reachable (i.e. in a loop). Otherwise, we can simply insert
phi nodes to thread the desired SSA value through to the exit block
(we don't need to do anything in the adjoint, since the reversal of
the CFG ensures dominance). Removing stacks allows for both more
efficient code generation and enables higher order auto-diff (since
we use control flow in Zygote, but can't handle differentiating code
that contains stacks). The headline example is something like the following:

function foo(b, x)
    if b
        sin(x)
    else
        cos(x)
    end
end

Then looking at @code_typed derivative(x->foo(true, x), 1.0), we get:

Before:

CodeInfo(
1 ── %1  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Int8,1}, svec(Any, Int64), :(:ccall), 2, Array{Int8,1}, 0, 0))::Array{Int8,1}
│    %2  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1}
│    %3  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1}
│    %4  = Base.sin::typeof(sin)
│          invoke %4(_3::Float64)::Float64
│    %6  = %new(##334#335{Float64}, x)::##334#335{Float64}
│    %7  = %new(##758#back#336{##334#335{Float64}}, %6)::##758#back#336{##334#335{Float64}}
[snip]
23 ─ %52 = invoke %47(1::Int8)::Tuple{Nothing,Nothing,Any}
│    %53 = Base.getfield(%52, 3, true)::Any
└───       goto #24
24 ─       return %53
) => Any

After (assuming some slight improvements to inference precision, that either
need to be enabled upstream or worked around here):

CodeInfo(
1 ─ %1 = Base.sin::typeof(sin)
│        invoke %1(_3::Float64)::Float64
│   %3 = Core.Intrinsics.not_int(true)::Bool
└──      goto #3 if not %3
2 ─      invoke Zygote.notnothing(nothing::Nothing)::Union{}
└──      $(Expr(:unreachable))::Union{}
3 ┄ %7 = invoke Zygote.cos(_3::Float64)::Float64
│   %8 = Base.mul_float(1.0, %7)::Float64
└──      goto #4
4 ─      goto #5
5 ─      goto #6
6 ─      goto #7
7 ─      return %8
) => Float64

Which is essentially perfect (there's a bit of junk left over, but LLVM
can take care of that. The only thing that doesn't get removed is the
useless invocation of sin, but that's a separate and known issue).

Right now Zygote inserts stacks whenever it needs to use an ssa value
not defined in the first basic block. This is of course unnecessary.
The condition for needing stacks is that the basic block that defines
it is self-reachable (i.e. in a loop). Otherwise, we can simply insert
phi nodes to thread the desired SSA value through to the exit block
(we don't need to do anything in the adjoint, since the reversal of
the CFG ensures dominance). Removing stacks allows for both more
efficient code generation and enables higher order auto-diff (since
we use control flow in Zygote, but can't handle differentiating code
that contains stacks). The headline example is something like the following:

```
function foo(b, x)
    if b
        sin(x)
    else
        cos(x)
    end
end
```

Then looking at `@code_typed derivative(x->foo(true, x), 1.0)`, we get:

Before:
```
CodeInfo(
1 ── %1  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Int8,1}, svec(Any, Int64), :(:ccall), 2, Array{Int8,1}, 0, 0))::Array{Int8,1}
│    %2  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1}
│    %3  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1}
│    %4  = Base.sin::typeof(sin)
│          invoke %4(_3::Float64)::Float64
│    %6  = %new(##334#335{Float64}, x)::##334#335{Float64}
│    %7  = %new(##758#back#336{##334#335{Float64}}, %6)::##758#back#336{##334#335{Float64}}
[snip]
23 ─ %52 = invoke %47(1::Int8)::Tuple{Nothing,Nothing,Any}
│    %53 = Base.getfield(%52, 3, true)::Any
└───       goto #24
24 ─       return %53
) => Any
```

After:
```
CodeInfo(
1 ─ %1 = Base.sin::typeof(sin)
│        invoke %1(_3::Float64)::Float64
│   %3 = Core.Intrinsics.not_int(true)::Bool
└──      goto #3 if not %3
2 ─      invoke Zygote.notnothing(nothing::Nothing)::Union{}
└──      $(Expr(:unreachable))::Union{}
3 ┄ %7 = invoke Zygote.cos(_3::Float64)::Float64
│   %8 = Base.mul_float(1.0, %7)::Float64
└──      goto #4
4 ─      goto #5
5 ─      goto #6
6 ─      goto #7
7 ─      return %8
) => Float64
```

Which is essentially perfect (there's a bit of junk left over, but LLVM
can take care of that. The only thing that doesn't get removed is the
useless invocation of `sin`, but that's a separate and known issue).
@ToucheSir
Copy link
Member

With the benefit of a couple years of hindsight, would something like this still be worth pursuing? nested reverse AD and performance of type unstable backwards passes have gained more prominence as issues recently, and my rudimentary understanding is that this PR could address a chunk of both?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants