From 238640185473d541aa9a729d320f9e7ed4d7fcb8 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Tue, 15 Oct 2024 06:49:52 +0200 Subject: [PATCH] LSTM --- src/layers/recurrent.jl | 145 +++++++++++++++++++---- test/ext_cuda/{curnn.jl => recurrent.jl} | 0 test/ext_cuda/runtests.jl | 4 +- test/layers/recurrent.jl | 63 ++++++++++ test/test_utils.jl | 3 + 5 files changed, 187 insertions(+), 28 deletions(-) rename test/ext_cuda/{curnn.jl => recurrent.jl} (100%) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 36a47c2f95..9a40bae399 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -33,7 +33,7 @@ h^\prime = \sigma(W_i x + W_h h + b) ``` and returns `h'`. -See also [`RNN`](@ref). +See [`RNN`](@ref) for a layer that processes entire sequences. # Arguments @@ -121,9 +121,8 @@ In the forward pass computes h_t = \sigma(W_i x_t + W_h h_{t-1} + b) ``` for all `len` steps `t` in the in input sequence. -Returns all hidden states `h_t` in a tensor of size `(out, len, batch_size)`. -See also [`RNNCell`](@ref). +See [`RNNCell`](@ref) for a layer that processes a single time step. # Arguments @@ -139,7 +138,9 @@ See also [`RNNCell`](@ref). The arguments of the forward pass are: - `x`: The input to the RNN. It should be a matrix size `in x len` or a tensor of size `in x len x batch_size`. -- `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`. +- `h`: The initial hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`. + +Returns all new hidden states `h_t` in a tensor of size `(out, len, batch_size)`. # Examples @@ -200,16 +201,69 @@ end # LSTM +@doc raw""" + LSTMCell(in => out; init = glorot_uniform, bias = true) + +The [Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory) cell. +Behaves like an RNN but generally exhibits a longer memory span over sequences. + +In the forward pass, computes + +```math +i_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i) +f_t = \sigma(W_{xf} x_t + W_{hf} h_{t-1} + b_f) +c_t = f_t \odot c_{t-1} + i_t \odot \tanh(W_{xc} x_t + W_{hc} h_{t-1} + b_c) +o_t = \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o) +h_t = o_t \odot \tanh(c_t) +``` + +The `LSTMCell` returns the new hidden state `h_t` and cell state `c_t` for a single time step. +See also [`LSTM`](@ref) for a layer that processes entire sequences. + +# Arguments + +- `in => out`: The input and output dimensions of the layer. +- `init`: The initialization function to use for the weights. Default is `glorot_uniform`. +- `bias`: Whether to include a bias term initialized to zero. Default is `true`. + +# Forward + + lstmcell(x, (h, c)) or lstmcell(x) +The arguments of the forward pass are: +- `x`: The input to the LSTM. It should be a matrix of size `in` or a tensor of size `in x batch_size`. +- `(h, c)`: A tuple containing the hidden and cell states of the LSTM. + They should be vectors of size `out` or matrices of size `out x batch_size`. + If not provided, they are assumed to be vectors of zeros. + +Returns a tuple `(h′, c′)` containing the new hidden state and cell state in tensors of size `out` or `out x batch_size`. + +# Examples + +```jldoctest +julia> l = LSTMCell(3 => 5) +LSTMCell(3 => 5) # 180 parameters + +julia> h = zeros(Float32, 5); # hidden state + +julia> c = zeros(Float32, 5); # cell state + +julia> x = rand(Float32, 3, 4); # in x batch_size + +julia> h′, c′ = l(x, (h, c)); + +julia> size(h′) # out x batch_size +(5, 4) +""" struct LSTMCell{I,H,V} Wi::I Wh::H bias::V end -function LSTMCell((in, out)::Pair; - init = glorot_uniform, - bias = true) +@layer LSTMCell + +function LSTMCell((in, out)::Pair; init = glorot_uniform, bias = true) Wi = init(out * 4, in) Wh = init(out * 4, out) b = create_bias(Wi, bias, size(Wi, 1)) @@ -218,35 +272,61 @@ function LSTMCell((in, out)::Pair; return cell end + function (m::LSTMCell)(x::AbstractVecOrMat, (h, c)) _size_check(m, x, 1 => size(m.Wi, 2)) b, o = m.bias, size(h, 1) - g = m.Wi * x .+ m.Wh*h .+ b + g = m.Wi * x .+ m.Wh * h .+ b input, forget, cell, output = multigate(g, o, Val(4)) c′ = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell) h′ = @. sigmoid_fast(output) * tanh_fast(c′) return h′, c′ end -@layer LSTMCell - Base.show(io::IO, l::LSTMCell) = print(io, "LSTMCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷4, ")") -""" - LSTM(in => out) + +@doc raw"""" + LSTM(in => out; init = glorot_uniform, bias = true) [Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory) recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. -The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. - -This constructor is syntactic sugar for `Recur(LSTMCell(a...))`, and so LSTMs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below. - See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) for a good overview of the internals. +In the forward pass, computes + +```math +i_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i) +f_t = \sigma(W_{xf} x_t + W_{hf} h_{t-1} + b_f) +c_t = f_t \odot c_{t-1} + i_t \odot \tanh(W_{xc} x_t + W_{hc} h_{t-1} + b_c) +o_t = \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o) +h_t = o_t \odot \tanh(c_t) +``` + +See also [`LSTMCell`](@ref). + +# Arguments + +- `in => out`: The input and output dimensions of the layer. +- `init`: The initialization function to use for the weights. Default is `glorot_uniform`. +- `bias`: Whether to include a bias term initialized to zero. Default is `true`. + +# Forward + + lstm(x, (h, c)) + +The arguments of the forward pass are: +- `x`: The input to the LSTM. It should be a matrix of size `in x len` or a tensor of size `in x len x batch_size`. +- `(h, c)`: A tuple containing the hidden and cell states of the LSTM. They should be vectors of size `out` or matrices of size `out x batch_size`. + +Returns a tuple `(h′, c′)` containing all new hidden states `h_t` and cell states `c_t` +in tensors of size `out x len` or `out x len x batch_size`. + # Examples + ```jldoctest julia> l = LSTM(3 => 5) Recur( @@ -262,15 +342,29 @@ julia> Flux.reset!(l); julia> l(rand(Float32, 3, 10)) |> size # batch size of 10 (5, 10) ``` +""" +struct LSTM{M} + cell::M +end -!!! warning "Batch size changes" - Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref). +@layer :expand LSTM -# Note: - `LSTMCell`s can be constructed directly by specifying the non-linear function, the `Wi` and `Wh` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `Wi` and `Wh` matrices do not need to be the same type. See the example in [`RNN`](@ref). -""" -LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) -Recur(m::LSTMCell) = Recur(m, m.state0) +function LSTM((in, out)::Pair; init = glorot_uniform, bias = true) + cell = LSTMCell(in => out; init, bias) + return LSTM(cell) +end + +function (m::LSTM)(x, (h, c)) + @assert ndims(x) == 2 || ndims(x) == 3 + h′ = [] + c′ = [] + for x_t in eachslice(x, dims=2) + h, c = m.cell(x_t, (h, c)) + h′ = vcat(h′, [h]) + c′ = vcat(c′, [c]) + end + return stack(h′, dims=2), stack(c′, dims=2) +end # GRU @@ -280,17 +374,16 @@ function _gru_output(gxs, ghs, bs) return r, z end -struct GRUCell{I,H,V,S} +struct GRUCell{I,H,V} Wi::I Wh::H b::V - state0::S end GRUCell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) = GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1)) -function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat) where {I,H,V,T} +function (m::GRUCell{I,H,V})(h, x::AbstractVecOrMat) where {I,H,V} _size_check(m, x, 1 => size(m.Wi,2)) Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1) xT = _match_eltype(m, T, x) diff --git a/test/ext_cuda/curnn.jl b/test/ext_cuda/recurrent.jl similarity index 100% rename from test/ext_cuda/curnn.jl rename to test/ext_cuda/recurrent.jl diff --git a/test/ext_cuda/runtests.jl b/test/ext_cuda/runtests.jl index 012a62d41a..d9802762c0 100644 --- a/test/ext_cuda/runtests.jl +++ b/test/ext_cuda/runtests.jl @@ -22,8 +22,8 @@ end @testset "cudnn" begin include("cudnn.jl") end -@testset "curnn" begin - include("curnn.jl") +@testset "recurrent" begin + include("recurrent.jl") end @testset "ctc" begin include("ctc.jl") diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 7f79adf342..aa58d7083e 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -82,3 +82,66 @@ end @test size(y) == (4, 3) test_gradients(model, x) end + +@testset "LSTMCell" begin + + function loss(r, x, hc) + h, c = hc + h′ = [] + c′ = [] + for x_t in x + h, c = r(x_t, (h, c)) + h′ = vcat(h′, [h]) + c′ = [c′..., c] + end + hnew = stack(h′, dims=2) + cnew = stack(c′, dims=2) + return mean(hnew.^2) + mean(cnew.^2) + end + + cell = LSTMCell(3 => 5) + @test length(Flux.trainables(cell)) == 3 + x = [rand(Float32, 3, 4) for _ in 1:6] + h = zeros(Float32, 5, 4) + c = zeros(Float32, 5, 4) + hnew, cnew = cell(x[1], (h, c)) + @test hnew isa Matrix{Float32} + @test cnew isa Matrix{Float32} + @test size(hnew) == (5, 4) + @test size(cnew) == (5, 4) + test_gradients(cell, x[1], (h, c), loss = (m, x, hc) -> mean(m(x, hc)[1])) + test_gradients(cell, x, (h, c), loss = loss) + + cell = LSTMCell(3 => 5, bias=false) + @test length(Flux.trainables(cell)) == 2 +end + +@testset "LSTM" begin + struct ModelLSTM + lstm::LSTM + h0::AbstractVector + c0::AbstractVector + end + + Flux.@layer :expand ModelLSTM + + (m::ModelLSTM)(x) = m.lstm(x, (m.h0, m.c0)) + + model = ModelLSTM(LSTM(2 => 4), zeros(Float32, 4), zeros(Float32, 4)) + + x = rand(Float32, 2, 3, 1) + h, c = model(x) + @test h isa Array{Float32, 3} + @test size(h) == (4, 3, 1) + @test c isa Array{Float32, 3} + @test size(c) == (4, 3, 1) + test_gradients(model, x, loss = (m, x) -> mean(m(x)[1])) + + x = rand(Float32, 2, 3) + h, c = model(x) + @test h isa Array{Float32, 2} + @test size(h) == (4, 3) + @test c isa Array{Float32, 2} + @test size(c) == (4, 3) + test_gradients(model, x, loss = (m, x) -> mean(m(x)[1])) +end diff --git a/test/test_utils.jl b/test/test_utils.jl index e9a2b0c04f..da55ebca03 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -47,6 +47,9 @@ function test_gradients( or CPU AD vs GPU AD.") end + ## Let's make sure first that the forward pass works. + @test loss(f, xs...) isa Number + if test_grad_x # Zygote gradient with respect to input. y, g = Zygote.withgradient((xs...) -> loss(f, xs...), xs...)