Skip to content

Commit

Permalink
corner case for repeat
Browse files Browse the repository at this point in the history
  • Loading branch information
Miha Zgubic committed Jun 8, 2022
1 parent e4029df commit d2526d2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,19 @@ function frule((_, ẋs), ::typeof(repeat), xs::AbstractArray, cnt...; kw...)
return repeat(xs, cnt...; kw...), repeat(ẋs, cnt...; kw...)
end

function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(Returns(1), ndims(xs)), outer=ntuple(Returns(1), ndims(xs)))
function rrule(::typeof(repeat), xs::AbstractArray; inner=nothing, outer=nothing)

project_Xs = ProjectTo(xs)
S = size(xs)
inner_size = inner === nothing ? ntuple(Returns(1), ndims(xs)) : inner
function repeat_pullback(ȳ)
dY = unthunk(ȳ)
Δ′ = zero(xs)
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
for (dest_idx, val) in pairs(IndexCartesian(), dY)
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
# First, round dest_idx[dim] to nearest gridpoint defined by inner_dims[dim], then
# wrap around based on original size S.
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
src_idx = [mod1(div(dest_idx[dim] - 1, inner_size[dim]) + 1, S[dim]) for dim in 1:length(S)]
Δ′[src_idx...] += val
end
= project_Xs(Δ′)
Expand Down
1 change: 1 addition & 0 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ end
test_rrule(repeat, rand(4, 5))
test_rrule(repeat, rand(4, 5); fkwargs = (outer=(1,2),))
test_rrule(repeat, rand(4, 5); fkwargs = (inner=(1,2), outer=(1,3)))
test_rrule(repeat, rand(4, 5); fkwargs = (outer=2,))

test_rrule(repeat, rand(4, ), 2)
test_rrule(repeat, rand(4, 5), 2)
Expand Down

0 comments on commit d2526d2

Please sign in to comment.