Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jul 31, 2024
1 parent e262958 commit 26831b1
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 24 deletions.
4 changes: 4 additions & 0 deletions benchmarl/conf/model/layers/gru.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
name: gru

hidden_size: 128
n_layers: 1
bias: True
dropout: 0
compile: False

mlp_num_cells: [256, 256]
mlp_layer_class: torch.nn.Linear
Expand Down
155 changes: 132 additions & 23 deletions benchmarl/models/gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Optional, Sequence, Type

import torch
import torch.nn.functional as F
from tensordict import TensorDict, TensorDictBase
from tensordict.utils import expand_as_right, unravel_key_list
from torch import nn
Expand All @@ -33,15 +34,31 @@ def __init__(
input_size: int,
hidden_size: int,
device: DEVICE_TYPING,
n_layers: int,
dropout: float,
bias: bool,
time_dim: int = -2,
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.device = device
self.time_dim = time_dim
self.n_layers = n_layers
self.dropout = dropout
self.bias = bias

self.gru = GRUCell(input_size, hidden_size, device=self.device)
self.grus = torch.nn.ModuleList(
[
GRUCell(
input_size if i == 0 else hidden_size,
hidden_size,
device=self.device,
bias=self.bias,
)
for i in range(self.n_layers)
]
)

def forward(
self,
Expand All @@ -50,18 +67,41 @@ def forward(
h,
):
hs = []
h = list(h.unbind(dim=-2))
for in_t, init_t in zip(
input.unbind(self.time_dim), is_init.unbind(self.time_dim)
):
h = torch.where(init_t, 0, h)
h = self.gru(in_t, h)
hs.append(h)
h_n = h
for layer in range(self.n_layers):
h[layer] = torch.where(init_t, 0, h[layer])

h[layer] = self.grus[layer](in_t, h[layer])

if layer < self.n_layers - 1 and self.dropout:
in_t = F.dropout(h[layer], p=self.dropout, training=self.training)
else:
in_t = h[layer]

hs.append(in_t)
h_n = torch.stack(h, dim=-2)
output = torch.stack(hs, self.time_dim)

return output, h_n


def get_net(input_size, hidden_size, n_layers, bias, device, dropout, compile):
gru = GRU(
input_size,
hidden_size,
n_layers=n_layers,
bias=bias,
device=device,
dropout=dropout,
)
if compile:
gru = torch.compile(gru, mode="reduce-overhead")
return gru


class MultiAgentGRU(torch.nn.Module):
def __init__(
self,
Expand All @@ -71,6 +111,10 @@ def __init__(
device: DEVICE_TYPING,
centralised: bool,
share_params: bool,
n_layers: int,
dropout: float,
bias: bool,
compile: bool,
):
super().__init__()
self.input_size = input_size
Expand All @@ -79,25 +123,38 @@ def __init__(
self.device = device
self.centralised = centralised
self.share_params = share_params
self.n_layers = n_layers
self.bias = bias
self.dropout = dropout
self.compile = compile

if self.centralised:
input_size = input_size * self.n_agents

def get_net(device):
return GRU(
input_size,
hidden_size,
device=device,
)

agent_networks = [
get_net(device=self.device)
get_net(
input_size=input_size,
hidden_size=self.hidden_size,
n_layers=self.n_layers,
bias=self.bias,
device=self.device,
dropout=self.dropout,
compile=self.compile,
)
for _ in range(self.n_agents if not self.share_params else 1)
]
self._make_params(agent_networks)

with torch.device("meta"):
self._empty_gru = get_net(device="meta")
self._empty_gru = get_net(
input_size=input_size,
hidden_size=self.hidden_size,
n_layers=self.n_layers,
bias=self.bias,
device="meta",
dropout=self.dropout,
compile=self.compile,
)
# Remove all parameters
TensorDict.from_module(self._empty_gru).data.to("meta").to_module(
self._empty_gru
Expand Down Expand Up @@ -148,12 +205,14 @@ def forward(
if self.centralised:
shape = (
batch,
self.n_layers,
self.hidden_size,
)
else:
shape = (
batch,
self.n_agents,
self.n_layers,
self.hidden_size,
)
h_0 = torch.zeros(
Expand Down Expand Up @@ -190,17 +249,17 @@ def run_net(self, input, is_init, h_0):
else:
output, h_n = self.vmap_func_module(
self._empty_gru,
(0, -2, -2, -2),
(0, -2, -2, -3),
(-2, -2),
)(self.params, input, is_init, h_0)
else:
with self.params.to_module(self._empty_gru):
if self.centralised:
output, h_n = self._empty_gru(input, is_init, h_0)
else:
output, h_n = torch.vmap(self._empty_gru, in_dims=-2, out_dims=-2)(
input, is_init, h_0
)
output, h_n = torch.vmap(
self._empty_gru, in_dims=(-2, -2, -3), out_dims=(-2, -3)
)(input, is_init, h_0)

return output, h_n

Expand All @@ -219,9 +278,39 @@ def _make_params(self, agent_networks):


class Gru(Model):
r"""A multi-layer Gated Recurrent Unit (GRU) RNN like the one from
`torch <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`__ .
The BenchMARL GRU accepts multiple inputs of 2 types:
- multi-agent arrays: Tensors of shape ``(*batch,A,F)``
- arrays: Tensors of shape ``(*batch,F)``
Where `A` is the number of agents and `F` is the number of features.
The features `F` will be processed in features of `hidden_size`
Args:
hidden_size (int): The number of features in the hidden state.
num_layers (int): Number of recurrent layers. E.g., setting ``num_layers=2``
would mean stacking two GRUs together to form a `stacked GRU`,
with the second GRU taking in outputs of the first GRU and
computing the final results. Default: 1
bias (bool): If ``False``, then the GRU layers do not use bias.
Default: ``True``
dropout (float): If non-zero, introduces a `Dropout` layer on the outputs of each
GRU layer except the last layer, with dropout probability equal to
:attr:`dropout`. Default: 0
compile (bool): If ``True``, compiles underlying gru model. Default: ``False``
"""

def __init__(
self,
hidden_size: int,
n_layers: int,
bias: bool,
dropout: float,
compile: bool,
**kwargs,
):

Expand All @@ -244,6 +333,10 @@ def __init__(
self.in_keys += self.rnn_keys

self.hidden_size = hidden_size
self.n_layers = n_layers
self.bias = bias
self.dropout = dropout
self.compile = compile

self.input_features = sum(
[spec.shape[-1] for spec in self.input_spec.values(True, True)]
Expand All @@ -256,16 +349,24 @@ def __init__(
self.hidden_size,
self.n_agents,
self.device,
bias=self.bias,
n_layers=self.n_layers,
centralised=self.centralised,
share_params=self.share_params,
dropout=self.dropout,
compile=self.compile,
)
else:
self.gru = nn.ModuleList(
[
GRU(
self.input_features,
self.hidden_size,
get_net(
input_size=self.input_features,
hidden_size=self.hidden_size,
n_layers=self.n_layers,
bias=self.bias,
device=self.device,
dropout=self.dropout,
compile=self.compile,
)
for _ in range(self.n_agents if not self.share_params else 1)
]
Expand Down Expand Up @@ -361,7 +462,7 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
assert is_init.shape == (batch, seq, 1)

h_0 = torch.zeros(
(batch, self.hidden_size),
(batch, self.n_layers, self.hidden_size),
device=self.device,
dtype=torch.float,
)
Expand Down Expand Up @@ -397,6 +498,10 @@ class GruConfig(ModelConfig):
"""Dataclass config for a :class:`~benchmarl.models.Gru`."""

hidden_size: int = MISSING
n_layers: int = MISSING
bias: bool = MISSING
dropout: float = MISSING
compile: bool = MISSING

mlp_num_cells: Sequence[int] = MISSING
mlp_layer_class: Type[nn.Module] = MISSING
Expand All @@ -417,6 +522,10 @@ def is_rnn(self) -> bool:
def get_model_state_spec(self, model_index: int = 0) -> CompositeSpec:
name = f"_hidden_gru_{model_index}"
spec = CompositeSpec(
{name: UnboundedContinuousTensorSpec(shape=(self.hidden_size,))}
{
name: UnboundedContinuousTensorSpec(
shape=(self.n_layers, self.hidden_size)
)
}
)
return spec
8 changes: 7 additions & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def test_loading_sequence_models(model_name, intermediate_size=10):
["cnn", "gnn", "mlp"],
["cnn", "mlp", "gnn"],
["cnn", "mlp"],
["cnn", "gru", "gnn", "mlp"],
["cnn", "gru", "mlp"],
["gru", "mlp"],
],
)
def test_models_forward_shape(
Expand Down Expand Up @@ -197,7 +200,7 @@ def test_models_forward_shape(
action_spec=None,
)
input_td = input_spec.rand()
if model_name == "gru":
if "gru" in model_name:
if len(batch_size) < 2:
if centralised:
pytest.skip("gru model with this batch sizes is a policy")
Expand Down Expand Up @@ -226,6 +229,9 @@ def test_models_forward_shape(
["cnn", "gnn", "mlp"],
["cnn", "mlp", "gnn"],
["cnn", "mlp"],
["cnn", "gru", "gnn", "mlp"],
["cnn", "gru", "mlp"],
["gru", "mlp"],
],
)
@pytest.mark.parametrize("batch_size", [(), (2,), (3, 2)])
Expand Down

0 comments on commit 26831b1

Please sign in to comment.