Skip to content

Commit

Permalink
subsequent batch_sizes are either back_size or len(H_in)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlnav committed Aug 9, 2024
1 parent 92e22e4 commit d14f4d2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
4 changes: 2 additions & 2 deletions libensemble/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ class GenSpecs(BaseModel):

batch_size: Optional[int] = 0
"""
Number of points to generate in each batch. If zero, falls back to ``initial_batch_size``.
If both options are zero, defaults to the number of workers.
Number of points to generate in each batch. If zero, falls back to the number of
completed evaluations most recently told to the generator.
Note: Certain generators included with libEnsemble decide
batch sizes via ``gen_specs["user"]`` or other methods.
Expand Down
15 changes: 7 additions & 8 deletions libensemble/utils/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,19 @@ def _to_array(self, x: list) -> npt.NDArray:
return x

def _get_points_updates(self, batch_size: int) -> (npt.NDArray, npt.NDArray):
# no ask_updates on external gens
return (
self._to_array(self.gen.ask(batch_size)),
None,
) # external ask/tell gens likely don't implement ask_updates
)

def _convert_tell(self, x: npt.NDArray) -> list:
self.gen.tell(np_to_list_dicts(x))

def _loop_over_gen(self, tag, Work):
def _loop_over_gen(self, tag, Work, H_in):
"""Interact with ask/tell generator that *does not* contain a background thread"""
while tag not in [PERSIS_STOP, STOP_TAG]:
batch_size = (
self.specs.get("batch_size") or self.specs.get("initial_batch_size") or Work["libE_info"]["batch_size"]
) # or len(Work["H_in"])?
batch_size = self.specs.get("batch_size") or len(H_in)
points, updates = self._get_points_updates(batch_size)
if updates is not None and len(updates): # returned "samples" and "updates". can combine if same dtype
H_out = np.append(points, updates)
Expand All @@ -150,7 +149,7 @@ def _get_initial_ask(self, libE_info) -> npt.NDArray:
def _start_generator_loop(self, tag, Work, H_in):
"""Start the generator loop after choosing best way of giving initial results to gen"""
self.gen.tell(np_to_list_dicts(H_in))
return self._loop_over_gen(tag, Work)
return self._loop_over_gen(tag, Work, H_in)

def _persistent_result(self, calc_in, persis_info, libE_info):
"""Setup comms with manager, setup gen, loop gen to completion, return gen's results"""
Expand Down Expand Up @@ -186,7 +185,7 @@ def _convert_tell(self, x: npt.NDArray) -> list:
def _start_generator_loop(self, tag, Work, H_in) -> npt.NDArray:
"""Start the generator loop after choosing best way of giving initial results to gen"""
self.gen.tell_numpy(H_in)
return self._loop_over_gen(tag, Work) # see parent class
return self._loop_over_gen(tag, Work, H_in) # see parent class


class LibensembleGenThreadRunner(AskTellGenRunner):
Expand All @@ -205,7 +204,7 @@ def _ask_and_send(self):
else:
self.ps.send(points)

def _loop_over_gen(self, _, _2):
def _loop_over_gen(self, *args):
"""Cycle between moving all outbound / inbound messages between threaded gen and manager"""
while True:
time.sleep(0.0025) # dont need to ping the gen relentlessly. Let it calculate. 400hz
Expand Down

0 comments on commit d14f4d2

Please sign in to comment.