Skip to content

Commit

Permalink
Rules for getindex(::Tuple) and sum(::Tuple) (#643)
Browse files Browse the repository at this point in the history
* getindex for tuples

* sum for tuples

* repeated indices in getindex, etc

* add colon case

* tidy, use Tanget

* first, tail

* simplify

* Apply 2 suggestions

Co-authored-by: Frames Catherine White <[email protected]>

* skip a test until JuliaDiff/ChainRulesTestUtils.jl#253, bump version

* comment on Zygote

Co-authored-by: Frames Catherine White <[email protected]>
  • Loading branch information
mcabbott and oxinabox authored Jul 14, 2022
1 parent dadb205 commit 8073c7c
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.37.0"
version = "1.38.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
78 changes: 78 additions & 0 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,58 @@
#####
##### getindex(::Tuple)
#####

function frule((_, ẋ), ::typeof(getindex), x::Tuple, i::Integer)
return x[i], ẋ[i]
end

function frule((_, ẋ), ::typeof(getindex), x::Tuple, i)
y = x[i]
return y, Tangent{typeof(y)}(ẋ[i]...)
end

"for a given typle type, returns a Val{N} where N is the length of the tuple"
_tuple_N(::Type{<:Tuple{Vararg{<:Any, N}}}) where {N} = Val(N)

function rrule(::typeof(getindex), x::T, i::Integer) where {T<:Tuple}
function getindex_back_1(dy)
dx = ntuple(j -> j == i ? dy : NoTangent(), _tuple_N(T))
return (NoTangent(), Tangent{T}(dx...), NoTangent())
end
return x[i], getindex_back_1
end

# Special case for tuples of only numbers
function rrule(::typeof(getindex), x::T, i::Integer) where {T<:NTuple{<:Any,<:Number}}
function getindex_back_2(dy_raw)
dy = unthunk(dy_raw)
dx = ntuple(j -> j == i ? dy : zero(dy), _tuple_N(T))
return (NoTangent(), Tangent{T}(dx...), NoTangent())
end
return x[i], getindex_back_2
end

# Note Zygote has getindex(::Tuple, ::UnitRange) separately from getindex(::Tuple, ::AbstractVector),
# whether that's more efficient has not been investigated here.
# https://github.com/FluxML/Zygote.jl/blob/master/src/lib/lib.jl#L125-L142
function rrule(::typeof(getindex), x::T, inds) where {T<:Tuple} # e.g. ranges, not type-stable
function getindex_back_3(dy_raw)
dy = unthunk(dy_raw)
dx = ntuple(Returns(NoTangent()), _tuple_N(T))
for (dyi, i) in zip(dy, inds)
dx = Base.setindex(dx, dyi + dx[i], i)
end
return (NoTangent(), Tangent{T}(dx...), NoTangent())
end
return x[inds], getindex_back_3
end

function rrule(::typeof(getindex), x::Tuple, ::Colon)
getindex_back_4(dy) = (NoTangent(), dy, NoTangent())
return x, getindex_back_4
end


#####
##### getindex
#####
Expand Down Expand Up @@ -31,6 +86,29 @@ function rrule(::typeof(getindex), x::Array{<:Number}, inds...)
return y, getindex_pullback
end

#####
##### first, tail
#####

function frule((_, ẋ), ::typeof(first), x::Tuple)
return first(x), first(ẋ)
end

function rrule(::typeof(first), x::T) where {T<:Tuple}
first_back(dy) = (NoTangent(), Tangent{T}(ntuple(j -> j == 1 ? dy : NoTangent(), _tuple_N(T))...))
return first(x), first_back
end

function frule((_, ẋ), ::typeof(Base.tail), x::Tuple)
y = Base.tail(x)
return y, Tangent{typeof(y)}(Base.tail(Tuple(ẋ))...)
end

function rrule(::typeof(Base.tail), x::T) where {T<:Tuple}
tail_pullback(dy) = (NoTangent(), Tangent{T}(NoTangent(), dy...))
return Base.tail(x), tail_pullback
end

#####
##### view
#####
Expand Down
11 changes: 11 additions & 0 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ function frule((_, ẏ, ẋ), ::typeof(sum!), y::AbstractArray, x::AbstractArray
return sum!(y, x), sum!(ẏ, ẋ)
end

function rrule(::typeof(sum), x::Tuple)
project = ProjectTo(x)
len = Val(length(x))
function sum_pullback(dy_raw)
dy = unthunk(dy_raw)
dx = dy isa AbstractZero ? dy : ntuple(Returns(dy), len)
return (NoTangent(), project(dx))
end
return sum(x), sum_pullback
end

function rrule(::typeof(sum), x::AbstractArray; dims=:)
project = ProjectTo(x)
y = sum(x; dims=dims)
Expand Down
38 changes: 38 additions & 0 deletions test/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,25 @@
@testset "getindex" begin
@testset "getindex(::Tuple, ...)" begin
x = (1.2, 3.4, 5.6)
x2 = (rand(2), (a=1.0, b=x))

# Forward
test_frule(getindex, x, 2)
test_frule(getindex, x2, 1)
test_frule(getindex, x, 1:2)
test_frule(getindex, x2, :)

# Reverse
test_rrule(getindex, x, 2)
@test_skip test_rrule(getindex, x2, 1, check_inferred=false) # method ambiguity, maybe fixed by https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/253

test_rrule(getindex, x, 2:3; check_inferred=false)
test_rrule(getindex, x, [1, 1, 2], check_inferred=false)
test_rrule(getindex, x2, 1:2, check_inferred=false)

test_rrule(getindex, x, :)
end

@testset "getindex(::Matrix{<:Number}, ...)" begin
x = [1.0 2.0 3.0; 10.0 20.0 30.0]

Expand Down Expand Up @@ -58,6 +79,23 @@
end
end

@testset "first & tail" begin
x = (1.2, 3.4, 5.6)
x2 = (rand(2), (a=1.0, b=x))

test_frule(first, x)
test_frule(first, x2)

test_rrule(first, x)
# test_rrule(first, x2) # MethodError: (::ChainRulesTestUtils.var"#test_approx##kw")(::NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}, ::typeof(test_approx), ::NoTangent, ::Tangent{NamedTuple{(:a, :b), Tuple{Float64, Tuple{Float64, Float64, Float64}}}, NamedTuple{(:a, :b), Tuple{Float64, Tangent{Tuple{Float64, Float64, Float64}, Tuple{Float64, Float64, Float64}}}}}, ::String) is ambiguous

test_frule(Base.tail, x, check_inferred=false) # return type Tuple{Tuple{Float64, Float64}, Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}} does not match inferred return type Tuple{Tuple{Float64, Float64}, Tangent{Tuple{Float64, Float64}}}
test_frule(Base.tail, x2, check_inferred=false)

test_rrule(Base.tail, x)
test_rrule(Base.tail, x2)
end

@testset "view" begin
test_frule(view, rand(3, 4), :, 1)
test_frule(view, rand(3, 4), 2, [1, 1, 2])
Expand Down
5 changes: 5 additions & 0 deletions test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
@testset "Reductions" begin
@testset "sum(::Tuple)" begin
test_frule(sum, Tuple(rand(5)))
test_frule(sum, (rand(2), rand(2)))

test_rrule(sum, Tuple(rand(5)))
test_rrule(sum, (1.2, 3.4 + 5im))
test_rrule(sum, (rand(2)', rand(1,2)))
end
@testset "sum(x; dims=$dims)" for dims in (:, 2, (1,3))
# Forward
Expand Down

2 comments on commit 8073c7c

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/64270

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.38.0 -m "<description of version>" 8073c7c4638bdd46f4e822d2ab72423c051c5e4b
git push origin v1.38.0

Please sign in to comment.