Skip to content

Commit

Permalink
[Config] Hydra config (#8)
Browse files Browse the repository at this point in the history
[Config] Hydra config
  • Loading branch information
matteobettini committed Sep 8, 2023
2 parents 501dc78 + a9e17cb commit 0b3afdb
Show file tree
Hide file tree
Showing 40 changed files with 730 additions and 293 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
### Python template

# Hydra
outputs/
multirun/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
20 changes: 20 additions & 0 deletions benchmarl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
def load_hydra_schemas():
from hydra.core.config_store import ConfigStore

from benchmarl.algorithms import algorithm_config_registry
from benchmarl.environments import _task_class_registry
from benchmarl.experiment import ExperimentConfig

# Create instance to load hydra schemas
cs = ConfigStore.instance()
# Load experiment schema
cs.store(name="experiment_config", group="experiment", node=ExperimentConfig)
# Load algos schemas
for algo_name, algo_schema in algorithm_config_registry.items():
cs.store(name=f"{algo_name}_config", group="algorithm", node=algo_schema)
# Load rask schemas
for task_schema_name, task_schema in _task_class_registry.items():
cs.store(name=task_schema_name, group="task", node=task_schema)


load_hydra_schemas()
22 changes: 11 additions & 11 deletions benchmarl/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from .qmix import Qmix, QmixConfig
from .vdn import Vdn, VdnConfig

all_algorithm_configs = (
IppoConfig,
MappoConfig,
MaddpgConfig,
IddpgConfig,
MasacConfig,
IsacConfig,
QmixConfig,
VdnConfig,
IqlConfig,
)
algorithm_config_registry = {
"mappo": MappoConfig,
"ippo": IppoConfig,
"maddpg": MaddpgConfig,
"iddpg": IddpgConfig,
"masac": MasacConfig,
"isac": IsacConfig,
"qmix": QmixConfig,
"vdn": VdnConfig,
"iql": IqlConfig,
}
63 changes: 41 additions & 22 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pathlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type

import torch.optim
from tensordict import TensorDictBase
Expand All @@ -15,7 +16,7 @@
from torchrl.objectives.utils import TargetNetUpdater

from benchmarl.models.common import ModelConfig
from benchmarl.utils import DEVICE_TYPING
from benchmarl.utils import DEVICE_TYPING, read_yaml_config


class Algorithm(ABC):
Expand All @@ -28,11 +29,13 @@ def __init__(
state_spec: Optional[CompositeSpec],
action_mask_spec: Optional[CompositeSpec],
group_map: Dict[str, List[str]],
on_policy: bool,
):
self.device: DEVICE_TYPING = experiment_config.train_device

self.experiment_config = experiment_config
self.model_config = model_config
self.on_policy = on_policy
self.group_map = group_map
self.observation_spec = observation_spec
self.action_spec = action_spec
Expand Down Expand Up @@ -109,18 +112,18 @@ def get_loss_and_updater(self, group: str) -> Tuple[LossModule, TargetNetUpdater
def get_replay_buffer(
self,
group: str,
) -> Dict[str, ReplayBuffer]:
) -> ReplayBuffer:
return self._get_replay_buffer(
group=group,
memory_size=self.experiment_config.replay_buffer_memory_size(
self.on_policy()
self.on_policy
),
sampling_size=self.experiment_config.train_minibatch_size(self.on_policy()),
sampling_size=self.experiment_config.train_minibatch_size(self.on_policy),
traj_len=self.experiment_config.traj_len,
storing_device=self.device,
)

def get_policy_for_loss(self, group: str) -> List[TensorDictModule]:
def get_policy_for_loss(self, group: str) -> TensorDictModule:
if group not in self._policies_for_loss.keys():
action_space = self.action_spec[group, "action"]
continuous = not isinstance(
Expand Down Expand Up @@ -210,24 +213,9 @@ def process_loss_vals(
) -> TensorDictBase:
return loss_vals

@staticmethod
@abstractmethod
def on_policy() -> bool:
raise NotImplementedError

@staticmethod
@abstractmethod
def supports_continuous_actions() -> bool:
raise NotImplementedError

@staticmethod
@abstractmethod
def supports_discrete_actions() -> bool:
raise NotImplementedError


@dataclass
class AlgorithmConfig(ABC):
class AlgorithmConfig:
def get_algorithm(
self,
experiment_config,
Expand All @@ -247,9 +235,40 @@ def get_algorithm(
state_spec=state_spec,
action_mask_spec=action_mask_spec,
group_map=group_map,
on_policy=self.on_policy(),
)

@staticmethod
def _load_from_yaml(name: str) -> Dict[str, Any]:
yaml_path = (
pathlib.Path(__file__).parent.parent
/ "conf"
/ "algorithm"
/ f"{name.lower()}.yaml"
)
return read_yaml_config(str(yaml_path.resolve()))

@staticmethod
@abstractmethod
def get_from_yaml(path: Optional[str] = None):
raise NotImplementedError

@staticmethod
@abstractmethod
def associated_class() -> Type[Algorithm]:
raise NotImplementedError

@staticmethod
@abstractmethod
def on_policy() -> bool:
raise NotImplementedError

@staticmethod
@abstractmethod
def supports_continuous_actions() -> bool:
raise NotImplementedError

@staticmethod
@abstractmethod
def supports_discrete_actions() -> bool:
raise NotImplementedError
56 changes: 32 additions & 24 deletions benchmarl/algorithms/iddpg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Dict, Type
from dataclasses import dataclass, MISSING
from typing import Dict, Optional, Type

import torch
from black import Tuple
Expand All @@ -25,7 +25,7 @@

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import DEVICE_TYPING
from benchmarl.utils import DEVICE_TYPING, read_yaml_config


class Iddpg(Algorithm):
Expand Down Expand Up @@ -154,8 +154,8 @@ def _get_policy_for_loss(
out_keys=[(group, "action")],
distribution_class=TanhDelta,
distribution_kwargs={
"min": self.action_spec[(group, "action")].space.minimum,
"max": self.action_spec[(group, "action")].space.maximum,
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
},
return_log_prob=False,
)
Expand Down Expand Up @@ -197,18 +197,6 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:

return batch

@staticmethod
def supports_continuous_actions() -> bool:
return True

@staticmethod
def supports_discrete_actions() -> bool:
return False

@staticmethod
def on_policy() -> bool:
return False

#####################
# Custom new methods
#####################
Expand Down Expand Up @@ -271,14 +259,34 @@ def get_value_module(self, group: str) -> TensorDictModule:

@dataclass
class IddpgConfig(AlgorithmConfig):
# You can add any kwargs from benchmarl.algorithms.Iddpg

share_param_actor: bool = True
share_param_critic: bool = True

loss_function: str = "l2"
delay_value: bool = True
share_param_actor: bool = MISSING
share_param_critic: bool = MISSING
loss_function: str = MISSING
delay_value: bool = MISSING

@staticmethod
def associated_class() -> Type[Algorithm]:
return Iddpg

@staticmethod
def supports_continuous_actions() -> bool:
return True

@staticmethod
def supports_discrete_actions() -> bool:
return False

@staticmethod
def on_policy() -> bool:
return False

@staticmethod
def get_from_yaml(path: Optional[str] = None):
if path is None:
return IddpgConfig(
**AlgorithmConfig._load_from_yaml(
name=IddpgConfig.associated_class().__name__,
)
)
else:
return IddpgConfig(**read_yaml_config(path))
61 changes: 35 additions & 26 deletions benchmarl/algorithms/ippo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Dict, Type
from dataclasses import dataclass, MISSING
from typing import Dict, Optional, Type

import torch
from black import Tuple
Expand All @@ -22,7 +22,7 @@

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import DEVICE_TYPING
from benchmarl.utils import DEVICE_TYPING, read_yaml_config


class Ippo(Algorithm):
Expand Down Expand Up @@ -165,8 +165,8 @@ def _get_policy_for_loss(
out_keys=[(group, "action")],
distribution_class=TanhNormal,
distribution_kwargs={
"min": self.action_spec[(group, "action")].space.minimum,
"max": self.action_spec[(group, "action")].space.maximum,
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
},
return_log_prob=True,
log_prob_key=(group, "log_prob"),
Expand Down Expand Up @@ -243,18 +243,6 @@ def process_loss_vals(
del loss_vals["loss_entropy"]
return loss_vals

@staticmethod
def supports_continuous_actions() -> bool:
return True

@staticmethod
def supports_discrete_actions() -> bool:
return True

@staticmethod
def on_policy() -> bool:
return True

#####################
# Custom new methods
#####################
Expand Down Expand Up @@ -298,17 +286,38 @@ def get_critic(self, group: str) -> TensorDictModule:

@dataclass
class IppoConfig(AlgorithmConfig):
# You can add any kwargs from benchmarl.algorithms.Ippo

share_param_actor: bool = True
share_param_critic: bool = True

clip_epsilon: float = 0.2
entropy_coef: bool = 0.0
critic_coef: float = 1.0
loss_critic_type: str = "l2"
lmbda: float = 0.9
share_param_actor: bool = MISSING
share_param_critic: bool = MISSING
clip_epsilon: float = MISSING
entropy_coef: float = MISSING
critic_coef: float = MISSING
loss_critic_type: str = MISSING
lmbda: float = MISSING

@staticmethod
def associated_class() -> Type[Algorithm]:
return Ippo

@staticmethod
def supports_continuous_actions() -> bool:
return True

@staticmethod
def supports_discrete_actions() -> bool:
return True

@staticmethod
def on_policy() -> bool:
return True

@staticmethod
def get_from_yaml(path: Optional[str] = None):
if path is None:
return IppoConfig(
**AlgorithmConfig._load_from_yaml(
name=IppoConfig.associated_class().__name__,
)
)
else:
return IppoConfig(**read_yaml_config(path))
Loading

0 comments on commit 0b3afdb

Please sign in to comment.