diff --git a/src/lib/lib.jl b/src/lib/lib.jl index b46341d39..dc389e316 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -16,9 +16,10 @@ accum(x::Tuple, y::Tuple) = accum.(x, y) accum(x::AbstractArray, y::AbstractArray) = accum.(x, y) @generated function accum(x::NamedTuple, y::NamedTuple) - # Zygote assumes that the NamedTuples will have the same keys - fieldnames(x) === fieldnames(y) || throw(ArgumentError("$x and $y keys must be the same")) - grad(x) = x in fieldnames(y) ? :(y.$x) : :nothing + # assumes that y has no keys apart from those also in x + fieldnames(y) ⊆ fieldnames(x) || throw(ArgumentError("$y keys must be a subset of $x keys")) + + grad(field) = field in fieldnames(y) ? :(y.$field) : :nothing Expr(:tuple, [:($f=accum(x.$f, $(grad(f)))) for f in fieldnames(x)]...) end