Skip to content

Commit

Permalink
LSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 15, 2024
1 parent bac60d2 commit 2386401
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 28 deletions.
145 changes: 119 additions & 26 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ h^\prime = \sigma(W_i x + W_h h + b)
```
and returns `h'`.
See also [`RNN`](@ref).
See [`RNN`](@ref) for a layer that processes entire sequences.
# Arguments
Expand Down Expand Up @@ -121,9 +121,8 @@ In the forward pass computes
h_t = \sigma(W_i x_t + W_h h_{t-1} + b)
```
for all `len` steps `t` in the in input sequence.
Returns all hidden states `h_t` in a tensor of size `(out, len, batch_size)`.
See also [`RNNCell`](@ref).
See [`RNNCell`](@ref) for a layer that processes a single time step.
# Arguments
Expand All @@ -139,7 +138,9 @@ See also [`RNNCell`](@ref).
The arguments of the forward pass are:
- `x`: The input to the RNN. It should be a matrix size `in x len` or a tensor of size `in x len x batch_size`.
- `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`.
- `h`: The initial hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`.
Returns all new hidden states `h_t` in a tensor of size `(out, len, batch_size)`.
# Examples
Expand Down Expand Up @@ -200,16 +201,69 @@ end


# LSTM
@doc raw"""
LSTMCell(in => out; init = glorot_uniform, bias = true)
The [Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory) cell.
Behaves like an RNN but generally exhibits a longer memory span over sequences.
In the forward pass, computes
```math
i_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i)
f_t = \sigma(W_{xf} x_t + W_{hf} h_{t-1} + b_f)
c_t = f_t \odot c_{t-1} + i_t \odot \tanh(W_{xc} x_t + W_{hc} h_{t-1} + b_c)
o_t = \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o)
h_t = o_t \odot \tanh(c_t)
```
The `LSTMCell` returns the new hidden state `h_t` and cell state `c_t` for a single time step.
See also [`LSTM`](@ref) for a layer that processes entire sequences.
# Arguments
- `in => out`: The input and output dimensions of the layer.
- `init`: The initialization function to use for the weights. Default is `glorot_uniform`.
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
# Forward
lstmcell(x, (h, c)) or lstmcell(x)
The arguments of the forward pass are:
- `x`: The input to the LSTM. It should be a matrix of size `in` or a tensor of size `in x batch_size`.
- `(h, c)`: A tuple containing the hidden and cell states of the LSTM.
They should be vectors of size `out` or matrices of size `out x batch_size`.
If not provided, they are assumed to be vectors of zeros.
Returns a tuple `(h′, c′)` containing the new hidden state and cell state in tensors of size `out` or `out x batch_size`.
# Examples
```jldoctest
julia> l = LSTMCell(3 => 5)
LSTMCell(3 => 5) # 180 parameters
julia> h = zeros(Float32, 5); # hidden state
julia> c = zeros(Float32, 5); # cell state
julia> x = rand(Float32, 3, 4); # in x batch_size
julia> h′, c′ = l(x, (h, c));
julia> size(h′) # out x batch_size
(5, 4)
"""
struct LSTMCell{I,H,V}
Wi::I
Wh::H
bias::V
end

function LSTMCell((in, out)::Pair;
init = glorot_uniform,
bias = true)
@layer LSTMCell

function LSTMCell((in, out)::Pair; init = glorot_uniform, bias = true)
Wi = init(out * 4, in)
Wh = init(out * 4, out)
b = create_bias(Wi, bias, size(Wi, 1))
Expand All @@ -218,35 +272,61 @@ function LSTMCell((in, out)::Pair;
return cell
end


function (m::LSTMCell)(x::AbstractVecOrMat, (h, c))
_size_check(m, x, 1 => size(m.Wi, 2))
b, o = m.bias, size(h, 1)
g = m.Wi * x .+ m.Wh*h .+ b
g = m.Wi * x .+ m.Wh * h .+ b
input, forget, cell, output = multigate(g, o, Val(4))
c′ = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell)
h′ = @. sigmoid_fast(output) * tanh_fast(c′)
return h′, c′
end

@layer LSTMCell

Base.show(io::IO, l::LSTMCell) =
print(io, "LSTMCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷4, ")")

"""
LSTM(in => out)

@doc raw""""
LSTM(in => out; init = glorot_uniform, bias = true)
[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)
recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences.
The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
This constructor is syntactic sugar for `Recur(LSTMCell(a...))`, and so LSTMs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
In the forward pass, computes
```math
i_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i)
f_t = \sigma(W_{xf} x_t + W_{hf} h_{t-1} + b_f)
c_t = f_t \odot c_{t-1} + i_t \odot \tanh(W_{xc} x_t + W_{hc} h_{t-1} + b_c)
o_t = \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o)
h_t = o_t \odot \tanh(c_t)
```
See also [`LSTMCell`](@ref).
# Arguments
- `in => out`: The input and output dimensions of the layer.
- `init`: The initialization function to use for the weights. Default is `glorot_uniform`.
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
# Forward
lstm(x, (h, c))
The arguments of the forward pass are:
- `x`: The input to the LSTM. It should be a matrix of size `in x len` or a tensor of size `in x len x batch_size`.
- `(h, c)`: A tuple containing the hidden and cell states of the LSTM. They should be vectors of size `out` or matrices of size `out x batch_size`.
Returns a tuple `(h′, c′)` containing all new hidden states `h_t` and cell states `c_t`
in tensors of size `out x len` or `out x len x batch_size`.
# Examples
```jldoctest
julia> l = LSTM(3 => 5)
Recur(
Expand All @@ -262,15 +342,29 @@ julia> Flux.reset!(l);
julia> l(rand(Float32, 3, 10)) |> size # batch size of 10
(5, 10)
```
"""
struct LSTM{M}
cell::M
end

!!! warning "Batch size changes"
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).
@layer :expand LSTM

# Note:
`LSTMCell`s can be constructed directly by specifying the non-linear function, the `Wi` and `Wh` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `Wi` and `Wh` matrices do not need to be the same type. See the example in [`RNN`](@ref).
"""
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
Recur(m::LSTMCell) = Recur(m, m.state0)
function LSTM((in, out)::Pair; init = glorot_uniform, bias = true)
cell = LSTMCell(in => out; init, bias)
return LSTM(cell)
end

function (m::LSTM)(x, (h, c))
@assert ndims(x) == 2 || ndims(x) == 3
h′ = []
c′ = []
for x_t in eachslice(x, dims=2)
h, c = m.cell(x_t, (h, c))
h′ = vcat(h′, [h])
c′ = vcat(c′, [c])
end
return stack(h′, dims=2), stack(c′, dims=2)
end

# GRU

Expand All @@ -280,17 +374,16 @@ function _gru_output(gxs, ghs, bs)
return r, z
end

struct GRUCell{I,H,V,S}
struct GRUCell{I,H,V}
Wi::I
Wh::H
b::V
state0::S
end

GRUCell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))

function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat) where {I,H,V,T}
function (m::GRUCell{I,H,V})(h, x::AbstractVecOrMat) where {I,H,V}
_size_check(m, x, 1 => size(m.Wi,2))
Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1)
xT = _match_eltype(m, T, x)
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions test/ext_cuda/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ end
@testset "cudnn" begin
include("cudnn.jl")
end
@testset "curnn" begin
include("curnn.jl")
@testset "recurrent" begin
include("recurrent.jl")
end
@testset "ctc" begin
include("ctc.jl")
Expand Down
63 changes: 63 additions & 0 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,66 @@ end
@test size(y) == (4, 3)
test_gradients(model, x)
end

@testset "LSTMCell" begin

function loss(r, x, hc)
h, c = hc
h′ = []
c′ = []
for x_t in x
h, c = r(x_t, (h, c))
h′ = vcat(h′, [h])
c′ = [c′..., c]
end
hnew = stack(h′, dims=2)
cnew = stack(c′, dims=2)
return mean(hnew.^2) + mean(cnew.^2)
end

cell = LSTMCell(3 => 5)
@test length(Flux.trainables(cell)) == 3
x = [rand(Float32, 3, 4) for _ in 1:6]
h = zeros(Float32, 5, 4)
c = zeros(Float32, 5, 4)
hnew, cnew = cell(x[1], (h, c))
@test hnew isa Matrix{Float32}
@test cnew isa Matrix{Float32}
@test size(hnew) == (5, 4)
@test size(cnew) == (5, 4)
test_gradients(cell, x[1], (h, c), loss = (m, x, hc) -> mean(m(x, hc)[1]))
test_gradients(cell, x, (h, c), loss = loss)

cell = LSTMCell(3 => 5, bias=false)
@test length(Flux.trainables(cell)) == 2
end

@testset "LSTM" begin
struct ModelLSTM
lstm::LSTM
h0::AbstractVector
c0::AbstractVector
end

Flux.@layer :expand ModelLSTM

(m::ModelLSTM)(x) = m.lstm(x, (m.h0, m.c0))

model = ModelLSTM(LSTM(2 => 4), zeros(Float32, 4), zeros(Float32, 4))

x = rand(Float32, 2, 3, 1)
h, c = model(x)
@test h isa Array{Float32, 3}
@test size(h) == (4, 3, 1)
@test c isa Array{Float32, 3}
@test size(c) == (4, 3, 1)
test_gradients(model, x, loss = (m, x) -> mean(m(x)[1]))

x = rand(Float32, 2, 3)
h, c = model(x)
@test h isa Array{Float32, 2}
@test size(h) == (4, 3)
@test c isa Array{Float32, 2}
@test size(c) == (4, 3)
test_gradients(model, x, loss = (m, x) -> mean(m(x)[1]))
end
3 changes: 3 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ function test_gradients(
or CPU AD vs GPU AD.")
end

## Let's make sure first that the forward pass works.
@test loss(f, xs...) isa Number

if test_grad_x
# Zygote gradient with respect to input.
y, g = Zygote.withgradient((xs...) -> loss(f, xs...), xs...)
Expand Down

0 comments on commit 2386401

Please sign in to comment.