Skip to content

Commit

Permalink
allow saving of sacess histories
Browse files Browse the repository at this point in the history
  • Loading branch information
Dilan Pathirana committed Mar 26, 2024
1 parent 98a3da0 commit 927a28b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
29 changes: 27 additions & 2 deletions pypesto/select/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Miscellaneous methods."""

import logging
from pathlib import Path
from typing import Iterable

import pandas as pd
Expand All @@ -10,6 +11,7 @@
from petab_select import Model, parameter_string_to_value
from petab_select.constants import PETAB_PROBLEM

from ..history import Hdf5History
from ..objective import Objective
from ..optimize.ess import (
SacessOptimizer,
Expand Down Expand Up @@ -181,21 +183,29 @@ class SacessMinimizeMethod:
Class attributes correspond to pyPESTO's SaCeSS optimizer, and are
documented there. Extra keyword arguments supplied to the constructor
will be passed on to the constructor of the SaCeSS optimizer, for example,
`max_walltime_s` can be specified in this way.
`max_walltime_s` can be specified in this way. If specified, `tmpdir` will
be treated as a parent directory.
"""

def __init__(
self,
num_workers: int,
local_optimizer,
tmpdir=None,
save_history: bool = False,
**optimizer_kwargs,
):
"""Construct a minimize-like object."""
self.num_workers = num_workers
self.local_optimizer = local_optimizer
self.optimizer_kwargs = optimizer_kwargs
self.tmpdir = Path(tmpdir)
self.save_history = save_history

def __call__(self, problem: Problem, **minimize_options):
if self.save_history and self.tmpdir is None:
self.tmpdir = Path.cwd() / "sacess_tmpdir"

def __call__(self, problem: Problem, model_hash: str, **minimize_options):
"""Create then run a problem-specific sacess optimizer."""
# create optimizer
ess_init_args = get_default_ess_options(
Expand All @@ -204,8 +214,13 @@ def __call__(self, problem: Problem, **minimize_options):
)
for x in ess_init_args:
x["local_optimizer"] = self.local_optimizer
if self.tmpdir is not None:
model_tmpdir = self.tmpdir / model_hash
model_tmpdir.mkdir(exist_ok=False, parents=True)

ess = SacessOptimizer(
ess_init_args=ess_init_args,
tmpdir=model_tmpdir,
**self.optimizer_kwargs,
)

Expand All @@ -214,4 +229,14 @@ def __call__(self, problem: Problem, **minimize_options):
problem=problem,
**minimize_options,
)

if self.save_history:
history_dir = model_tmpdir / "history"
history_dir.mkdir(exist_ok=False, parents=True)
for history_index, history in enumerate(ess.histories):
Hdf5History.from_history(
other=history,
file=history_dir / (str(history_index) + ".hdf5"),
id_=history_index,
)
return result
8 changes: 7 additions & 1 deletion pypesto/select/model_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..optimize import minimize
from ..problem import Problem
from ..result import OptimizerResult, Result
from .misc import model_to_pypesto_problem
from .misc import SacessMinimizeMethod, model_to_pypesto_problem

OBJECTIVE_CUSTOMIZER_TYPE = Callable[[ObjectiveBase], None]
TYPE_POSTPROCESSOR = Callable[["ModelProblem"], None] # noqa: F821
Expand Down Expand Up @@ -146,6 +146,12 @@ def minimize(self) -> Result:
-------
The optimization result.
"""
if isinstance(self.minimize_method, SacessMinimizeMethod):
return self.minimize_method(
self.pypesto_problem,
model_hash=self.model.get_hash(),
**self.minimize_options,
)
return self.minimize_method(
self.pypesto_problem,
**self.minimize_options,
Expand Down

0 comments on commit 927a28b

Please sign in to comment.