Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Free CuArrays in the reverse pass #1340

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ MacroTools = "0.5"
NaNMath = "0.3, 1"
Requires = "1.1"
SpecialFunctions = "1.6, 2"
ZygoteRules = "0.2.1"
ZygoteRules = "0.2.3"
julia = "1.6"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using LinearAlgebra, Statistics
using LinearAlgebra: copytri!, AbstractTriangular

import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
literal_getproperty, literal_getfield
literal_getproperty, literal_getfield, maybe_final, @adjoint_final

using ChainRulesCore
using ChainRules: ChainRules, rrule, unthunk, canonicalize
Expand Down
38 changes: 29 additions & 9 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ struct ZygoteRuleConfig{CTX<:AContext} <: RuleConfig{Union{HasReverseMode,NoForw
end
ZygoteRuleConfig() = ZygoteRuleConfig(Context())

@inline only_once(::Type{<:AContext}) = false # can't directly use Context{true,true} as not defined yet

_is_rrule_redispatcher(m::Method) = m.sig == Tuple{typeof(rrule), RuleConfig, Vararg}

Expand Down Expand Up @@ -195,17 +196,26 @@ _project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x)
(project::ProjectTo{AbstractArray})(dx::Tangent) = dx

"""
ZBack{F}(back) <: Function
ZBack{Y,F}(y, back) <: Function

Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conventions.
(A functor here is used rather than a closure to avoid boxing issues);
Now captures the forward result to call `finalize(y)` when done, if `only_once` says this is safe.
"""
struct ZBack{F} <: Function
struct ZBack{Y,F} <: Function
fwd::Y
back::F
end
@inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy)))
@inline (s::ZBack{Nothing})(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy)))
@inline function (s::ZBack)(dy)
∇s = wrap_chainrules_output(s.back(wrap_chainrules_input(dy)))
maybe_final(s.fwd)
∇s
end

# `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603
# though it might be worth keeping as a performance optimization (benchmarking pending)
@inline (s::ZBack{Nothing})(::Nothing) = nothing
@inline (s::ZBack)(::Nothing) = nothing

"""
Expand All @@ -214,9 +224,11 @@ end
Returns a the (primal) value of `f(args...)` and a pullback, by invoking `ChainRulesCore.rrule(f, args...)`.
The pullback is appropriately wrapped up to follow Zygote conventions.
"""
@inline function chain_rrule(config, f, args...)
# @inline function chain_rrule(config::ZygoteRuleConfig{Context{I,O}}, f::F, args...) where {I,O,F}
@inline function chain_rrule(config::ZygoteRuleConfig{C}, f::F, args...) where {C,F}
y, back = rrule(config, f, args...)
return y, ZBack(back)
free = only_once(C) ? y : nothing
return y, ZBack(free, back)
end


Expand All @@ -226,10 +238,12 @@ end
As per [`chain_rrule`](@ref) but with support for kwargs.
`kwf` should be the kwfunc matching to `f`, and `kwargs` are a `NamedTuple` of keyword arguments.
"""
@inline function chain_rrule_kw(config, kwf, kwargs, f, args...)
# @inline function chain_rrule_kw(config::ZygoteRuleConfig{Context{I,O}}, kwf, kwargs, f::F, args...) where {I,O,F}
@inline function chain_rrule_kw(config::ZygoteRuleConfig{C}, kwf, kwargs, f::F, args...) where {C,F}
y, back = rrule(config, f, args...; kwargs...)
free = only_once(C) ? y : nothing
function kw_zpullback(dy)
dxs = ZBack(back)(dy)
dxs = ZBack(free, back)(dy)
if dxs === nothing # if dxs is nothing, then all partiaols are nothing
# Zygote convention is a single nothing no mather how partials, if all are nothing
return nothing
Expand All @@ -240,7 +254,8 @@ As per [`chain_rrule`](@ref) but with support for kwargs.
return y, kw_zpullback
end

function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs...)
# function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig{Context{I,O}}, f_args...; kwargs...) where {I,O}
function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig{C}, f_args...; kwargs...) where {C}
# first check whether there is an `rrule` which handles this directly
direcct = rrule(config, f_args...; kwargs...)
direcct === nothing || return direcct
Expand All @@ -255,7 +270,12 @@ function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs
_pullback(config.context, f_args...)
end

ad_pullback(Δ) = zygote2differential(pb(wrap_chainrules_output(Δ)), f_args)
free = only_once(C) ? y : nothing
function ad_pullback(Δ)
∇s = zygote2differential(pb(wrap_chainrules_output(Δ)), f_args)
maybe_final(free)
∇s
end
return y, ad_pullback
end

Expand Down
113 changes: 82 additions & 31 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@ import Base.Broadcast: broadcasted, materialize!
# Internal container used to track accumulated gradients of mutable types (including params).
# Type param I ∈ (true, false) indicates whether implicit params are in use.
# By default, this should be false unless pullback(f, ::Params) is called.
mutable struct Context{I} <: AContext
# Type parameter O ∈ (true, false) indecates whether we know the reverse pass will be
# run at most once (e.g. within gradient), defaults to false (for pullback, and jacobain).
mutable struct Context{I,O} <: AContext
cache::Union{IdDict{Any,Any},Nothing}
end

Context() = Context{false}(nothing)
Context() = Context{false,false}(nothing)
Context{I}(cache=nothing) where {I} = Context{I,false}(cache)
Context{I,O}() where {I,O} = Context{I,O}(nothing)

cache(cx::Context) = cx.cache === nothing ? (cx.cache = IdDict()) : cx.cache

@inline only_once(::Type{<:Context{<:Any,true}}) = true

struct Pullback{S,T}
t::T
end
Expand Down Expand Up @@ -93,7 +99,9 @@ julia> gradient([7, 11], 0, 1) do x, y, d
```
"""
function gradient(f, args...)
y, back = pullback(f, args...)
# Type parameters for Context are implicit=false, once=true
cx = Context{false,true}(nothing)
y, back = pullback(f, cx, args...)
grad = back(sensitivity(y))
isnothing(grad) ? nothing : map(_project, args, grad)
end
Expand All @@ -104,6 +112,21 @@ Base.adjoint(f::Function) = x -> begin # still piracy! avoids projection for le
back(sensitivity(y))[1]
end

# This is inserted into @adjoint_final by ZygoteRules
@inline maybe_final(::Context{<:Any,true}, x) = maybe_final(x)
# The goal is to free CuArrays promptly.
@inline maybe_final(x::DenseArray) = finalize(x)

# Without an @adjoint rule for this, some hessian tests fail:
# Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_finalize_th), Nothing
# And if it in fact finalises, then other 2nd derivative tests fail. So do nothing:
@adjoint maybe_final(x) = nothing, _ -> nothing
@adjoint maybe_final(::Context, x) = nothing, _ -> nothing

# Probably just for testing:
maybe_final(x::Vector) = resize!(x, 0)
maybe_final(x::Array{<:AbstractFloat}) = fill!(x, NaN)

"""
withgradient(f, args...)
withgradient(f, ::Params)
Expand All @@ -129,40 +152,16 @@ julia> res.grad[w]
```
"""
function withgradient(f, args...)
y, back = pullback(f, args...)
# Type parameters for Context are implicit=false, once=true
cx = Context{false,true}()
y, back = pullback(f, cx, args...)
grad = back(sensitivity(y))
results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad)
(val=y, grad=results)
end

# Param-style wrappers

"""
gradient(() -> loss(), ps::Params) -> Grads

Gradient with implicit parameters. Takes a zero-argument function,
and returns a dictionary-like container, whose keys are arrays `x in ps`.

See also [`withgradient`](@ref) to keep the value `loss()`.

```jldoctest; setup=:(using Zygote)
julia> x = [1 2 3; 4 5 6]; y = [7, 8]; z = [1, 10, 100];

julia> g = gradient(Params([x, y])) do
sum(x .* y .* z')
end
Grads(...)

julia> g[x]
2×3 Matrix{Float64}:
7.0 70.0 700.0
8.0 80.0 800.0

julia> haskey(g, z) # only x and y are parameters
false
```
"""
gradient
# Param-style wrappers

"""
Params([A, B])
Expand Down Expand Up @@ -391,6 +390,58 @@ function pullback(f, ps::Params)
end
end

"""
gradient(() -> loss(), ps::Params) -> Grads

Gradient with implicit parameters. Takes a zero-argument function,
and returns a dictionary-like container, whose keys are arrays `x in ps`.

See also [`withgradient`](@ref) to keep the value `loss()`.

```jldoctest; setup=:(using Zygote)
julia> x = [1 2 3; 4 5 6]; y = [7, 8]; z = [1, 10, 100];

julia> g = gradient(Params([x, y])) do
sum(x .* y .* z')
end
Grads(...)

julia> g[x]
2×3 Matrix{Float64}:
7.0 70.0 700.0
8.0 80.0 800.0

julia> haskey(g, z) # only x and y are parameters
false
```
"""
function gradient(f, ps::Params)
y, back = pullback(f, ps)
back(sensitivity(y))
end

"""
withgradient(f, ps::Params) -> Grads

Returns both the value of the function and the [`gradient`](@ref),
as a named tuple.

```jldoctest; setup=:(using Zygote)
julia> w = [3.0];

julia> res = withgradient(() -> sum(abs2, w), Params([w])) # implicit mode
(val = 9.0, grad = Grads(...))

julia> res.grad[w]
1-element Vector{Float64}:
6.0
```
"""
function withgradient(f, ps::Params)
y, back = pullback(f, ps)
(val=y, grad=back(sensitivity(y)))
end

# Code Reflection

function code_ir(f, T)
Expand Down
2 changes: 2 additions & 0 deletions src/compiler/interface2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ end
chain_rrule_f = :chain_rrule
end

# Here ZygoteRuleConfig{Zygote.Context{false, true}} is passed to chain_rrule

hascr, cr_edge = has_chain_rrule(cr_T)
hascr && return :($chain_rrule_f(ZygoteRuleConfig(ctx), f, args...))

Expand Down
4 changes: 3 additions & 1 deletion src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)]
ys = map(first, ys_and_backs)
arg_ax = map(_tryaxes, args)
function map_back(Δ)
if Base.issingletontype(F) && length(args) == 1
∇s = if Base.issingletontype(F) && length(args) == 1
Δarg = $mapfunc(((_,pb), δ) -> last_or_nothing(pb(δ)), ys_and_backs, Δ) # No unzip needed
(nothing, Δarg)
elseif Base.issingletontype(F)
Expand All @@ -207,6 +207,8 @@ for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)]
Δargs = map(_restore, Δf_and_args[2:end], arg_ax)
(Δf, Δargs...)
end
maybe_final(cx, ys_and_backs)
∇s
end
map_back(::Nothing) = nothing
return ys, map_back
Expand Down
Loading