Skip to content

Commit

Permalink
Merge #1367
Browse files Browse the repository at this point in the history
1367: RNN update to drop CUDNN, fix LSTM bug and output type stability r=CarloLucibello a=jeremiedb

PR related to #1114 #1360 #1365 

Some experiment for RNN handling. 

Hidden state of each cell structure was dropped as they weren't needed (AFAIK, only needed for size inference for CUDNN, but bias size could be used as a substitute to cells' `h` there as well). 

Looked to drop dependence on CUDNN entirely, so it's a pure Flux/CUDA.jl. File `src/cuda/curnnjl` no longer used. No  modifications were made to the cell computations. Initial test seems to show decent performance, but yet to benchmark. 

Pending issue: despite having dropped completely the CUDNN dependency, there's still an instability issue that seems present when running on GPU. This is illustrated in the test at lines 1-50 of file `test\rnn-test-jdb.jl`. If that test runs on CPU, it goes well thorugh the 100 iterations. However, the same on GPU will thow NAs after couple dozens of iterations. 
My only hypothesis so far: when performing the iteration over the sequence through `m.(x)` or `map(rnn, x)`, is the order of the execution safe? Ie: is it possible that there isn't a `sync()` on the CUDA side between those seq steps, which may mess up the state?

### PR Checklist

- [x] Tests are added
- [ ] Entry in NEWS.md
- [ ] Documentation, if applicable
- [ ] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: jeremiedb <[email protected]>
Co-authored-by: jeremie.db <[email protected]>
  • Loading branch information
3 people authored Nov 7, 2020
2 parents 8fb94be + b5c3b6f commit 0e0e2e7
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 132 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
docs/build/
docs/site/
deps
# Manifest.toml
.vscode
# Manifest.toml
7 changes: 5 additions & 2 deletions src/cuda/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
module CUDAint

using ..CUDA

using CUDA: CUDNN
include("curnn.jl")

import ..Flux: Flux
import Zygote
using Zygote: @adjoint

include("cudnn.jl")

end
89 changes: 0 additions & 89 deletions src/cuda/curnn.jl

This file was deleted.

55 changes: 25 additions & 30 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

gate(h, n) = (1:h) .+ h*(n-1)
gate(x::AbstractVector, h, n) = @view x[gate(h,n)]
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
Expand All @@ -24,21 +25,19 @@ rnn.(1:10) # apply to a sequence
rnn.state # 60
```
"""
mutable struct Recur{T}
mutable struct Recur{T,S}
cell::T
init
state
state::S
end

Recur(m, h = hidden(m)) = Recur(m, h, h)

function (m::Recur)(xs...)
h, y = m.cell(m.state, xs...)
m.state = h
return y
end

@functor Recur cell, init
@functor Recur
trainable(a::Recur) = (a.cell,)

Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")

Expand All @@ -52,34 +51,30 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to:
rnn.state = hidden(rnn.cell)
```
"""
reset!(m::Recur) = (m.state = m.init)
reset!(m::Recur) = (m.state = m.cell.state)
reset!(m) = foreach(reset!, functor(m)[1])

flip(f, xs) = reverse(f.(reverse(xs)))

# Vanilla RNN

mutable struct RNNCell{F,A,V}
struct RNNCell{F,A,V,S}
σ::F
Wi::A
Wh::A
b::V
h::V
state::S
end

RNNCell(in::Integer, out::Integer, σ = tanh;
init = glorot_uniform) =
RNNCell(σ, init(out, in), init(out, out),
init(out), zeros(out))
RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros, init_state=zeros) =
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))

function (m::RNNCell)(h, x)
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
h = σ.(Wi*x .+ Wh*h .+ b)
return h, h
end

hidden(m::RNNCell) = m.h

@functor RNNCell

function Base.show(io::IO, l::RNNCell)
Expand All @@ -94,22 +89,23 @@ end
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
output fed back into the input each time step.
"""
Recur(m::RNNCell) = Recur(m, m.state)
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))

# LSTM

mutable struct LSTMCell{A,V}
struct LSTMCell{A,V,S}
Wi::A
Wh::A
b::V
h::V
c::V
state::S
end

function LSTMCell(in::Integer, out::Integer;
init = glorot_uniform)
cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4),
zeros(out), zeros(out))
init = glorot_uniform,
initb = zeros,
init_state = zeros)
cell = LSTMCell(init(out * 4, in), init(out * 4, out), initb(out * 4), (init_state(out,1), init_state(out,1)))
cell.b[gate(out, 2)] .= 1
return cell
end
Expand All @@ -126,8 +122,6 @@ function (m::LSTMCell)((h, c), x)
return (h′, c), h′
end

hidden(m::LSTMCell) = (m.h, m.c)

@functor LSTMCell

Base.show(io::IO, l::LSTMCell) =
Expand All @@ -142,20 +136,22 @@ recurrent layer. Behaves like an RNN but generally exhibits a longer memory span
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
# Recur(m::LSTMCell) = Recur(m, (zeros(length(m.b)÷4), zeros(length(m.b)÷4)),
# (zeros(length(m.b)÷4), zeros(length(m.b)÷4)))
Recur(m::LSTMCell) = Recur(m, m.state)
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))

# GRU

mutable struct GRUCell{A,V}
struct GRUCell{A,V,S}
Wi::A
Wh::A
b::V
h::V
state::S
end

GRUCell(in, out; init = glorot_uniform) =
GRUCell(init(out * 3, in), init(out * 3, out),
init(out * 3), zeros(out))
GRUCell(in, out; init = glorot_uniform, initb = zeros, init_state = zeros) =
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))

function (m::GRUCell)(h, x)
b, o = m.b, size(h, 1)
Expand All @@ -167,8 +163,6 @@ function (m::GRUCell)(h, x)
return h′, h′
end

hidden(m::GRUCell) = m.h

@functor GRUCell

Base.show(io::IO, l::GRUCell) =
Expand All @@ -183,6 +177,7 @@ RNN but generally exhibits a longer memory span over sequences.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
Recur(m::GRUCell) = Recur(m, m.state)
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))

@adjoint function Broadcast.broadcasted(f::Recur, args...)
Expand Down
8 changes: 3 additions & 5 deletions test/cuda/curnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Flux: pullback
Flux.reset!(m)
θ = gradient(() -> sum(m(x)), params(m))
@test x isa CuArray
@test_broken θ[m.cell.Wi] isa CuArray
@test θ[m.cell.Wi] isa CuArray
@test_broken collect(m̄[].cell[].Wi) == collect(θ[m.cell.Wi])
end

Expand All @@ -20,17 +20,15 @@ end
Flux.reset!(rnn)
Flux.reset!(curnn)
x = batch_size == 1 ?
rand(10) :
rand(10, batch_size)
rand(Float32, 10) :
rand(Float32, 10, batch_size)
cux = gpu(x)

y, back = pullback((r, x) -> r(x), rnn, x)
cuy, cuback = pullback((r, x) -> r(x), curnn, cux)

@test y collect(cuy)

@test haskey(Flux.CUDAint.descs, curnn.cell)

= randn(size(y))
m̄, x̄ = back(ȳ)
cum̄, cux̄ = cuback(gpu(ȳ))
Expand Down
4 changes: 2 additions & 2 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ref FluxML/Flux.jl#1209
@testset "BPTT" begin
seq = [rand(2) for i = 1:3]
seq = [rand(Float32, (2,1)) for i = 1:3]
for r [RNN,]
rnn = r(2,3)
Flux.reset!(rnn)
Expand All @@ -11,7 +11,7 @@
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
tanh.(rnn.cell.Wi * seq[2] + Wh *
tanh.(rnn.cell.Wi * seq[1] +
Wh * rnn.init
Wh * rnn.cell.state
+ rnn.cell.b)
+ rnn.cell.b)
+ rnn.cell.b)),
Expand Down
6 changes: 3 additions & 3 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,16 @@ end
m = Dense(10, 5)
@test size.(params(m)) == [(5, 10), (5,)]
m = RNN(10, 5)
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5, 1)]

# Layer duplicated in same chain, params just once pls.
c = Chain(m, m)
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5,)]
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5, 1)]

# Self-referential array. Just want params, no stack overflow pls.
r = Any[nothing,m]
r[1] = r
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)]
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5, 1)]
end

@testset "Basic Stacking" begin
Expand Down

0 comments on commit 0e0e2e7

Please sign in to comment.