Skip to content

Commit

Permalink
Merge branch 'main' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Sep 8, 2023
2 parents 8a110f9 + 5981415 commit d01d7e0
Show file tree
Hide file tree
Showing 21 changed files with 128 additions and 66 deletions.
2 changes: 1 addition & 1 deletion benchmarl/conf/algorithm/iddpg.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- _self_
- iddpg_config
- _self_


share_param_actor: True
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/conf/algorithm/ippo.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- _self_
- ippo_config
- _self_


share_param_actor: False
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/conf/algorithm/iql.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- _self_
- iql_config
- _self_


delay_value: True
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/conf/algorithm/isac.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- _self_
- isac_config
- _self_


share_param_actor: True
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/conf/algorithm/maddpg.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- _self_
- maddpg_config
- _self_


share_param_actor: True
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/conf/algorithm/mappo.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- _self_
- mappo_config
- _self_


share_param_actor: False
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/conf/algorithm/masac.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- _self_
- masac_config
- _self_


share_param_actor: True
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/conf/algorithm/qmix.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- _self_
- qmix_config
- _self_


mixing_embed_dim: 32
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/conf/algorithm/vdn.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- _self_
- vdn_config
- _self_

delay_value: True
loss_function: "l2"
Expand Down
3 changes: 2 additions & 1 deletion benchmarl/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
defaults:
- _self_
- experiment: base_experiment
- algorithm: mappo
- task: vmas/balance
- model: sequence
- _self_

seed: 0
2 changes: 1 addition & 1 deletion benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- _self_
- experiment_config
- _self_

sampling_device: "cpu"
train_device: "cpu"
Expand Down
Empty file removed benchmarl/conf/model/__init__.py
Empty file.
12 changes: 12 additions & 0 deletions benchmarl/conf/model/layers/mlp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@


name: mlp

num_cells: [256, 256]
layer_class: torch.nn.Linear

activation_class: torch.nn.Tanh
activation_kwargs: null

norm_class: null
norm_kwargs: null
12 changes: 12 additions & 0 deletions benchmarl/conf/model/sequence.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
defaults:
# Here is a list of layers for this model
# You can use configs from "layer"
- [email protected]: mlp
- [email protected]: mlp
- _self_

intermediate_sizes: [16]
# You can override your layers like this
layers:
l1:
num_cells: [4]
37 changes: 22 additions & 15 deletions benchmarl/hydra_run.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import hydra
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf

from benchmarl.algorithms import algorithm_config_registry
from benchmarl.environments import task_config_registry
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models.common import ModelConfig
from benchmarl.models.mlp import MlpConfig
from benchmarl.experiment import Experiment
from benchmarl.models import model_config_registry
from benchmarl.models.common import ModelConfig, parse_model_config, SequenceModelConfig


def load_experiment_from_hydra_config(
cfg: DictConfig, algo_name: str, task_name: str, model_config: ModelConfig
) -> Experiment:
algorithm_config = algorithm_config_registry[algo_name](**cfg.algorithm)
def load_experiment_from_hydra_config(cfg: DictConfig, task_name: str) -> Experiment:
algorithm_config = OmegaConf.to_object(cfg.algorithm)
experiment_config = OmegaConf.to_object(cfg.experiment)
task_config = task_config_registry[task_name].update_config(cfg.task)
experiment_config = ExperimentConfig(**cfg.experiment)
model_config = load_model_from_hydra_config(cfg.model)

return Experiment(
task=task_config,
Expand All @@ -25,18 +23,27 @@ def load_experiment_from_hydra_config(
)


def load_model_from_hydra_config(cfg: DictConfig) -> ModelConfig:
if "layers" in cfg.keys():
model_configs = [
load_model_from_hydra_config(cfg.layers[f"l{i}"])
for i in range(1, len(cfg.layers) + 1)
]
return SequenceModelConfig(
model_configs=model_configs, intermediate_sizes=cfg.intermediate_sizes
)
else:
model_class = model_config_registry[cfg.name]
return model_class(**parse_model_config(OmegaConf.to_container(cfg)))


@hydra.main(version_base=None, config_path="conf", config_name="config")
def hydra_experiment(cfg: DictConfig) -> None:
hydra_choices = HydraConfig.get().runtime.choices
algo_name = hydra_choices.algorithm
task_name = hydra_choices.task
experiment = load_experiment_from_hydra_config(
cfg,
algo_name=algo_name,
task_name=task_name,
model_config=MlpConfig(
num_cells=[64]
), # Model still needs to be hydra configurable
)
experiment.run()

Expand Down
3 changes: 3 additions & 0 deletions benchmarl/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from models.mlp import MlpConfig

model_config_registry = {"mlp": MlpConfig}
37 changes: 35 additions & 2 deletions benchmarl/models/common.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
import pathlib
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from typing import List, Sequence
from typing import Any, Dict, List, Optional, Sequence

from hydra.utils import get_class
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModuleBase, TensorDictSequential
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec

from benchmarl.utils import DEVICE_TYPING
from benchmarl.utils import DEVICE_TYPING, read_yaml_config


def _check_spec(tensordict, spec):
if not spec.is_in(tensordict):
raise ValueError(f"TensorDict {tensordict} not in spec {spec}")


def parse_model_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
del cfg["name"]
kwargs = {}
for key, value in cfg.items():
if key.endswith("class") and value is not None:
value = get_class(cfg[key])
kwargs.update({key: value})
return kwargs


def output_has_agent_dim(share_params: bool, centralised: bool) -> bool:
if share_params and centralised:
return False
Expand Down Expand Up @@ -144,6 +156,23 @@ def get_model(
def associated_class():
raise NotImplementedError

@staticmethod
def _load_from_yaml(name: str) -> Dict[str, Any]:
yaml_path = (
pathlib.Path(__file__).parent.parent
/ "conf"
/ "model"
/ "layers"
/ f"{name.lower()}.yaml"
)
cfg = read_yaml_config(str(yaml_path.resolve()))
return parse_model_config(cfg)

@staticmethod
@abstractmethod
def get_from_yaml(path: Optional[str] = None):
raise NotImplementedError


@dataclass
class SequenceModelConfig(ModelConfig):
Expand Down Expand Up @@ -212,3 +241,7 @@ def get_model(
@staticmethod
def associated_class():
raise NotImplementedError

@staticmethod
def get_from_yaml(path: Optional[str] = None):
raise NotImplementedError
24 changes: 17 additions & 7 deletions benchmarl/models/mlp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from dataclasses import dataclass
from dataclasses import dataclass, MISSING
from typing import Optional, Sequence, Type

import torch
from tensordict import TensorDictBase
from torch import nn
from torchrl.modules import MLP, MultiAgentMLP
from utils import read_yaml_config

from benchmarl.models.common import Model, ModelConfig
from benchmarl.models.common import Model, ModelConfig, parse_model_config


class Mlp(Model):
Expand Down Expand Up @@ -88,12 +89,10 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:

@dataclass
class MlpConfig(ModelConfig):
# You can add any kwargs from torchrl.modules.MLP
num_cells: Sequence[int] = MISSING
layer_class: Type[nn.Module] = MISSING

num_cells: Sequence[int] = (256, 256)
layer_class: Type[nn.Module] = nn.Linear

activation_class: Type[nn.Module] = nn.Tanh
activation_class: Type[nn.Module] = MISSING
activation_kwargs: Optional[dict] = None

norm_class: Type[nn.Module] = None
Expand All @@ -102,3 +101,14 @@ class MlpConfig(ModelConfig):
@staticmethod
def associated_class():
return Mlp

@staticmethod
def get_from_yaml(path: Optional[str] = None):
if path is None:
return MlpConfig(
**ModelConfig._load_from_yaml(
name=MlpConfig.associated_class().__name__,
)
)
else:
return MlpConfig(**parse_model_config(read_yaml_config(path)))
24 changes: 7 additions & 17 deletions examples/simple_hydra_run.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,22 @@
import hydra

from benchmarl.algorithms import algorithm_config_registry
from benchmarl.environments import task_config_registry
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models.common import SequenceModelConfig
from benchmarl.models.mlp import MlpConfig
from benchmarl.experiment import Experiment

from benchmarl.hydra_run import load_model_from_hydra_config
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf


@hydra.main(version_base=None, config_path="../benchmarl/conf", config_name="config")
def hydra_experiment(cfg: DictConfig) -> None:
hydra_choices = HydraConfig.get().runtime.choices
algo_name = hydra_choices.algorithm
task_name = hydra_choices.task

algorithm_config = algorithm_config_registry[algo_name](**cfg.algorithm)
algorithm_config = OmegaConf.to_object(cfg.algorithm)
experiment_config = OmegaConf.to_object(cfg.experiment)
task_config = task_config_registry[task_name].update_config(cfg.task)
experiment_config = ExperimentConfig(**cfg.experiment)

# Model still need to be refactored for hydra loading
model_config = SequenceModelConfig(
model_configs=[
MlpConfig(num_cells=[64, 64]),
MlpConfig(num_cells=[256]),
],
intermediate_sizes=[128],
)
model_config = load_model_from_hydra_config(cfg.model)

experiment = Experiment(
task=task_config,
Expand Down
5 changes: 3 additions & 2 deletions examples/simple_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from benchmarl.experiment import ExperimentConfig
from benchmarl.models.common import SequenceModelConfig
from benchmarl.models.mlp import MlpConfig
from torch import nn

if __name__ == "__main__":

Expand All @@ -20,8 +21,8 @@
# Model still need to be refactored for hydra loading
model_config = SequenceModelConfig(
model_configs=[
MlpConfig(num_cells=[64, 64]),
MlpConfig(num_cells=[256]),
MlpConfig.get_from_yaml(),
MlpConfig(num_cells=[256], layer_class=nn.Linear, activation_class=nn.Tanh),
],
intermediate_sizes=[128],
)
Expand Down
17 changes: 5 additions & 12 deletions test/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,12 @@ def test_all_algos_hydra(algo_config):
with initialize(version_base=None, config_path="../benchmarl/conf"):
cfg = compose(
config_name="config",
overrides=[f"algorithm={algo_config}"],
overrides=[
f"algorithm={algo_config}",
"model=layers/mlp",
],
return_hydra_config=True,
)
task_name = cfg.hydra.runtime.choices.task
algo_name = cfg.hydra.runtime.choices.algorithm
model_config = SequenceModelConfig(
model_configs=[
MlpConfig(num_cells=[8]),
MlpConfig(num_cells=[4]),
],
intermediate_sizes=[5],
)
experiment = load_experiment_from_hydra_config(
cfg, algo_name=algo_name, task_name=task_name, model_config=model_config
)
experiment = load_experiment_from_hydra_config(cfg, task_name=task_name)
experiment.run()

0 comments on commit d01d7e0

Please sign in to comment.