From 26831b12b5ef3e12cfaf76c6b946e4496d4be89e Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 31 Jul 2024 20:05:01 +0200 Subject: [PATCH] amend --- benchmarl/conf/model/layers/gru.yaml | 4 + benchmarl/models/gru.py | 155 +++++++++++++++++++++++---- test/test_models.py | 8 +- 3 files changed, 143 insertions(+), 24 deletions(-) diff --git a/benchmarl/conf/model/layers/gru.yaml b/benchmarl/conf/model/layers/gru.yaml index 5f0b47da..b882b159 100644 --- a/benchmarl/conf/model/layers/gru.yaml +++ b/benchmarl/conf/model/layers/gru.yaml @@ -2,6 +2,10 @@ name: gru hidden_size: 128 +n_layers: 1 +bias: True +dropout: 0 +compile: False mlp_num_cells: [256, 256] mlp_layer_class: torch.nn.Linear diff --git a/benchmarl/models/gru.py b/benchmarl/models/gru.py index 4eab9614..d65e21e3 100644 --- a/benchmarl/models/gru.py +++ b/benchmarl/models/gru.py @@ -16,6 +16,7 @@ from typing import Optional, Sequence, Type import torch +import torch.nn.functional as F from tensordict import TensorDict, TensorDictBase from tensordict.utils import expand_as_right, unravel_key_list from torch import nn @@ -33,6 +34,9 @@ def __init__( input_size: int, hidden_size: int, device: DEVICE_TYPING, + n_layers: int, + dropout: float, + bias: bool, time_dim: int = -2, ): super().__init__() @@ -40,8 +44,21 @@ def __init__( self.hidden_size = hidden_size self.device = device self.time_dim = time_dim + self.n_layers = n_layers + self.dropout = dropout + self.bias = bias - self.gru = GRUCell(input_size, hidden_size, device=self.device) + self.grus = torch.nn.ModuleList( + [ + GRUCell( + input_size if i == 0 else hidden_size, + hidden_size, + device=self.device, + bias=self.bias, + ) + for i in range(self.n_layers) + ] + ) def forward( self, @@ -50,18 +67,41 @@ def forward( h, ): hs = [] + h = list(h.unbind(dim=-2)) for in_t, init_t in zip( input.unbind(self.time_dim), is_init.unbind(self.time_dim) ): - h = torch.where(init_t, 0, h) - h = self.gru(in_t, h) - hs.append(h) - h_n = h + for layer in range(self.n_layers): + h[layer] = torch.where(init_t, 0, h[layer]) + + h[layer] = self.grus[layer](in_t, h[layer]) + + if layer < self.n_layers - 1 and self.dropout: + in_t = F.dropout(h[layer], p=self.dropout, training=self.training) + else: + in_t = h[layer] + + hs.append(in_t) + h_n = torch.stack(h, dim=-2) output = torch.stack(hs, self.time_dim) return output, h_n +def get_net(input_size, hidden_size, n_layers, bias, device, dropout, compile): + gru = GRU( + input_size, + hidden_size, + n_layers=n_layers, + bias=bias, + device=device, + dropout=dropout, + ) + if compile: + gru = torch.compile(gru, mode="reduce-overhead") + return gru + + class MultiAgentGRU(torch.nn.Module): def __init__( self, @@ -71,6 +111,10 @@ def __init__( device: DEVICE_TYPING, centralised: bool, share_params: bool, + n_layers: int, + dropout: float, + bias: bool, + compile: bool, ): super().__init__() self.input_size = input_size @@ -79,25 +123,38 @@ def __init__( self.device = device self.centralised = centralised self.share_params = share_params + self.n_layers = n_layers + self.bias = bias + self.dropout = dropout + self.compile = compile if self.centralised: input_size = input_size * self.n_agents - def get_net(device): - return GRU( - input_size, - hidden_size, - device=device, - ) - agent_networks = [ - get_net(device=self.device) + get_net( + input_size=input_size, + hidden_size=self.hidden_size, + n_layers=self.n_layers, + bias=self.bias, + device=self.device, + dropout=self.dropout, + compile=self.compile, + ) for _ in range(self.n_agents if not self.share_params else 1) ] self._make_params(agent_networks) with torch.device("meta"): - self._empty_gru = get_net(device="meta") + self._empty_gru = get_net( + input_size=input_size, + hidden_size=self.hidden_size, + n_layers=self.n_layers, + bias=self.bias, + device="meta", + dropout=self.dropout, + compile=self.compile, + ) # Remove all parameters TensorDict.from_module(self._empty_gru).data.to("meta").to_module( self._empty_gru @@ -148,12 +205,14 @@ def forward( if self.centralised: shape = ( batch, + self.n_layers, self.hidden_size, ) else: shape = ( batch, self.n_agents, + self.n_layers, self.hidden_size, ) h_0 = torch.zeros( @@ -190,7 +249,7 @@ def run_net(self, input, is_init, h_0): else: output, h_n = self.vmap_func_module( self._empty_gru, - (0, -2, -2, -2), + (0, -2, -2, -3), (-2, -2), )(self.params, input, is_init, h_0) else: @@ -198,9 +257,9 @@ def run_net(self, input, is_init, h_0): if self.centralised: output, h_n = self._empty_gru(input, is_init, h_0) else: - output, h_n = torch.vmap(self._empty_gru, in_dims=-2, out_dims=-2)( - input, is_init, h_0 - ) + output, h_n = torch.vmap( + self._empty_gru, in_dims=(-2, -2, -3), out_dims=(-2, -3) + )(input, is_init, h_0) return output, h_n @@ -219,9 +278,39 @@ def _make_params(self, agent_networks): class Gru(Model): + r"""A multi-layer Gated Recurrent Unit (GRU) RNN like the one from + `torch `__ . + + The BenchMARL GRU accepts multiple inputs of 2 types: + + - multi-agent arrays: Tensors of shape ``(*batch,A,F)`` + - arrays: Tensors of shape ``(*batch,F)`` + + Where `A` is the number of agents and `F` is the number of features. + The features `F` will be processed in features of `hidden_size` + + Args: + hidden_size (int): The number of features in the hidden state. + num_layers (int): Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two GRUs together to form a `stacked GRU`, + with the second GRU taking in outputs of the first GRU and + computing the final results. Default: 1 + bias (bool): If ``False``, then the GRU layers do not use bias. + Default: ``True`` + dropout (float): If non-zero, introduces a `Dropout` layer on the outputs of each + GRU layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + compile (bool): If ``True``, compiles underlying gru model. Default: ``False`` + + """ + def __init__( self, hidden_size: int, + n_layers: int, + bias: bool, + dropout: float, + compile: bool, **kwargs, ): @@ -244,6 +333,10 @@ def __init__( self.in_keys += self.rnn_keys self.hidden_size = hidden_size + self.n_layers = n_layers + self.bias = bias + self.dropout = dropout + self.compile = compile self.input_features = sum( [spec.shape[-1] for spec in self.input_spec.values(True, True)] @@ -256,16 +349,24 @@ def __init__( self.hidden_size, self.n_agents, self.device, + bias=self.bias, + n_layers=self.n_layers, centralised=self.centralised, share_params=self.share_params, + dropout=self.dropout, + compile=self.compile, ) else: self.gru = nn.ModuleList( [ - GRU( - self.input_features, - self.hidden_size, + get_net( + input_size=self.input_features, + hidden_size=self.hidden_size, + n_layers=self.n_layers, + bias=self.bias, device=self.device, + dropout=self.dropout, + compile=self.compile, ) for _ in range(self.n_agents if not self.share_params else 1) ] @@ -361,7 +462,7 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: assert is_init.shape == (batch, seq, 1) h_0 = torch.zeros( - (batch, self.hidden_size), + (batch, self.n_layers, self.hidden_size), device=self.device, dtype=torch.float, ) @@ -397,6 +498,10 @@ class GruConfig(ModelConfig): """Dataclass config for a :class:`~benchmarl.models.Gru`.""" hidden_size: int = MISSING + n_layers: int = MISSING + bias: bool = MISSING + dropout: float = MISSING + compile: bool = MISSING mlp_num_cells: Sequence[int] = MISSING mlp_layer_class: Type[nn.Module] = MISSING @@ -417,6 +522,10 @@ def is_rnn(self) -> bool: def get_model_state_spec(self, model_index: int = 0) -> CompositeSpec: name = f"_hidden_gru_{model_index}" spec = CompositeSpec( - {name: UnboundedContinuousTensorSpec(shape=(self.hidden_size,))} + { + name: UnboundedContinuousTensorSpec( + shape=(self.n_layers, self.hidden_size) + ) + } ) return spec diff --git a/test/test_models.py b/test/test_models.py index a534710f..c27495ec 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -145,6 +145,9 @@ def test_loading_sequence_models(model_name, intermediate_size=10): ["cnn", "gnn", "mlp"], ["cnn", "mlp", "gnn"], ["cnn", "mlp"], + ["cnn", "gru", "gnn", "mlp"], + ["cnn", "gru", "mlp"], + ["gru", "mlp"], ], ) def test_models_forward_shape( @@ -197,7 +200,7 @@ def test_models_forward_shape( action_spec=None, ) input_td = input_spec.rand() - if model_name == "gru": + if "gru" in model_name: if len(batch_size) < 2: if centralised: pytest.skip("gru model with this batch sizes is a policy") @@ -226,6 +229,9 @@ def test_models_forward_shape( ["cnn", "gnn", "mlp"], ["cnn", "mlp", "gnn"], ["cnn", "mlp"], + ["cnn", "gru", "gnn", "mlp"], + ["cnn", "gru", "mlp"], + ["gru", "mlp"], ], ) @pytest.mark.parametrize("batch_size", [(), (2,), (3, 2)])