From a05ebb28cd6379b8221111823835f8172861e248 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 20 May 2024 23:46:27 +0800 Subject: [PATCH] Add rrule for NamedTuple merge --- src/rulesets/Base/base.jl | 22 ++++++++++++++++++++++ test/rulesets/Base/base.jl | 5 +++++ 2 files changed, 27 insertions(+) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 6c66d19ee..951f62922 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -291,3 +291,25 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(task_local_storage end return y, task_local_storage_pullback end + + +#### +#### merge +#### + +function rrule(::typeof(merge), nt1::NamedTuple{F1}, nt2::NamedTuple{F2}) where {F1, F2} + y = merge(nt1, nt2) + function merge_pullback(dy) + dnt1 = Tangent{typeof(nt1)}(; + (f1 => (f1 in F2 ? ZeroTangent() : getproperty(dy, f1)) for f1 in F1)... + ) + dnt2 = Tangent{typeof(nt2)}(; + (f2 => getproperty(dy, f2) for f2 in F2)... + ) + return (NoTangent(), dnt1, dnt2) + end + merge_pullback(dy::AbstractThunk) = merge_pullback(unthunk(dy)) + merge_pullback(x::AbstractZero) = (NoTangent(), x, x) + + return y, merge_pullback +end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 25c755f55..ec87640b7 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -257,4 +257,9 @@ end test_rrule(map, Multiplier(4.5), (6.7, 8.9), (0.1, 0.2, 0.3), check_inferred=false) end end + + @testset "merge NamedTuple" begin + test_rrule(merge, (;a=1.0), (;b=2.0), check_inferred=false) + test_rrule(merge, (;a=1.0), (;a=2.0), check_inferred=false) + end end