Skip to content

Commit

Permalink
change accumulate -> _accumulate!
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jul 18, 2022
1 parent e4b20da commit 87b4ea4
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 27 deletions.
36 changes: 20 additions & 16 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -506,34 +506,35 @@ _no_tuple_tangent(dx) = dx

# Like `foldl` this by definition works in order, so it makes sense to allow stateful `f`.

# Also like `foldl`, the version with a keyword `init` can't easily be given a gradient.
# Move it down to: `_accumulate!(op, B, A::AbstractVector, dims::Nothing, init::Nothing)`

function rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(accumulate), op::G, x::Union{AbstractArray, Tuple};
init=_INIT, dims=nothing
config::RuleConfig{>:HasReverseMode}, ::typeof(Base._accumulate!), op::G, y, x::AbstractVector, dims::Nothing, init,
) where {G}
isnothing(dims) || dims == 1 && x isa Base.AbstractVecOrTuple || throw(
"accumulate(op, x; dims) is not currently supported by ChainRules, sorry"
# It's not supported by AD either, so no point calling back, and no regression:
# gradient(x -> sum(accumulate(/, x, dims=1)), rand(3,4))
# ERROR: Mutating arrays is not supported
)
list, start = if init === _INIT

list, start = if init === nothing
_drop1(x), first(x)
else
x, init
x, something(init)
end
hobbits = accumulate(list; init = (start, nothing)) do (a, _), b
c, back = rrule_via_ad(config, op, a, b)
end
y = map(first, hobbits)
if init === _INIT
# y = map(first, hobbits)
if init === nothing
# `hobbits` is one short, and first one doesn't invoke `op`
y = _vcat1(first(x), y)
# y = _vcat1(first(x), y)
y[1] = first(x)
map!(first, @view(y[2:end]), hobbits)
else
map!(first, y, hobbits)
end
axe = axes(x)
project = ProjectTo(x)
function decumulate(dy)
dy_plain = _no_tuple_tangent(unthunk(dy))
rev_list = if init === _InitialValue()
rev_list = if init === nothing
# Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...))
# gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
_zip2(_reverse1(hobbits), _reverse1(dy_plain))
Expand All @@ -546,11 +547,14 @@ function rrule(
end
dop = sum(first, trio)
dx = map(last, _reverse1(trio))
if init == _INIT
if init == nothing
# `hobbits` is one short, and the first one is weird
dx = _vcat1(trio[end][2] + dy_plain[1], dx)
end
return (NoTangent(), dop, project(_reshape1(dx, axe)))
dy = @not_implemented "no gradient for `B` in `accumulate!(f, B, A)`, the rule intends to support `accumulate` only"
d_init_not = @not_implemented "gradient for accumulate does not at present include init, sorry"
d_init = init === nothing ? NoTangent() : Tangent{typeof(init)}(; value = d_init_not)
return (NoTangent(), dop, dy, project(_reshape1(dx, axe)), NoTangent(), d_init)
end
return _reshape1(y, axe), decumulate
end
48 changes: 37 additions & 11 deletions test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
# `foldl(op, itr; init)` goes to `mapfoldr_impl(identity, op, init, itr)`. The rule is
# now attached there, as this is the simplest way to handle `init` keyword.
@eval using Base: mapfoldl_impl
@eval _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()
_INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()

# Simple
y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, [1, 2, 3])
Expand Down Expand Up @@ -337,36 +337,45 @@ end
end # cumprod

@testset "accumulate(f, ::Array)" begin
# `accumulate(f, A; init)` goes to `_accumulate!(op, B, A, dims::Nothing, init::Nothing)`.
# The rule is now attached there, as this is the simplest way to handle `init` keyword.
@eval using Base: _accumulate!

# Simple
y1, b1 = rrule(CFG, accumulate, *, [1, 2, 3, 4]; init=1)
y1, b1 = rrule(CFG, _accumulate!, *, [0, 0, 0, 0], [1, 2, 3, 4], nothing, Some(1))
@test y1 == [1, 2, 6, 24]
@test b1([1, 1, 1, 1]) == (NoTangent(), NoTangent(), [33, 16, 10, 6])
@test b1([1, 1, 1, 1])[3] isa ChainRulesCore.NotImplemented
@test b1([1, 1, 1, 1])[4] == [33, 16, 10, 6]
@test b1([1, 1, 1, 1])[6] isa Tangent{Some{Int64}}
@test b1([1, 1, 1, 1])[6].value isa ChainRulesCore.NotImplemented

y2, b2 = rrule(CFG, accumulate, /, [1 2; 3 4])
@test y2 accumulate(/, [1 2; 3 4])
@test b2(ones(2, 2))[3] [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6

# Test execution order
c3 = Counter()
y3, b3 = rrule(CFG, accumulate, c3, [5, 7, 11]; init=3)
y3, b3 = rrule(CFG, _accumulate!, c3, [0, 0, 0], [5, 7, 11], nothing, Some(3))
@test c3 == Counter(3)
@test y3 == [8, 30, 123] == accumulate(Counter(), [5, 7, 11]; init=3)
@test b3([1, 1, 1]) == (NoTangent(), NoTangent(), [29169, 602, 23]) # the 23 is clear!
@test b3([1, 1, 1])[4] == [29169, 602, 23] # the 23 is clear!

c4 = Counter()
y4, b4 = rrule(CFG, accumulate, c4, [5, 7, 11])
y4, b4 = rrule(CFG, _accumulate!, c4, [0, 0, 0], [5, 7, 11], nothing, nothing)
@test c4 == Counter(2)
@test y4 == [5, (5+7)*1, ((5+7)*1 + 11)*2] == accumulate(Counter(), [5, 7, 11])
@test b4([1, 1, 1]) == (NoTangent(), NoTangent(), [417, 42*(1 + 12), 22])
@test b4([1, 1, 1])[4] == [417, 42*(1 + 12), 22]

# Test gradient of function
y7, b7 = rrule(CFG, accumulate, Multiplier(3), [5, 7, 11])
y7, b7 = rrule(CFG, _accumulate!, Multiplier(3), [0, 0, 0], [5, 7, 11], nothing, nothing)
@test y7 == accumulate((x,y)->x*y*3, [5, 7, 11])
@test b7([1, 1, 1]) == (NoTangent(), Tangent{Multiplier{Int}}(x = 2345,), [715, 510, 315])
@test b7([1, 1, 1])[2] == Tangent{Multiplier{Int}}(; x = 2345,)
@test b7([1, 1, 1])[4] == [715, 510, 315]

y8, b8 = rrule(CFG, accumulate, Multiplier(13), [5, 7, 11], init=3)
y8, b8 = rrule(CFG, _accumulate!, Multiplier(13), [0, 0, 0], [5, 7, 11], nothing, Some(3))
@test y8 == [195, 17745, 2537535] == accumulate((x,y)->x*y*13, [5, 7, 11], init=3)
@test b8([1, 1, 1]) == (NoTangent(), Tangent{Multiplier{Int}}(x = 588330,), [511095, 365040, 230685])
@test b8([1, 1, 1])[2] == Tangent{Multiplier{Int}}(; x = 588330,)
@test b8([1, 1, 1])[4] == [511095, 365040, 230685]
# To find these numbers:
# ForwardDiff.derivative(z -> sum(accumulate((x,y)->x*y*z, [5,7,11], init=3)), 13)
# ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string
Expand All @@ -385,5 +394,22 @@ end
# Finite differencing
test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand()))
test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false)

test_rrule(_accumulate!, *, randn(5) NoTangent(), randn(5), nothing, nothing)
test_rrule(_accumulate!, /, randn(5) NoTangent(), randn(5), nothing, Some(1 + rand()))
# if VERSION >= v"1.5"
# test_rrule(accumulate, /, 1 .+ rand(3, 4))
# test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand()))
# end
end
# VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin
# # Simple
# y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1)
# @test y1 == (1, 2, 6, 24)
# @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6))

# # Finite differencing
# test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand()))
# test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false)
# end
end

0 comments on commit 87b4ea4

Please sign in to comment.