Skip to content

Commit

Permalink
minimal change foldl -> mapfoldl_impl
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jul 18, 2022
1 parent 7faaf5d commit e4b20da
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 31 deletions.
23 changes: 13 additions & 10 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,12 @@ end
# to carry intermediate results along creates arrays of tuples which could be avoided; using a
# loop can be a few times faster. Note also that it does not return a gradient for `init`.

# Maybe that's a problem. Let's move the rule to `mapfoldr_impl(f, op, init, itr)`, where it's easier?

function rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(foldl), op::G, x::Union{AbstractArray, Tuple};
init=_InitialValue()
config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple};
) where {G}
list, start = if init === _InitialValue()
list, start = if init === _INIT
_drop1(x), first(x)
else
# Case with init keyword is simpler to understand first!
Expand All @@ -455,11 +456,12 @@ function rrule(
end
dop = sum(first, trio)
dx = map(last, _reverse1(trio))
if init === _InitialValue()
if init === _INIT
# `hobbits` is one short
dx = _vcat1(trio[end][2], dx)
end
return (NoTangent(), dop, project(_reshape1(dx, axe)))
d_init = @not_implemented "gradient for foldl does not at present include init, sorry"
return (NoTangent(), NoTangent(), dop, d_init, project(_reshape1(dx, axe)))
end
return y, unfoldl
end
Expand All @@ -484,7 +486,8 @@ _reverse1(x::Tuple) = reverse(x)
_drop1(x::Tuple) = Base.tail(x)
_zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N)

struct _InitialValue end # Old versions don't have `Base._InitialValue`
# struct _InitialValue end # Old versions don't have `Base._InitialValue`
const _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()

_vcat1(x, ys::AbstractVector) = vcat(x, ys)
_vcat1(x::AbstractArray, ys::AbstractVector) = vcat([x], ys)
Expand All @@ -505,15 +508,15 @@ _no_tuple_tangent(dx) = dx

function rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(accumulate), op::G, x::Union{AbstractArray, Tuple};
init=_InitialValue(), dims=nothing
init=_INIT, dims=nothing
) where {G}
isnothing(dims) || dims == 1 && x isa Base.AbstractVecOrTuple || throw(
"accumulate(op, x; dims) is not currently supported by ChainRules, sorry"
# It's not supported by AD either, so no point calling back, and no regression:
# gradient(x -> sum(accumulate(/, x, dims=1)), rand(3,4))
# ERROR: Mutating arrays is not supported
)
list, start = if init === _InitialValue()
list, start = if init === _INIT
_drop1(x), first(x)
else
x, init
Expand All @@ -522,7 +525,7 @@ function rrule(
c, back = rrule_via_ad(config, op, a, b)
end
y = map(first, hobbits)
if init === _InitialValue()
if init === _INIT
# `hobbits` is one short, and first one doesn't invoke `op`
y = _vcat1(first(x), y)
end
Expand All @@ -543,7 +546,7 @@ function rrule(
end
dop = sum(first, trio)
dx = map(last, _reverse1(trio))
if init == _InitialValue()
if init == _INIT
# `hobbits` is one short, and the first one is weird
dx = _vcat1(trio[end][2] + dy_plain[1], dx)
end
Expand Down
54 changes: 33 additions & 21 deletions test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,60 +214,72 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
end # prod

@testset "foldl(f, ::Array)" begin
# `foldl(op, itr; init)` goes to `mapfoldr_impl(identity, op, init, itr)`. The rule is
# now attached there, as this is the simplest way to handle `init` keyword.
@eval using Base: mapfoldl_impl
@eval _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()

# Simple
y1, b1 = rrule(CFG, foldl, *, [1, 2, 3]; init=1)
y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, [1, 2, 3])
@test y1 == 6
b1(7) == (NoTangent(), NoTangent(), [42, 21, 14])
@test b1(7)[1:3] == (NoTangent(), NoTangent(), NoTangent())
@test b1(7)[4] isa ChainRulesCore.NotImplemented
@test b1(7)[5] == [42, 21, 14]

y2, b2 = rrule(CFG, foldl, *, [1 2; 0 4]) # without init, needs vcat
y2, b2 = rrule(CFG, mapfoldl_impl, identity, *, _INIT, [1 2; 0 4]) # without init, needs vcat
@test y2 == 0
b2(8) == (NoTangent(), NoTangent(), [0 0; 64 0]) # matrix, needs reshape
@test b2(8)[5] == [0 0; 64 0] # matrix, needs reshape

# Test execution order
c5 = Counter()
y5, b5 = rrule(CFG, foldl, c5, [5, 7, 11])
y5, b5 = rrule(CFG, mapfoldl_impl, identity, c5, _INIT, [5, 7, 11])
@test c5 == Counter(2)
@test y5 == ((5 + 7)*1 + 11)*2 == foldl(Counter(), [5, 7, 11])
@test b5(1) == (NoTangent(), NoTangent(), [12*32, 12*42, 22])
@test b5(1)[5] == [12*32, 12*42, 22]
@test c5 == Counter(42)

c6 = Counter()
y6, b6 = rrule(CFG, foldl, c6, [5, 7, 11], init=3)
y6, b6 = rrule(CFG, mapfoldl_impl, identity, c6, 3, [5, 7, 11])
@test c6 == Counter(3)
@test y6 == (((3 + 5)*1 + 7)*2 + 11)*3 == foldl(Counter(), [5, 7, 11], init=3)
@test b6(1) == (NoTangent(), NoTangent(), [63*33*13, 43*13, 23])
@test b6(1)[5] == [63*33*13, 43*13, 23]
@test c6 == Counter(63)

# Test gradient of function
y7, b7 = rrule(CFG, foldl, Multiplier(3), [5, 7, 11])
y7, b7 = rrule(CFG, mapfoldl_impl, identity, Multiplier(3), _INIT, [5, 7, 11])
@test y7 == foldl((x,y)->x*y*3, [5, 7, 11])
@test b7(1) == (NoTangent(), Tangent{Multiplier{Int}}(x = 2310,), [693, 495, 315])
b7_1 = b7(1)
@test b7_1[3] == Tangent{Multiplier{Int}}(x = 2310,)
@test b7_1[5] == [693, 495, 315]

y8, b8 = rrule(CFG, foldl, Multiplier(13), [5, 7, 11], init=3)
y8, b8 = rrule(CFG, mapfoldl_impl, identity, Multiplier(13), 3, [5, 7, 11])
@test y8 == 2_537_535 == foldl((x,y)->x*y*13, [5, 7, 11], init=3)
@test b8(1) == (NoTangent(), Tangent{Multiplier{Int}}(x = 585585,), [507507, 362505, 230685])
b8_1 = b8(1)
@test b8_1[3] == Tangent{Multiplier{Int}}(x = 585585,)
@test b8_1[5] == [507507, 362505, 230685]
# To find these numbers:
# ForwardDiff.derivative(z -> foldl((x,y)->x*y*z, [5,7,11], init=3), 13)
# ForwardDiff.gradient(z -> foldl((x,y)->x*y*13, z, init=3), [5,7,11]) |> string

# Finite differencing
test_rrule(foldl, /, 1 .+ rand(3,4))
test_rrule(foldl, *, rand(ComplexF64,3,4); fkwargs=(; init=rand(ComplexF64)))
test_rrule(foldl, +, rand(ComplexF64,7); fkwargs=(; init=rand(ComplexF64)))
test_rrule(foldl, max, rand(3); fkwargs=(; init=999))
test_rrule(mapfoldl_impl, identity, /, _INIT, 1 .+ rand(3,4))
test_rrule(mapfoldl_impl, identity, *, rand(ComplexF64), rand(ComplexF64,3,4))
test_rrule(mapfoldl_impl, identity, +, rand(ComplexF64), rand(ComplexF64,7))
test_rrule(mapfoldl_impl, identity, max, 999, rand(3))
end
@testset "foldl(f, ::Tuple)" begin
y1, b1 = rrule(CFG, foldl, *, (1,2,3); init=1)
y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, (1,2,3))
@test y1 == 6
b1(7) == (NoTangent(), NoTangent(), Tangent{NTuple{3,Int}}(42, 21, 14))
@test b1(7)[5] == Tangent{NTuple{3,Int}}(42, 21, 14)

y2, b2 = rrule(CFG, foldl, *, (1, 2, 0, 4))
y2, b2 = rrule(CFG, mapfoldl_impl, identity, *, _INIT, (1, 2, 0, 4))
@test y2 == 0
b2(8) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(0, 0, 64, 0))
@test b2(8)[5] == Tangent{NTuple{4,Int}}(0, 0, 64, 0)

# Finite differencing
test_rrule(foldl, /, Tuple(1 .+ rand(5)))
test_rrule(foldl, *, Tuple(rand(ComplexF64, 5)))
test_rrule(mapfoldl_impl, identity, /, _INIT, Tuple(1 .+ rand(5)))
test_rrule(mapfoldl_impl, identity, *, _INIT, Tuple(rand(ComplexF64, 5)))
end
end

Expand Down

0 comments on commit e4b20da

Please sign in to comment.