diff --git a/Project.toml b/Project.toml index 8d7f6fe85..bb1dfe177 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.38.0" +version = "1.39.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index ac712e965..6bea6e06c 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -200,3 +200,44 @@ function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{3}) cube_pullback(dy) = (NoTangent(), NoTangent(), ProjectTo(x)(3 * x2 * dy), NoTangent()) return x2 * x, cube_pullback end + +##### +##### `map` +##### + +# Ideally reverse mode should always iterate in reverse order. For `map` and broadcasting +# this may matter with a stateful `f`, but in general their order isn't guaranteed anyway, +# so it's unclear how much effort should be spent on that. But `map` on Tuples normally +# gets unrolled, so perhaps it does guarantee order, and reversing it should be cheap too. + +function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tuple...) where {F} + length_y = minimum(length, xs) + hobbits = ntuple(length_y) do i + args = getindex.(xs, i) + rrule_via_ad(config, f, args...) + end + y = map(first, hobbits) + num_xs = Val(length(xs)) + paddings = map(x -> ntuple(Returns(NoTangent()), (length(x) - length_y)), xs) + all(isempty, paddings) || @error """map(f, xs::Tuple...) does not allow mistmatched lengths! + But its `rrule` does; when JuliaLang/julia #42216 is fixed this warning should be removed.""" + function map_pullback(dy_raw) + dy = unthunk(dy_raw) + # We want to call the pullbacks in `rrule_via_ad` in reverse sequence to the forward pass: + backevals = ntuple(length_y) do i + rev_i = length_y - i + 1 + last(hobbits[rev_i])(dy[rev_i]) + end |> reverse + # This df doesn't infer, could test Base.issingletontype(F), but it's not the only inference problem. + df = ProjectTo(f)(sum(first, backevals)) + # Now unzip that. Because `map` like `zip` should when any `x` stops, some `dx`s may need padding. + # Although in fact, `map(+, (1,2), (3,4,5))` is an error... https://github.com/JuliaLang/julia/issues/42216 + dxs = ntuple(num_xs) do k + dx_short = map(bv -> bv[k+1], backevals) + ProjectTo(xs[k])((dx_short..., paddings[k]...)) # ProjectTo makes the Tangent for us + end + return (NoTangent(), df, dxs...) + end + map_back(dy::AbstractZero) = (NoTangent(), NoTangent(), ntuple(Returns(NoTangent()), num_xs)...) + return y, map_pullback +end diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 1c842feab..a83f72cc7 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -62,6 +62,17 @@ end ##### `sum(f, x)` ##### +function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f::F, xs::Tuple) where {F} + fxs, unmap = rrule(config, map, f, xs) + y, unsum = rrule(config, sum, fxs) + function sum_pullback_f(dy) + _, dfxs = unsum(dy) + _, df, dxs = unmap(dfxs) + (NoTangent(), df, dxs) + end + y, sum_pullback_f +end + function rrule( config::RuleConfig{>:HasReverseMode}, ::typeof(sum), diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 77dff1827..a65881747 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -216,4 +216,16 @@ @test frule(NoRules, 1.0) === nothing @test rrule(NoRules, 1.0) === nothing end + + @testset "map(f, ::Tuple...)" begin + test_rrule(map, identity, (1.0, 2.0), check_inferred=false) + test_rrule(map, +, (1.0, 2.0), (3.0, 4.0), check_inferred=false) + test_rrule(map, make_two_vec, (4.0, 5.0 + 6im), check_inferred=false) + test_rrule(map, Multiplier(rand() + im), Tuple(rand(3)), check_inferred=false) + + if try map(+, (1,), (2,3)); true catch e; false end + # True when https://github.com/JuliaLang/julia/issues/42216 has been fixed + test_rrule(map, Multiplier(4.5), (6.7, 8.9), (0.1, 0.2, 0.3), check_inferred=false) + end + end end diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index fdbda0a4a..23c4d6da5 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -67,6 +67,10 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig() end end # sum abs2 + @testset "sum(f, xs::Tuple)" begin + test_rrule(sum, sqrt, Tuple(rand(3)), check_inferred=false) + end + @testset "sum(f, xs)" begin # This calls back into AD test_rrule(sum, abs, [-4.0, 2.0, 2.0])