Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] GRU and general RNN compatibility #116

Merged
merged 27 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,16 +263,11 @@ agent group. Here is a table of the models implemented in BenchMARL
| Name | Decentralized | Centralized with local inputs | Centralized with global input |
|------------------------------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes |
| [GRU](benchmarl/models/gru.py) | Yes | Yes | Yes |
| [GNN](benchmarl/models/gnn.py) | Yes | Yes | No |
| [CNN](benchmarl/models/cnn.py) | Yes | Yes | Yes |
| [Deepsets](benchmarl/models/deepsets.py) | Yes | Yes | Yes |

And the ones that are _work in progress_

| Name | Decentralized | Centralized with local inputs | Centralized with global input |
|--------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| RNN (GRU and LSTM) | Yes | Yes | Yes |


## Fine-tuned public benchmarks
> [!WARNING]
Expand Down
102 changes: 100 additions & 2 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,27 @@
import pathlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type

from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import (
CompositeSpec,
DiscreteTensorSpec,
LazyTensorStorage,
OneHotDiscreteTensorSpec,
ReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement
from torchrl.envs import Compose, Transform
from torchrl.envs import (
Compose,
EnvBase,
InitTracker,
TensorDictPrimer,
Transform,
TransformedEnv,
)
from torchrl.objectives import LossModule
from torchrl.objectives.utils import HardUpdate, SoftUpdate, TargetNetUpdater

Expand Down Expand Up @@ -51,6 +59,16 @@ def __init__(self, experiment):
self.action_spec = experiment.action_spec
self.state_spec = experiment.state_spec
self.action_mask_spec = experiment.action_mask_spec
self.has_independent_critic = (
experiment.algorithm_config.has_independent_critic()
)
self.has_centralized_critic = (
experiment.algorithm_config.has_centralized_critic()
)
self.has_critic = experiment.algorithm_config.has_critic
self.has_rnn = self.model_config.is_rnn or (
self.critic_model_config.is_rnn and self.has_critic
)

# Cached values that will be instantiated only once and then remain fixed
self._losses_and_updaters = {}
Expand Down Expand Up @@ -142,6 +160,14 @@ def get_replay_buffer(
"""
memory_size = self.experiment_config.replay_buffer_memory_size(self.on_policy)
sampling_size = self.experiment_config.train_minibatch_size(self.on_policy)
if self.has_rnn:
sequence_length = -(
-self.experiment_config.collected_frames_per_batch(self.on_policy)
// self.experiment_config.n_envs_per_worker(self.on_policy)
)
memory_size = -(-memory_size // sequence_length)
sampling_size = -(-sampling_size // sequence_length)

sampler = SamplerWithoutReplacement() if self.on_policy else RandomSampler()
return TensorDictReplayBuffer(
storage=LazyTensorStorage(
Expand Down Expand Up @@ -218,6 +244,54 @@ def get_parameters(self, group: str) -> Dict[str, Iterable]:
loss=self.get_loss_and_updater(group)[0],
)

def process_env_fun(
self,
env_fun: Callable[[], EnvBase],
) -> Callable[[], EnvBase]:
"""
This function can be used to wrap env_fun

Args:
env_fun (callable): a function that takes no args and creates an enviornment

Returns: a function that takes no args and creates an enviornment

"""
if self.has_rnn:

def model_fun():
env = env_fun()

spec_actor = self.model_config.get_model_state_spec()
spec_actor = CompositeSpec(
{
group: CompositeSpec(
spec_actor.expand(len(agents), *spec_actor.shape),
shape=(len(agents),),
)
for group, agents in self.group_map.items()
}
)

env = TransformedEnv(
env,
Compose(
*(
[InitTracker(init_key="is_init")]
+ (
[TensorDictPrimer(spec_actor, reset_key="_reset")]
if len(spec_actor.keys(True, True)) > 0
else []
)
)
),
)
return env

return model_fun

return env_fun

###############################
# Abstract methods to implement
###############################
Expand Down Expand Up @@ -399,3 +473,27 @@ def supports_discrete_actions() -> bool:
If the algorithm supports discrete actions
"""
raise NotImplementedError

@staticmethod
def has_independent_critic() -> bool:
"""
If the algorithm uses an independent critic
"""
return False

@staticmethod
def has_centralized_critic() -> bool:
"""
If the algorithm uses a centralized critic
"""
return False

def has_critic(self) -> bool:
"""
If the algorithm uses a critic
"""
if self.has_centralized_critic() and self.has_independent_critic():
raise ValueError(
"Algorithm can either have a centralized critic or an indpendent one"
)
return self.has_centralized_critic() or self.has_independent_critic()
4 changes: 4 additions & 0 deletions benchmarl/algorithms/iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,7 @@ def supports_discrete_actions() -> bool:
@staticmethod
def on_policy() -> bool:
return False

@staticmethod
def has_independent_critic() -> bool:
return True
4 changes: 4 additions & 0 deletions benchmarl/algorithms/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,7 @@ def supports_discrete_actions() -> bool:
@staticmethod
def on_policy() -> bool:
return True

@staticmethod
def has_independent_critic() -> bool:
return True
4 changes: 4 additions & 0 deletions benchmarl/algorithms/isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,7 @@ def supports_discrete_actions() -> bool:
@staticmethod
def on_policy() -> bool:
return False

@staticmethod
def has_independent_critic() -> bool:
return True
4 changes: 4 additions & 0 deletions benchmarl/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,7 @@ def supports_discrete_actions() -> bool:
@staticmethod
def on_policy() -> bool:
return False

@staticmethod
def has_centralized_critic() -> bool:
return True
4 changes: 4 additions & 0 deletions benchmarl/algorithms/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,7 @@ def supports_discrete_actions() -> bool:
@staticmethod
def on_policy() -> bool:
return True

@staticmethod
def has_centralized_critic() -> bool:
return True
4 changes: 4 additions & 0 deletions benchmarl/algorithms/masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,3 +463,7 @@ def supports_discrete_actions() -> bool:
@staticmethod
def on_policy() -> bool:
return False

@staticmethod
def has_centralized_critic() -> bool:
return True
15 changes: 15 additions & 0 deletions benchmarl/conf/model/layers/gru.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

name: gru

hidden_size: 128
n_layers: 1
bias: True
dropout: 0
compile: False

mlp_num_cells: [256, 256]
mlp_layer_class: torch.nn.Linear
mlp_activation_class: torch.nn.Tanh
mlp_activation_kwargs: null
mlp_norm_class: null
mlp_norm_kwargs: null
49 changes: 33 additions & 16 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from torchrl.record.loggers import generate_exp_name
from tqdm import tqdm

from benchmarl.algorithms import IsacConfig, MasacConfig

from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task
from benchmarl.experiment.callback import Callback, CallbackNotifier
Expand Down Expand Up @@ -322,8 +324,12 @@ def __init__(
self.task = task
self.model_config = model_config
self.critic_model_config = (
critic_model_config if critic_model_config is not None else model_config
critic_model_config
if critic_model_config is not None
else copy.deepcopy(model_config)
)
self.critic_model_config.is_critic = True

self.algorithm_config = algorithm_config
self.seed = seed

Expand All @@ -345,6 +351,7 @@ def on_policy(self) -> bool:
def _setup(self):
self.config.validate(self.on_policy)
seed_everything(self.seed)
self._perfrom_checks()
self._set_action_type()
self._setup_task()
self._setup_algorithm()
Expand All @@ -353,6 +360,15 @@ def _setup(self):
self._setup_logger()
self._on_setup()

def _perfrom_checks(self):
if isinstance(self.algorithm_config, (MasacConfig, IsacConfig)) and (
self.model_config.is_rnn or self.critic_model_config.is_rnn
):
raise ValueError(
"SAC based losses not compatible with RNNs due to https://github.com/pytorch/rl/issues/2338."
" Please leave a comment on the issue if you would like this feature."
)

def _set_action_type(self):
if (
self.task.supports_continuous_actions()
Expand All @@ -377,21 +393,17 @@ def _set_action_type(self):
)

def _setup_task(self):
test_env = self.model_config.process_env_fun(
self.task.get_env_fun(
num_envs=self.config.evaluation_episodes,
continuous_actions=self.continuous_actions,
seed=self.seed,
device=self.config.sampling_device,
)
test_env = self.task.get_env_fun(
num_envs=self.config.evaluation_episodes,
continuous_actions=self.continuous_actions,
seed=self.seed,
device=self.config.sampling_device,
)()
env_func = self.model_config.process_env_fun(
self.task.get_env_fun(
num_envs=self.config.n_envs_per_worker(self.on_policy),
continuous_actions=self.continuous_actions,
seed=self.seed,
device=self.config.sampling_device,
)
env_func = self.task.get_env_fun(
num_envs=self.config.n_envs_per_worker(self.on_policy),
continuous_actions=self.continuous_actions,
seed=self.seed,
device=self.config.sampling_device,
)

transforms_env = self.task.get_env_transforms(test_env)
Expand Down Expand Up @@ -427,6 +439,10 @@ def _setup_task(self):

def _setup_algorithm(self):
self.algorithm = self.algorithm_config.get_algorithm(experiment=self)

self.test_env = self.algorithm.process_env_fun(lambda: self.test_env)()
self.env_func = self.algorithm.process_env_fun(self.env_func)

self.replay_buffers = {
group: self.algorithm.get_replay_buffer(
group=group,
Expand Down Expand Up @@ -610,7 +626,8 @@ def _collection_loop(self):
for group in self.train_group_map.keys():
group_batch = batch.exclude(*self._get_excluded_keys(group))
group_batch = self.algorithm.process_batch(group, group_batch)
group_batch = group_batch.reshape(-1)
if not self.algorithm.has_rnn:
group_batch = group_batch.reshape(-1)
self.replay_buffers[group].extend(group_batch)

training_tds = []
Expand Down
4 changes: 4 additions & 0 deletions benchmarl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .common import Model, ModelConfig, SequenceModel, SequenceModelConfig
from .deepsets import Deepsets, DeepsetsConfig
from .gnn import Gnn, GnnConfig
from .gru import Gru, GruConfig
from .mlp import Mlp, MlpConfig

classes = [
Expand All @@ -19,11 +20,14 @@
"CnnConfig",
"Deepsets",
"DeepsetsConfig",
"Gru",
"GruConfig",
]

model_config_registry = {
"mlp": MlpConfig,
"gnn": GnnConfig,
"cnn": CnnConfig,
"deepsets": DeepsetsConfig,
"gru": GruConfig,
}
2 changes: 2 additions & 0 deletions benchmarl/models/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def __init__(
share_params=kwargs.pop("share_params"),
device=kwargs.pop("device"),
action_spec=kwargs.pop("action_spec"),
model_index=kwargs.pop("model_index"),
is_critic=kwargs.pop("is_critic"),
)

self.x = self.input_spec[self.image_in_keys[0]].shape[-3]
Expand Down
Loading
Loading