diff --git a/benchmarl/conf/algorithm/iddpg.yaml b/benchmarl/conf/algorithm/iddpg.yaml index 1e81561e..6cd946fb 100644 --- a/benchmarl/conf/algorithm/iddpg.yaml +++ b/benchmarl/conf/algorithm/iddpg.yaml @@ -1,6 +1,6 @@ defaults: - - _self_ - iddpg_config + - _self_ share_param_actor: True diff --git a/benchmarl/conf/algorithm/ippo.yaml b/benchmarl/conf/algorithm/ippo.yaml index 7ccdee18..c39e22ee 100644 --- a/benchmarl/conf/algorithm/ippo.yaml +++ b/benchmarl/conf/algorithm/ippo.yaml @@ -1,6 +1,6 @@ defaults: - - _self_ - ippo_config + - _self_ share_param_actor: False diff --git a/benchmarl/conf/algorithm/iql.yaml b/benchmarl/conf/algorithm/iql.yaml index d751186b..0f4f26fc 100644 --- a/benchmarl/conf/algorithm/iql.yaml +++ b/benchmarl/conf/algorithm/iql.yaml @@ -1,6 +1,6 @@ defaults: - - _self_ - iql_config + - _self_ delay_value: True diff --git a/benchmarl/conf/algorithm/isac.yaml b/benchmarl/conf/algorithm/isac.yaml index 5ff23eea..49f871bc 100644 --- a/benchmarl/conf/algorithm/isac.yaml +++ b/benchmarl/conf/algorithm/isac.yaml @@ -1,6 +1,6 @@ defaults: - - _self_ - isac_config + - _self_ share_param_actor: True diff --git a/benchmarl/conf/algorithm/maddpg.yaml b/benchmarl/conf/algorithm/maddpg.yaml index d9e4a398..99f47b0e 100644 --- a/benchmarl/conf/algorithm/maddpg.yaml +++ b/benchmarl/conf/algorithm/maddpg.yaml @@ -1,6 +1,6 @@ defaults: - - _self_ - maddpg_config + - _self_ share_param_actor: True diff --git a/benchmarl/conf/algorithm/mappo.yaml b/benchmarl/conf/algorithm/mappo.yaml index 346676fe..691b11e1 100644 --- a/benchmarl/conf/algorithm/mappo.yaml +++ b/benchmarl/conf/algorithm/mappo.yaml @@ -1,6 +1,6 @@ defaults: - - _self_ - mappo_config + - _self_ share_param_actor: False diff --git a/benchmarl/conf/algorithm/masac.yaml b/benchmarl/conf/algorithm/masac.yaml index d19b564d..1b76833d 100644 --- a/benchmarl/conf/algorithm/masac.yaml +++ b/benchmarl/conf/algorithm/masac.yaml @@ -1,6 +1,6 @@ defaults: - - _self_ - masac_config + - _self_ share_param_actor: True diff --git a/benchmarl/conf/algorithm/qmix.yaml b/benchmarl/conf/algorithm/qmix.yaml index 11924903..95c2ebf4 100644 --- a/benchmarl/conf/algorithm/qmix.yaml +++ b/benchmarl/conf/algorithm/qmix.yaml @@ -1,6 +1,6 @@ defaults: - - _self_ - qmix_config + - _self_ mixing_embed_dim: 32 diff --git a/benchmarl/conf/algorithm/vdn.yaml b/benchmarl/conf/algorithm/vdn.yaml index 6a0fed9a..2f2e6fe0 100644 --- a/benchmarl/conf/algorithm/vdn.yaml +++ b/benchmarl/conf/algorithm/vdn.yaml @@ -1,6 +1,6 @@ defaults: - - _self_ - vdn_config + - _self_ delay_value: True loss_function: "l2" diff --git a/benchmarl/conf/config.yaml b/benchmarl/conf/config.yaml index 748fa69c..fd7375bc 100644 --- a/benchmarl/conf/config.yaml +++ b/benchmarl/conf/config.yaml @@ -1,7 +1,8 @@ defaults: - - _self_ - experiment: base_experiment - algorithm: mappo - task: vmas/balance + - model: sequence + - _self_ seed: 0 diff --git a/benchmarl/conf/experiment/base_experiment.yaml b/benchmarl/conf/experiment/base_experiment.yaml index 33b2506d..f10cee14 100644 --- a/benchmarl/conf/experiment/base_experiment.yaml +++ b/benchmarl/conf/experiment/base_experiment.yaml @@ -1,6 +1,6 @@ defaults: - - _self_ - experiment_config + - _self_ sampling_device: "cpu" train_device: "cpu" diff --git a/benchmarl/conf/model/__init__.py b/benchmarl/conf/model/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/benchmarl/conf/model/layers/mlp.yaml b/benchmarl/conf/model/layers/mlp.yaml new file mode 100644 index 00000000..1817e089 --- /dev/null +++ b/benchmarl/conf/model/layers/mlp.yaml @@ -0,0 +1,12 @@ + + +name: mlp + +num_cells: [256, 256] +layer_class: torch.nn.Linear + +activation_class: torch.nn.Tanh +activation_kwargs: null + +norm_class: null +norm_kwargs: null diff --git a/benchmarl/conf/model/sequence.yaml b/benchmarl/conf/model/sequence.yaml new file mode 100644 index 00000000..0f53eca6 --- /dev/null +++ b/benchmarl/conf/model/sequence.yaml @@ -0,0 +1,12 @@ +defaults: + # Here is a list of layers for this model + # You can use configs from "layer" + - layers@layers.l1: mlp + - layers@layers.l2: mlp + - _self_ + +intermediate_sizes: [16] +# You can override your layers like this +layers: + l1: + num_cells: [4] diff --git a/benchmarl/hydra_run.py b/benchmarl/hydra_run.py index de0b73a5..d72a8d76 100644 --- a/benchmarl/hydra_run.py +++ b/benchmarl/hydra_run.py @@ -1,20 +1,18 @@ import hydra from hydra.core.hydra_config import HydraConfig -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf -from benchmarl.algorithms import algorithm_config_registry from benchmarl.environments import task_config_registry -from benchmarl.experiment import Experiment, ExperimentConfig -from benchmarl.models.common import ModelConfig -from benchmarl.models.mlp import MlpConfig +from benchmarl.experiment import Experiment +from benchmarl.models import model_config_registry +from benchmarl.models.common import ModelConfig, parse_model_config, SequenceModelConfig -def load_experiment_from_hydra_config( - cfg: DictConfig, algo_name: str, task_name: str, model_config: ModelConfig -) -> Experiment: - algorithm_config = algorithm_config_registry[algo_name](**cfg.algorithm) +def load_experiment_from_hydra_config(cfg: DictConfig, task_name: str) -> Experiment: + algorithm_config = OmegaConf.to_object(cfg.algorithm) + experiment_config = OmegaConf.to_object(cfg.experiment) task_config = task_config_registry[task_name].update_config(cfg.task) - experiment_config = ExperimentConfig(**cfg.experiment) + model_config = load_model_from_hydra_config(cfg.model) return Experiment( task=task_config, @@ -25,18 +23,27 @@ def load_experiment_from_hydra_config( ) +def load_model_from_hydra_config(cfg: DictConfig) -> ModelConfig: + if "layers" in cfg.keys(): + model_configs = [ + load_model_from_hydra_config(cfg.layers[f"l{i}"]) + for i in range(1, len(cfg.layers) + 1) + ] + return SequenceModelConfig( + model_configs=model_configs, intermediate_sizes=cfg.intermediate_sizes + ) + else: + model_class = model_config_registry[cfg.name] + return model_class(**parse_model_config(OmegaConf.to_container(cfg))) + + @hydra.main(version_base=None, config_path="conf", config_name="config") def hydra_experiment(cfg: DictConfig) -> None: hydra_choices = HydraConfig.get().runtime.choices - algo_name = hydra_choices.algorithm task_name = hydra_choices.task experiment = load_experiment_from_hydra_config( cfg, - algo_name=algo_name, task_name=task_name, - model_config=MlpConfig( - num_cells=[64] - ), # Model still needs to be hydra configurable ) experiment.run() diff --git a/benchmarl/models/__init__.py b/benchmarl/models/__init__.py index e69de29b..d3509727 100644 --- a/benchmarl/models/__init__.py +++ b/benchmarl/models/__init__.py @@ -0,0 +1,3 @@ +from models.mlp import MlpConfig + +model_config_registry = {"mlp": MlpConfig} diff --git a/benchmarl/models/common.py b/benchmarl/models/common.py index 26a01d7b..ae2f1d8f 100644 --- a/benchmarl/models/common.py +++ b/benchmarl/models/common.py @@ -1,12 +1,14 @@ +import pathlib from abc import ABC, abstractmethod from dataclasses import asdict, dataclass -from typing import List, Sequence +from typing import Any, Dict, List, Optional, Sequence +from hydra.utils import get_class from tensordict import TensorDictBase from tensordict.nn import TensorDictModuleBase, TensorDictSequential from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec -from benchmarl.utils import DEVICE_TYPING +from benchmarl.utils import DEVICE_TYPING, read_yaml_config def _check_spec(tensordict, spec): @@ -14,6 +16,16 @@ def _check_spec(tensordict, spec): raise ValueError(f"TensorDict {tensordict} not in spec {spec}") +def parse_model_config(cfg: Dict[str, Any]) -> Dict[str, Any]: + del cfg["name"] + kwargs = {} + for key, value in cfg.items(): + if key.endswith("class") and value is not None: + value = get_class(cfg[key]) + kwargs.update({key: value}) + return kwargs + + def output_has_agent_dim(share_params: bool, centralised: bool) -> bool: if share_params and centralised: return False @@ -144,6 +156,23 @@ def get_model( def associated_class(): raise NotImplementedError + @staticmethod + def _load_from_yaml(name: str) -> Dict[str, Any]: + yaml_path = ( + pathlib.Path(__file__).parent.parent + / "conf" + / "model" + / "layers" + / f"{name.lower()}.yaml" + ) + cfg = read_yaml_config(str(yaml_path.resolve())) + return parse_model_config(cfg) + + @staticmethod + @abstractmethod + def get_from_yaml(path: Optional[str] = None): + raise NotImplementedError + @dataclass class SequenceModelConfig(ModelConfig): @@ -212,3 +241,7 @@ def get_model( @staticmethod def associated_class(): raise NotImplementedError + + @staticmethod + def get_from_yaml(path: Optional[str] = None): + raise NotImplementedError diff --git a/benchmarl/models/mlp.py b/benchmarl/models/mlp.py index b481ee2e..59c51b77 100644 --- a/benchmarl/models/mlp.py +++ b/benchmarl/models/mlp.py @@ -1,12 +1,13 @@ -from dataclasses import dataclass +from dataclasses import dataclass, MISSING from typing import Optional, Sequence, Type import torch from tensordict import TensorDictBase from torch import nn from torchrl.modules import MLP, MultiAgentMLP +from utils import read_yaml_config -from benchmarl.models.common import Model, ModelConfig +from benchmarl.models.common import Model, ModelConfig, parse_model_config class Mlp(Model): @@ -88,12 +89,10 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: @dataclass class MlpConfig(ModelConfig): - # You can add any kwargs from torchrl.modules.MLP + num_cells: Sequence[int] = MISSING + layer_class: Type[nn.Module] = MISSING - num_cells: Sequence[int] = (256, 256) - layer_class: Type[nn.Module] = nn.Linear - - activation_class: Type[nn.Module] = nn.Tanh + activation_class: Type[nn.Module] = MISSING activation_kwargs: Optional[dict] = None norm_class: Type[nn.Module] = None @@ -102,3 +101,14 @@ class MlpConfig(ModelConfig): @staticmethod def associated_class(): return Mlp + + @staticmethod + def get_from_yaml(path: Optional[str] = None): + if path is None: + return MlpConfig( + **ModelConfig._load_from_yaml( + name=MlpConfig.associated_class().__name__, + ) + ) + else: + return MlpConfig(**parse_model_config(read_yaml_config(path))) diff --git a/examples/simple_hydra_run.py b/examples/simple_hydra_run.py index d8f4693f..2aea53f4 100644 --- a/examples/simple_hydra_run.py +++ b/examples/simple_hydra_run.py @@ -1,32 +1,22 @@ import hydra -from benchmarl.algorithms import algorithm_config_registry from benchmarl.environments import task_config_registry -from benchmarl.experiment import Experiment, ExperimentConfig -from benchmarl.models.common import SequenceModelConfig -from benchmarl.models.mlp import MlpConfig +from benchmarl.experiment import Experiment + +from benchmarl.hydra_run import load_model_from_hydra_config from hydra.core.hydra_config import HydraConfig -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf @hydra.main(version_base=None, config_path="../benchmarl/conf", config_name="config") def hydra_experiment(cfg: DictConfig) -> None: hydra_choices = HydraConfig.get().runtime.choices - algo_name = hydra_choices.algorithm task_name = hydra_choices.task - algorithm_config = algorithm_config_registry[algo_name](**cfg.algorithm) + algorithm_config = OmegaConf.to_object(cfg.algorithm) + experiment_config = OmegaConf.to_object(cfg.experiment) task_config = task_config_registry[task_name].update_config(cfg.task) - experiment_config = ExperimentConfig(**cfg.experiment) - - # Model still need to be refactored for hydra loading - model_config = SequenceModelConfig( - model_configs=[ - MlpConfig(num_cells=[64, 64]), - MlpConfig(num_cells=[256]), - ], - intermediate_sizes=[128], - ) + model_config = load_model_from_hydra_config(cfg.model) experiment = Experiment( task=task_config, diff --git a/examples/simple_run.py b/examples/simple_run.py index 14ef7dad..11b7c8ea 100644 --- a/examples/simple_run.py +++ b/examples/simple_run.py @@ -4,6 +4,7 @@ from benchmarl.experiment import ExperimentConfig from benchmarl.models.common import SequenceModelConfig from benchmarl.models.mlp import MlpConfig +from torch import nn if __name__ == "__main__": @@ -20,8 +21,8 @@ # Model still need to be refactored for hydra loading model_config = SequenceModelConfig( model_configs=[ - MlpConfig(num_cells=[64, 64]), - MlpConfig(num_cells=[256]), + MlpConfig.get_from_yaml(), + MlpConfig(num_cells=[256], layer_class=nn.Linear, activation_class=nn.Tanh), ], intermediate_sizes=[128], ) diff --git a/test/test_algorithms.py b/test/test_algorithms.py index ca9aa4a7..4dad36ad 100644 --- a/test/test_algorithms.py +++ b/test/test_algorithms.py @@ -39,19 +39,12 @@ def test_all_algos_hydra(algo_config): with initialize(version_base=None, config_path="../benchmarl/conf"): cfg = compose( config_name="config", - overrides=[f"algorithm={algo_config}"], + overrides=[ + f"algorithm={algo_config}", + "model=layers/mlp", + ], return_hydra_config=True, ) task_name = cfg.hydra.runtime.choices.task - algo_name = cfg.hydra.runtime.choices.algorithm - model_config = SequenceModelConfig( - model_configs=[ - MlpConfig(num_cells=[8]), - MlpConfig(num_cells=[4]), - ], - intermediate_sizes=[5], - ) - experiment = load_experiment_from_hydra_config( - cfg, algo_name=algo_name, task_name=task_name, model_config=model_config - ) + experiment = load_experiment_from_hydra_config(cfg, task_name=task_name) experiment.run()