Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attach rule to mapfoldl_impl not foldl #569

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 126 additions & 75 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -417,137 +417,188 @@ end
end

#####
##### `foldl`
##### `mapfoldl(f, g, ::Tuple)`
#####

using Base: mapfoldl_impl

# For tuples there should be no harm in handling `map` first.
# This will also catch `mapreduce`.

function rrule(
cfg::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), f::F, op::G, init, x::Tuple;
) where {F,G}
y, backmap = rrule(cfg, map, f, x)
z, backred = rrule(cfg, Base.mapfoldl_impl, identity, op, init, y)
function mapfoldl_pullback_tuple(dz)
_, _, dop, dinit, dy = backred(dz)
_, df, dx = backmap(dy)
return (NoTangent(), df, dop, dinit, dx)
end
return z, mapfoldl_pullback_tuple
end

#####
##### `foldl(f, ::Tuple)`
#####

# `foldl` guarantees to execute `f` in order, left to right. So it makes sense even when
# this `f` is stateful, in which case the gradient must be calculated in the reverse order.
# this `f` is stateful, in which case the gradient must be calculated in the reverse order.

# The implementation aims to be efficient for both tuples and arrays, although using accumulate
# to carry intermediate results along creates arrays of tuples which could be avoided; using a
# loop can be a few times faster. Note also that it does not return a gradient for `init`.
# The rule is attached to `Base.mapfoldl_impl` because this gets the `init` keyword as an argument,
# which is handled below. For tuples, `reduce` also comes here.

function rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(foldl), op::G, x::Union{AbstractArray, Tuple};
init=_InitialValue()
config::RuleConfig{>:HasReverseMode},
::typeof(Base.mapfoldl_impl),
::typeof(identity),
op::G,
init::Base._InitialValue,
x::Tuple;
) where {G}
list, start = if init === _InitialValue()
_drop1(x), first(x)
else
# Case with init keyword is simpler to understand first!
_reshape1(x, :), init # (vec is for Julia 1.0, accumulate is fussy)
end
hobbits = accumulate(list; init=(start, nothing)) do (a,_), b
hobbits = accumulate(Base.tail(x); init=(first(x), nothing)) do (a, _), b
# Here `a` is what we would normally cary forward, and `_` ignores
# the previous iteration's pullback function (needed later),
# while `b` is the fresh input from `list` as usual.
c, back = rrule_via_ad(config, op, a, b) # LHS is just documentation here!
c, back = rrule_via_ad(config, op, a, b)
# We don't really need to store every `c`, last one is `foldl` output.
# (The name, BTW, is because "there and back again" is the subtitle of Tolkien's book.)
end
y = first(last(hobbits))
axe = axes(x)
project = ProjectTo(x)
function unfoldl(dy)
trio = accumulate(_reverse1(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
function foldl_pullback_tuple(dy)
trio = accumulate(reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
ds, da, db = back(dc)
# Don't need to store every `da`, need one for the next iteration + maybe last
# Don't need to store every `da`, need one for the next iteration + the last.
end
dop = sum(first, trio)
dx = map(last, _reverse1(trio))
if init === _InitialValue()
# `hobbits` is one short
dx = _vcat1(trio[end][2], dx)
end
return (NoTangent(), dop, project(_reshape1(dx, axe)))
dx = (trio[end][2], reverse(map(last, trio))...)
return (NoTangent(), NoTangent(), ProjectTo(op)(dop), NoTangent(), project(dx))
end
return y, unfoldl
return y, foldl_pullback_tuple
end

function rrule(
config::RuleConfig{>:HasReverseMode},
::typeof(Base.mapfoldl_impl),
::typeof(identity),
op::G,
init,
x::Tuple;
) where {G}
# Trivial case handled here to avoid ambiguities (and necc. because of Base.tail below)
foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent())
isempty(x) && return init, foldl_pullback_empty

# Treat `init` by simply appending it to the `x`:
y, back = rrule(config, Base.mapfoldl_impl, identity, op, Base._InitialValue(), (init, x...))
project_x = ProjectTo(x)
project_in = ProjectTo(init)
function foldl_pullback_tuple_init(dy)
_, _, dop, _, dxplus = back(dy)
return (NoTangent(), NoTangent(), dop, project_in(first(dxplus)), project_x(Base.tail(dxplus)))
end
return y, foldl_pullback_tuple_init
end

#####
##### Iterator-or-Tuple functions
##### `foldl(f, ::Array)`
#####

# This zoo of underscore functions helps `foldl` & `accumulate` handle both tuples and arrays,
# and also provides some alternatives for versions of Julia where iterators weren't supported.
# Inspired by `Base._reverse`, used in defn of `foldr`.
# The implementation was originally for both tuples and arrays, although using accumulate
# to carry intermediate results along creates arrays of tuples which could be avoided.
# Using a loop can be a few times faster, this should be replaced:
# https://github.com/FluxML/Zygote.jl/issues/644#issuecomment-628762305

# To support 2nd derivatives, some may need their own gradient rules. And _drop1 should perhaps
# be replaced by _peel1 like Iterators.peel
# Note also that it does not return a gradient for `init`, now marked `@not_implemented`.

_reverse1(x) = Iterators.reverse(x)
_drop1(x) = Iterators.drop(x, 1)
_zip2(x, y) = zip(x, y) # for `accumulate`, below

_reverse1(x::Tuple) = reverse(x)
_drop1(x::Tuple) = Base.tail(x)
_zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N)

struct _InitialValue end # Old versions don't have `Base._InitialValue`
function rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple};
) where {G}
start, list = if init === Base._InitialValue()
Iterators.peel(x)
else
# Case with init keyword is simpler to understand first!
init, x
end
hobbits = accumulate(list; init=(start, nothing)) do (a, _), b
c, back = rrule_via_ad(config, op, a, b)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to not capture the accumulated outputs (cs) in the pullback? It seems easy enough for tuples using map, but I'm unsure if the extra allocation would be welcomed for arrays.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you can write a for loop like this: FluxML/Zygote.jl#644 (comment) . IMO this array method should probably be replaced, but not today.

Carrying c by updating a variable from inside accumulate was very slow, IIRC it hits the closure issue.

Copy link
Contributor

@ToucheSir ToucheSir Aug 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried that in FluxML/Flux.jl#2003. The main challenges are nested differentiation and handling the case when typeof(x |> f) != typeof(x |> f |> f) (you must widen, which means preallocating an array is impossible without return_type shenanigans).

So, assuming type inference cooperates, the accumulate approach seems no less promising. Would there be any objections to a post-processing step like the following which allows the GC to clean up intermediate outputs before the pullback?

# ... y = first(last(hobbits))
# If outputs are (recursively) allocated inline, we're less worried about memory overhead
# and the GC can't free them individually anyways. 
if !isbitstype(eltype(hobbits))
  hobbits = map(((_, pb)) -> (nothing, pb),  hobbits)
end
# axe = axes(x) ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the mutation has these problems, but was much quicker, maybe it can be used when safe.

The intention with writing foldl in terms of accumulate was to allow for 2nd derivatives, but not sure this actually works right now.

Re saving memory, we can add something like unzip_accumulate(f, xs; init) = StructArrays.components(StructArray(Iterators.accumulate(f, xs; init))) to free the bits we don't need anymore.

julia> accumulate([1,2,3], init=(4,5)) do prev, this
         this .+ prev
       end
3-element Vector{Tuple{Int64, Int64}}:
 (5, 6)
 (7, 8)
 (10, 11)

julia> unzip_accumulate([1,2,3], init=(4,5)) do prev, this
         this .+ prev
       end
([5, 7, 10], [6, 8, 11])

But this PR would like to kick the can down the road on such improvements.

(And others -- it returns @not_implemented for accumulate's init, might be easy to do better, but tired of adding tests... at least it's no longer wrong.)

end
y = first(last(hobbits))
axe = axes(x)
project = ProjectTo(x)
function unfoldl(dy)
trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
ds, da, db = back(dc)
end
dop = sum(first, trio)
dx = map(last, Iterators.reverse(trio))
if init === Base._InitialValue() # `hobbits` is one short
dx = _vcat1(trio[end][2], dx)
end
d_init = @not_implemented "gradient for foldl does not at present include init, sorry"
Comment on lines +536 to +539
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if init === Base._InitialValue() # `hobbits` is one short
dx = _vcat1(trio[end][2], dx)
end
d_init = @not_implemented "gradient for foldl does not at present include init, sorry"
if init === Base._InitialValue() # `hobbits` is one short
dx = _vcat1(trio[end][2], dx)
d_init = NoTangent()
else
d_init = trio[end][2]
end

Would this work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably!

It's been a while, but my memory is that I mostly got tired of making tests, so thought I'd leave that for later.

return (NoTangent(), NoTangent(), dop, d_init, project(reshape(dx, axe)))
end
return y, unfoldl
end

_vcat1(x, ys::AbstractVector) = vcat(x, ys)
_vcat1(x::AbstractArray, ys::AbstractVector) = vcat([x], ys)
_vcat1(x, ys::Tuple) = (x, ys...)

_reshape1(x::AbstractArray, axe) = reshape(x, axe)
_reshape1(x::Tuple, axe) = x

_no_tuple_tangent(dx::Tangent) = ChainRulesCore.backing(dx)
_no_tuple_tangent(dx) = dx


#####
##### `accumulate`
#####

# Like `foldl` this by definition works in order, so it makes sense to allow stateful `f`.

# Also like `foldl`, the version with a keyword `init` can't easily be given a gradient.
# Move it down to: `_accumulate!(op, B, A::AbstractVector, dims::Nothing, init::Nothing)`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we don't at present support getting back a gradient for init, except if it's nothing and then it doesn't matter.
So we might as well put this on accumulate?

I am a little uncomfortable putting rules on mutating fuctions.
Though perhaps this one is safe as we are always fully overwriting y and never reading it?
A comment to that effect would be good if so.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention was to move both to functions with positional init, and this mutating function was the best option I could find in Base's dispatch.

Then it could have a gradient for init. I didn't get around to writing one, mostly got tired of fighting tests. But at least step 1 makes step 2 easier, it would be a small PR. And for now it returns @not_implemented which is better than a silent zero, in theory at least.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think this is unsafe the same way that fill! is unsafe. Except that in practice, I think it's much less likely to cause problems, as anyone who gets to accumulate! has probably been trained out of hoping that mutation will work.

The originally envisaged use case was that the 2nd derivative of foldl would involve this accumulate gradient. But I don't recall whether I ever checked that.


function rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(accumulate), op::G, x::Union{AbstractArray, Tuple};
init=_InitialValue(), dims=nothing
config::RuleConfig{>:HasReverseMode},
::typeof(Base._accumulate!),
op::G, y::AbstractVector,
x::AbstractVector,
dims::Nothing,
init,
) where {G}
isnothing(dims) || dims == 1 && x isa Base.AbstractVecOrTuple || throw(
"accumulate(op, x; dims) is not currently supported by ChainRules, sorry"
# It's not supported by AD either, so no point calling back, and no regression:
# gradient(x -> sum(accumulate(/, x, dims=1)), rand(3,4))
# ERROR: Mutating arrays is not supported
)
list, start = if init === _InitialValue()
_drop1(x), first(x)

start, list = if init === nothing
Iterators.peel(x)
else
x, init
something(init), x
end
hobbits = accumulate(list; init = (start, nothing)) do (a, _), b
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hobbits = accumulate(list; init = (start, nothing)) do (a, _), b
# "The Hobbit", or "There and Back Again"
hobbits = accumulate(list; init = (start, nothing)) do (a, _), b

c, back = rrule_via_ad(config, op, a, b)
end
y = map(first, hobbits)
if init === _InitialValue()
if init === nothing
# `hobbits` is one short, and first one doesn't invoke `op`
y = _vcat1(first(x), y)
y[1] = first(x)
map!(first, @view(y[2:end]), hobbits)
else
map!(first, y, hobbits)
end
axe = axes(x)
project = ProjectTo(x)
function decumulate(dy)
dy_plain = _no_tuple_tangent(unthunk(dy))
rev_list = if init === _InitialValue()
# Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...))
# gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
_zip2(_reverse1(hobbits), _reverse1(dy_plain))
else
_zip2(_reverse1(hobbits), _reverse1(dy_plain))
end
dy_plain = unthunk(dy)
rev_list = zip(Iterators.reverse(hobbits), Iterators.reverse(dy_plain))
# Here we rely on `zip` to stop early when init === nothing. Begin explicit with Iterators.reverse(Iterators.drop(..., 1))
# gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
trio = accumulate(rev_list; init=(0, ZeroTangent(), 0)) do (_, dc, _), ((_, back), dz)
ds, da, db = back(dc + dz)
# Don't need to store every 'da', but need for next iteration, and the last one.
end
dop = sum(first, trio)
dx = map(last, _reverse1(trio))
if init == _InitialValue()
dx = map(last, Iterators.reverse(trio))
if init == nothing
# `hobbits` is one short, and the first one is weird
dx = _vcat1(trio[end][2] + dy_plain[1], dx)
end
return (NoTangent(), dop, project(_reshape1(dx, axe)))
dy = @not_implemented "no gradient for `B` in `accumulate!(f, B, A)`, the rule intends to support `accumulate` only"
d_init_not = @not_implemented "gradient for accumulate does not at present include init, sorry"
d_init = init === nothing ? NoTangent() : Tangent{typeof(init)}(; value = d_init_not)
return (NoTangent(), dop, dy, project(reshape(dx, axe)), NoTangent(), d_init)
end
return _reshape1(y, axe), decumulate
return reshape(y, axe), decumulate
end
Loading
Loading