Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utilize ChainRulesCore thunks #966

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
5 changes: 3 additions & 2 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ module Zygote
using LinearAlgebra, Statistics
using LinearAlgebra: copytri!, AbstractTriangular

import ZygoteRules
import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
literal_getproperty, literal_getfield
literal_getproperty, literal_getfield, unthunk_tangent

using ChainRulesCore
using ChainRules: ChainRules, rrule, unthunk, canonicalize
using ChainRules: ChainRules, AbstractThunk, rrule, unthunk, canonicalize
using IRTools
using MacroTools, Requires
using MacroTools: @forward
Expand Down
12 changes: 11 additions & 1 deletion src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# ToDo: Move some of this to ZygoteRules, or move unthunk_tangent for Tuple and NamedTuple from
# Zygote rules here?
function unthunk_tangent end
@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x
@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x
@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x)
unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
oschulz marked this conversation as resolved.
Show resolved Hide resolved
@non_differentiable unthunk_tangent(::IdDict)


struct ZygoteRuleConfig{CTX<:AContext} <: RuleConfig{Union{HasReverseMode,NoForwardsMode}}
context::CTX
end
Expand Down Expand Up @@ -102,7 +113,6 @@ is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f)
Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally.
"""
@inline wrap_chainrules_output(x) = x
@inline wrap_chainrules_output(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) # For now we are just not going to deal with thunks
oschulz marked this conversation as resolved.
Show resolved Hide resolved
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
@inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing
Expand Down
10 changes: 8 additions & 2 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ end
_pullback(f, args...) = _pullback(Context(), f, args...)

tailmemaybe(::Nothing) = nothing
tailmemaybe(x::Tuple) = Base.tail(x)
tailmemaybe(x::Tuple) = unthunk_tangent(Base.tail(x))

# unthunking is essentially an identity operation on a lazy value, but
# `@adjoint unthunk_tangent(x) = unthunk_tangent(x), ȳ -> (ȳ,)` is not enough to make
# nested AD work, so define
@adjoint tailmemaybe(xs::Tuple) = tailmemaybe(xs), x̄s -> ((nothing, x̄s...),)


@inline pullback(f, args...) = pullback(f, Context(), args...)
function pullback(f, cx::AContext, args...)
Expand Down Expand Up @@ -376,7 +382,7 @@ function pullback(f, ps::Params)
cache(cx)[p] = nothing
end
back(Δ)
Grads(cx.cache, ps) # TODO make a copy
Grads(unthunk_tangent(cx.cache), ps) # TODO make a copy
end
end

Expand Down
12 changes: 12 additions & 0 deletions src/compiler/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@ using IRTools: IR, Variable, Pipe, xcall, var, prewalk, postwalk,
insertafter!, finish, expand!, prune!, substitute!, substitute,
block, block!, branch!, return!, stmt, meta


# TODO: Temporary, to be removed when ChainRulesCore rrules are required to
oschulz marked this conversation as resolved.
Show resolved Hide resolved
# support thunks as an input and all instances of _adjoint_keepthunks in
# Zygote have been replaces by rrules:
macro _adjoint_keepthunks(ex)
ZygoteRules.gradm(ex, false, true)
end
macro _adjoint_keepthunks!(ex)
ZygoteRules.gradm(ex, true, true)
end


@inline tuple_va(N, xs) = xs
@inline tuple_va(N, x, xs...) = (x, tuple_va(N, xs...)...)
@inline tuple_va(::Val{N}, ::Nothing) where N = ntuple(_ -> nothing, Val(N))
Expand Down
3 changes: 2 additions & 1 deletion src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)})
end

function unbroadcast(x::AbstractArray, x̄)
function unbroadcast(x::AbstractArray, maybethunked_x̄)
x̄ = unthunk_tangent(maybethunked_x̄)
oschulz marked this conversation as resolved.
Show resolved Hide resolved
N = ndims(x̄)
if length(x) == length(x̄)
_project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors
Expand Down
44 changes: 22 additions & 22 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ function accum(x::RefValue, y::RefValue)
end

# Core functions
@adjoint deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)
@_adjoint_keepthunks deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)

@adjoint (::Type{V})(x...) where V<:Val = V(x...), _ -> nothing
@_adjoint_keepthunks (::Type{V})(x...) where V<:Val = V(x...), _ -> nothing

@adjoint ifelse(cond::Bool, t, f) =
@_adjoint_keepthunks ifelse(cond::Bool, t, f) =
ifelse(cond, t, f),
Δ -> cond ? (nothing, Δ, zero(Δ)) : (nothing, zero(Δ), Δ)

@adjoint Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing)
@_adjoint_keepthunks Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing)

accum_param(::Context{false}, _, Δ) = Δ
@generated function accum_param(cx::Context, x, Δ)
Expand All @@ -70,11 +70,11 @@ end

unwrap(x) = x

@adjoint unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),)
@_adjoint_keepthunks unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),)

unwrap(ref, x) = x

@adjoint unwrap(ref, x) = unwrap(x), function (x̄)
@_adjoint_keepthunks unwrap(ref, x) = unwrap(x), function (x̄)
accum_global(__context__, ref, x̄)
(accum_param(__context__, x, x̄),)
end
Expand All @@ -88,7 +88,7 @@ function global_set(ref, val)
end
end

@adjoint! function global_set(ref, x)
@_adjoint_keepthunks! function global_set(ref, x)
global_set(ref, x), function (x̄)
gs = cache(__context__)
x̄ = accum(get(gs, ref, nothing), x̄)
Expand All @@ -101,9 +101,9 @@ end

using Base: tail

@adjoint tuple(xs...) = xs, identity
@_adjoint_keepthunks tuple(xs...) = xs, identity

@adjoint function literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i}
@_adjoint_keepthunks function literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i}
val = xs[i]
function back(Δ)
accum_param(__context__, val, Δ) === nothing && return
Expand All @@ -112,7 +112,7 @@ using Base: tail
val, back
end

@adjoint function getindex(xs::NTuple{N,Any}, i::Integer) where N
@_adjoint_keepthunks function getindex(xs::NTuple{N,Any}, i::Integer) where N
val = xs[i]
function back(Δ)
accum_param(__context__, val, Δ) === nothing && return
Expand All @@ -121,10 +121,10 @@ end
return val, back
end

@adjoint getindex(xs::NTuple{N,Any}, r::AbstractUnitRange) where N =
@_adjoint_keepthunks getindex(xs::NTuple{N,Any}, r::AbstractUnitRange) where N =
(xs[r], Δ -> (ntuple(j -> j in r ? Δ[findfirst(isequal(j), r)] : nothing, Val(N)), nothing))

@adjoint function getindex(xs::NTuple{N,Any}, r::AbstractVector) where N
@_adjoint_keepthunks function getindex(xs::NTuple{N,Any}, r::AbstractVector) where N
val = xs[r]
function back(Δ)
dxs = ntuple(Val(length(xs))) do x
Expand Down Expand Up @@ -155,18 +155,18 @@ function _pullback(cx::AContext, ::typeof(literal_indexed_iterate), xs::Tuple, :
end

# Needed for iteration lowering
@adjoint Core.getfield(xs::NTuple{N,Any}, i::Int) where N =
@_adjoint_keepthunks Core.getfield(xs::NTuple{N,Any}, i::Int) where N =
(xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing))

@adjoint Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Int) where {K,N} =
@_adjoint_keepthunks Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Int) where {K,N} =
(xs[i], Δ -> (NamedTuple{K}(ntuple(j -> i == j ? Δ : nothing, Val(N))), nothing))

@adjoint function Base.first(xs::Tuple)
@_adjoint_keepthunks function Base.first(xs::Tuple)
drest = map(_->nothing, tail(xs))
first(xs), Δ -> ((Δ, drest...),)
end

@adjoint Base.tail(xs::Tuple) = tail(xs), x̄s -> ((nothing, x̄s...),)
@_adjoint_keepthunks Base.tail(xs::Tuple) = tail(xs), x̄s -> ((nothing, x̄s...),)

_empty(x) = length(x)
_empty(x::Union{Tuple,NamedTuple}) = map(_->nothing, x)
Expand All @@ -188,7 +188,7 @@ end

unapply(t, xs) = _unapply(t, xs)[1]

@adjoint! function Core._apply(f, args...)
@_adjoint_keepthunks! function Core._apply(f, args...)
y, back = Core._apply(_pullback, (__context__, f), args...)
st = map(_empty, args)
y, function (Δ)
Expand All @@ -199,7 +199,7 @@ unapply(t, xs) = _unapply(t, xs)[1]
end

if VERSION >= v"1.4.0-DEV.304"
@adjoint! function Core._apply_iterate(::typeof(iterate), f, args...)
@_adjoint_keepthunks! function Core._apply_iterate(::typeof(iterate), f, args...)
y, back = Core._apply(_pullback, (__context__, f), args...)
st = map(_empty, args)
y, function (Δ)
Expand All @@ -225,7 +225,7 @@ end
@generated pair(::Val{k}, v, _=nothing) where k = :($k = v,)
@generated pair(::Val{k}, v, ::NamedTuple{keys}) where {k,keys} = k isa Int ? :($(getfield(keys, k)) = v,) : :($k = v,)

@adjoint function literal_getfield(x, ::Val{f}) where f
@_adjoint_keepthunks function literal_getfield(x, ::Val{f}) where f
val = getfield(x, f)
function back(Δ)
accum_param(__context__, val, Δ) === nothing && return
Expand Down Expand Up @@ -273,7 +273,7 @@ function grad_mut(cx::Context, x)
end
end

@adjoint! function setfield!(x, f, val)
@_adjoint_keepthunks! function setfield!(x, f, val)
y = setfield!(x, f, val)
g = grad_mut(__context__, x)
y, function (_)
Expand All @@ -289,13 +289,13 @@ end

Jnew{T}(g) where T = Jnew{T,typeof(g)}(g)

@adjoint! function __new__(T, args...)
@_adjoint_keepthunks! function __new__(T, args...)
x = __new__(T, args...)
g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
x, Jnew{T,typeof(g),false}(g)
end

@adjoint! function __splatnew__(T, args)
@_adjoint_keepthunks! function __splatnew__(T, args)
x = __splatnew__(T, args)
g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
x, Jnew{T,typeof(g),true}(g)
Expand Down