From be799c6d5bf7cb1371c2564468ede03c8cc92cc4 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 15 Oct 2024 09:45:59 +0200 Subject: [PATCH] gru --- src/Flux.jl | 2 +- src/layers/recurrent.jl | 197 ++++++++++++++++++++++++++++------------ 2 files changed, 139 insertions(+), 60 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index c054997e7e..e0d01639ca 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -35,7 +35,7 @@ Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zyg export Chain, Dense, Embedding, EmbeddingBag, Maxout, SkipConnection, Parallel, PairwiseFusion, - RNNCell, LSTMCell, + RNNCell, LSTMCell, GRUCell, GRUv3Cell, RNN, LSTM, GRU, GRUv3, SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv, AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index d6ab4ab9df..37cef2fb2c 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -447,48 +447,111 @@ end Base.show(io::IO, l::GRUCell) = print(io, "GRUCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")") -""" - GRU(in => out, +@doc raw""" + GRU(in => out; init = glorot_uniform, bias = true) [Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. This implements the variant proposed in v1 of the referenced paper. -The integer arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. +The forward pass computes -This constructor is syntactic sugar for `Recur(GRUCell(a...))`, and so GRUs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below. +```math +r_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i) +z_t = \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z) +h̃_t = \tanh(W_{xh} x_t + r_t \odot W_{hh} h_{t-1} + b_h) +h_t = (1 - z_t) \odot h̃_t + z_t \odot h_{t-1} +``` +for all `len` steps `t` in the input sequence. +See [`GRUCell`](@ref) for a layer that processes a single time step. -See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) -for a good overview of the internals. +# Forward + + gru(x, h) + gru(x) + +The arguments of the forward pass are: + +- `x`: The input to the GRU. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`. +- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. # Examples + ```jldoctest -julia> g = GRU(3 => 5) -Recur( - GRUCell(3 => 5), # 140 parameters -) # Total: 4 trainable arrays, 140 parameters, - # plus 1 non-trainable, 5 parameters, summarysize 784 bytes. +d_in, d_out, len, batch_size = 2, 3, 4, 5 +gru = GRU(d_in => d_out) +x = rand(Float32, (d_in, len, batch_size)) +h0 = zeros(Float32, d_out) +h = gru(x, h0) # out x len x batch_size +``` +""" +struct GRU{M} + cell::M +end + +@layer :expand GRU + +function GRU((in, out)::Pair; init = glorot_uniform, bias = true) + cell = GRUCell(in => out; init, bias) + return GRU(cell) +end -julia> g(rand(Float32, 3)) |> size -(5,) +function (m::GRU)(x) + h = zeros_like(x, size(m.cell.Wh, 2)) + return m(x, h) +end -julia> Flux.reset!(g); +function (m::GRU)(x, h) + @assert ndims(x) == 2 || ndims(x) == 3 + h′ = [] + for x_t in eachslice(x, dims=2) + h = m.cell(x_t, h) + h′ = vcat(h′, [h]) + end + return stack(h′, dims=2) +end -julia> g(rand(Float32, 3, 10)) |> size # batch size of 10 -(5, 10) +# GRU v3 +@doc raw""" + GRUv3Cell(in => out, init = glorot_uniform, bias = true) + +[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. +Behaves like an RNN but generally exhibits a longer memory span over sequences. +This implements the variant proposed in v3 of the referenced paper. + +The forward pass computes +```math +r = \sigma(W_{xi} x + W_{hi} h + b_i) +z = \sigma(W_{xz} x + W_{hz} h + b_z) +h̃ = \tanh(W_{xh} x + W_{hh̃} (r \odot W_{hh} h) + b_h) +h' = (1 - z) \odot h̃ + z \odot h ``` +and returns `h'`. This is a single time step of the GRU. -!!! warning "Batch size changes" - Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref). +See [`GRUv3`](@ref) for a layer that processes entire sequences. +See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer. -# Note: - `GRUCell`s can be constructed directly by specifying the non-linear function, the `Wi` and `Wh` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `Wi` and `Wh` matrices do not need to be the same type. See the example in [`RNN`](@ref). -""" -GRU(a...; ka...) = Recur(GRUCell(a...; ka...)) -Recur(m::GRUCell) = Recur(m, m.state0) +# Arguments -# GRU v3 +- `in => out`: The input and output dimensions of the layer. +- `init`: The initialization function to use for the weights. Default is `glorot_uniform`. +- `bias`: Whether to include a bias term initialized to zero. Default is `true`. +# Forward + + gruv3cell(x, h) + gruv3cell(x) + +The arguments of the forward pass are: +- `x`: The input to the GRU. It should be a vector of size `in` or a matrix of size `in x batch_size`. +- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`. +""" struct GRUv3Cell{I,H,V,HH} Wi::I Wh::H @@ -496,64 +559,80 @@ struct GRUv3Cell{I,H,V,HH} Wh_h̃::HH end -GRUv3Cell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) = - GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3), - init(out, out), init_state(out,1)) +@layer GRUv3Cell + +function GRUv3Cell((in, out)::Pair; init = glorot_uniform, bias = true) + Wi = init(out * 3, in) + Wh = init(out * 3, out) + Wh_h̃ = init(out, out) + b = create_bias(Wi, bias, out * 3) + return GRUv3Cell(Wi, Wh, b, Wh_h̃) +end + +(m::GRUv3Cell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2))) function (m::GRUv3Cell)(x::AbstractVecOrMat, h) _size_check(m, x, 1 => size(m.Wi,2)) - Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1) - xT = _match_eltype(m, T, x) - gxs = chunk(Wi * xT, 3, dims=1) + Wi, Wh, b, Wh_h̃ = m.Wi, m.Wh, m.b, m.Wh_h̃ + gxs = chunk(Wi * x, 3, dims=1) ghs = chunk(Wh * h, 2, dims=1) bs = chunk(b, 3, dims=1) - r, z = _gru_output(gxs, ghs, bs) + r = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1]) + z = @. sigmoid_fast(gxs[2] + ghs[2] + bs[2]) h̃ = tanh_fast.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3]) h′ = @. (1 - z) * h̃ + z * h - return h′, reshape_cell_output(h′, x) + return h′ end -@layer GRUv3Cell - Base.show(io::IO, l::GRUv3Cell) = print(io, "GRUv3Cell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")") -""" + +@doc raw""" GRUv3(in => out) [Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. This implements the variant proposed in v3 of the referenced paper. -The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. - -This constructor is syntactic sugar for `Recur(GRUv3Cell(a...))`, and so GRUv3s are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below. +The forward pass computes -See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) -for a good overview of the internals. +```math +r_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i) +z_t = \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z) +h̃_t = \tanh(W_{xh} x_t + W_{hh̃} (r_t \odot W_{hh} h_{t-1}) + b_h) +h_t = (1 - z_t) \odot h̃_t + z_t \odot h_{t-1} +``` +for all `len` steps `t` in the input sequence. +See [`GRUv3Cell`](@ref) for a layer that processes a single time step. +See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer. # Examples -```jldoctest -julia> g = GRUv3(3 => 5) -Recur( - GRUv3Cell(3 => 5), # 140 parameters -) # Total: 5 trainable arrays, 140 parameters, - # plus 1 non-trainable, 5 parameters, summarysize 840 bytes. +TODO +""" +struct GRUv3{M} + cell::M +end -julia> g(rand(Float32, 3)) |> size -(5,) +@layer :expand GRUv3 -julia> Flux.reset!(g); +function GRUv3((in, out)::Pair; init = glorot_uniform, bias = true) + cell = GRUv3Cell(in => out; init, bias) + return GRUv3(cell) +end -julia> g(rand(Float32, 3, 10)) |> size # batch size of 10 -(5, 10) -``` +function (m::GRUv3)(x) + h = zeros_like(x, size(m.cell.Wh, 2)) + return m(x, h) +end -!!! warning "Batch size changes" - Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref). +function (m::GRUv3)(x, h) + @assert ndims(x) == 2 || ndims(x) == 3 + h′ = [] + for x_t in eachslice(x, dims=2) + h = m.cell(x_t, h) + h′ = vcat(h′, [h]) + end + return stack(h′, dims=2) +end -# Note: - `GRUv3Cell`s can be constructed directly by specifying the non-linear function, the `Wi`, `Wh`, and `Wh_h` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `Wi`, `Wh`, and `Wh_h` matrices do not need to be the same type. See the example in [`RNN`](@ref). -""" -GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...)) -Recur(m::GRUv3Cell) = Recur(m, m.state0)