From 85807a7fe0915e79d12d814076b675f7f8af496c Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 21 Sep 2023 16:04:13 +0800 Subject: [PATCH] add frules for getfield --- src/rulesets/Base/indexing.jl | 9 ++++++--- test/rulesets/Base/indexing.jl | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 1334cc925..37ed8ca48 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -1,6 +1,10 @@ # Int rather than Int64/Integer is intentional -function frule((_, ẋ), ::typeof(getfield), x::Tuple, i::Int) - return x.i, ẋ.i +function ChainRulesCore.frule((_, Δ, _), ::typeof(getfield), strct, sym::Union{Int,Symbol}) + return (getfield(strct, sym), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym)) +end + +function ChainRulesCore.frule((_, Δ, _, _), ::typeof(getfield), strct, sym::Union{Int,Symbol}, inbounds) + return (getfield(strct, sym, inbounds), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym)) end "for a given tuple type, returns a Val{N} where N is the length of the tuple" @@ -21,7 +25,6 @@ function rrule(::typeof(getindex), x::T, i::Integer) where {T<:NTuple{<:Any,<:Nu dx = ntuple(j -> j == i ? dy : zero(dy), _tuple_N(T)) return (NoTangent(), Tangent{T}(dx...), NoTangent()) end - return x[i], getindex_back_2 end # Note Zygote has getindex(::Tuple, ::UnitRange) separately from getindex(::Tuple, ::AbstractVector), diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index d3c7ecfb4..c21bb8425 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -1,3 +1,17 @@ +@testset "getfield" begin + struct Foo + x::Float64 + y::Float64 + end + test_frule(getfield, Foo(1.5, 2.5), :x, check_inferred=false) + + test_frule(getfield, (; a=1.5, b=2.5), :a, check_inferred=false) + test_frule(getfield, (; a=1.5, b=2.5), 2) + + test_frule(getfield, (1.5, 2.5), 2) + test_frule(getfield, (1.5, 2.5), 2, true) +end + @testset "getindex" begin @testset "getindex(::Tuple, ...)" begin x = (1.2, 3.4, 5.6)