Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add seed fixing option to PySMO's sampling methods to enhance reproducibility #1307

Merged
merged 30 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3362cd3
Fix NumPy array creation error by specifying object type
OOAmusat Feb 4, 2023
2368de9
Removing print and display statements
OOAmusat Feb 6, 2023
66b37e4
Merge branch 'main' into main
ksbeattie Feb 9, 2023
d627e68
Merge branch 'main' into main
dangunter Feb 9, 2023
ab087b6
Merge branch 'IDAES:main' into main
OOAmusat Jun 29, 2023
81c3c3e
Merge branch 'IDAES:main' into main
OOAmusat Dec 6, 2023
7bb0310
Adding a function for custom sampling.
OOAmusat Dec 7, 2023
2eadf88
Improve errors and warnings
OOAmusat Dec 8, 2023
2a952a0
Tests for CustomSampling
OOAmusat Dec 8, 2023
1efb59c
running black...
OOAmusat Dec 8, 2023
741865a
Updating docs and example.
OOAmusat Dec 8, 2023
cdde009
Fix docs
OOAmusat Dec 8, 2023
879d720
Improve docsstrings.
OOAmusat Dec 8, 2023
6920ce3
Improving tests based on feedback
OOAmusat Dec 9, 2023
87588ba
Merge branch 'main' into pysmo_custom_sampling
OOAmusat Dec 9, 2023
1bd7d3d
Edit Gaussian sampling bounds to allow for strict enforcement
OOAmusat Dec 11, 2023
429dcc4
Add tests to validate for Gaussian bounds
OOAmusat Dec 11, 2023
d54f592
Update test_sampling.py
OOAmusat Dec 11, 2023
fd9b7d1
Update test_sampling.py
OOAmusat Dec 11, 2023
dc4de4c
Add missing check in init
OOAmusat Dec 11, 2023
1e00f62
Improve docs on Gaussian distribution samples.
OOAmusat Dec 11, 2023
9f36da4
Update test_sampling.py
OOAmusat Dec 11, 2023
870cad1
Merge branch 'main' into pysmo_custom_sampling
andrewlee94 Dec 14, 2023
361077c
Merge branch 'main' into pysmo_custom_sampling
OOAmusat Dec 18, 2023
9fcb8c3
Merge branch 'IDAES:main' into pysmo_custom_sampling
OOAmusat Dec 18, 2023
3af1da5
Add random seed specification option
OOAmusat Dec 18, 2023
e2fa3c6
Switch seed check to try-except
OOAmusat Dec 19, 2023
93cde23
Merge branch 'main' into pysmo_custom_sampling
ksbeattie Jan 11, 2024
59d7d9e
Merge branch 'main' into pysmo_custom_sampling
lbianchi-lbl Jan 18, 2024
2477b3e
Merge branch 'main' into pysmo_custom_sampling
OOAmusat Jan 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions idaes/core/surrogate/pysmo/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ def __init__(
sampling_type=None,
xlabels=None,
ylabels=None,
rand_seed=None,
):
"""
Initialization of **LatinHypercubeSampling** class. Two inputs are required.
Expand All @@ -496,6 +497,7 @@ def __init__(
Keyword Args:
xlabels (list): List of column names (if **data_input** is a dataframe) or column numbers (if **data_input** is an array) for the independent/input variables. Only used in "selection" mode. Default is None.
ylabels (list): List of column names (if **data_input** is a dataframe) or column numbers (if **data_input** is an array) for the dependent/output variables. Only used in "selection" mode. Default is None.
rand_seed (int): Option that allows users to fix the numpy random seed generator for reproducibility (if required).

Returns:
**self** function containing the input information
Expand Down Expand Up @@ -594,6 +596,12 @@ def __init__(
self.number_of_samples = number_of_samples
self.x_data = bounds_array # Only x data will be present in this case

if rand_seed is not None:
if not isinstance(rand_seed, int):
OOAmusat marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError("Random seed must be an integer.")
self.seed_value = rand_seed
np.random.seed(self.seed_value)

def variable_sample_creation(self, variable_min, variable_max):
"""

Expand Down Expand Up @@ -1269,6 +1277,7 @@ def __init__(
sampling_type=None,
xlabels=None,
ylabels=None,
rand_seed=None,
):
"""
Initialization of CVTSampling class. Two inputs are required, while an optional option to control the solution accuracy may be specified.
Expand All @@ -1285,6 +1294,7 @@ def __init__(
Keyword Args:
xlabels (list): List of column names (if **data_input** is a dataframe) or column numbers (if **data_input** is an array) for the independent/input variables. Only used in "selection" mode. Default is None.
ylabels (list): List of column names (if **data_input** is a dataframe) or column numbers (if **data_input** is an array) for the dependent/output variables. Only used in "selection" mode. Default is None.
rand_seed (int): Option that allows users to fix the numpy random seed generator for reproducibility (if required).
tolerance(float): Maximum allowable Euclidean distance between centres from consecutive iterations of the algorithm. Termination condition for algorithm.

- The smaller the value of tolerance, the better the solution but the longer the algorithm requires to converge. Default value is :math:`10^{-7}`.
Expand Down Expand Up @@ -1412,6 +1422,12 @@ def __init__(
raise Exception("Invalid tolerance input")
self.eps = tolerance

if rand_seed is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing as this is in __init__, should this code be moved to the base class to avoid duplicating the code in each subclass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I'd rather leave it here because it is an argument specific to a subset of the sampling methods.

if not isinstance(rand_seed, int):
raise TypeError("Random seed must be an integer.")
self.seed_value = rand_seed
np.random.seed(self.seed_value)

@staticmethod
def random_sample_selection(no_samples, no_features):
"""
Expand Down Expand Up @@ -1591,6 +1607,7 @@ def __init__(
xlabels=None,
ylabels=None,
strictly_enforce_gaussian_bounds=False,
rand_seed=None,
):
"""
Initialization of CustomSampling class. Four inputs are required.
Expand All @@ -1608,6 +1625,7 @@ def __init__(
Keyword Args:
xlabels (list): List of column names (if **data_input** is a dataframe) or column numbers (if **data_input** is an array) for the independent/input variables. Only used in "selection" mode. Default is None.
ylabels (list): List of column names (if **data_input** is a dataframe) or column numbers (if **data_input** is an array) for the dependent/output variables. Only used in "selection" mode. Default is None.
rand_seed (int): Option that allows users to fix the numpy random seed generator for reproducibility (if required).
strictly_enforce_gaussian_bounds (bool): Boolean specifying whether the provided bounds for normal distributions should be strictly enforced. Note that selecting this option may affect the underlying distribution. Default is False.

Returns:
Expand Down Expand Up @@ -1732,13 +1750,18 @@ def __init__(
)
self.normal_bounds_enforced = strictly_enforce_gaussian_bounds

if rand_seed is not None:
if not isinstance(rand_seed, int):
raise TypeError("Random seed must be an integer.")
self.seed_value = rand_seed

def generate_from_dist(self, dist_name):
if dist_name.lower() in ["uniform", "random"]:
dist = getattr(np.random.default_rng(), dist_name.lower())
dist = getattr(np.random.default_rng(self.seed_value), dist_name.lower())
var_values = np.array(dist(size=self.number_of_samples))
return dist, var_values
elif dist_name.lower() == "normal":
dist = getattr(np.random.default_rng(), "normal")
dist = getattr(np.random.default_rng(self.seed_value), "normal")
var_values = dist(loc=0.5, scale=1 / 6, size=self.number_of_samples)
if not self.normal_bounds_enforced:
return dist, np.array(var_values)
Expand Down
Loading
Loading