Skip to content

Commit

Permalink
FIX PR: sort keys in standardize_optmize_options_bounds instead of st…
Browse files Browse the repository at this point in the history
…andardize_net
  • Loading branch information
nghi-truyen committed Apr 20, 2024
1 parent a377a50 commit 87fa4c3
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions smash/core/simulation/_standardize.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ def _standardize_simulation_optimize_options_bounds(
f"included in the feasible domain ]{low}, {upp}[ in bounds optimize_options"
)

bounds = {key: bounds[key] for key in parameters}

return bounds


Expand Down Expand Up @@ -311,14 +313,13 @@ def _standardize_simulation_optimize_options_descriptor(


def _standardize_simulation_optimize_options_net(
model: Model, parameters: np.ndarray, bounds: dict, net: Net | None, **kwargs
model: Model, bounds: dict, net: Net | None, **kwargs
) -> Net:
bounds = {key: bounds[key] for key in parameters} # reorder bounds by parameters
bound_values = list(bounds.values())

ncv = len(parameters)
nd = model.setup.nd

bound_values = list(bounds.values())
ncv = len(bound_values)

active_mask = np.where(model.mesh.active_cell == 1)
ntrain = active_mask[0].shape[0]

Expand Down Expand Up @@ -388,7 +389,7 @@ def _standardize_simulation_optimize_options_net(

diff = np.not_equal(net_bounds, bound_values)

for i, name in enumerate(parameters):
for i, name in enumerate(bounds.keys()):
if diff[i].any():
warnings.warn(
f"net optimize_options: Inconsistent value(s) between the bound in scaling layer and "
Expand Down

0 comments on commit 87fa4c3

Please sign in to comment.