Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add adaptive optimizers for all mappings #315

Merged
merged 83 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 81 commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
b58ec06
Merge remote-tracking branch 'upstream/main' into enh-combine-optimizers
nghi-truyen Jul 19, 2024
233a8cc
initalize branch combine optimizers
nghi-truyen Jul 20, 2024
90afd60
Merge branch 'enh-hybrid-models' into enh-combine-optimizers
nghi-truyen Jul 20, 2024
c70db0c
Merge remote-tracking branch 'upstream/enh-cnn-regionalization' into …
nghi-truyen Jul 21, 2024
30eb4c5
ENH: add set, get and forward_pass methods for Net
nghi-truyen Jul 22, 2024
e4b06e8
ENH: add test net forward pass + fix doc net
nghi-truyen Jul 22, 2024
75a847c
Merge branch 'enh-hybrid-models' into enh-combine-optimizers
nghi-truyen Jul 22, 2024
91bed2f
Merge branch 'enh-net-forward-pass' into enh-combine-optimizers
nghi-truyen Jul 22, 2024
3d71aec
Fix indent docstring net
nghi-truyen Jul 22, 2024
b8dc977
Merge branch 'enh-net-forward-pass' into enh-combine-optimizers
nghi-truyen Jul 22, 2024
63a49b7
Merge branch 'enh-hybrid-models' into enh-net-forward-pass
nghi-truyen Jul 23, 2024
b4c5dfc
add weight/bias_shape into trainable layer
nghi-truyen Jul 24, 2024
ee7824d
add weight/bias_shape into trainable layer
nghi-truyen Jul 24, 2024
0ffd8ad
Remove redundant in Python function sbs_optimize
nghi-truyen Jul 24, 2024
7f5f9ec
Merge remote-tracking branch 'upstream/enh-cnn-regionalization' into …
nghi-truyen Jul 24, 2024
41a7aa6
Merge branch 'enh-hybrid-models' into enh-net-forward-pass
nghi-truyen Jul 24, 2024
91429b2
adaptive optimizers for all mappings
nghi-truyen Jul 24, 2024
e1a5a64
Merge remote-tracking branch 'upstream/move-optimizers-to-Python' int…
nghi-truyen Jul 24, 2024
7b5b069
finalize optimizer combination
nghi-truyen Jul 25, 2024
268fe5e
Fix parameter update when using early stopping + dic api doc optimize
nghi-truyen Jul 25, 2024
b14c4af
Correct comment typo
nghi-truyen Jul 25, 2024
3b60a04
Fix comment
nghi-truyen Jul 25, 2024
9fc544b
Merge branch 'enh-net-forward-pass' into enh-adaptive-opt-for-all-map…
nghi-truyen Jul 25, 2024
873302c
Merge branch 'enh-hybrid-models' into enh-net-forward-pass
nghi-truyen Jul 25, 2024
a371562
Merge branch 'enh-net-forward-pass' into enh-adaptive-opt-for-all-map…
nghi-truyen Jul 25, 2024
a189bc3
Merge branch 'enh-hybrid-models' into enh-net-forward-pass
nghi-truyen Jul 25, 2024
a577e2b
Merge branch 'enh-net-forward-pass' into enh-adaptive-opt-for-all-map…
nghi-truyen Jul 25, 2024
47bda59
Change x_train name to x in net
nghi-truyen Jul 25, 2024
d007323
Merge branch 'enh-net-forward-pass' into enh-adaptive-opt-for-all-map…
nghi-truyen Jul 25, 2024
834fa73
retrieve key x in control_info + fix api doc for control prior name
nghi-truyen Jul 25, 2024
68df6a3
Fix dtype for control_info which is applied a finalization Python fun…
nghi-truyen Jul 25, 2024
71b3148
Merge branch 'enh-hybrid-models' into enh-net-forward-pass
nghi-truyen Jul 25, 2024
fed6fe7
Merge branch 'enh-net-forward-pass' into enh-adaptive-opt-for-all-map…
nghi-truyen Jul 25, 2024
d0d2fe3
improve unbounded check for finalize_get_control_info function
nghi-truyen Jul 26, 2024
4951e29
Merge branch 'enh-hybrid-models' into enh-net-forward-pass
nghi-truyen Jul 29, 2024
016e853
Generate baseline
nghi-truyen Jul 29, 2024
0d66bb2
Merge branch 'enh-net-forward-pass' into enh-adaptive-opt-for-all-map…
nghi-truyen Jul 29, 2024
812a4d2
Fix errors occured when merging branch
nghi-truyen Jul 29, 2024
bbc73cf
Fix callback argument in _gradient_based_optimize_problem
nghi-truyen Jul 29, 2024
4eae8e6
Merge branch 'enh-hybrid-models' into enh-net-forward-pass
nghi-truyen Jul 30, 2024
50b9fd8
Merge branch 'enh-net-forward-pass' into enh-adaptive-opt-for-all-map…
nghi-truyen Jul 30, 2024
bb207ff
Fix doc net
nghi-truyen Jul 31, 2024
4133580
Merge branch 'enh-hybrid-models' into enh-net-forward-pass
nghi-truyen Jul 31, 2024
7af47fd
Merge branch 'enh-net-forward-pass' into enh-adaptive-opt-for-all-map…
nghi-truyen Jul 31, 2024
7b297cb
Fix raise message net.set_weight_bias
nghi-truyen Jul 31, 2024
c0ff623
Merge branch 'enh-net-forward-pass' into enh-adaptive-opt-for-all-map…
nghi-truyen Jul 31, 2024
3221a1d
Fix error in previous merge
nghi-truyen Jul 31, 2024
6098461
Add choices to raise error when using sbs optimizer for hybrid struct…
nghi-truyen Jul 31, 2024
56855c6
Merge branch 'enh-hybrid-models' into enh-net-forward-pass:
nghi-truyen Aug 2, 2024
b2890e1
Merge branch 'enh-net-forward-pass' into enh-adaptive-opt-for-all-map…
nghi-truyen Aug 2, 2024
1bce0fc
Merge branch 'enh-hybrid-models' into enh-net-forward-pass:
nghi-truyen Aug 2, 2024
b40ca7a
Merge branch 'enh-net-forward-pass' into enh-adaptive-opt-for-all-map…
nghi-truyen Aug 2, 2024
217f05d
Merge branch 'enh-hybrid-models' into enh-net-forward-pass
nghi-truyen Aug 2, 2024
3cbaacb
Merge branch 'enh-net-forward-pass' into enh-adaptive-opt-for-all-map…
nghi-truyen Aug 2, 2024
8c6c0f0
Merge branch 'enh-hybrid-models' into enh-net-forward-pass
nghi-truyen Aug 3, 2024
af24e13
Merge branch 'enh-net-forward-pass' into enh-adaptive-opt-for-all-map…
nghi-truyen Aug 3, 2024
78c85ce
Generic check optimizer in case of hybrid models
nghi-truyen Aug 3, 2024
0ed8ecc
ENH: add random_state to set_weight and set_bias methods
nghi-truyen Aug 3, 2024
d0bbb22
Minor fix typos
nghi-truyen Aug 3, 2024
42cfe15
Merge branch 'enh-net-forward-pass' into enh-adaptive-opt-for-all-map…
nghi-truyen Aug 3, 2024
fac74cf
DOC: fix optimize_options documentation
nghi-truyen Aug 3, 2024
a04d2e5
MAINT: remove lbfgsb fortran external file + fix verbose
nghi-truyen Aug 4, 2024
e93333b
MAINT: change to g format to display cost values optimize verbose
nghi-truyen Aug 4, 2024
35e4096
MAINT: change display format verbose smash.optmize
nghi-truyen Aug 5, 2024
685aab5
FIX: update verbose api documentation
nghi-truyen Aug 5, 2024
39627d3
FIX: fix float format verbose ann optimize
nghi-truyen Aug 5, 2024
b397bc5
FIX: remove duplicated function due to merging error
nghi-truyen Aug 5, 2024
c255963
MAINT: handle return options + fix verbose optimize:
nghi-truyen Aug 6, 2024
700d582
Fix api doc cance
nghi-truyen Aug 6, 2024
c1e1bca
FIX: make check
nghi-truyen Aug 6, 2024
4166e0e
FIX/ENH: fix doc default_optimize + generic doc for mapping and optim…
nghi-truyen Aug 7, 2024
14c6cd2
MAINT: remove see also default_optimize doc
nghi-truyen Aug 8, 2024
23df534
FIX: returns control with random values depending on random_state ins…
nghi-truyen Aug 14, 2024
1b40d8d
Merge branch 'enh-hybrid-models' into enh-net-forward-pass
nghi-truyen Aug 14, 2024
fceee65
Merge branch 'enh-net-forward-pass' into enh-adaptive-opt-for-all-map…
nghi-truyen Aug 14, 2024
ebe3275
MAINT: merge branch main into enh-adaptive-opt-for-all-mappings
nghi-truyen Sep 10, 2024
1f6798e
Merge pull request #8 from nghi-truyen/merge-adaptive-opt
nghi-truyen Sep 10, 2024
ab49571
Re-generate baseline and fix unittest
nghi-truyen Sep 10, 2024
dbcec34
Merge remote-tracking branch 'upstream/main' into enh-adaptive-optimi…
nghi-truyen Sep 10, 2024
5976ba9
FIX PR: re-generate baseline and Merge remote-tracking branch 'upstre…
nghi-truyen Sep 11, 2024
73b7ffb
Apply suggestion changes from the first review of FC
nghi-truyen Sep 11, 2024
1bbb453
Apply suggestion changes from PAG and FC review
nghi-truyen Sep 12, 2024
c8e3910
Apply suggestion changes from FC second review
nghi-truyen Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/math_num_documentation/forward_structure.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ In `smash` a forward/direct spatially distributed model is obtained by chaining
- (optional) a descriptors-to-parameters mapping :math:`\phi` either for parameters imposing spatial constrain and/or regional mapping between physical descriptor and model conceptual parameters, see :ref:`mapping section <math_num_documentation.mapping>`.
- (optional) a ``snow`` operator :math:`\mathcal{M}_{snw}` generating a melt flux :math:`m_{lt}` which is then summed with the precipitation flux to feed the ``hydrological`` operator :math:`\mathcal{M}_{rr}`.
- A ``hydrological`` production operator :math:`\mathcal{M}_{rr}` generating an elementary discharge :math:`q_t` which feeds the routing operator.
- A ``routing`` operator :math:`\mathcal{M}_{hy}` simulating propagation of discharge :math:`Q)`.
- A ``routing`` operator :math:`\mathcal{M}_{hy}` simulating propagation of discharge :math:`Q`.

The operators chaining principle is presented in section :ref:`forward and inverse problems statement <math_num_documentation.forward_inverse_problem.chaining>` (cf. :ref:`Eq. 2 <math_num_documentation.forward_inverse_problem.forward_problem_Mhy_circ_Mrr>` ) and the chaining fluxes are explicitated in the diagram below. The forward model obtained reads :math:`\mathcal{M}=\mathcal{M}_{hy}\left(\,.\,,\mathcal{M}_{rr}\left(\,.\,,\mathcal{M}_{snw}\left(.\right)\right)\right)` .

Expand Down
8 changes: 4 additions & 4 deletions doc/source/user_guide/classical_uses/lez_regionalization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ We also pass other options specific to the use of a NN:
- ``optimize_options``
- ``random_state``: a random seed used to initialize neural network weights.
- ``learning_rate``: the learning rate used for weights updates during training.
- ``termination_crit``: the number of training ``epochs`` for the neural network and a positive number to stop training when the loss function does not decrease below the current optimal value for ``early_stopping`` consecutive ``epochs``
- ``termination_crit``: the maximum number of training ``maxiter`` for the neural network and a positive number to stop training when the loss function does not decrease below the current optimal value for ``early_stopping`` consecutive iterations.

- ``return_options``
- ``net``: return the optimized neural network
Expand All @@ -240,7 +240,7 @@ We also pass other options specific to the use of a NN:
optimize_options={
"random_state": 23,
"learning_rate": 0.004,
"termination_crit": dict(epochs=100, early_stopping=20),
"termination_crit": dict(maxiter=100, early_stopping=20),
},
return_options={"net": True},
common_options={"ncpu": ncpu},
Expand All @@ -255,7 +255,7 @@ We also pass other options specific to the use of a NN:
optimize_options={
"random_state": 23,
"learning_rate": 0.004,
"termination_crit": dict(epochs=100, early_stopping=20),
"termination_crit": dict(maxiter=100, early_stopping=20),
},
return_options={"net": True},
)
Expand All @@ -276,7 +276,7 @@ Other information is available in the `smash.factory.Net` object, including the
.. ipython:: python

plt.plot(opt_ann.net.history["loss_train"]);
plt.xlabel("Epoch");
plt.xlabel("Iteration");
plt.ylabel("$1-NSE$");
plt.grid(alpha=.7, ls="--");
@savefig user_guide.classical_uses.lez_regionalization.ann_J.png
Expand Down
22 changes: 12 additions & 10 deletions doc/source/user_guide/quickstart/cance_first_simulation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -681,19 +681,21 @@ First, several information were displayed on the screen during optimization

.. code-block:: text

At iterate 0 nfg = 1 J = 0.695010 ddx = 0.64
At iterate 1 nfg = 30 J = 0.098411 ddx = 0.64
At iterate 2 nfg = 59 J = 0.045409 ddx = 0.32
At iterate 3 nfg = 88 J = 0.038182 ddx = 0.16
At iterate 4 nfg = 117 J = 0.037362 ddx = 0.08
At iterate 5 nfg = 150 J = 0.037087 ddx = 0.02
At iterate 6 nfg = 183 J = 0.036800 ddx = 0.02
At iterate 7 nfg = 216 J = 0.036763 ddx = 0.01
CONVERGENCE: DDX < 0.01
</> Optimize
At iterate 0 nfg = 1 J = 6.95010e-01 ddx = 0.64
At iterate 1 nfg = 30 J = 9.84107e-02 ddx = 0.64
At iterate 2 nfg = 59 J = 4.54087e-02 ddx = 0.32
At iterate 3 nfg = 88 J = 3.81818e-02 ddx = 0.16
At iterate 4 nfg = 117 J = 3.73617e-02 ddx = 0.08
At iterate 5 nfg = 150 J = 3.70873e-02 ddx = 0.02
At iterate 6 nfg = 183 J = 3.68004e-02 ddx = 0.02
At iterate 7 nfg = 216 J = 3.67635e-02 ddx = 0.01
At iterate 8 nfg = 240 J = 3.67277e-02 ddx = 0.01
CONVERGENCE: DDX < 0.01

inoelloc marked this conversation as resolved.
Show resolved Hide resolved
These lines show the different iterations of the optimization with information on the number of iterations, the number of cumulative evaluations ``nfg``
(number of foward runs performed within each iteration of the optimization algorithm), the value of the cost function to minimize ``J`` and the value of the adaptive descent step ``ddx`` of this heuristic search algorihtm.
So, to summarize, the optimization algorithm has converged after 7 iterations by reaching the descent step tolerance criterion of 0.01. This optimization required to perform 216 forward run evaluations and leads to a final cost function value on the order of 0.04.
So, to summarize, the optimization algorithm has converged after 8 iterations by reaching the descent step tolerance criterion of 0.01. This optimization required to perform 240 forward run evaluations and leads to a final cost function value of 0.0367.

Then, we can ask which cost function ``J`` has been minimized and which parameters have been optimized. So, by default, the cost function to be minimized is one minus the Nash-Sutcliffe efficiency ``nse`` (:math:`1 - \text{NSE}`)
and the optimized parameters are the set of rainfall-runoff parameters (``cp``, ``ct``, ``kexc`` and ``llr``). In the current configuration spatially
Expand Down
86 changes: 58 additions & 28 deletions smash/_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,9 +757,7 @@ def get_neurons_from_hydrological_module(hydrological_module: str, hidden_neuron
"zeros",
]

PY_OPTIMIZER_CLASS = ["Adam", "SGD", "Adagrad", "RMSprop"]

PY_OPTIMIZER = [opt.lower() for opt in PY_OPTIMIZER_CLASS]
OPTIMIZER_CLASS = ["Adam", "SGD", "Adagrad", "RMSprop"]

ACTIVATION_FUNCTION_CLASS = [
"Sigmoid",
Expand Down Expand Up @@ -799,31 +797,23 @@ def get_neurons_from_hydrological_module(hydrological_module: str, hidden_neuron

MAPPING = ["uniform", "distributed"] + REGIONAL_MAPPING

F90_OPTIMIZER = ["sbs", "lbfgsb"]
ADAPTIVE_OPTIMIZER = [opt.lower() for opt in OPTIMIZER_CLASS]
GRADIENT_BASED_OPTIMIZER = ["lbfgsb"] + ADAPTIVE_OPTIMIZER
HEURISTIC_OPTIMIZER = ["sbs"]
nghi-truyen marked this conversation as resolved.
Show resolved Hide resolved

OPTIMIZER = F90_OPTIMIZER + PY_OPTIMIZER
OPTIMIZER = HEURISTIC_OPTIMIZER + GRADIENT_BASED_OPTIMIZER

# % Following MAPPING order
# % The first optimizer for each mapping is used as default optimizer
MAPPING_OPTIMIZER = dict(
zip(
MAPPING,
[
F90_OPTIMIZER,
["lbfgsb"],
["lbfgsb"],
["lbfgsb"],
PY_OPTIMIZER,
],
)
)

F90_OPTIMIZER_CONTROL_TFM = dict(
zip(
F90_OPTIMIZER,
[
["sbs", "normalize", "keep"],
["normalize", "keep"],
OPTIMIZER, # for uniform mapping (all optimizers are possible, default is sbs)
*(
[GRADIENT_BASED_OPTIMIZER] * 3
), # for distributed, multi-linear, multi-polynomial mappings (default is lbfgsb)
ADAPTIVE_OPTIMIZER, # for ann mapping (default is adam)
],
)
)
Expand All @@ -843,11 +833,11 @@ def get_neurons_from_hydrological_module(hydrological_module: str, hidden_neuron
DEFAULT_TERMINATION_CRIT = dict(
**dict(
zip(
F90_OPTIMIZER,
["sbs", "lbfgsb"],
[{"maxiter": 50}, {"maxiter": 100, "factr": 1e6, "pgtol": 1e-12}],
)
),
**dict(zip(PY_OPTIMIZER, len(PY_OPTIMIZER) * [{"epochs": 200, "early_stopping": 0}])),
**dict(zip(ADAPTIVE_OPTIMIZER, len(ADAPTIVE_OPTIMIZER) * [{"maxiter": 200, "early_stopping": 0}])),
)

CONTROL_PRIOR_DISTRIBUTION = [
Expand Down Expand Up @@ -885,6 +875,22 @@ def get_neurons_from_hydrological_module(hydrological_module: str, hidden_neuron
"control_tfm",
"termination_crit",
],
**dict(
zip(
itertools.product(["uniform", "distributed"], ADAPTIVE_OPTIMIZER),
2
* len(ADAPTIVE_OPTIMIZER)
* [
[
"parameters",
"bounds",
"control_tfm",
"learning_rate",
"termination_crit",
]
],
)
), # product between 2 mappings (uniform, distributed) and all adaptive optimizers
("multi-linear", "lbfgsb"): [
"parameters",
"bounds",
Expand All @@ -901,8 +907,25 @@ def get_neurons_from_hydrological_module(hydrological_module: str, hidden_neuron
],
**dict(
zip(
[("ann", optimizer) for optimizer in PY_OPTIMIZER],
len(PY_OPTIMIZER)
itertools.product(["multi-linear", "multi-polynomial"], ADAPTIVE_OPTIMIZER),
2
* len(ADAPTIVE_OPTIMIZER)
* [
[
"parameters",
"bounds",
"control_tfm",
"descriptor",
"learning_rate",
"termination_crit",
]
],
)
), # product between 2 mappings (multi-linear, multi-polynomial) and all adaptive optimizers
**dict(
zip(
[("ann", optimizer) for optimizer in ADAPTIVE_OPTIMIZER],
len(ADAPTIVE_OPTIMIZER)
* [
[
"parameters",
Expand All @@ -917,6 +940,15 @@ def get_neurons_from_hydrological_module(hydrological_module: str, hidden_neuron
),
}

OPTIMIZER_CONTROL_TFM = {
(mapping, optimizer): ["sbs", "normalize", "keep"] # in case of sbs optimizer
if optimizer == "sbs"
else ["normalize", "keep"] # in case of ann mapping
if mapping != "ann"
else ["keep"] # other cases
for mapping, optimizer in SIMULATION_OPTIMIZE_OPTIONS_KEYS.keys()
}

DEFAULT_SIMULATION_COST_OPTIONS = {
"forward_run": {
"jobs_cmpt": "nse",
Expand Down Expand Up @@ -962,11 +994,10 @@ def get_neurons_from_hydrological_module(hydrological_module: str, hidden_neuron
"rr_states": False,
"q_domain": False,
"internal_fluxes": False,
"iter_cost": False,
"iter_projg": False,
"control_vector": False,
"net": False,
"cost": False,
"projg": False,
"jobs": False,
"jreg": False,
"lcurve_wjreg": False,
Expand All @@ -976,10 +1007,9 @@ def get_neurons_from_hydrological_module(hydrological_module: str, hidden_neuron
"rr_states": False,
"q_domain": False,
"internal_fluxes": False,
"iter_cost": False,
"iter_projg": False,
"control_vector": False,
"cost": False,
"projg": False,
"log_lkh": False,
"log_prior": False,
"log_h": False,
Expand Down
33 changes: 17 additions & 16 deletions smash/core/signal_analysis/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,23 @@ def evaluation(

>>> model.optimize()
</> Optimize
At iterate 0 nfg = 1 J = 0.695010 ddx = 0.64
At iterate 1 nfg = 30 J = 0.098411 ddx = 0.64
At iterate 2 nfg = 59 J = 0.045409 ddx = 0.32
At iterate 3 nfg = 88 J = 0.038182 ddx = 0.16
At iterate 4 nfg = 117 J = 0.037362 ddx = 0.08
At iterate 5 nfg = 150 J = 0.037087 ddx = 0.02
At iterate 6 nfg = 183 J = 0.036800 ddx = 0.02
At iterate 7 nfg = 216 J = 0.036763 ddx = 0.01
At iterate 0 nfg = 1 J = 6.95010e-01 ddx = 0.64
At iterate 1 nfg = 30 J = 9.84107e-02 ddx = 0.64
At iterate 2 nfg = 59 J = 4.54087e-02 ddx = 0.32
At iterate 3 nfg = 88 J = 3.81818e-02 ddx = 0.16
At iterate 4 nfg = 117 J = 3.73617e-02 ddx = 0.08
At iterate 5 nfg = 150 J = 3.70873e-02 ddx = 0.02
At iterate 6 nfg = 183 J = 3.68004e-02 ddx = 0.02
At iterate 7 nfg = 216 J = 3.67635e-02 ddx = 0.01
At iterate 8 nfg = 240 J = 3.67277e-02 ddx = 0.01
inoelloc marked this conversation as resolved.
Show resolved Hide resolved
CONVERGENCE: DDX < 0.01

Compute multiple evaluation metrics for all catchments

>>> smash.evaluation(model, metric=["mae", "mse", "nse", "kge"])
array([[ 3.16965151, 44.78328323, 0.96327233, 0.94752783],
[ 1.07771611, 4.38410997, 0.90453297, 0.84582865],
[ 0.33045691, 0.50611502, 0.84956211, 0.8045246 ]])
array([[ 3.16766095, 44.77915192, 0.96327233, 0.94864655],
[ 1.07711864, 4.36171055, 0.90502125, 0.84566253],
[ 0.33053449, 0.50542408, 0.84976768, 0.8039571 ]])

nghi-truyen marked this conversation as resolved.
Show resolved Hide resolved
Add start and end evaluation dates

Expand All @@ -106,12 +107,12 @@ def evaluation(

>>> smash.evaluation(model, metric="nse")
array([[0.96327233],
[0.90453297],
[0.84956211]])
[0.90502125],
[0.84976768]])
>>> smash.evaluation(model, metric="nse", start_eval=start_eval, end_eval=end_eval)
array([[0.9404493 ],
[0.86493075],
[0.76471144]])
array([[0.94048417],
[0.8667959 ],
[0.76593578]])
"""
metric, start_eval, end_eval = _standardize_evaluation_args(metric, start_eval, end_eval, model.setup)

Expand Down
Loading