Skip to content

Commit

Permalink
add frules for getfield
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Sep 21, 2023
1 parent ba52ec8 commit 85807a7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Int rather than Int64/Integer is intentional
function frule((_, ẋ), ::typeof(getfield), x::Tuple, i::Int)
return x.i, ẋ.i
function ChainRulesCore.frule((_, Δ, _), ::typeof(getfield), strct, sym::Union{Int,Symbol})
return (getfield(strct, sym), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym))
end

function ChainRulesCore.frule((_, Δ, _, _), ::typeof(getfield), strct, sym::Union{Int,Symbol}, inbounds)
return (getfield(strct, sym, inbounds), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym))
end

"for a given tuple type, returns a Val{N} where N is the length of the tuple"
Expand All @@ -21,7 +25,6 @@ function rrule(::typeof(getindex), x::T, i::Integer) where {T<:NTuple{<:Any,<:Nu
dx = ntuple(j -> j == i ? dy : zero(dy), _tuple_N(T))
return (NoTangent(), Tangent{T}(dx...), NoTangent())
end
return x[i], getindex_back_2
end

# Note Zygote has getindex(::Tuple, ::UnitRange) separately from getindex(::Tuple, ::AbstractVector),
Expand Down
14 changes: 14 additions & 0 deletions test/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
@testset "getfield" begin
struct Foo
x::Float64
y::Float64
end
test_frule(getfield, Foo(1.5, 2.5), :x, check_inferred=false)

test_frule(getfield, (; a=1.5, b=2.5), :a, check_inferred=false)
test_frule(getfield, (; a=1.5, b=2.5), 2)

test_frule(getfield, (1.5, 2.5), 2)
test_frule(getfield, (1.5, 2.5), 2, true)
end

@testset "getindex" begin
@testset "getindex(::Tuple, ...)" begin
x = (1.2, 3.4, 5.6)
Expand Down

0 comments on commit 85807a7

Please sign in to comment.