Skip to content

Commit

Permalink
extended testing
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 17, 2024
1 parent be799c6 commit 8e31c85
Show file tree
Hide file tree
Showing 8 changed files with 324 additions and 38 deletions.
47 changes: 28 additions & 19 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat)
return h
end

function Base.show(io::IO, l::RNNCell)
print(io, "RNNCell(", size(l.Wi, 2), " => ", size(l.Wi, 1))
print(io, ", ", l.σ)
function Base.show(io::IO, m::RNNCell)
print(io, "RNNCell(", size(m.Wi, 2), " => ", size(m.Wi, 1))
print(io, ", ", m.σ)
print(io, ")")
end

Expand Down Expand Up @@ -262,16 +262,16 @@ end

function (m::LSTMCell)(x::AbstractVecOrMat, (h, c))
_size_check(m, x, 1 => size(m.Wi, 2))
b, o = m.bias, size(h, 1)
b = m.bias
g = m.Wi * x .+ m.Wh * h .+ b
input, forget, cell, output = chunk(g, 4; dims=1)
c′ = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell)
h′ = @. sigmoid_fast(output) * tanh_fast(c′)
return h′, c′
end

Base.show(io::IO, l::LSTMCell) =
print(io, "LSTMCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷4, ")")
Base.show(io::IO, m::LSTMCell) =
print(io, "LSTMCell(", size(m.Wi, 2), " => ", size(m.Wi, 1)÷4, ")")


@doc raw""""
Expand Down Expand Up @@ -431,21 +431,26 @@ function GRUCell((in, out)::Pair; init = glorot_uniform, bias = true)
return GRUCell(Wi, Wh, b)
end

(m::GRUCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))

function (m::GRUCell)(x::AbstractVecOrMat, h)
_size_check(m, x, 1 => size(m.Wi,2))
Wi, Wh, b = m.Wi, m.Wh, m.b
gxs = chunk(Wi * x, 3, dims=1)
ghs = chunk(Wh * h, 3, dims=1)
bs = chunk(b, 3, dims=1)
gxs = chunk(m.Wi * x, 3, dims=1)
ghs = chunk(m.Wh * h, 3, dims=1)
if m.b isa AbstractArray
bs = chunk(m.b, 3, dims=1)
else # b == false
bs = [false, false, false]
end
r = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1])
z = @. sigmoid_fast(gxs[2] + ghs[2] + bs[2])
= @. tanh_fast(gxs[3] + r * ghs[3] + bs[3])
h′ = @. (1 - z) *+ z * h
return h′
end

Base.show(io::IO, l::GRUCell) =
print(io, "GRUCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")")
Base.show(io::IO, m::GRUCell) =
print(io, "GRUCell(", size(m.Wi, 2), " => ", size(m.Wi, 1)÷3, ")")

@doc raw"""
GRU(in => out; init = glorot_uniform, bias = true)
Expand Down Expand Up @@ -507,6 +512,7 @@ end
function (m::GRU)(x, h)
@assert ndims(x) == 2 || ndims(x) == 3
h′ = []
# [x] = [in, L] or [in, L, B]
for x_t in eachslice(x, dims=2)
h = m.cell(x_t, h)
h′ = vcat(h′, [h])
Expand Down Expand Up @@ -573,19 +579,22 @@ end

function (m::GRUv3Cell)(x::AbstractVecOrMat, h)
_size_check(m, x, 1 => size(m.Wi,2))
Wi, Wh, b, Wh_h̃ = m.Wi, m.Wh, m.b, m.Wh_h̃
gxs = chunk(Wi * x, 3, dims=1)
ghs = chunk(Wh * h, 2, dims=1)
bs = chunk(b, 3, dims=1)
gxs = chunk(m.Wi * x, 3, dims=1)
ghs = chunk(m.Wh * h, 3, dims=1)
if m.b isa AbstractArray
bs = chunk(m.b, 3, dims=1)
else # m.b == false
bs = [false, false, false]
end
r = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1])
z = @. sigmoid_fast(gxs[2] + ghs[2] + bs[2])
= tanh_fast.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3])
= tanh_fast.(gxs[3] .+ (m.Wh_h̃ * (r .* h)) .+ bs[3])
h′ = @. (1 - z) *+ z * h
return h′
end

Base.show(io::IO, l::GRUv3Cell) =
print(io, "GRUv3Cell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")")
Base.show(io::IO, m::GRUv3Cell) =
print(io, "GRUv3Cell(", size(m.Wi, 2), " => ", size(m.Wi, 1)÷3, ")")


@doc raw"""
Expand Down
5 changes: 5 additions & 0 deletions test/ext_amdgpu/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ end
@testset "Basic" begin
include("basic.jl")
end

@testset "Recurrent" begin
BROKEN_TESTS = []
include("../ext_common/recurrent_gpu_ad.jl")
end
163 changes: 163 additions & 0 deletions test/ext_common/recurrent_gpu_ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@

@testset "RNNCell GPU AD" begin
function loss(r, x, h)
y = []
for x_t in x
h = r(x_t, h)
y = vcat(y, [h])
end
# return mean(h)
y = stack(y, dims=2) # [D, L] or [D, L, B]
return mean(y)
end

d_in, d_out, len, batch_size = 2, 3, 4, 5
r = RNNCell(d_in => d_out)
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
# Single Step
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :rnncell_single BROKEN_TESTS
# Multiple Steps
@test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :rnncell_multiple BROKEN_TESTS
end

@testset "RNN GPU AD" begin
struct ModelRNN
rnn::RNN
h0::AbstractVector
end

Flux.@layer :expand ModelRNN

(m::ModelRNN)(x) = m.rnn(x, m.h0)

d_in, d_out, len, batch_size = 2, 3, 4, 5
model = ModelRNN(RNN(d_in => d_out), zeros(Float32, d_out))
x_nobatch = randn(Float32, d_in, len)
@test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false) broken = :rnn_nobatch BROKEN_TESTS
x = randn(Float32, d_in, batch_size)
@test test_gradients(model, x, test_gpu=true, compare_finite_diff=false) broken = :rnn BROKEN_TESTS
end

@testset "LSTMCell" begin

function loss(r, x, hc)
h, c = hc
h′ = []
c′ = []
for x_t in x
h, c = r(x_t, (h, c))
h′ = vcat(h′, [h])
c′ = [c′..., c]
end
hnew = stack(h′, dims=2)
cnew = stack(c′, dims=2)
return mean(hnew) + mean(cnew)
end

d_in, d_out, len, batch_size = 2, 3, 4, 5
cell = LSTMCell(d_in => d_out)
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
c = zeros(Float32, d_out)
# Single Step
@test test_gradients(cell, x[1], (h, c); test_gpu=true, compare_finite_diff=false,
loss = (m, x, (h, c)) -> mean(m(x, (h, c))[1])) broken = :lstmcell_single BROKEN_TESTS
# Multiple Steps
@test test_gradients(cell, x, (h, c); test_gpu=true, compare_finite_diff=false, loss) broken = :lstmcell_multiple BROKEN_TESTS
end

@testset "LSTM" begin
struct ModelLSTM
lstm::LSTM
h0::AbstractVector
c0::AbstractVector
end

Flux.@layer :expand ModelLSTM

(m::ModelLSTM)(x) = m.lstm(x, (m.h0, m.c0))

d_in, d_out, len, batch_size = 2, 3, 4, 5
model = ModelLSTM(LSTM(d_in => d_out), zeros(Float32, d_out), zeros(Float32, d_out))
x_nobatch = randn(Float32, d_in, len)
@test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false,
loss = (m, x) -> mean(m(x)[1])) broken = :lstm_nobatch BROKEN_TESTS
x = randn(Float32, d_in, len, batch_size)
@test test_gradients(model, x; test_gpu=true, compare_finite_diff=false,
loss = (m, x) -> mean(m(x)[1])) broken = :lstm BROKEN_TESTS
end

@testset "GRUCell" begin
function loss(r, x, h)
y = []
for x_t in x
h = r(x_t, h)
y = vcat(y, [h])
end
y = stack(y, dims=2) # [D, L] or [D, L, B]
return mean(y)
end

d_in, d_out, len, batch_size = 2, 3, 4, 5
r = GRUCell(d_in => d_out)
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :grucell_single BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :grucell_multiple BROKEN_TESTS
end

@testset "GRU GPU AD" begin
struct ModelGRU
gru::GRU
h0::AbstractVector
end

Flux.@layer :expand ModelGRU

(m::ModelGRU)(x) = m.gru(x, m.h0)

d_in, d_out, len, batch_size = 2, 3, 4, 5
model = ModelGRU(GRU(d_in => d_out), zeros(Float32, d_out))
x_nobatch = randn(Float32, d_in, len)
@test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false) broken = :gru_nobatch BROKEN_TESTS
x = randn(Float32, d_in, len, batch_size)
@test test_gradients(model, x; test_gpu=true, compare_finite_diff=false) broken = :gru BROKEN_TESTS
end

@testset "GRUv3Cell GPU AD" begin
function loss(r, x, h)
y = []
for x_t in x
h = r(x_t, h)
y = vcat(y, [h])
end
y = stack(y, dims=2) # [D, L] or [D, L, B]
return mean(y)
end

d_in, d_out, len, batch_size = 2, 3, 4, 5
r = GRUv3Cell(d_in => d_out)
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :gruv3cell_single BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :gruv3cell_multiple BROKEN_TESTS
end

@testset "GRUv3 GPU AD" begin
struct ModelGRUv3
gru::GRUv3
h0::AbstractVector
end

Flux.@layer :expand ModelGRUv3

(m::ModelGRUv3)(x) = m.gru(x, m.h0)

d_in, d_out, len, batch_size = 2, 3, 4, 5
model = ModelGRUv3(GRUv3(d_in => d_out), zeros(Float32, d_out))
x_nobatch = randn(Float32, d_in, len)
@test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false) broken = :gruv3_nobatch BROKEN_TESTS
x = randn(Float32, d_in, len, batch_size)
@test test_gradients(model, x; test_gpu=true, compare_finite_diff=false) broken = :gruv3 BROKEN_TESTS
end
6 changes: 0 additions & 6 deletions test/ext_cuda/recurrent.jl

This file was deleted.

5 changes: 3 additions & 2 deletions test/ext_cuda/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ end
@testset "cudnn" begin
include("cudnn.jl")
end
@testset "recurrent" begin
include("recurrent.jl")
@testset "Recurrent" begin
BROKEN_TESTS = []
include("../ext_common/recurrent_gpu_ad.jl")
end
@testset "ctc" begin
include("ctc.jl")
Expand Down
5 changes: 5 additions & 0 deletions test/ext_metal/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ end
include("basic.jl")
end

@testset "Recurrent" begin
BROKEN_TESTS = [:lstm, :gru, :gruv3]
include("../ext_common/recurrent_gpu_ad.jl")
end

@testset "Huber Loss test" begin
X = Flux.gpu(Float32[0,1])
Y = Flux.gpu(Float32[1,0])
Expand Down
Loading

0 comments on commit 8e31c85

Please sign in to comment.