Skip to content

Commit

Permalink
[Examples] Update hydra examples
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Sep 19, 2023
1 parent b8f33bf commit 286d6f3
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 47 deletions.
55 changes: 24 additions & 31 deletions benchmarl/hydra_run.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import hydra
from hydra.core.hydra_config import HydraConfig
from experiment import ExperimentConfig
from omegaconf import DictConfig, OmegaConf

from benchmarl.environments import task_config_registry
from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task, task_config_registry
from benchmarl.experiment import Experiment
from benchmarl.models import model_config_registry
from benchmarl.models.common import ModelConfig, parse_model_config, SequenceModelConfig


def load_experiment_from_hydra_config(cfg: DictConfig, task_name: str) -> Experiment:
algorithm_config = OmegaConf.to_object(cfg.algorithm)
experiment_config = OmegaConf.to_object(cfg.experiment)
task_config = task_config_registry[task_name].update_config(
OmegaConf.to_container(cfg.task, resolve=True)
)
model_config = load_model_from_hydra_config(cfg.model)
def load_experiment_from_hydra(cfg: DictConfig, task_name: str) -> Experiment:
algorithm_config = load_algorithm_config_from_hydra(cfg.algorithm)
experiment_config = load_experiment_config_from_hydra(cfg.experiment)
task_config = load_task_config_from_hydra(cfg.task, task_name)
model_config = load_model_config_from_hydra(cfg.model)

return Experiment(
task=task_config,
Expand All @@ -25,10 +23,24 @@ def load_experiment_from_hydra_config(cfg: DictConfig, task_name: str) -> Experi
)


def load_model_from_hydra_config(cfg: DictConfig) -> ModelConfig:
def load_task_config_from_hydra(cfg: DictConfig, task_name: str) -> Task:
return task_config_registry[task_name].update_config(
OmegaConf.to_container(cfg.task, resolve=True)
)


def load_experiment_config_from_hydra(cfg: DictConfig) -> ExperimentConfig:
return OmegaConf.to_object(cfg)


def load_algorithm_config_from_hydra(cfg: DictConfig) -> AlgorithmConfig:
return OmegaConf.to_object(cfg)


def load_model_config_from_hydra(cfg: DictConfig) -> ModelConfig:
if "layers" in cfg.keys():
model_configs = [
load_model_from_hydra_config(cfg.layers[f"l{i}"])
load_model_config_from_hydra(cfg.layers[f"l{i}"])
for i in range(1, len(cfg.layers) + 1)
]
return SequenceModelConfig(
Expand All @@ -39,22 +51,3 @@ def load_model_from_hydra_config(cfg: DictConfig) -> ModelConfig:
return model_class(
**parse_model_config(OmegaConf.to_container(cfg, resolve=True))
)


@hydra.main(version_base=None, config_path="conf", config_name="config")
def hydra_experiment(cfg: DictConfig) -> None:
hydra_choices = HydraConfig.get().runtime.choices
task_name = hydra_choices.task
print(f"\nAlgorithm: {hydra_choices.algorithm}, Task: {task_name}")
print("\nLoaded config:\n")
print(OmegaConf.to_yaml(cfg))

experiment = load_experiment_from_hydra_config(
cfg,
task_name=task_name,
)
experiment.run()


if __name__ == "__main__":
hydra_experiment()
23 changes: 7 additions & 16 deletions examples/simple_hydra_run.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import hydra

from benchmarl.environments import task_config_registry
from benchmarl.experiment import Experiment
from benchmarl.hydra_run import load_model_from_hydra_config
from benchmarl.hydra_run import load_experiment_from_hydra
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf

Expand All @@ -11,20 +8,14 @@
def hydra_experiment(cfg: DictConfig) -> None:
hydra_choices = HydraConfig.get().runtime.choices
task_name = hydra_choices.task
print(f"\nAlgorithm: {hydra_choices.algorithm}, Task: {task_name}")
print("\nLoaded config:\n")
print(OmegaConf.to_yaml(cfg))

algorithm_config = OmegaConf.to_object(cfg.algorithm)
experiment_config = OmegaConf.to_object(cfg.experiment)
task_config = task_config_registry[task_name].update_config(cfg.task)
model_config = load_model_from_hydra_config(cfg.model)

experiment = Experiment(
task=task_config,
algorithm_config=algorithm_config,
model_config=model_config,
seed=cfg.seed,
config=experiment_config,
experiment = load_experiment_from_hydra(
cfg,
task_name=task_name,
)

experiment.run()


Expand Down
50 changes: 50 additions & 0 deletions examples/vmas_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import hydra
from benchmarl.experiment import Experiment

from benchmarl.hydra_run import (
load_algorithm_config_from_hydra,
load_experiment_config_from_hydra,
load_model_config_from_hydra,
load_task_config_from_hydra,
)
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf


@hydra.main(version_base=None, config_path="../benchmarl/conf", config_name="config")
def hydra_experiment(cfg: DictConfig) -> None:
hydra_choices = HydraConfig.get().runtime.choices
task_name = hydra_choices.task
print(f"\nAlgorithm: {hydra_choices.algorithm}, Task: {task_name}")

algorithm_config = load_algorithm_config_from_hydra(cfg.algorithm)
task_config = load_task_config_from_hydra(cfg.task, task_name)
model_config = load_model_config_from_hydra(cfg.model)
experiment_config = cfg.experiment

# Hyperparameter changes for VMAS experiments
experiment_config.sampling_device = "cuda"
experiment_config.train_device = "cuda"
experiment_config.collected_frames_per_batch = 60_000
experiment_config.n_envs_per_worker = 600
experiment_config.on_policy_minibatch_size = 4096
experiment_config.evaluation_episodes = 200
experiment_config = load_experiment_config_from_hydra(cfg.experiment)

print("\nLoaded config:\n")
print(OmegaConf.to_yaml(cfg))

experiment = Experiment(
task=task_config,
algorithm_config=algorithm_config,
model_config=model_config,
seed=cfg.seed,
config=experiment_config,
)
experiment.run()


if __name__ == "__main__":
hydra_experiment()
# To reproduce the VMAS results launch this with
# python run.py algorithm=ippo "task=vmas/navigation,vmas/balance,vmas/sampling" "seed=0,1,2"

0 comments on commit 286d6f3

Please sign in to comment.