Skip to content

Commit

Permalink
gru
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 15, 2024
1 parent 6458113 commit be799c6
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 60 deletions.
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zyg

export Chain, Dense, Embedding, EmbeddingBag,
Maxout, SkipConnection, Parallel, PairwiseFusion,
RNNCell, LSTMCell,
RNNCell, LSTMCell, GRUCell, GRUv3Cell,
RNN, LSTM, GRU, GRUv3,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Expand Down
197 changes: 138 additions & 59 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -447,113 +447,192 @@ end
Base.show(io::IO, l::GRUCell) =
print(io, "GRUCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")")

"""
GRU(in => out,
@doc raw"""
GRU(in => out; init = glorot_uniform, bias = true)
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an
RNN but generally exhibits a longer memory span over sequences. This implements
the variant proposed in v1 of the referenced paper.
The integer arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
The forward pass computes
This constructor is syntactic sugar for `Recur(GRUCell(a...))`, and so GRUs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
```math
r_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i)
z_t = \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z)
h̃_t = \tanh(W_{xh} x_t + r_t \odot W_{hh} h_{t-1} + b_h)
h_t = (1 - z_t) \odot h̃_t + z_t \odot h_{t-1}
```
for all `len` steps `t` in the input sequence.
See [`GRUCell`](@ref) for a layer that processes a single time step.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
# Forward
gru(x, h)
gru(x)
The arguments of the forward pass are:
- `x`: The input to the GRU. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`.
- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros.
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
# Examples
```jldoctest
julia> g = GRU(3 => 5)
Recur(
GRUCell(3 => 5), # 140 parameters
) # Total: 4 trainable arrays, 140 parameters,
# plus 1 non-trainable, 5 parameters, summarysize 784 bytes.
d_in, d_out, len, batch_size = 2, 3, 4, 5
gru = GRU(d_in => d_out)
x = rand(Float32, (d_in, len, batch_size))
h0 = zeros(Float32, d_out)
h = gru(x, h0) # out x len x batch_size
```
"""
struct GRU{M}
cell::M
end

@layer :expand GRU

function GRU((in, out)::Pair; init = glorot_uniform, bias = true)
cell = GRUCell(in => out; init, bias)
return GRU(cell)
end

julia> g(rand(Float32, 3)) |> size
(5,)
function (m::GRU)(x)
h = zeros_like(x, size(m.cell.Wh, 2))
return m(x, h)
end

julia> Flux.reset!(g);
function (m::GRU)(x, h)
@assert ndims(x) == 2 || ndims(x) == 3
h′ = []
for x_t in eachslice(x, dims=2)
h = m.cell(x_t, h)
h′ = vcat(h′, [h])
end
return stack(h′, dims=2)
end

julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
(5, 10)
# GRU v3
@doc raw"""
GRUv3Cell(in => out, init = glorot_uniform, bias = true)
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer.
Behaves like an RNN but generally exhibits a longer memory span over sequences.
This implements the variant proposed in v3 of the referenced paper.
The forward pass computes
```math
r = \sigma(W_{xi} x + W_{hi} h + b_i)
z = \sigma(W_{xz} x + W_{hz} h + b_z)
h̃ = \tanh(W_{xh} x + W_{hh̃} (r \odot W_{hh} h) + b_h)
h' = (1 - z) \odot h̃ + z \odot h
```
and returns `h'`. This is a single time step of the GRU.
!!! warning "Batch size changes"
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).
See [`GRUv3`](@ref) for a layer that processes entire sequences.
See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer.
# Note:
`GRUCell`s can be constructed directly by specifying the non-linear function, the `Wi` and `Wh` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `Wi` and `Wh` matrices do not need to be the same type. See the example in [`RNN`](@ref).
"""
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
Recur(m::GRUCell) = Recur(m, m.state0)
# Arguments
# GRU v3
- `in => out`: The input and output dimensions of the layer.
- `init`: The initialization function to use for the weights. Default is `glorot_uniform`.
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
# Forward
gruv3cell(x, h)
gruv3cell(x)
The arguments of the forward pass are:
- `x`: The input to the GRU. It should be a vector of size `in` or a matrix of size `in x batch_size`.
- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros.
Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`.
"""
struct GRUv3Cell{I,H,V,HH}
Wi::I
Wh::H
b::V
Wh_h̃::HH
end

GRUv3Cell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
init(out, out), init_state(out,1))
@layer GRUv3Cell

function GRUv3Cell((in, out)::Pair; init = glorot_uniform, bias = true)
Wi = init(out * 3, in)
Wh = init(out * 3, out)
Wh_h̃ = init(out, out)
b = create_bias(Wi, bias, out * 3)
return GRUv3Cell(Wi, Wh, b, Wh_h̃)
end

(m::GRUv3Cell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))

function (m::GRUv3Cell)(x::AbstractVecOrMat, h)
_size_check(m, x, 1 => size(m.Wi,2))
Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1)
xT = _match_eltype(m, T, x)
gxs = chunk(Wi * xT, 3, dims=1)
Wi, Wh, b, Wh_h̃ = m.Wi, m.Wh, m.b, m.Wh_h̃
gxs = chunk(Wi * x, 3, dims=1)
ghs = chunk(Wh * h, 2, dims=1)
bs = chunk(b, 3, dims=1)
r, z = _gru_output(gxs, ghs, bs)
r = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1])
z = @. sigmoid_fast(gxs[2] + ghs[2] + bs[2])
= tanh_fast.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3])
h′ = @. (1 - z) *+ z * h
return h′, reshape_cell_output(h′, x)
return h′
end

@layer GRUv3Cell

Base.show(io::IO, l::GRUv3Cell) =
print(io, "GRUv3Cell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")")

"""

@doc raw"""
GRUv3(in => out)
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an
RNN but generally exhibits a longer memory span over sequences. This implements
the variant proposed in v3 of the referenced paper.
The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
This constructor is syntactic sugar for `Recur(GRUv3Cell(a...))`, and so GRUv3s are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
The forward pass computes
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
```math
r_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i)
z_t = \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z)
h̃_t = \tanh(W_{xh} x_t + W_{hh̃} (r_t \odot W_{hh} h_{t-1}) + b_h)
h_t = (1 - z_t) \odot h̃_t + z_t \odot h_{t-1}
```
for all `len` steps `t` in the input sequence.
See [`GRUv3Cell`](@ref) for a layer that processes a single time step.
See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer.
# Examples
```jldoctest
julia> g = GRUv3(3 => 5)
Recur(
GRUv3Cell(3 => 5), # 140 parameters
) # Total: 5 trainable arrays, 140 parameters,
# plus 1 non-trainable, 5 parameters, summarysize 840 bytes.
TODO
"""
struct GRUv3{M}
cell::M
end

julia> g(rand(Float32, 3)) |> size
(5,)
@layer :expand GRUv3

julia> Flux.reset!(g);
function GRUv3((in, out)::Pair; init = glorot_uniform, bias = true)
cell = GRUv3Cell(in => out; init, bias)
return GRUv3(cell)
end

julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
(5, 10)
```
function (m::GRUv3)(x)
h = zeros_like(x, size(m.cell.Wh, 2))
return m(x, h)
end

!!! warning "Batch size changes"
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).
function (m::GRUv3)(x, h)
@assert ndims(x) == 2 || ndims(x) == 3
h′ = []
for x_t in eachslice(x, dims=2)
h = m.cell(x_t, h)
h′ = vcat(h′, [h])
end
return stack(h′, dims=2)
end

# Note:
`GRUv3Cell`s can be constructed directly by specifying the non-linear function, the `Wi`, `Wh`, and `Wh_h` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `Wi`, `Wh`, and `Wh_h` matrices do not need to be the same type. See the example in [`RNN`](@ref).
"""
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
Recur(m::GRUv3Cell) = Recur(m, m.state0)

0 comments on commit be799c6

Please sign in to comment.