From d2526d2d593accb7f49311530ce2ca1aa2070696 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 8 Jun 2022 18:36:45 +0100 Subject: [PATCH] corner case for repeat --- src/rulesets/Base/array.jl | 7 ++++--- test/rulesets/Base/array.jl | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 69d4559c4..7a0269ad4 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -166,18 +166,19 @@ function frule((_, ẋs), ::typeof(repeat), xs::AbstractArray, cnt...; kw...) return repeat(xs, cnt...; kw...), repeat(ẋs, cnt...; kw...) end -function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(Returns(1), ndims(xs)), outer=ntuple(Returns(1), ndims(xs))) +function rrule(::typeof(repeat), xs::AbstractArray; inner=nothing, outer=nothing) project_Xs = ProjectTo(xs) S = size(xs) + inner_size = inner === nothing ? ntuple(Returns(1), ndims(xs)) : inner function repeat_pullback(ȳ) dY = unthunk(ȳ) Δ′ = zero(xs) # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′ for (dest_idx, val) in pairs(IndexCartesian(), dY) - # First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then + # First, round dest_idx[dim] to nearest gridpoint defined by inner_dims[dim], then # wrap around based on original size S. - src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)] + src_idx = [mod1(div(dest_idx[dim] - 1, inner_size[dim]) + 1, S[dim]) for dim in 1:length(S)] Δ′[src_idx...] += val end x̄ = project_Xs(Δ′) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 173184289..b96450945 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -128,6 +128,7 @@ end test_rrule(repeat, rand(4, 5)) test_rrule(repeat, rand(4, 5); fkwargs = (outer=(1,2),)) test_rrule(repeat, rand(4, 5); fkwargs = (inner=(1,2), outer=(1,3))) + test_rrule(repeat, rand(4, 5); fkwargs = (outer=2,)) test_rrule(repeat, rand(4, ), 2) test_rrule(repeat, rand(4, 5), 2)