Skip to content

Commit

Permalink
Merge branch 'main' into gnn
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Sep 26, 2023
2 parents 19102c0 + 944d562 commit 27f62cf
Show file tree
Hide file tree
Showing 13 changed files with 40 additions and 16 deletions.
4 changes: 4 additions & 0 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
self,
experiment_config: "DictConfig", # noqa: F821
model_config: ModelConfig,
critic_model_config: ModelConfig,
observation_spec: CompositeSpec,
action_spec: CompositeSpec,
state_spec: Optional[CompositeSpec],
Expand All @@ -34,6 +35,7 @@ def __init__(

self.experiment_config = experiment_config
self.model_config = model_config
self.critic_model_config = critic_model_config
self.on_policy = on_policy
self.group_map = group_map
self.observation_spec = observation_spec
Expand Down Expand Up @@ -225,6 +227,7 @@ def get_algorithm(
self,
experiment_config,
model_config: ModelConfig,
critic_model_config: ModelConfig,
observation_spec: CompositeSpec,
action_spec: CompositeSpec,
state_spec: CompositeSpec,
Expand All @@ -235,6 +238,7 @@ def get_algorithm(
**self.__dict__,
experiment_config=experiment_config,
model_config=model_config,
critic_model_config=critic_model_config,
observation_spec=observation_spec,
action_spec=action_spec,
state_spec=state_spec,
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/algorithms/iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def get_value_module(self, group: str) -> TensorDictModule:
)

modules.append(
self.model_config.get_model(
self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/algorithms/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def get_critic(self, group: str) -> TensorDictModule:
)
}
)
value_module = self.model_config.get_model(
value_module = self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down
4 changes: 2 additions & 2 deletions benchmarl/algorithms/isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule:
)
}
)
value_module = self.model_config.get_model(
value_module = self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down Expand Up @@ -336,7 +336,7 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:
)

modules.append(
self.model_config.get_model(
self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down
4 changes: 2 additions & 2 deletions benchmarl/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def get_value_module(self, group: str) -> TensorDictModule:
)

modules.append(
self.model_config.get_model(
self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down Expand Up @@ -257,7 +257,7 @@ def get_value_module(self, group: str) -> TensorDictModule:
)

modules.append(
self.model_config.get_model(
self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down
4 changes: 2 additions & 2 deletions benchmarl/algorithms/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def get_critic(self, group: str) -> TensorDictModule:
)

if self.state_spec is not None:
value_module = self.model_config.get_model(
value_module = self.critic_model_config.get_model(
input_spec=self.state_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand All @@ -281,7 +281,7 @@ def get_critic(self, group: str) -> TensorDictModule:
)
}
)
value_module = self.model_config.get_model(
value_module = self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down
8 changes: 4 additions & 4 deletions benchmarl/algorithms/masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule:
)

if self.state_spec is not None:
value_module = self.model_config.get_model(
value_module = self.critic_model_config.get_model(
input_spec=self.state_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand All @@ -299,7 +299,7 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule:
)
}
)
value_module = self.model_config.get_model(
value_module = self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down Expand Up @@ -365,7 +365,7 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:
)

modules.append(
self.model_config.get_model(
self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down Expand Up @@ -406,7 +406,7 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:
)

modules.append(
self.model_config.get_model(
self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down
7 changes: 6 additions & 1 deletion benchmarl/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterator, Sequence, Set
from typing import Iterator, Optional, Sequence, Set

from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task
Expand All @@ -14,12 +14,16 @@ def __init__(
tasks: Sequence[Task],
seeds: Set[int],
experiment_config: ExperimentConfig,
critic_model_config: Optional[ModelConfig] = None,
):
self.algorithm_configs = algorithm_configs
self.tasks = tasks
self.seeds = seeds

self.model_config = model_config
self.critic_model_config = (
critic_model_config if critic_model_config is not None else model_config
)
self.experiment_config = experiment_config

print(f"Created benchmark with {self.n_experiments} experiments.")
Expand All @@ -37,6 +41,7 @@ def get_experiments(self) -> Iterator[Experiment]:
algorithm_config=algorithm_config,
seed=seed,
model_config=self.model_config,
critic_model_config=self.critic_model_config,
config=self.experiment_config,
)

Expand Down
1 change: 1 addition & 0 deletions benchmarl/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ defaults:
- algorithm: ???
- task: ???
- model: layers/mlp
- model@critic_model: layers/mlp
- _self_

seed: 0
5 changes: 5 additions & 0 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,15 @@ def __init__(
model_config: ModelConfig,
seed: int,
config: ExperimentConfig,
critic_model_config: Optional[ModelConfig] = None,
):
self.config = config

self.task = task
self.model_config = model_config
self.critic_model_config = (
critic_model_config if critic_model_config is not None else model_config
)
self.algorithm_config = algorithm_config
self.seed = seed

Expand Down Expand Up @@ -233,6 +237,7 @@ def _setup_algorithm(self):
self.algorithm = self.algorithm_config.get_algorithm(
experiment_config=self.config,
model_config=self.model_config,
critic_model_config=self.critic_model_config,
observation_spec=self.observation_spec,
action_spec=self.action_spec,
state_spec=self.state_spec,
Expand Down
2 changes: 2 additions & 0 deletions benchmarl/hydra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ def load_experiment_from_hydra(cfg: DictConfig, task_name: str) -> Experiment:
experiment_config = load_experiment_config_from_hydra(cfg.experiment)
task_config = load_task_config_from_hydra(cfg.task, task_name)
model_config = load_model_config_from_hydra(cfg.model)
critic_model_config = load_model_config_from_hydra(cfg.critic_model)

return Experiment(
task=task_config,
algorithm_config=algorithm_config,
model_config=model_config,
critic_model_config=critic_model_config,
seed=cfg.seed,
config=experiment_config,
)
Expand Down
7 changes: 4 additions & 3 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ def test_loading_sequence_models(model_name, intermidiate_size=10):
"model=sequence",
f"model/[email protected]={model_name}",
f"model/[email protected]={model_name}",
f"model.intermediate_sizes={[intermidiate_size]}",
f"+model/[email protected]={model_name}",
f"model.intermediate_sizes={[intermidiate_size,intermidiate_size]}",
],
)
hydra_model_config = load_model_config_from_hydra(cfg.model)
layer_config = model_config_registry[model_name].get_from_yaml()
yaml_config = SequenceModelConfig(
model_configs=[layer_config, layer_config],
intermediate_sizes=[intermidiate_size],
model_configs=[layer_config, layer_config, layer_config],
intermediate_sizes=[intermidiate_size, intermidiate_size],
)
assert hydra_model_config == yaml_config
6 changes: 6 additions & 0 deletions test/test_vmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task, VmasTask
from benchmarl.experiment import Experiment
from benchmarl.models import MlpConfig
from torch import nn
from utils_experiment import ExperimentUtils

_has_vmas = importlib.util.find_spec("vmas") is not None
Expand Down Expand Up @@ -78,10 +80,14 @@ def test_share_policy_params(
mlp_sequence_config,
):
experiment_config.share_policy_params = share_params
critic_model_config = MlpConfig(
num_cells=[6], activation_class=nn.Tanh, layer_class=nn.Linear
)
task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config.get_from_yaml(),
model_config=mlp_sequence_config,
critic_model_config=critic_model_config,
seed=0,
config=experiment_config,
task=task,
Expand Down

0 comments on commit 27f62cf

Please sign in to comment.