Skip to content

Commit

Permalink
Fix Metric issue
Browse files Browse the repository at this point in the history
  • Loading branch information
RemiLehe committed Aug 1, 2024
1 parent 844a554 commit c125d4d
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 75 deletions.
136 changes: 64 additions & 72 deletions optimas/generators/ax/developer/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
from ax.core.observation import ObservationFeatures
from ax.core.generator_run import GeneratorRun
from ax.storage.json_store.save import save_experiment
from ax.storage.metric_registry import register_metrics

try:
from ax.storage.metric_registry import register_metric as register_metrics
except ImportError:
# For Ax >= 0.4.1
from ax.storage.metric_registry import register_metrics
from ax.modelbridge.registry import Models, MT_MTGP_trans
from ax.core.experiment import Experiment
from ax.core.data import Data
from ax.modelbridge.transforms.convert_metric_names import (
tconfig_from_mt_experiment,
)
from ax.utils.common.typeutils import checked_cast

from optimas.generators.ax.base import AxGenerator
from optimas.core import (
Expand All @@ -44,75 +47,64 @@
LOFI_RETURNED = "lofi_returned"
HIFI_RETURNED = "hifi_returned"

try:
from ax.modelbridge.factory import get_MTGP_LEGACY as get_MTGP
except ImportError:
# For Ax >= 0.4.1: get_MTGP_LEGACY is deprecated, due to this PR:
# https://github.com/facebook/Ax/pull/2508
from ax.modelbridge.registry import Models, MT_MTGP_trans
from ax.core.experiment import Experiment
from ax.core.data import Data
from ax.modelbridge.transforms.convert_metric_names import (
tconfig_from_mt_experiment,
)
from ax.utils.common.typeutils import checked_cast

# This function is from https://ax.dev/tutorials/multi_task.html#4b.-Multi-task-Bayesian-optimization
def get_MTGP(
experiment: Experiment,
data: Data,
search_space: Optional[SearchSpace] = None,
trial_index: Optional[int] = None,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.double,
) -> TorchModelBridge:
"""Instantiates a Multi-task Gaussian Process (MTGP) model that generates
points with EI.
"""
trial_index_to_type = {
t.index: t.trial_type for t in experiment.trials.values()
}
transforms = MT_MTGP_trans
transform_configs = {
"TrialAsTask": {
"trial_level_map": {"trial_type": trial_index_to_type}
},
"ConvertMetricNames": tconfig_from_mt_experiment(experiment),
}

# Choose the status quo features for the experiment from the selected trial.
# If trial_index is None, we will look for a status quo from the last
# experiment trial to use as a status quo for the experiment.
if trial_index is None:
trial_index = len(experiment.trials) - 1
elif trial_index >= len(experiment.trials):
raise ValueError(
"trial_index is bigger than the number of experiment trials"
)

status_quo = experiment.trials[trial_index].status_quo
if status_quo is None:
status_quo_features = None
else:
status_quo_features = ObservationFeatures(
parameters=status_quo.parameters,
trial_index=trial_index, # pyre-ignore[6]
)
# get_MTGP is not part of the Ax codebase, as of Ax 0.4.1, due to this PR:
# https://github.com/facebook/Ax/pull/2508
# Here we use `get_MTGP` from https://ax.dev/tutorials/multi_task.html#4b.-Multi-task-Bayesian-optimization
def get_MTGP(
experiment: Experiment,
data: Data,
search_space: Optional[SearchSpace] = None,
trial_index: Optional[int] = None,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.double,
) -> TorchModelBridge:
"""Instantiates a Multi-task Gaussian Process (MTGP) model that generates
points with EI.
"""
trial_index_to_type = {
t.index: t.trial_type for t in experiment.trials.values()
}
transforms = MT_MTGP_trans
transform_configs = {
"TrialAsTask": {
"trial_level_map": {"trial_type": trial_index_to_type}
},
"ConvertMetricNames": tconfig_from_mt_experiment(experiment),
}

# Choose the status quo features for the experiment from the selected trial.
# If trial_index is None, we will look for a status quo from the last
# experiment trial to use as a status quo for the experiment.
if trial_index is None:
trial_index = len(experiment.trials) - 1
elif trial_index >= len(experiment.trials):
raise ValueError(
"trial_index is bigger than the number of experiment trials"
)

return checked_cast(
TorchModelBridge,
Models.ST_MTGP(
experiment=experiment,
search_space=search_space or experiment.search_space,
data=data,
transforms=transforms,
transform_configs=transform_configs,
torch_dtype=dtype,
torch_device=device,
status_quo_features=status_quo_features,
),
status_quo = experiment.trials[trial_index].status_quo
if status_quo is None:
status_quo_features = None
else:
status_quo_features = ObservationFeatures(
parameters=status_quo.parameters,
trial_index=trial_index, # pyre-ignore[6]
)

return checked_cast(
TorchModelBridge,
Models.ST_MTGP(
experiment=experiment,
search_space=search_space or experiment.search_space,
data=data,
transforms=transforms,
transform_configs=transform_configs,
torch_dtype=dtype,
torch_device=device,
status_quo_features=status_quo_features,
),
)


class AxMultitaskGenerator(AxGenerator):
"""Multitask Bayesian optimization using the Ax developer API.
Expand Down Expand Up @@ -377,7 +369,7 @@ def _create_experiment(self) -> MultiTypeExperiment:
)

# Register metric in order to be able to save experiment to json file.
_, encoder_registry, decoder_registry = register_metrics(AxMetric)
_, encoder_registry, decoder_registry = register_metrics({AxMetric, None})
self._encoder_registry = encoder_registry
self._decoder_registry = decoder_registry

Expand Down
3 changes: 0 additions & 3 deletions optimas/generators/ax/service/single_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@ def _create_generation_steps(
else:
# Use a SAAS model with qNEI acquisition function.
MODEL_CLASS = Models.FULLYBAYESIAN
# Disable additional logs from fully Bayesian model.
bo_model_kwargs["disable_progbar"] = True
bo_model_kwargs["verbose"] = False
else:
if len(self.objectives) > 1:
# Use a model with qNEHVI acquisition function.
Expand Down

0 comments on commit c125d4d

Please sign in to comment.