Skip to content

Commit

Permalink
Merge #1358
Browse files Browse the repository at this point in the history
1358: Fix BPTT by overriding stateful broadcast adjoint r=DhairyaLGandhi a=DhairyaLGandhi

Fixes #1209 

In this PR, we replace the regular broadcasting adjoint with that of the `map` equivalent which is better tested in terms of stateful cases. We ultimately will revert back to the broadacasting adjoint via FluxML/Zygote.jl#807 but this specialises the case for recurrent layers 


@oxinabox @ToucheSir Comments?
### PR Checklist

- [x] Tests are added
- [ ] Entry in NEWS.md
- [x] Documentation, if applicable
- [ ] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: Dhairya Gandhi <[email protected]>
  • Loading branch information
bors[bot] and Dhairya Gandhi authored Oct 15, 2020
2 parents 9ed04bb + f6f9925 commit 98e7222
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,7 @@ See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))

@adjoint function Broadcast.broadcasted(f::Recur, args...)
Zygote.∇map(__context__, f, args...)
end
21 changes: 21 additions & 0 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Ref FluxML/Flux.jl#1209
@testset "BPTT" begin
seq = [rand(2) for i = 1:3]
for r [RNN,]
rnn = r(2,3)
Flux.reset!(rnn)
grads_seq = gradient(Flux.params(rnn)) do
sum(rnn.(seq)[3])
end
Flux.reset!(rnn);
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
tanh.(rnn.cell.Wi * seq[2] + Wh *
tanh.(rnn.cell.Wi * seq[1] +
Wh * rnn.init
+ rnn.cell.b)
+ rnn.cell.b)
+ rnn.cell.b)),
rnn.cell.Wh)
@test grads_seq[rnn.cell.Wh] bptt[1]
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ end
include("layers/basic.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
include("layers/recurrent.jl")
include("layers/conv.jl")
end

Expand Down

0 comments on commit 98e7222

Please sign in to comment.