Skip to content

Commit

Permalink
Switch seed check to try-except
Browse files Browse the repository at this point in the history
  • Loading branch information
OOAmusat committed Dec 19, 2023
1 parent 3af1da5 commit e2fa3c6
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 17 deletions.
27 changes: 16 additions & 11 deletions idaes/core/surrogate/pysmo/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,10 +597,11 @@ def __init__(
self.x_data = bounds_array # Only x data will be present in this case

if rand_seed is not None:
if not isinstance(rand_seed, int):
raise TypeError("Random seed must be an integer.")
self.seed_value = rand_seed
np.random.seed(self.seed_value)
try:
self.seed_value = int(rand_seed)
np.random.seed(self.seed_value)
except ValueError:
raise ValueError("Random seed must be an integer.")

def variable_sample_creation(self, variable_min, variable_max):
"""
Expand Down Expand Up @@ -1423,10 +1424,11 @@ def __init__(
self.eps = tolerance

if rand_seed is not None:
if not isinstance(rand_seed, int):
raise TypeError("Random seed must be an integer.")
self.seed_value = rand_seed
np.random.seed(self.seed_value)
try:
self.seed_value = int(rand_seed)
np.random.seed(self.seed_value)
except ValueError:
raise ValueError("Random seed must be an integer.")

@staticmethod
def random_sample_selection(no_samples, no_features):
Expand Down Expand Up @@ -1751,9 +1753,12 @@ def __init__(
self.normal_bounds_enforced = strictly_enforce_gaussian_bounds

if rand_seed is not None:
if not isinstance(rand_seed, int):
raise TypeError("Random seed must be an integer.")
self.seed_value = rand_seed
try:
self.seed_value = int(rand_seed)
except ValueError:
raise ValueError("Random seed must be an integer.")
else:
self.seed_value = rand_seed

def generate_from_dist(self, dist_name):
if dist_name.lower() in ["uniform", "random"]:
Expand Down
70 changes: 64 additions & 6 deletions idaes/core/surrogate/pysmo/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,24 @@ def test__init__selection_right_behaviour_with_specified_random_seed(
np.testing.assert_array_equal(LHSClass.x_data, np.array(input_array)[:, :-1])
assert LHSClass.seed_value == rand_seed

@pytest.mark.unit
@pytest.mark.parametrize("array_type", [np.array, pd.DataFrame])
def test__init__selection_right_behaviour_with_specified_float_random_seed(
self, array_type
):
input_array = array_type(self.input_array)
rand_seed = 15.1
LHSClass = LatinHypercubeSampling(
input_array,
number_of_samples=6,
sampling_type="selection",
rand_seed=rand_seed,
)
np.testing.assert_array_equal(LHSClass.data, input_array)
np.testing.assert_array_equal(LHSClass.number_of_samples, 6)
np.testing.assert_array_equal(LHSClass.x_data, np.array(input_array)[:, :-1])
assert LHSClass.seed_value == int(rand_seed)

@pytest.mark.unit
@pytest.mark.parametrize("array_type", [np.array, pd.DataFrame])
def test__init__selection_zero_samples(self, array_type):
Expand Down Expand Up @@ -569,12 +587,12 @@ def test__init__selection_wrong_input_data_type(self, array_type):
@pytest.mark.parametrize("array_type", [np.array, pd.DataFrame])
def test__init__selection_non_integer_random_seed(self, array_type):
input_array = array_type(self.input_array)
with pytest.raises(TypeError, match="Random seed must be an integer."):
with pytest.raises(ValueError, match="Random seed must be an integer."):
LHSClass = LatinHypercubeSampling(
input_array,
number_of_samples=5,
sampling_type="selection",
rand_seed=1.2,
rand_seed="1.2",
)

@pytest.mark.unit
Expand Down Expand Up @@ -2105,6 +2123,26 @@ def test__init__selection_right_behaviour_with_specified_random_seed(
np.testing.assert_array_equal(CVTClass.eps, 1e-7)
assert CVTClass.seed_value == rand_seed

@pytest.mark.unit
@pytest.mark.parametrize("array_type", [np.array, pd.DataFrame])
def test__init__selection_right_behaviour_with_specified_float_random_seed(
self, array_type
):
input_array = array_type(self.input_array)
rand_seed = 2.2
CVTClass = CVTSampling(
input_array,
number_of_samples=6,
tolerance=None,
sampling_type="selection",
rand_seed=rand_seed,
)
np.testing.assert_array_equal(CVTClass.data, input_array)
np.testing.assert_array_equal(CVTClass.number_of_centres, 6)
np.testing.assert_array_equal(CVTClass.x_data, np.array(input_array)[:, :-1])
np.testing.assert_array_equal(CVTClass.eps, 1e-7)
assert CVTClass.seed_value == int(rand_seed)

@pytest.mark.unit
@pytest.mark.parametrize("array_type", [np.array, pd.DataFrame])
def test__init__selection_zero_samples(self, array_type):
Expand Down Expand Up @@ -2208,12 +2246,12 @@ def test__init__selection_tolerance_too_tight(self, array_type):
@pytest.mark.parametrize("array_type", [np.array, pd.DataFrame])
def test__init__selection_non_integer_random_seed(self, array_type):
input_array = array_type(self.input_array)
with pytest.raises(TypeError, match="Random seed must be an integer."):
with pytest.raises(ValueError, match="Random seed must be an integer."):
CVTClass = CVTSampling(
input_array,
number_of_samples=5,
sampling_type="selection",
rand_seed=1.2,
rand_seed="1.2",
tolerance=None,
)

Expand Down Expand Up @@ -2799,6 +2837,26 @@ def test__init__selection_right_behaviour_with_specified_random_seed(
assert CSClass.dist_vector == ["uniform", "normal"]
assert CSClass.seed_value == rand_seed

@pytest.mark.unit
@pytest.mark.parametrize("array_type", [np.array, pd.DataFrame])
def test__init__selection_right_behaviour_with_specified_float_random_seed(
self, array_type
):
input_array = array_type(self.input_array)
rand_seed = 1.2
CSClass = CustomSampling(
input_array,
number_of_samples=6,
sampling_type="selection",
list_of_distributions=["uniform", "normal"],
rand_seed=rand_seed,
)
np.testing.assert_array_equal(CSClass.data, input_array)
np.testing.assert_array_equal(CSClass.number_of_samples, 6)
np.testing.assert_array_equal(CSClass.x_data, np.array(input_array)[:, :-1])
assert CSClass.dist_vector == ["uniform", "normal"]
assert CSClass.seed_value == int(rand_seed)

@pytest.mark.unit
@pytest.mark.parametrize("array_type", [np.array, pd.DataFrame])
def test__init__selection_zero_samples(self, array_type):
Expand Down Expand Up @@ -2960,13 +3018,13 @@ def test__init__selection_distribution_not_available(self, array_type):
@pytest.mark.parametrize("array_type", [np.array, pd.DataFrame])
def test__init__selection_non_integer_random_seed(self, array_type):
input_array = array_type(self.input_array)
with pytest.raises(TypeError, match="Random seed must be an integer."):
with pytest.raises(ValueError, match="Random seed must be an integer."):
CSClass = CustomSampling(
input_array,
number_of_samples=5,
sampling_type="selection",
list_of_distributions=["uniform", "normal"],
rand_seed=1.2,
rand_seed="1.2",
)

@pytest.mark.unit
Expand Down

0 comments on commit e2fa3c6

Please sign in to comment.