Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Nov 27, 2023
1 parent 8e6fd98 commit ca4e2c7
Show file tree
Hide file tree
Showing 23 changed files with 306 additions and 94 deletions.
Binary file added .DS_Store
Binary file not shown.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,16 @@ agent group. Here is a table of the models implemented in BenchMARL

| Name | Decentralized | Centralized with local inputs | Centralized with global input |
|--------------------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes |
| [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes |
| [GNN](benchmarl/models/gnn.py) | Yes | No | No |

And the ones that are _work in progress_

| Name | Decentralized | Centralized with local inputs | Centralized with global input |
|--------------------------------------------------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| [GNN](https://github.com/facebookresearch/BenchMARL/pull/18) | Yes | Yes | No |
| CNN | Yes | Yes | Yes |
| Name | Decentralized | Centralized with local inputs | Centralized with global input |
|--------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| CNN | Yes | Yes | Yes |
| RNN (GRU and LSTM) | Yes | Yes | Yes |


## Fine-tuned public benchmarks
> [!WARNING]
Expand Down
Binary file added benchmarl/.DS_Store
Binary file not shown.
7 changes: 7 additions & 0 deletions benchmarl/conf/model/layers/gnn.yaml
Original file line number Diff line number Diff line change
@@ -1 +1,8 @@
name: gnn

topology: full
self_loops: False

gnn_class: torch_geometric.nn.conv.GraphConv
gnn_kwargs:
aggr: "add"
1 change: 0 additions & 1 deletion benchmarl/conf/model/layers/mlp.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@


name: mlp

num_cells: [256, 256]
Expand Down
Binary file added benchmarl/conf/task/.DS_Store
Binary file not shown.
Binary file added benchmarl/environments/.DS_Store
Binary file not shown.
Binary file added benchmarl/experiment/.DS_Store
Binary file not shown.
5 changes: 3 additions & 2 deletions benchmarl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
# LICENSE file in the root directory of this source tree.
#

from .gnn import GnnConfig
from .mlp import MlpConfig
from .common import Model, ModelConfig, SequenceModel, SequenceModelConfig
from .gnn import Gnn, GnnConfig
from .mlp import Mlp, MlpConfig

classes = ["Mlp", "MlpConfig", "Gnn", "GnnConfig"]

Expand Down
36 changes: 35 additions & 1 deletion benchmarl/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,41 @@ def get_from_yaml(cls, path: Optional[str] = None):

@dataclass
class SequenceModelConfig(ModelConfig):
"""Dataclass for a :class:`~benchmarl.models.SequenceModel`."""
"""Dataclass for a :class:`~benchmarl.models.SequenceModel`.
Examples:
.. code-block:: python
import torch_geometric
from torch import nn
from benchmarl.algorithms import IppoConfig
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models import SequenceModelConfig, GnnConfig, MlpConfig
experiment = Experiment(
algorithm_config=IppoConfig.get_from_yaml(),
model_config=SequenceModelConfig(
model_configs=[
MlpConfig(num_cells=[8], activation_class=nn.Tanh, layer_class=nn.Linear),
GnnConfig(
topology="full",
self_loops=False,
gnn_class=torch_geometric.nn.conv.GraphConv,
),
MlpConfig(num_cells=[6], activation_class=nn.Tanh, layer_class=nn.Linear),
],
intermediate_sizes=[5, 3],
),
seed=0,
config=ExperimentConfig.get_from_yaml(),
task=VmasTask.NAVIGATION.get_from_yaml(),
)
experiment.run()
"""

model_configs: Sequence[ModelConfig]
intermediate_sizes: Sequence[int]
Expand Down
224 changes: 150 additions & 74 deletions benchmarl/models/gnn.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,134 @@
from __future__ import annotations

import importlib
from dataclasses import dataclass
from dataclasses import dataclass, MISSING
from math import prod
from typing import Optional
from typing import Optional, Type

import torch
from tensordict import TensorDictBase
from torch import nn, Tensor

from benchmarl.models.common import Model, ModelConfig, parse_model_config
from benchmarl.utils import _read_yaml_config
from benchmarl.models.common import Model, ModelConfig

_has_torch_geometric = importlib.util.find_spec("torch_geometric") is not None
if _has_torch_geometric:
import torch_geometric
from torch_geometric.nn import GATv2Conv, GINEConv, GraphConv

TOPOLOGY_TYPES = {"full", "empty"}


def _get_edge_index(topology: str, self_loops: bool, n_agents: int, device: str):
if topology == "full":
adjacency = torch.ones(n_agents, n_agents, device=device, dtype=torch.long)
elif topology == "empty":
adjacency = torch.ones(n_agents, n_agents, device=device, dtype=torch.long)

edge_index, _ = torch_geometric.utils.dense_to_sparse(adjacency)

if self_loops:
edge_index, _ = torch_geometric.utils.add_self_loops(edge_index)
else:
edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index)

return edge_index


class Gnn(Model):
"""A GNN model.
GNN models can be used as "decentralized" actors or critics.
Args:
topology (str): Topology of the graph adjacency matrix. Options: "full", "empty".
self_loops (str): Whether the resulting adjacency matrix will have self loops.
gnn_class (Type[torch_geometric.nn.MessagePassing]): the gnn convolution class to use
gnn_kwargs (dict, optional): the dict of arguments to pass to the gnn conv class
Examples:
.. code-block:: python
import torch_geometric
from torch import nn
from benchmarl.algorithms import IppoConfig
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models import SequenceModelConfig, GnnConfig, MlpConfig
experiment = Experiment(
algorithm_config=IppoConfig.get_from_yaml(),
model_config=GnnConfig(
topology="full",
self_loops=False,
gnn_class=torch_geometric.nn.conv.GATv2Conv,
gnn_kwargs={},
),
critic_model_config=SequenceModelConfig(
model_configs=[
MlpConfig(num_cells=[8], activation_class=nn.Tanh, layer_class=nn.Linear),
GnnConfig(
topology="full",
self_loops=False,
gnn_class=torch_geometric.nn.conv.GraphConv,
),
MlpConfig(num_cells=[6], activation_class=nn.Tanh, layer_class=nn.Linear),
],
intermediate_sizes=[5,3],
),
seed=0,
config=ExperimentConfig.get_from_yaml(),
task=VmasTask.NAVIGATION.get_from_yaml(),
)
experiment.run()
"""

def __init__(
self,
topology: str,
self_loops: bool,
gnn_class: Type[torch_geometric.nn.MessagePassing],
gnn_kwargs: Optional[dict] = None,
**kwargs,
):
self.topology = topology
self.self_loops = self_loops

super().__init__(**kwargs)

self.input_features = self.input_leaf_spec.shape[-1]
self.output_features = self.output_leaf_spec.shape[-1]

if gnn_kwargs is None:
gnn_kwargs = {}
gnn_kwargs.update(
{"in_channels": self.input_features, "out_channels": self.output_features}
)

self.gnns = nn.ModuleList(
[
GnnKernel(
in_dim=self.input_features,
out_dim=self.output_features,
)
gnn_class(**gnn_kwargs).to(self.device)
for _ in range(self.n_agents if not self.share_params else 1)
]
)
self.fully_connected_adjacency = torch.ones(
self.n_agents, self.n_agents, device=self.device
self.edge_index = _get_edge_index(
topology=self.topology,
self_loops=self.self_loops,
device=self.device,
n_agents=self.n_agents,
)

def _perform_checks(self):
super()._perform_checks()

if self.topology not in TOPOLOGY_TYPES:
raise ValueError(
f"Got topology: {self.topology} but only available options are {TOPOLOGY_TYPES}"
)
if self.centralised:
raise ValueError("GNN model can only be used in non-centralised critics")
if not self.input_has_agent_dim:
raise ValueError(
"The GNN module is not compatible with input that does not have the agent dimension,"
Expand All @@ -67,15 +153,9 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
# Gather in_key
input = tensordict.get(self.in_key)

# For now fully connected
adjacency = self.fully_connected_adjacency.to(input.device)

edge_index, _ = torch_geometric.utils.dense_to_sparse(adjacency)
edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index)

batch_size = input.shape[:-2]

graph = batch_from_dense_to_ptg(x=input, edge_index=edge_index)
graph = batch_from_dense_to_ptg(x=input, edge_index=self.edge_index)

if not self.share_params:
res = torch.stack(
Expand All @@ -97,54 +177,53 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
).view(*batch_size, self.n_agents, self.output_features)

tensordict.set(self.out_key, res)

return tensordict


class GnnKernel(nn.Module):
def __init__(self, in_dim, out_dim, **cfg):
super().__init__()

# gnn_types = {"GraphConv", "GATv2Conv", "GINEConv"}
# aggr_types = {"add", "mean", "max"}

self.aggr = "add"
self.gnn_type = "GraphConv"

self.in_dim = in_dim
self.out_dim = out_dim
self.activation_fn = nn.Tanh

if self.gnn_type == "GraphConv":
self.gnn = GraphConv(
self.in_dim,
self.out_dim,
aggr=self.aggr,
)
elif self.gnn_type == "GATv2Conv":
# Default adds self loops
self.gnn = GATv2Conv(
self.in_dim,
self.out_dim,
edge_dim=self.edge_features,
fill_value=0.0,
share_weights=True,
add_self_loops=True,
aggr=self.aggr,
)
elif self.gnn_type == "GINEConv":
self.gnn = GINEConv(
nn=nn.Sequential(
torch.nn.Linear(self.in_dim, self.out_dim),
self.activation_fn(),
),
edge_dim=self.edge_features,
aggr=self.aggr,
)

def forward(self, x, edge_index):
out = self.gnn(x, edge_index)
return out
# class GnnKernel(nn.Module):
# def __init__(self, in_dim, out_dim, **cfg):
# super().__init__()
#
# gnn_types = {"GraphConv", "GATv2Conv", "GINEConv"}
# aggr_types = {"add", "mean", "max"}
#
# self.aggr = "add"
# self.gnn_type = "GraphConv"
#
# self.in_dim = in_dim
# self.out_dim = out_dim
# self.activation_fn = nn.Tanh
#
# if self.gnn_type == "GraphConv":
# self.gnn = GraphConv(
# self.in_dim,
# self.out_dim,
# aggr=self.aggr,
# )
# elif self.gnn_type == "GATv2Conv":
# # Default adds self loops
# self.gnn = GATv2Conv(
# self.in_dim,
# self.out_dim,
# edge_dim=self.edge_features,
# fill_value=0.0,
# share_weights=True,
# add_self_loops=True,
# aggr=self.aggr,
# )
# elif self.gnn_type == "GINEConv":
# self.gnn = GINEConv(
# nn=nn.Sequential(
# torch.nn.Linear(self.in_dim, self.out_dim),
# self.activation_fn(),
# ),
# edge_dim=self.edge_features,
# aggr=self.aggr,
# )
#
# def forward(self, x, edge_index):
# out = self.gnn(x, edge_index)
# return out


def batch_from_dense_to_ptg(
Expand Down Expand Up @@ -180,17 +259,14 @@ def batch_from_dense_to_ptg(

@dataclass
class GnnConfig(ModelConfig):
"""Dataclass config for a :class:`~benchmarl.models.Gnn`."""

topology: str = MISSING
self_loops: bool = MISSING

gnn_class: Type[torch_geometric.nn.MessagePassing] = MISSING
gnn_kwargs: Optional[dict] = None

@staticmethod
def associated_class():
return Gnn

@staticmethod
def get_from_yaml(path: Optional[str] = None) -> GnnConfig:
if path is None:
return GnnConfig(
**ModelConfig._load_from_yaml(
name=GnnConfig.associated_class().__name__,
)
)
else:
return GnnConfig(**parse_model_config(_read_yaml_config(path)))
Binary file added docs/.DS_Store
Binary file not shown.
Binary file added docs/source/.DS_Store
Binary file not shown.
2 changes: 2 additions & 0 deletions docs/source/concepts/components.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,5 @@ agent group. Here is a table of the models implemented in BenchMARL
+=================================+===============+===============================+===============================+
| :class:`~benchmarl.models.Mlp` | Yes | Yes | Yes |
+---------------------------------+---------------+-------------------------------+-------------------------------+
| :class:`~benchmarl.models.Gnn` | Yes | No | No |
+---------------------------------+---------------+-------------------------------+-------------------------------+
Binary file added examples/.DS_Store
Binary file not shown.
Binary file added examples/extending/.DS_Store
Binary file not shown.
Binary file added fine_tuned/.DS_Store
Binary file not shown.
Binary file added fine_tuned/vmas/.DS_Store
Binary file not shown.
Loading

0 comments on commit ca4e2c7

Please sign in to comment.