Skip to content

Commit

Permalink
Apply suggestion change from FC review
Browse files Browse the repository at this point in the history
  • Loading branch information
nghi-truyen committed Sep 14, 2024
1 parent 7929854 commit 7c816ba
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions smash/core/simulation/_standardize.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _standardize_simulation_optimize_options_descriptor(
def _standardize_simulation_optimize_options_net(
model: Model, bounds: dict, net: Net | None, **kwargs
) -> Net:
x, y, nd = model.physio_data.descriptor.shape
nrow, ncol, nd = model.physio_data.descriptor.shape

bound_values = list(bounds.values())
ncv = len(bound_values)
Expand All @@ -340,7 +340,7 @@ def _standardize_simulation_optimize_options_net(
net = Net()

net.add_dense(nd * 3, input_shape=nd, activation="relu")
net.add_dense(round(np.sqrt(nd * ncv) * np.log(x * y)), activation="relu")
net.add_dense(round(np.sqrt(nd * ncv) * np.log(nrow * ncol)), activation="relu")
net.add_dense(ncv * 3, activation="relu")
net.add_dense(ncv, activation="tanh")
net.add_scale(bound_values)
Expand All @@ -350,7 +350,7 @@ def _standardize_simulation_optimize_options_net(
# % Check input shape
net_in = net.layers[0].input_shape

x_in = (x, y, nd) if len(net_in) == 3 else (nd,) # in case of cnn and mlp resp.
x_in = (nrow, ncol, nd) if len(net_in) == 3 else (nd,) # in case of cnn and mlp resp.

if net_in != x_in:
raise ValueError(
Expand Down

0 comments on commit 7c816ba

Please sign in to comment.