Skip to content

Commit

Permalink
[Loggers] Logging compatibility (#11)
Browse files Browse the repository at this point in the history
* amend

Signed-off-by: Matteo Bettini <[email protected]>

* amend

Signed-off-by: Matteo Bettini <[email protected]>

* amend

Signed-off-by: Matteo Bettini <[email protected]>

* amend

Signed-off-by: Matteo Bettini <[email protected]>

* amend

Signed-off-by: Matteo Bettini <[email protected]>

* amend

Signed-off-by: Matteo Bettini <[email protected]>

* ippo one optim

Signed-off-by: Matteo Bettini <[email protected]>

* revert ippo

Signed-off-by: Matteo Bettini <[email protected]>

* cfg

Signed-off-by: Matteo Bettini <[email protected]>

* update task config

Signed-off-by: Matteo Bettini <[email protected]>

* remove callback

Signed-off-by: Matteo Bettini <[email protected]>

* readme loggers

Signed-off-by: Matteo Bettini <[email protected]>

* amend

Signed-off-by: Matteo Bettini <[email protected]>

---------

Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Sep 13, 2023
1 parent c265330 commit 040b949
Show file tree
Hide file tree
Showing 19 changed files with 819 additions and 289 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,16 @@ Configuring a layer
```bash
python hydra_run.py "model.layers.l1.num_cells=[3]"
```

## Logging

BenchMARL is compatible with the [TorchRL loggers](https://github.com/pytorch/rl/tree/main/torchrl/record/loggers).
A list of logger names can be provided in the [experiment config](benchmarl/conf/experiment/base_experiment.yaml).
Example of available options are: `wandb`, `csv`, `mflow`, `tensorboard` or any other option available in TorchRL. You can specify the loggers
in the yaml config files or in the script arguments like so:
```bash
python hydra_run.py "experiment.loggers=[wandb]"
```

Additionally, you can specify a `create_json` argument which instructs the trainer to output a `.json` file in the
format specified by [marl-eval](https://github.com/instadeepai/marl-eval).
2 changes: 1 addition & 1 deletion benchmarl/algorithms/iql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, MISSING
from typing import Dict, Optional, Type, Tuple
from typing import Dict, Optional, Tuple, Type

import torch
from tensordict import TensorDictBase
Expand Down
Empty file.
4 changes: 2 additions & 2 deletions benchmarl/conf/algorithm/ippo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ defaults:
- _self_


share_param_actor: False
share_param_critic: False
share_param_actor: True
share_param_critic: True
clip_epsilon: 0.2
entropy_coef: 0.0
critic_coef: 1.0
Expand Down
Empty file.
27 changes: 18 additions & 9 deletions benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,29 @@ defaults:
- experiment_config
- _self_

sampling_device: "cpu"
train_device: "cpu"
gamma: 0.9
sampling_device: "cuda"
train_device: "cuda"
gamma: 0.99
polyak_tau: 0.005
lr: 0.000003
n_optimizer_steps: 10
collected_frames_per_batch: 1000
n_collection_envs: 1
n_iters: 3
lr: 0.00005
n_optimizer_steps: 45
collected_frames_per_batch: 60_000
n_envs_per_worker: 600
n_iters: 500
prefer_continuous_actions: True
clip_grad_norm: True
clip_grad_val: 40

on_policy_minibatch_size: 100
on_policy_minibatch_size: 4096

off_policy_memory_size: 100_000
off_policy_train_batch_size: 10_000
off_policy_prioritised_alpha: 0.7
off_policy_prioritised_beta: 0.5

evaluation: True
evaluation_interval: 30
evaluation_episodes: 200

loggers: [wandb]
create_json: True
Empty file removed benchmarl/conf/task/__init__.py
Empty file.
3 changes: 2 additions & 1 deletion benchmarl/conf/task/vmas/balance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ defaults:
- vmas_balance_config


max_steps: 3
max_steps: 100
n_agents: 3
16 changes: 13 additions & 3 deletions benchmarl/environments/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os.path as osp
import pathlib
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

from torchrl.data import CompositeSpec
from torchrl.envs import EnvBase
Expand Down Expand Up @@ -49,12 +49,12 @@ def update_config(self, config: Dict[str, Any]):
self.config.update(config)
return self

def get_env(
def get_env_fun(
self,
num_envs: int,
continuous_actions: bool,
seed: Optional[int],
) -> EnvBase:
) -> Callable[[], EnvBase]:
raise NotImplementedError

def supports_continuous_actions(self) -> bool:
Expand All @@ -63,6 +63,12 @@ def supports_continuous_actions(self) -> bool:
def supports_discrete_actions(self) -> bool:
raise NotImplementedError

def max_steps(self) -> int:
raise NotImplementedError

def has_render(self) -> bool:
raise NotImplementedError

def group_map(self, env: EnvBase) -> Dict[str, List[str]]:
raise NotImplementedError

Expand All @@ -84,6 +90,10 @@ def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
def get_from_yaml(self, path: Optional[str] = None):
raise NotImplementedError

@staticmethod
def env_name() -> str:
return "vmas"

def __repr__(self):
cls_name = self.__class__.__name__
return f"{cls_name}.{self.name}: (config={self.config})"
Expand Down
1 change: 1 addition & 0 deletions benchmarl/environments/vmas/balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
@dataclass
class TaskConfig:
max_steps: int = MISSING
n_agents: int = MISSING
20 changes: 15 additions & 5 deletions benchmarl/environments/vmas/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional

from torchrl.data import CompositeSpec
from torchrl.envs import EnvBase
Expand All @@ -12,13 +12,13 @@
class VmasTask(Task):
BALANCE = None

def get_env(
def get_env_fun(
self,
num_envs: int,
continuous_actions: bool,
seed: Optional[int],
) -> EnvBase:
return VmasEnv(
) -> Callable[[], EnvBase]:
return lambda: VmasEnv(
scenario=self.name.lower(),
num_envs=num_envs,
continuous_actions=continuous_actions,
Expand All @@ -33,6 +33,12 @@ def supports_continuous_actions(self) -> bool:
def supports_discrete_actions(self) -> bool:
return True

def has_render(self) -> bool:
return True

def max_steps(self) -> bool:
return self.config["max_steps"]

def group_map(self, env: EnvBase) -> Dict[str, List[str]]:
return {"agents": [agent.name for agent in env.agents]}

Expand All @@ -55,11 +61,15 @@ def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
def action_spec(self, env: EnvBase) -> CompositeSpec:
return env.unbatched_action_spec

@staticmethod
def env_name() -> str:
return "vmas"

def get_from_yaml(self, path: Optional[str] = None):
if path is None:
task_name = self.name.lower()
return self.update_config(
Task._load_from_yaml(str(Path("vmas") / Path(task_name)))
Task._load_from_yaml(str(Path(self.env_name()) / Path(task_name)))
)
else:
return self.update_config(**read_yaml_config(path))
Expand Down
Loading

0 comments on commit 040b949

Please sign in to comment.