diff --git a/benchmarl/algorithms/common.py b/benchmarl/algorithms/common.py index 371e3a8f..c21d6e34 100644 --- a/benchmarl/algorithms/common.py +++ b/benchmarl/algorithms/common.py @@ -23,6 +23,7 @@ def __init__( self, experiment_config: "DictConfig", # noqa: F821 model_config: ModelConfig, + critic_model_config: ModelConfig, observation_spec: CompositeSpec, action_spec: CompositeSpec, state_spec: Optional[CompositeSpec], @@ -34,6 +35,7 @@ def __init__( self.experiment_config = experiment_config self.model_config = model_config + self.critic_model_config = critic_model_config self.on_policy = on_policy self.group_map = group_map self.observation_spec = observation_spec @@ -225,6 +227,7 @@ def get_algorithm( self, experiment_config, model_config: ModelConfig, + critic_model_config: ModelConfig, observation_spec: CompositeSpec, action_spec: CompositeSpec, state_spec: CompositeSpec, @@ -235,6 +238,7 @@ def get_algorithm( **self.__dict__, experiment_config=experiment_config, model_config=model_config, + critic_model_config=critic_model_config, observation_spec=observation_spec, action_spec=action_spec, state_spec=state_spec, diff --git a/benchmarl/algorithms/iddpg.py b/benchmarl/algorithms/iddpg.py index 02b48c99..bea37bd3 100644 --- a/benchmarl/algorithms/iddpg.py +++ b/benchmarl/algorithms/iddpg.py @@ -211,7 +211,7 @@ def get_value_module(self, group: str) -> TensorDictModule: ) modules.append( - self.model_config.get_model( + self.critic_model_config.get_model( input_spec=critic_input_spec, output_spec=critic_output_spec, n_agents=n_agents, diff --git a/benchmarl/algorithms/ippo.py b/benchmarl/algorithms/ippo.py index b6422cb7..e3eaca66 100644 --- a/benchmarl/algorithms/ippo.py +++ b/benchmarl/algorithms/ippo.py @@ -259,7 +259,7 @@ def get_critic(self, group: str) -> TensorDictModule: ) } ) - value_module = self.model_config.get_model( + value_module = self.critic_model_config.get_model( input_spec=critic_input_spec, output_spec=critic_output_spec, n_agents=n_agents, diff --git a/benchmarl/algorithms/isac.py b/benchmarl/algorithms/isac.py index abb696b7..cc8bcb8c 100644 --- a/benchmarl/algorithms/isac.py +++ b/benchmarl/algorithms/isac.py @@ -281,7 +281,7 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule: ) } ) - value_module = self.model_config.get_model( + value_module = self.critic_model_config.get_model( input_spec=critic_input_spec, output_spec=critic_output_spec, n_agents=n_agents, @@ -336,7 +336,7 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule: ) modules.append( - self.model_config.get_model( + self.critic_model_config.get_model( input_spec=critic_input_spec, output_spec=critic_output_spec, n_agents=n_agents, diff --git a/benchmarl/algorithms/maddpg.py b/benchmarl/algorithms/maddpg.py index 82e833d1..b567ea7f 100644 --- a/benchmarl/algorithms/maddpg.py +++ b/benchmarl/algorithms/maddpg.py @@ -216,7 +216,7 @@ def get_value_module(self, group: str) -> TensorDictModule: ) modules.append( - self.model_config.get_model( + self.critic_model_config.get_model( input_spec=critic_input_spec, output_spec=critic_output_spec, n_agents=n_agents, @@ -257,7 +257,7 @@ def get_value_module(self, group: str) -> TensorDictModule: ) modules.append( - self.model_config.get_model( + self.critic_model_config.get_model( input_spec=critic_input_spec, output_spec=critic_output_spec, n_agents=n_agents, diff --git a/benchmarl/algorithms/mappo.py b/benchmarl/algorithms/mappo.py index f5264758..5517a052 100644 --- a/benchmarl/algorithms/mappo.py +++ b/benchmarl/algorithms/mappo.py @@ -257,7 +257,7 @@ def get_critic(self, group: str) -> TensorDictModule: ) if self.state_spec is not None: - value_module = self.model_config.get_model( + value_module = self.critic_model_config.get_model( input_spec=self.state_spec, output_spec=critic_output_spec, n_agents=n_agents, @@ -281,7 +281,7 @@ def get_critic(self, group: str) -> TensorDictModule: ) } ) - value_module = self.model_config.get_model( + value_module = self.critic_model_config.get_model( input_spec=critic_input_spec, output_spec=critic_output_spec, n_agents=n_agents, diff --git a/benchmarl/algorithms/masac.py b/benchmarl/algorithms/masac.py index 101c5ff9..9a2180e8 100644 --- a/benchmarl/algorithms/masac.py +++ b/benchmarl/algorithms/masac.py @@ -275,7 +275,7 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule: ) if self.state_spec is not None: - value_module = self.model_config.get_model( + value_module = self.critic_model_config.get_model( input_spec=self.state_spec, output_spec=critic_output_spec, n_agents=n_agents, @@ -299,7 +299,7 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule: ) } ) - value_module = self.model_config.get_model( + value_module = self.critic_model_config.get_model( input_spec=critic_input_spec, output_spec=critic_output_spec, n_agents=n_agents, @@ -365,7 +365,7 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule: ) modules.append( - self.model_config.get_model( + self.critic_model_config.get_model( input_spec=critic_input_spec, output_spec=critic_output_spec, n_agents=n_agents, @@ -406,7 +406,7 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule: ) modules.append( - self.model_config.get_model( + self.critic_model_config.get_model( input_spec=critic_input_spec, output_spec=critic_output_spec, n_agents=n_agents, diff --git a/benchmarl/benchmark.py b/benchmarl/benchmark.py index 961df975..48a49dd5 100644 --- a/benchmarl/benchmark.py +++ b/benchmarl/benchmark.py @@ -1,4 +1,4 @@ -from typing import Iterator, Sequence, Set +from typing import Iterator, Optional, Sequence, Set from benchmarl.algorithms.common import AlgorithmConfig from benchmarl.environments import Task @@ -14,12 +14,16 @@ def __init__( tasks: Sequence[Task], seeds: Set[int], experiment_config: ExperimentConfig, + critic_model_config: Optional[ModelConfig] = None, ): self.algorithm_configs = algorithm_configs self.tasks = tasks self.seeds = seeds self.model_config = model_config + self.critic_model_config = ( + critic_model_config if critic_model_config is not None else model_config + ) self.experiment_config = experiment_config print(f"Created benchmark with {self.n_experiments} experiments.") @@ -37,6 +41,7 @@ def get_experiments(self) -> Iterator[Experiment]: algorithm_config=algorithm_config, seed=seed, model_config=self.model_config, + critic_model_config=self.critic_model_config, config=self.experiment_config, ) diff --git a/benchmarl/conf/config.yaml b/benchmarl/conf/config.yaml index 34389434..02343b24 100644 --- a/benchmarl/conf/config.yaml +++ b/benchmarl/conf/config.yaml @@ -3,6 +3,7 @@ defaults: - algorithm: ??? - task: ??? - model: layers/mlp + - model@critic_model: layers/mlp - _self_ seed: 0 diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 29d93053..7c3f348a 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -127,11 +127,15 @@ def __init__( model_config: ModelConfig, seed: int, config: ExperimentConfig, + critic_model_config: Optional[ModelConfig] = None, ): self.config = config self.task = task self.model_config = model_config + self.critic_model_config = ( + critic_model_config if critic_model_config is not None else model_config + ) self.algorithm_config = algorithm_config self.seed = seed @@ -233,6 +237,7 @@ def _setup_algorithm(self): self.algorithm = self.algorithm_config.get_algorithm( experiment_config=self.config, model_config=self.model_config, + critic_model_config=self.critic_model_config, observation_spec=self.observation_spec, action_spec=self.action_spec, state_spec=self.state_spec, diff --git a/benchmarl/hydra_config.py b/benchmarl/hydra_config.py index 3da816f2..7dd65814 100644 --- a/benchmarl/hydra_config.py +++ b/benchmarl/hydra_config.py @@ -12,11 +12,13 @@ def load_experiment_from_hydra(cfg: DictConfig, task_name: str) -> Experiment: experiment_config = load_experiment_config_from_hydra(cfg.experiment) task_config = load_task_config_from_hydra(cfg.task, task_name) model_config = load_model_config_from_hydra(cfg.model) + critic_model_config = load_model_config_from_hydra(cfg.critic_model) return Experiment( task=task_config, algorithm_config=algorithm_config, model_config=model_config, + critic_model_config=critic_model_config, seed=cfg.seed, config=experiment_config, ) diff --git a/test/test_models.py b/test/test_models.py index 0119ef6d..aa228076 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -33,13 +33,14 @@ def test_loading_sequence_models(model_name, intermidiate_size=10): "model=sequence", f"model/layers@model.layers.l1={model_name}", f"model/layers@model.layers.l2={model_name}", - f"model.intermediate_sizes={[intermidiate_size]}", + f"+model/layers@model.layers.l3={model_name}", + f"model.intermediate_sizes={[intermidiate_size,intermidiate_size]}", ], ) hydra_model_config = load_model_config_from_hydra(cfg.model) layer_config = model_config_registry[model_name].get_from_yaml() yaml_config = SequenceModelConfig( - model_configs=[layer_config, layer_config], - intermediate_sizes=[intermidiate_size], + model_configs=[layer_config, layer_config, layer_config], + intermediate_sizes=[intermidiate_size, intermidiate_size], ) assert hydra_model_config == yaml_config diff --git a/test/test_vmas.py b/test/test_vmas.py index f4a88f0e..73c7986e 100644 --- a/test/test_vmas.py +++ b/test/test_vmas.py @@ -13,6 +13,8 @@ from benchmarl.algorithms.common import AlgorithmConfig from benchmarl.environments import Task, VmasTask from benchmarl.experiment import Experiment +from benchmarl.models import MlpConfig +from torch import nn from utils_experiment import ExperimentUtils _has_vmas = importlib.util.find_spec("vmas") is not None @@ -78,10 +80,14 @@ def test_share_policy_params( mlp_sequence_config, ): experiment_config.share_policy_params = share_params + critic_model_config = MlpConfig( + num_cells=[6], activation_class=nn.Tanh, layer_class=nn.Linear + ) task = task.get_from_yaml() experiment = Experiment( algorithm_config=algo_config.get_from_yaml(), model_config=mlp_sequence_config, + critic_model_config=critic_model_config, seed=0, config=experiment_config, task=task,