Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating WeightInitializers and fixes around it #203

Merged
merged 2 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

Expand All @@ -34,10 +33,9 @@ Optim = "1"
PartialFunctions = "1.2"
Random = "1.10"
SafeTestsets = "0.1"
SparseArrays = "1.10"
Statistics = "1.10"
Test = "1"
WeightInitializers = "0.1.5"
WeightInitializers = "0.1.6"
julia = "1.10"

[extras]
Expand Down
16 changes: 8 additions & 8 deletions src/ReservoirComputing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@ using NNlib
using Optim
using PartialFunctions
using Random
using SparseArrays
using Statistics
using WeightInitializers

export NLADefault, NLAT1, NLAT2, NLAT3
export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates
export StandardRidge, LinearModel
export scaled_rand, weighted_init, sparse_init, informed_init, minimal_init
export scaled_rand, weighted_init, informed_init, minimal_init
export rand_sparse, delay_line, delay_line_backward, cycle_jumps, simple_cycle, pseudo_svd
export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal
export ESN, train
Expand Down Expand Up @@ -76,26 +75,27 @@ end
#fallbacks for initializers
for initializer in (:rand_sparse, :delay_line, :delay_line_backward, :cycle_jumps,
:simple_cycle, :pseudo_svd,
:scaled_rand, :weighted_init, :sparse_init, :informed_init, :minimal_init)
:scaled_rand, :weighted_init, :informed_init, :minimal_init)
NType = ifelse(initializer === :rand_sparse, Real, Number)
@eval function ($initializer)(dims::Integer...; kwargs...)
return $initializer(_default_rng(), Float32, dims...; kwargs...)
return $initializer(WeightInitializers._default_rng(), Float32, dims...; kwargs...)
end
@eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...)
return $initializer(rng, Float32, dims...; kwargs...)
end
@eval function ($initializer)(::Type{T},
dims::Integer...; kwargs...) where {T <: $NType}
return $initializer(_default_rng(), T, dims...; kwargs...)
return $initializer(WeightInitializers._default_rng(), T, dims...; kwargs...)
end
@eval function ($initializer)(rng::AbstractRNG; kwargs...)
return __partial_apply($initializer, (rng, (; kwargs...)))
return WeightInitializers.__partial_apply($initializer, (rng, (; kwargs...)))
end
@eval function ($initializer)(rng::AbstractRNG,
::Type{T}; kwargs...) where {T <: $NType}
return __partial_apply($initializer, ((rng, T), (; kwargs...)))
return WeightInitializers.__partial_apply($initializer, ((rng, T), (; kwargs...)))
end
@eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...))
@eval ($initializer)(; kwargs...) = WeightInitializers.__partial_apply(
$initializer, (; kwargs...))
end

#general
Expand Down
2 changes: 1 addition & 1 deletion src/esn/deepesn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function DeepESN(train_data,
nla_type = NLADefault(),
states_type = StandardStates(),
washout::Int = 0,
rng = _default_rng(),
rng = WeightInitializers._default_rng(),
T = Float64,
matrix_type = typeof(train_data))
if states_type isa AbstractPaddedStates
Expand Down
9 changes: 1 addition & 8 deletions src/esn/esn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function ESN(train_data,
nla_type = NLADefault(),
states_type = StandardStates(),
washout = 0,
rng = _default_rng(),
rng = WeightInitializers._default_rng(),
T = Float32,
matrix_type = typeof(train_data))
if states_type isa AbstractPaddedStates
Expand Down Expand Up @@ -120,13 +120,6 @@ trained_esn = train(esn, target_data)
# Train the ESN using a custom training method
trained_esn = train(esn, target_data, training_method = StandardRidge(1.0))
```

# Notes

- When using a `Hybrid` variation, the function extends the state matrix with data from the
physical model included in the `variation`.
- The training is handled by a lower-level `_train` function which takes the new state matrix
and performs the actual training using the specified `training_method`.
"""
function train(esn::AbstractEchoStateNetwork,
target_data,
Expand Down
37 changes: 0 additions & 37 deletions src/esn/esn_input_layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,43 +77,6 @@ function weighted_init(rng::AbstractRNG,
return layer_matrix
end

# TODO: @MartinuzziFrancesco remove when pr gets into WeightInitializers
"""
sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1), sparsity=T(0.1)) where {T <: Number}

Create and return a sparse layer matrix for use in neural network models.
The matrix will be of size specified by `dims`, with the specified `sparsity` and `scaling`.

# Arguments

- `rng`: An instance of `AbstractRNG` for random number generation.
- `T`: The data type for the elements of the matrix.
- `dims`: Dimensions of the resulting sparse layer matrix.
- `scaling`: The scaling factor for the sparse layer matrix. Defaults to 0.1.
- `sparsity`: The sparsity level of the sparse layer matrix, controlling the fraction of zero elements. Defaults to 0.1.

# Returns

A sparse layer matrix.

# Example

```julia
rng = Random.default_rng()
input_layer = sparse_init(rng, Float64, (3, 300); scaling = 0.2, sparsity = 0.1)
```
"""
function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
scaling = T(0.1), sparsity = T(0.1)) where {T <: Number}
res_size, in_size = dims
layer_matrix = Matrix(sprand(rng, T, res_size, in_size, sparsity))
layer_matrix = T.(2.0) .* (layer_matrix .- T.(0.5))
replace!(layer_matrix, T(-1.0) => T(0.0))
layer_matrix = scaling .* layer_matrix

return layer_matrix
end

"""
informed_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1), model_in_size, gamma=T(0.5)) where {T <: Number}

Expand Down
19 changes: 4 additions & 15 deletions src/esn/esn_reservoirs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ function rand_sparse(rng::AbstractRNG,
::Type{T},
dims::Integer...;
radius = T(1.0),
sparsity = T(0.1)) where {T <: Number}
reservoir_matrix = Matrix{T}(sprand(rng, dims..., sparsity))
reservoir_matrix = T(2.0) .* (reservoir_matrix .- T(0.5))
replace!(reservoir_matrix, T(-1.0) => T(0.0))
sparsity = T(0.1),
std = T(1.0)) where {T <: Number}
lcl_sparsity = T(1) - sparsity #consistency with current implementations
reservoir_matrix = sparse_init(rng, T, dims...; sparsity = lcl_sparsity, std = std)
rho_w = maximum(abs.(eigvals(reservoir_matrix)))
reservoir_matrix .*= radius / rho_w
if Inf in unique(reservoir_matrix) || -Inf in unique(reservoir_matrix)
Expand Down Expand Up @@ -299,14 +299,3 @@ end
function get_sparsity(M, dim)
return size(M[M .!= 0], 1) / (dim * dim - size(M[M .!= 0], 1)) #nonzero/zero elements
end

# from WeightInitializers.jl, TODO @MartinuzziFrancesco consider importing package
function _default_rng()
@static if VERSION >= v"1.7"
return Xoshiro(1234)
else
return MersenneTwister(1234)
end
end

__partial_apply(fn, inp) = fn$inp
2 changes: 1 addition & 1 deletion src/esn/hybridesn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ function HybridESN(model,
nla_type = NLADefault(),
states_type = StandardStates(),
washout = 0,
rng = _default_rng(),
rng = WeightInitializers._default_rng(),
T = Float32,
matrix_type = typeof(train_data))
train_data = vcat(train_data, model.model_data[:, 1:(end - 1)])
Expand Down
1 change: 0 additions & 1 deletion test/esn/test_inits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ reservoir_inits = [
input_inits = [
scaled_rand,
weighted_init,
sparse_init,
minimal_init,
minimal_init(; sampling_type = :irrational)
]
Expand Down
Loading