From 69b0584cc9282ca28cbb147a4a8f3e9912f6029f Mon Sep 17 00:00:00 2001 From: jlnav Date: Thu, 12 Sep 2024 12:33:04 -0500 Subject: [PATCH] removing hardcoded gen_specs.out, removing hardcoded persis_info.nworkers, use gen_specs.get("out") so if it isnt provided, the dtype discovery process commences --- libensemble/gen_classes/aposmm.py | 3 +-- libensemble/gen_classes/sampling.py | 2 -- libensemble/generators.py | 1 - libensemble/utils/runners.py | 4 ++-- 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/libensemble/gen_classes/aposmm.py b/libensemble/gen_classes/aposmm.py index d49832730..108282e07 100644 --- a/libensemble/gen_classes/aposmm.py +++ b/libensemble/gen_classes/aposmm.py @@ -30,8 +30,7 @@ def __init__( ] gen_specs["persis_in"] = ["x", "f", "local_pt", "sim_id", "sim_ended", "x_on_cube", "local_min"] if not persis_info: - persis_info = add_unique_random_streams({}, 4, seed=4321)[1] - persis_info["nworkers"] = 4 + persis_info = add_unique_random_streams({}, 2, seed=4321)[1] super().__init__(History, persis_info, gen_specs, libE_info, **kwargs) self.all_local_minima = [] self.results_idx = 0 diff --git a/libensemble/gen_classes/sampling.py b/libensemble/gen_classes/sampling.py index dd347db51..166286482 100644 --- a/libensemble/gen_classes/sampling.py +++ b/libensemble/gen_classes/sampling.py @@ -34,7 +34,6 @@ class UniformSample(SampleBase): def __init__(self, _=[], persis_info={}, gen_specs={}, libE_info=None, **kwargs): super().__init__(_, persis_info, gen_specs, libE_info, **kwargs) self._get_user_params(self.gen_specs["user"]) - self.gen_specs["out"] = [("x", float, (self.n,))] def ask_numpy(self, n_trials): H_o = np.zeros(n_trials, dtype=self.gen_specs["out"]) @@ -60,7 +59,6 @@ def __init__(self, _, persis_info, gen_specs, libE_info=None, **kwargs): self.gen_specs = gen_specs self.persis_info = persis_info self._get_user_params(self.gen_specs["user"]) - self.gen_specs["out"] = [("x", float, (self.n,))] def ask(self, n_trials): H_o = [] diff --git a/libensemble/generators.py b/libensemble/generators.py index 70eac32e1..37b974139 100644 --- a/libensemble/generators.py +++ b/libensemble/generators.py @@ -99,7 +99,6 @@ def __init__( self.gen_specs["user"] = kwargs if not persis_info: self.persis_info = add_unique_random_streams({}, 4, seed=4321)[1] - self.persis_info["nworkers"] = 4 else: self.persis_info = persis_info diff --git a/libensemble/utils/runners.py b/libensemble/utils/runners.py index 1d94fa097..08d52a27e 100644 --- a/libensemble/utils/runners.py +++ b/libensemble/utils/runners.py @@ -108,7 +108,7 @@ def __init__(self, specs): def _get_points_updates(self, batch_size: int) -> (npt.NDArray, npt.NDArray): # no ask_updates on external gens - return (list_dicts_to_np(self.gen.ask(batch_size), dtype=self.specs["out"]), None) + return (list_dicts_to_np(self.gen.ask(batch_size), dtype=self.specs.get("out")), None) def _convert_tell(self, x: npt.NDArray) -> list: self.gen.tell(np_to_list_dicts(x)) @@ -142,7 +142,7 @@ def _persistent_result(self, calc_in, persis_info, libE_info): if self.gen.thread is None: self.gen.setup() # maybe we're reusing a live gen from a previous run # libE gens will hit the following line, but list_dicts_to_np will passthrough if the output is a numpy array - H_out = list_dicts_to_np(self._get_initial_ask(libE_info), dtype=self.specs["out"]) + H_out = list_dicts_to_np(self._get_initial_ask(libE_info), dtype=self.specs.get("out")) tag, Work, H_in = self.ps.send_recv(H_out) # evaluate the initial sample final_H_in = self._start_generator_loop(tag, Work, H_in) return self.gen.final_tell(final_H_in), FINISHED_PERSISTENT_GEN_TAG