diff --git a/src/evox/monitors/evoxvis_monitor.py b/src/evox/monitors/evoxvis_monitor.py index 59b73319..c31bab18 100644 --- a/src/evox/monitors/evoxvis_monitor.py +++ b/src/evox/monitors/evoxvis_monitor.py @@ -1,7 +1,10 @@ -import pyarrow as pa +import tempfile import time +from pathlib import Path from typing import Optional -import warnings + +import jax.experimental.host_callback as hcb +import pyarrow as pa class EvoXVisMonitor: @@ -10,8 +13,15 @@ class EvoXVisMonitor: Parameters ---------- - out_file - The path of the output file + base_filename + The base filename of the log file, + the final filename will be ``_.arrow``, + where i is an incrementing number. + out_dir + This directory to write the log file into. + When set to None, the default directory will be used. + The default is ``/evox``, + on Windows, it's usually ``C:\TEMP\evox``, and on MacOS/Linux/BSDs it's ``/tmp/evox``. out_type "stream" or "file", For more information, please refer to https://arrow.apache.org/docs/python/ipc.html @@ -28,13 +38,30 @@ class EvoXVisMonitor: """ def __init__( - self, out_file: str, out_type: str = "file", batch_size: int = 64, compression: Optional[str] = None + self, + base_filename: str, + out_dir: str = None, + out_type: str = "file", + batch_size: int = 64, + compression: Optional[str] = None, ): self.get_time = time.perf_counter_ns self.batch_size = batch_size self.generation_counter = 0 self.batch_record = [] - self.sink = pa.OSFile(out_file, "wb") + if out_dir is None: + base_path = Path(tempfile.gettempdir()).joinpath("evox") + else: + base_path = Path(out_dir) + # if the log dir is not created, create it first + if not base_path.exists(): + base_path.mkdir(parents=True, exist_ok=True) + # find the next available number + i = 0 + while base_path.joinpath(f"{base_filename}_{i}.arrow").exists(): + i += 1 + path_str = str(base_path.joinpath(f"{base_filename}_{i}.arrow")) + self.sink = pa.OSFile(path_str, "wb") self.out_type = out_type self.compression = compression @@ -43,6 +70,9 @@ def __init__( self.population_size = None self.population_dtype = None self.fitness_dtype = None + + # the ec_schema is left empty until the first write + # then we can infer the schema self.ec_schema = None self.writer = None @@ -59,6 +89,9 @@ def _write_batch(self): if len(self.generation) == 0: return + # first time writing the data + # infer the data schema + # and create the writer if self.ec_schema is None: population_byte_len = len(self.population[0]) fitness_byte_len = len(self.fitness[0]) @@ -73,8 +106,8 @@ def _write_batch(self): metadata={ "population_size": str(self.population_size), "population_dtype": self.population_dtype, - "fitness_dtype": self.fitness_dtype - } + "fitness_dtype": self.fitness_dtype, + }, ) if self.out_type == "file": @@ -84,6 +117,7 @@ def _write_batch(self): options=pa.ipc.IpcWriteOptions(compression=self.compression), ) + # actually write the data to disk self.writer.write_batch( pa.record_batch( [ @@ -97,20 +131,17 @@ def _write_batch(self): ) self._reset_batch() - def record_pop(self, population): + def record_pop(self, population, transform=None): self.population.append(population.tobytes()) self.population_size = population.shape[0] self.population_dtype = str(population.dtype) - # self.population_size.append(population.shape[0]) - # self.population_dtype.append(str(population.dtype)) return population - def record_fit(self, fitness): + def record_fit(self, fitness, transform=None): self.generation.append(self.generation_counter) self.timestamp.append(self.get_time()) self.fitness.append(fitness.tobytes()) self.fitness_dtype = str(fitness.dtype) - # self.fitness_dtype.append(str(fitness.dtype)) self.generation_counter += 1 if len(self.fitness) >= self.batch_size: @@ -118,7 +149,11 @@ def record_fit(self, fitness): return fitness + def flush(self): + hcb.barrier_wait() + def close(self): + self.flush() self._write_batch() self.writer.close() self.sink.close()