Skip to content

Commit

Permalink
removing hardcoded gen_specs.out, removing hardcoded persis_info.nwor…
Browse files Browse the repository at this point in the history
…kers, use gen_specs.get("out") so if it isnt provided, the dtype discovery process commences
  • Loading branch information
jlnav committed Sep 12, 2024
1 parent 7fdd8a6 commit 69b0584
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 7 deletions.
3 changes: 1 addition & 2 deletions libensemble/gen_classes/aposmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def __init__(
]
gen_specs["persis_in"] = ["x", "f", "local_pt", "sim_id", "sim_ended", "x_on_cube", "local_min"]
if not persis_info:
persis_info = add_unique_random_streams({}, 4, seed=4321)[1]
persis_info["nworkers"] = 4
persis_info = add_unique_random_streams({}, 2, seed=4321)[1]
super().__init__(History, persis_info, gen_specs, libE_info, **kwargs)
self.all_local_minima = []
self.results_idx = 0
Expand Down
2 changes: 0 additions & 2 deletions libensemble/gen_classes/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class UniformSample(SampleBase):
def __init__(self, _=[], persis_info={}, gen_specs={}, libE_info=None, **kwargs):
super().__init__(_, persis_info, gen_specs, libE_info, **kwargs)
self._get_user_params(self.gen_specs["user"])
self.gen_specs["out"] = [("x", float, (self.n,))]

def ask_numpy(self, n_trials):
H_o = np.zeros(n_trials, dtype=self.gen_specs["out"])
Expand All @@ -60,7 +59,6 @@ def __init__(self, _, persis_info, gen_specs, libE_info=None, **kwargs):
self.gen_specs = gen_specs
self.persis_info = persis_info
self._get_user_params(self.gen_specs["user"])
self.gen_specs["out"] = [("x", float, (self.n,))]

def ask(self, n_trials):
H_o = []
Expand Down
1 change: 0 additions & 1 deletion libensemble/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def __init__(
self.gen_specs["user"] = kwargs
if not persis_info:
self.persis_info = add_unique_random_streams({}, 4, seed=4321)[1]
self.persis_info["nworkers"] = 4
else:
self.persis_info = persis_info

Expand Down
4 changes: 2 additions & 2 deletions libensemble/utils/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(self, specs):

def _get_points_updates(self, batch_size: int) -> (npt.NDArray, npt.NDArray):
# no ask_updates on external gens
return (list_dicts_to_np(self.gen.ask(batch_size), dtype=self.specs["out"]), None)
return (list_dicts_to_np(self.gen.ask(batch_size), dtype=self.specs.get("out")), None)

def _convert_tell(self, x: npt.NDArray) -> list:
self.gen.tell(np_to_list_dicts(x))
Expand Down Expand Up @@ -142,7 +142,7 @@ def _persistent_result(self, calc_in, persis_info, libE_info):
if self.gen.thread is None:
self.gen.setup() # maybe we're reusing a live gen from a previous run
# libE gens will hit the following line, but list_dicts_to_np will passthrough if the output is a numpy array
H_out = list_dicts_to_np(self._get_initial_ask(libE_info), dtype=self.specs["out"])
H_out = list_dicts_to_np(self._get_initial_ask(libE_info), dtype=self.specs.get("out"))
tag, Work, H_in = self.ps.send_recv(H_out) # evaluate the initial sample
final_H_in = self._start_generator_loop(tag, Work, H_in)
return self.gen.final_tell(final_H_in), FINISHED_PERSISTENT_GEN_TAG
Expand Down

0 comments on commit 69b0584

Please sign in to comment.