Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
oxinabox and github-actions[bot] authored May 22, 2024
1 parent 44f1a5e commit 49d5ae7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
9 changes: 5 additions & 4 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,12 @@ end
#### merge
####
# need to work around inability to return closures from generated functions
struct MergePullback{T1, T2}
end
struct MergePullback{T1,T2} end
(this::MergePullback)(dy::AbstractThunk) = this(unthunk(dy))
(::MergePullback)(x::AbstractZero) = (NoTangent(), x, x)
@generated function(::MergePullback{T1,T2})(dy::Tangent) where {F1,T1<:NamedTuple{F1},F2,T2<:NamedTuple{F2}}
@generated function (::MergePullback{T1,T2})(
dy::Tangent
) where {F1,T1<:NamedTuple{F1},F2,T2<:NamedTuple{F2}}
_getproperty_kwexpr(key) = :($key = getproperty(dy, $(Meta.quot(key))))
quote
dnt1 = Tangent{T1}(; $(map(_getproperty_kwexpr, setdiff(F1, F2))...))
Expand All @@ -309,7 +310,7 @@ end
end
end

function rrule(::typeof(merge), nt1::T1, nt2::T2) where {T1<:NamedTuple, T2<:NamedTuple}
function rrule(::typeof(merge), nt1::T1, nt2::T2) where {T1<:NamedTuple,T2<:NamedTuple}
y = merge(nt1, nt2)
return y, MergePullback{T1,T2}()
end
4 changes: 2 additions & 2 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ end
end

@testset "merge NamedTuple" begin
test_rrule(merge, (;a=1.0), (;b=2.0))
test_rrule(merge, (;a=1.0), (;a=2.0))
test_rrule(merge, (; a=1.0), (; b=2.0))
test_rrule(merge, (; a=1.0), (; a=2.0))
end
end

0 comments on commit 49d5ae7

Please sign in to comment.