Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
finish RNNCell

RNN rework

LSTMCell

LSTM

more work

gru

extended testing

reset! deprecation

fix test

unbreak l2 test

fix tests

fixes
  • Loading branch information
CarloLucibello committed Oct 21, 2024
1 parent 834bed3 commit cf56985
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 16 deletions.
2 changes: 0 additions & 2 deletions perf/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@ Flux.@functor RNNWrapper

# Need to specialize for RNNWrapper.
fw(r::RNNWrapper, X::Vector{<:AbstractArray}) = begin
Flux.reset!(r.rnn)
[r.rnn(x) for x in X]
end

fw(r::RNNWrapper, X) = begin
Flux.reset!(r.rnn)
r.rnn(X)
end

Expand Down
5 changes: 5 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,8 @@ end
# where `loss_mxy` accepts the model as its first argument.
# """
# ))

function reset!(x)
Base.depwarn("reset!(m) is deprecated. You can remove this call as it is no more needed.", :reset!)
return x
end
2 changes: 1 addition & 1 deletion test/ext_amdgpu/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ end
end

@testset "Recurrent" begin
BROKEN_TESTS = []
global BROKEN_TESTS = []
include("../ext_common/recurrent_gpu_ad.jl")
end
2 changes: 1 addition & 1 deletion test/ext_cuda/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ end
include("cudnn.jl")
end
@testset "Recurrent" begin
BROKEN_TESTS = []
global BROKEN_TESTS = []
include("../ext_common/recurrent_gpu_ad.jl")
end
@testset "ctc" begin
Expand Down
2 changes: 0 additions & 2 deletions test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ end

@testset "Models" begin
function loss(model, x)
Flux.reset!(model)
sum(model(x))
end

Expand Down Expand Up @@ -126,7 +125,6 @@ end

@testset "Recurrence Tests" begin
function loss(model, x)
Flux.reset!(model)
for i in 1:3
x = model(x)
end
Expand Down
2 changes: 1 addition & 1 deletion test/ext_metal/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end
end

@testset "Recurrent" begin
BROKEN_TESTS = [:lstm, :gru, :gruv3]
global BROKEN_TESTS = [:lstm, :gru, :gruv3]
include("../ext_common/recurrent_gpu_ad.jl")
end

Expand Down
18 changes: 9 additions & 9 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,19 +251,19 @@ end
end

@testset "Params" begin
m = Dense(10, 5)
m = Dense(10 => 5)
@test size.(params(m)) == [(5, 10), (5,)]
m = RNN(10, 5)
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5, 1)]
m = RNN(10 => 5)
@test size.(params(m)) == [(5, 10), (5, 5), (5,)]

# Layer duplicated in same chain, params just once pls.
c = Chain(m, m)
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5, 1)]
@test size.(params(c)) == [(5, 10), (5, 5), (5,)]

# Self-referential array. Just want params, no stack overflow pls.
r = Any[nothing,m]
r[1] = r
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5, 1)]
@test size.(params(r)) == [(5, 10), (5, 5), (5,)]

# Ensure functor explores inside Transpose but not SubArray
m = (x = view([1,2,3]pi, 1:2), y = transpose([4 5]pi))
Expand All @@ -273,7 +273,7 @@ end
@testset "params gradient" begin
m = (x=[1,2.0], y=[3.0]);

# Explicit -- was broken by #2054 / then fixed / now broken again on julia v1.11
# Explicit -- was broken by #2054
gnew = gradient(m -> (sum(norm, Flux.params(m))), m)[1]
@test gnew.x [0.4472135954999579, 0.8944271909999159]
@test gnew.y [1.0]
Expand All @@ -286,7 +286,7 @@ end
end

@testset "Precision" begin
m = Chain(Dense(10, 5, relu; bias=false), Dense(5, 2))
m = Chain(Dense(10 => 5, relu; bias=false), Dense(5 => 2))
x64 = rand(Float64, 10)
x32 = rand(Float32, 10)
i64 = rand(Int64, 10)
Expand Down Expand Up @@ -467,10 +467,10 @@ end
@test modules[5] === m2
@test modules[6] === m3

mod_par = Flux.modules(Parallel(Flux.Bilinear(2,2,2,cbrt), Dense(2,2,abs), Dense(2,2,abs2)))
mod_par = Flux.modules(Parallel(Flux.Bilinear(2,2,2,cbrt), Dense(2=>2,abs), Dense(2=>2,abs2)))
@test length(mod_par) == 5

mod_rnn = Flux.modules(Chain(Dense(2,3), BatchNorm(3), LSTM(3,4)))
mod_rnn = Flux.modules(Chain(Dense(2=>3), BatchNorm(3), LSTM(3=>4)))
@test length(mod_rnn) == 6
@test mod_rnn[end] isa Flux.LSTMCell

Expand Down

0 comments on commit cf56985

Please sign in to comment.