diff --git a/smash/core/simulation/_standardize.py b/smash/core/simulation/_standardize.py index e71e03ee..441ef1a5 100644 --- a/smash/core/simulation/_standardize.py +++ b/smash/core/simulation/_standardize.py @@ -330,7 +330,7 @@ def _standardize_simulation_optimize_options_descriptor( def _standardize_simulation_optimize_options_net( model: Model, bounds: dict, net: Net | None, **kwargs ) -> Net: - x, y, nd = model.physio_data.descriptor.shape + nrow, ncol, nd = model.physio_data.descriptor.shape bound_values = list(bounds.values()) ncv = len(bound_values) @@ -340,7 +340,7 @@ def _standardize_simulation_optimize_options_net( net = Net() net.add_dense(nd * 3, input_shape=nd, activation="relu") - net.add_dense(round(np.sqrt(nd * ncv) * np.log(x * y)), activation="relu") + net.add_dense(round(np.sqrt(nd * ncv) * np.log(nrow * ncol)), activation="relu") net.add_dense(ncv * 3, activation="relu") net.add_dense(ncv, activation="tanh") net.add_scale(bound_values) @@ -350,7 +350,7 @@ def _standardize_simulation_optimize_options_net( # % Check input shape net_in = net.layers[0].input_shape - x_in = (x, y, nd) if len(net_in) == 3 else (nd,) # in case of cnn and mlp resp. + x_in = (nrow, ncol, nd) if len(net_in) == 3 else (nd,) # in case of cnn and mlp resp. if net_in != x_in: raise ValueError(