From 1b8e121072352d649b7f6c4aaa80bbefb75f721d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 29 Aug 2022 17:45:26 -0400 Subject: [PATCH] fix https://github.com/JuliaDiff/ChainRules.jl/issues/672 --- src/rulesets/Base/mapreduce.jl | 15 ++++++++++++++- test/rulesets/Base/mapreduce.jl | 4 ++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index fa8c1c576..83ed130c4 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -420,11 +420,13 @@ end ##### `mapfoldl(f, g, ::Tuple)` ##### +using Base: mapfoldl_impl + # For tuples there should be no harm in handling `map` first. # This will also catch `mapreduce`. function rrule( - cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), f::F, op::G, init, x::Tuple; + cfg::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), f::F, op::G, init, x::Tuple; ) where {F,G} y, backmap = rrule(cfg, map, f, x) z, backred = rrule(cfg, Base.mapfoldl_impl, identity, op, init, y) @@ -436,6 +438,11 @@ function rrule( return z, mapfoldl_pullback_tuple end +function rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), f, op, init, x::Tuple{}) + foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent()) + return init, foldl_pullback_empty +end + ##### ##### `foldl(f, ::Tuple)` ##### @@ -495,6 +502,12 @@ function rrule( return y, foldl_pullback_tuple_init end +# Base.tail doesn't work on (), trivial case: +function rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), ::typeof(identity), op, init, x::Tuple{}) + foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent()) + return init, foldl_pullback_empty +end + ##### ##### `foldl(f, ::Array)` ##### diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 80bc58bc6..bd3031ed5 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -303,6 +303,10 @@ const _INIT = Base._InitialValue() # Finite differencing test_rrule(mapfoldl_impl, identity, /, _INIT, Tuple(1 .+ rand(5))) test_rrule(mapfoldl_impl, identity, *, 1+rand(), Tuple(rand(ComplexF64, 5))) + + # Trivial case + test_rrule(mapfoldl_impl, identity, /, 2pi, ()) + test_rrule(mapfoldl_impl, sqrt, /, 2pi, ()) end @testset "mapfoldl(f, g, ::Tuple)" begin test_rrule(mapfoldl_impl, cbrt, /, _INIT, Tuple(1 .+ rand(5)), check_inferred=false)