Skip to content

Commit

Permalink
update implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Miha Zgubic committed Mar 23, 2021
1 parent 4c7a68c commit 0e63d38
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ accum(x::Tuple, y::Tuple) = accum.(x, y)
accum(x::AbstractArray, y::AbstractArray) = accum.(x, y)

@generated function accum(x::NamedTuple, y::NamedTuple)
# Zygote assumes that the NamedTuples will have the same keys
fieldnames(x) === fieldnames(y) || throw(ArgumentError("$x and $y keys must be the same"))
grad(x) = x in fieldnames(y) ? :(y.$x) : :nothing
# assumes that y has no keys apart from those also in x
fieldnames(y) fieldnames(x) || throw(ArgumentError("$y keys must be a subset of $x keys"))

grad(field) = field in fieldnames(y) ? :(y.$field) : :nothing
Expr(:tuple, [:($f=accum(x.$f, $(grad(f)))) for f in fieldnames(x)]...)
end

Expand Down

0 comments on commit 0e63d38

Please sign in to comment.