From 3c2ca02d796c440a916cec881eb1043af163537f Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 15 Oct 2020 20:40:14 +0530 Subject: [PATCH 1/5] use map adjoint for recurrent layers --- src/layers/recurrent.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index a93c4a0aed..4c3cc0a612 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -184,3 +184,7 @@ See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) for a good overview of the internals. """ GRU(a...; ka...) = Recur(GRUCell(a...; ka...)) + +@adjoint function Broadcast.broadcasted(f::Recur, args...) + Zygote.∇map(__context__, f, args...) +end From dfa2d02b7de0a811278d76946bb6c40e59ae08d0 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 15 Oct 2020 20:47:50 +0530 Subject: [PATCH 2/5] fix 1209 --- test/layers/recurrent.jl | 20 ++++++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 21 insertions(+) create mode 100644 test/layers/recurrent.jl diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl new file mode 100644 index 0000000000..973fe83c77 --- /dev/null +++ b/test/layers/recurrent.jl @@ -0,0 +1,20 @@ +Ref FluxML/Flux.jl#1209 +@testset "BPTT" begin + seq = [rand(2) for i = 1:3] + for rnn ∈ [RNN, LSTM, GRU] + Flux.reset!(rnn) + grads_seq = gradient(Flux.params(rnn)) do + sum(rnn.(seq)[3]) + end + Flux.reset!(rnn); + bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh * + tanh.(rnn.cell.Wi * seq[2] + Wh * + tanh.(rnn.cell.Wi * seq[1] + + Wh * rnn.init + + rnn.cell.b) + + rnn.cell.b) + + rnn.cell.b)), + rnn.cell.Wh) + @test grads_seq[rnn.cell.Wh] ≈ bptt[1] + end +end diff --git a/test/runtests.jl b/test/runtests.jl index c5861cd25c..d6240fe8bb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -30,6 +30,7 @@ end include("layers/basic.jl") include("layers/normalisation.jl") include("layers/stateless.jl") + include("layers/recurrent.jl") include("layers/conv.jl") end From 4f98435cee80f35a4cf6dbbd7276fed194ef4501 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 15 Oct 2020 20:53:10 +0530 Subject: [PATCH 3/5] test only on RNN --- test/layers/recurrent.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 973fe83c77..96f4306605 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -1,7 +1,7 @@ Ref FluxML/Flux.jl#1209 @testset "BPTT" begin seq = [rand(2) for i = 1:3] - for rnn ∈ [RNN, LSTM, GRU] + for rnn ∈ [RNN,] Flux.reset!(rnn) grads_seq = gradient(Flux.params(rnn)) do sum(rnn.(seq)[3]) From 4d27f4b52709c04ab79520c5258f74b7dcbb200b Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 15 Oct 2020 20:54:44 +0530 Subject: [PATCH 4/5] typo --- test/layers/recurrent.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 96f4306605..173f8452ea 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -1,4 +1,4 @@ -Ref FluxML/Flux.jl#1209 +# Ref FluxML/Flux.jl#1209 @testset "BPTT" begin seq = [rand(2) for i = 1:3] for rnn ∈ [RNN,] From f6f9925beb83d9f05377ee0e96c2cdeb06987024 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 15 Oct 2020 21:17:05 +0530 Subject: [PATCH 5/5] add rnn in test - oops --- test/layers/recurrent.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 173f8452ea..2bb093fc96 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -1,7 +1,8 @@ # Ref FluxML/Flux.jl#1209 @testset "BPTT" begin seq = [rand(2) for i = 1:3] - for rnn ∈ [RNN,] + for r ∈ [RNN,] + rnn = r(2,3) Flux.reset!(rnn) grads_seq = gradient(Flux.params(rnn)) do sum(rnn.(seq)[3])