Skip to content

Commit

Permalink
various test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Feb 12, 2024
1 parent 4c3925b commit ebde232
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 98 deletions.
1 change: 1 addition & 0 deletions src/esn/deepesn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ function DeepESN(train_data,
rng = _default_rng(),
T = Float64,
matrix_type = typeof(train_data))

if states_type isa AbstractPaddedStates
in_size = size(train_data, 1) + 1
train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))),
Expand Down
5 changes: 3 additions & 2 deletions test/esn/test_drivers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ esn_configs = [
Dict(:reservoir => rand_sparse(; radius = 1.2),
:reservoir_driver => GRU(variant = Minimal(),
reservoir = rand_sparse(; radius = 1.0, sparsity = 0.5),
inner_layer = scaled_rand)),
inner_layer = scaled_rand,
bias = scaled_rand)),
Dict(:reservoir => rand_sparse(; radius = 1.2),
:reservoir_driver => MRNN(activation_function = (tanh, sigmoid),
scaling_factor = (0.8, 0.1))),
]

for config in esn_configs
@testset "Test Drivers: $config" for config in esn_configs
test_esn(input_data, target_data, training_method, config)
end
23 changes: 0 additions & 23 deletions test/esn/test_nla.jl

This file was deleted.

46 changes: 0 additions & 46 deletions test/esn/test_states.jl

This file was deleted.

17 changes: 7 additions & 10 deletions test/esn/test_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ const input_data = reduce(hcat, data[1:(train_len - 1)])
const target_data = reduce(hcat, data[2:train_len])
const test = reduce(hcat, data[(train_len + 1):(train_len + predict_len)])
const reg = 10e-6
#test_types = [Float64, Float32, Float16]

Random.seed!(77)
esn = ESN(input_data;
reservoir = RandSparseReservoir(res_size, 1.2, 0.1))
res = rand_sparse(; radius=1.2, sparsity=0.1)
esn = ESN(input_data, 1, res_size;
reservoir = rand_sparse)

training_methods = [
StandardRidge(regularization_coeff = reg),
Expand All @@ -21,14 +23,9 @@ training_methods = [
EpsilonSVR(),
]

for t in training_methods
output_layer = train(esn, target_data, t)
# TODO check types
@testset "Training Algo Tests: $ta" for ta in training_methods
output_layer = train(esn, target_data, ta)
output = esn(Predictive(input_data), output_layer)
@test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.22
end

for t in training_methods
output_layer = train(esn, target_data, t)
output, states = esn(Predictive(input_data), output_layer, save_states = true)
@test size(states) == (res_size, size(input_data, 2))
end
30 changes: 13 additions & 17 deletions test/test_states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using ReservoirComputing
test_array = [1, 2, 3, 4, 5, 6, 7, 8, 9]
extension = [0, 0, 0]
padding = 10.0
test_types = [Float64, Float32, Float16]

nlas = [(NLADefault(), test_array),
(NLAT1(), [1, 2, 9, 4, 25, 6, 49, 8, 81]),
Expand All @@ -18,20 +19,15 @@ pes = [(StandardStates(), test_array),
1)),
(ExtendedStates(), vcat(extension, test_array))]

function test_nla(algo, expected_output)
nla_array = ReservoirComputing.nla(algo, test_array)
@test nla_array == expected_output
end

function test_states_type(state_type, expected_output)
states_output = state_type(NLADefault(), test_array, extension)
@test states_output == expected_output
end

@testset "Nonlinear Algorithms Testing" for (algo, expected_output) in nlas
test_nla(algo, expected_output)
end

@testset "States Testing" for (state_type, expected_output) in pes
test_states_type(state_type, expected_output)
end
@testset "States Testing" for T in test_types
@testset "Nonlinear Algorithms Testing: $algo $T" for (algo, expected_output) in nlas
nla_array = ReservoirComputing.nla(algo, T.(test_array))
@test nla_array == expected_output
@test eltype(nla_array) == T
end
@testset "States Testing: $state_type $T" for (state_type, expected_output) in pes
states_output = state_type(NLADefault(), T.(test_array), T.(extension))
@test states_output == expected_output
@test eltype(states_output) == T
end
end

0 comments on commit ebde232

Please sign in to comment.