Skip to content

Commit

Permalink
[Model] GNN (#30)
Browse files Browse the repository at this point in the history
* gnn

* empty

* empty

* Amend

* amend

* amend

* copyright

* setup

* empty

* amend

* amend
  • Loading branch information
matteobettini authored Nov 27, 2023
1 parent 1944f54 commit ee12a28
Show file tree
Hide file tree
Showing 14 changed files with 434 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/unittest/install_dependencies.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@


python -m pip install --upgrade pip
python -m pip install flake8 pytest pytest-cov hydra-core tqdm
python -m pip install flake8 pytest pytest-cov hydra-core tqdm torch_geometric

if [ -f requirements.txt ]; then pip install -r requirements.txt; fi

Expand Down
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,16 @@ agent group. Here is a table of the models implemented in BenchMARL

| Name | Decentralized | Centralized with local inputs | Centralized with global input |
|--------------------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes |
| [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes |
| [GNN](benchmarl/models/gnn.py) | Yes | No | No |

And the ones that are _work in progress_

| Name | Decentralized | Centralized with local inputs | Centralized with global input |
|--------------------------------------------------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| [GNN](https://github.com/facebookresearch/BenchMARL/pull/18) | Yes | Yes | No |
| CNN | Yes | Yes | Yes |
| Name | Decentralized | Centralized with local inputs | Centralized with global input |
|--------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| CNN | Yes | Yes | Yes |
| RNN (GRU and LSTM) | Yes | Yes | Yes |


## Fine-tuned public benchmarks
> [!WARNING]
Expand Down
8 changes: 8 additions & 0 deletions benchmarl/conf/model/layers/gnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: gnn

topology: full
self_loops: False

gnn_class: torch_geometric.nn.conv.GraphConv
gnn_kwargs:
aggr: "add"
1 change: 0 additions & 1 deletion benchmarl/conf/model/layers/mlp.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@


name: mlp

num_cells: [256, 256]
Expand Down
8 changes: 3 additions & 5 deletions benchmarl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
#

from .common import Model, ModelConfig, SequenceModel, SequenceModelConfig
from .gnn import Gnn, GnnConfig
from .mlp import Mlp, MlpConfig

classes = [
"Mlp",
"MlpConfig",
]
classes = ["Mlp", "MlpConfig", "Gnn", "GnnConfig"]

model_config_registry = {"mlp": MlpConfig}
model_config_registry = {"mlp": MlpConfig, "gnn": GnnConfig}
36 changes: 35 additions & 1 deletion benchmarl/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,41 @@ def get_from_yaml(cls, path: Optional[str] = None):

@dataclass
class SequenceModelConfig(ModelConfig):
"""Dataclass for a :class:`~benchmarl.models.SequenceModel`."""
"""Dataclass for a :class:`~benchmarl.models.SequenceModel`.
Examples:
.. code-block:: python
import torch_geometric
from torch import nn
from benchmarl.algorithms import IppoConfig
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models import SequenceModelConfig, GnnConfig, MlpConfig
experiment = Experiment(
algorithm_config=IppoConfig.get_from_yaml(),
model_config=SequenceModelConfig(
model_configs=[
MlpConfig(num_cells=[8], activation_class=nn.Tanh, layer_class=nn.Linear),
GnnConfig(
topology="full",
self_loops=False,
gnn_class=torch_geometric.nn.conv.GraphConv,
),
MlpConfig(num_cells=[6], activation_class=nn.Tanh, layer_class=nn.Linear),
],
intermediate_sizes=[5, 3],
),
seed=0,
config=ExperimentConfig.get_from_yaml(),
task=VmasTask.NAVIGATION.get_from_yaml(),
)
experiment.run()
"""

model_configs: Sequence[ModelConfig]
intermediate_sizes: Sequence[int]
Expand Down
278 changes: 278 additions & 0 deletions benchmarl/models/gnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from __future__ import annotations

import importlib
from dataclasses import dataclass, MISSING
from math import prod
from typing import Optional, Type

import torch
from tensordict import TensorDictBase
from torch import nn, Tensor

from benchmarl.models.common import Model, ModelConfig

_has_torch_geometric = importlib.util.find_spec("torch_geometric") is not None
if _has_torch_geometric:
import torch_geometric

TOPOLOGY_TYPES = {"full", "empty"}


def _get_edge_index(topology: str, self_loops: bool, n_agents: int, device: str):
if topology == "full":
adjacency = torch.ones(n_agents, n_agents, device=device, dtype=torch.long)
elif topology == "empty":
adjacency = torch.ones(n_agents, n_agents, device=device, dtype=torch.long)

edge_index, _ = torch_geometric.utils.dense_to_sparse(adjacency)

if self_loops:
edge_index, _ = torch_geometric.utils.add_self_loops(edge_index)
else:
edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index)

return edge_index


class Gnn(Model):
"""A GNN model.
GNN models can be used as "decentralized" actors or critics.
Args:
topology (str): Topology of the graph adjacency matrix. Options: "full", "empty".
self_loops (str): Whether the resulting adjacency matrix will have self loops.
gnn_class (Type[torch_geometric.nn.MessagePassing]): the gnn convolution class to use
gnn_kwargs (dict, optional): the dict of arguments to pass to the gnn conv class
Examples:
.. code-block:: python
import torch_geometric
from torch import nn
from benchmarl.algorithms import IppoConfig
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models import SequenceModelConfig, GnnConfig, MlpConfig
experiment = Experiment(
algorithm_config=IppoConfig.get_from_yaml(),
model_config=GnnConfig(
topology="full",
self_loops=False,
gnn_class=torch_geometric.nn.conv.GATv2Conv,
gnn_kwargs={},
),
critic_model_config=SequenceModelConfig(
model_configs=[
MlpConfig(num_cells=[8], activation_class=nn.Tanh, layer_class=nn.Linear),
GnnConfig(
topology="full",
self_loops=False,
gnn_class=torch_geometric.nn.conv.GraphConv,
),
MlpConfig(num_cells=[6], activation_class=nn.Tanh, layer_class=nn.Linear),
],
intermediate_sizes=[5,3],
),
seed=0,
config=ExperimentConfig.get_from_yaml(),
task=VmasTask.NAVIGATION.get_from_yaml(),
)
experiment.run()
"""

def __init__(
self,
topology: str,
self_loops: bool,
gnn_class: Type[torch_geometric.nn.MessagePassing],
gnn_kwargs: Optional[dict] = None,
**kwargs,
):
self.topology = topology
self.self_loops = self_loops

super().__init__(**kwargs)

self.input_features = self.input_leaf_spec.shape[-1]
self.output_features = self.output_leaf_spec.shape[-1]

if gnn_kwargs is None:
gnn_kwargs = {}
gnn_kwargs.update(
{"in_channels": self.input_features, "out_channels": self.output_features}
)

self.gnns = nn.ModuleList(
[
gnn_class(**gnn_kwargs).to(self.device)
for _ in range(self.n_agents if not self.share_params else 1)
]
)
self.edge_index = _get_edge_index(
topology=self.topology,
self_loops=self.self_loops,
device=self.device,
n_agents=self.n_agents,
)

def _perform_checks(self):
super()._perform_checks()

if self.topology not in TOPOLOGY_TYPES:
raise ValueError(
f"Got topology: {self.topology} but only available options are {TOPOLOGY_TYPES}"
)
if self.centralised:
raise ValueError("GNN model can only be used in non-centralised critics")
if not self.input_has_agent_dim:
raise ValueError(
"The GNN module is not compatible with input that does not have the agent dimension,"
"such as the global state in centralised critics. Please choose another critic model"
"if your algorithm has a centralized critic and the task has a global state."
)

if self.input_leaf_spec.shape[-2] != self.n_agents:
raise ValueError(
"The second to last input spec dimension should be the number of agents"
)
if (
self.output_has_agent_dim
and self.output_leaf_spec.shape[-2] != self.n_agents
):
raise ValueError(
"If the GNN output has the agent dimension,"
" the second to last spec dimension should be the number of agents"
)

def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
# Gather in_key
input = tensordict.get(self.in_key)

batch_size = input.shape[:-2]

graph = batch_from_dense_to_ptg(x=input, edge_index=self.edge_index)

if not self.share_params:
res = torch.stack(
[
gnn(graph.x, graph.edge_index).view(
*batch_size,
self.n_agents,
self.output_features,
)[:, i]
for i, gnn in enumerate(self.gnns)
],
dim=-2,
)

else:
res = self.gnns[0](
graph.x,
graph.edge_index,
).view(*batch_size, self.n_agents, self.output_features)

tensordict.set(self.out_key, res)
return tensordict


# class GnnKernel(nn.Module):
# def __init__(self, in_dim, out_dim, **cfg):
# super().__init__()
#
# gnn_types = {"GraphConv", "GATv2Conv", "GINEConv"}
# aggr_types = {"add", "mean", "max"}
#
# self.aggr = "add"
# self.gnn_type = "GraphConv"
#
# self.in_dim = in_dim
# self.out_dim = out_dim
# self.activation_fn = nn.Tanh
#
# if self.gnn_type == "GraphConv":
# self.gnn = GraphConv(
# self.in_dim,
# self.out_dim,
# aggr=self.aggr,
# )
# elif self.gnn_type == "GATv2Conv":
# # Default adds self loops
# self.gnn = GATv2Conv(
# self.in_dim,
# self.out_dim,
# edge_dim=self.edge_features,
# fill_value=0.0,
# share_weights=True,
# add_self_loops=True,
# aggr=self.aggr,
# )
# elif self.gnn_type == "GINEConv":
# self.gnn = GINEConv(
# nn=nn.Sequential(
# torch.nn.Linear(self.in_dim, self.out_dim),
# self.activation_fn(),
# ),
# edge_dim=self.edge_features,
# aggr=self.aggr,
# )
#
# def forward(self, x, edge_index):
# out = self.gnn(x, edge_index)
# return out


def batch_from_dense_to_ptg(
x: Tensor,
edge_index: Tensor,
) -> torch_geometric.data.Batch:
batch_size = prod(x.shape[:-2])
n_agents = x.shape[-2]
x = x.view(-1, x.shape[-1])

b = torch.arange(batch_size, device=x.device)

graphs = torch_geometric.data.Batch()
graphs.ptr = torch.arange(0, (batch_size + 1) * n_agents, n_agents)
graphs.batch = torch.repeat_interleave(b, n_agents)
graphs.x = x
graphs.edge_attr = None

n_edges = edge_index.shape[1]
# Tensor of shape [batch_size * n_edges]
# in which edges corresponding to the same graph have the same index.
batch = torch.repeat_interleave(b, n_edges)
# Edge index for the batched graphs of shape [2, n_edges * batch_size]
# we sum to each batch an offset of batch_num * n_agents to make sure that
# the adjacency matrices remain independent
batch_edge_index = edge_index.repeat(1, batch_size) + batch * n_agents
graphs.edge_index = batch_edge_index

graphs = graphs.to(x.device)

return graphs


@dataclass
class GnnConfig(ModelConfig):
"""Dataclass config for a :class:`~benchmarl.models.Gnn`."""

topology: str = MISSING
self_loops: bool = MISSING

gnn_class: Type[torch_geometric.nn.MessagePassing] = MISSING
gnn_kwargs: Optional[dict] = None

@staticmethod
def associated_class():
return Gnn
2 changes: 2 additions & 0 deletions docs/source/concepts/components.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,5 @@ agent group. Here is a table of the models implemented in BenchMARL
+=================================+===============+===============================+===============================+
| :class:`~benchmarl.models.Mlp` | Yes | Yes | Yes |
+---------------------------------+---------------+-------------------------------+-------------------------------+
| :class:`~benchmarl.models.Gnn` | Yes | No | No |
+---------------------------------+---------------+-------------------------------+-------------------------------+
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def package_files(directory):
extras_require={
"vmas": ["vmas>=1.2.10"],
"pettingzoo": ["pettingzoo[all]>=1.24.1"],
"gnn": ["torch_geometric"],
},
packages=find_packages(),
include_package_data=True,
Expand Down
Loading

0 comments on commit ee12a28

Please sign in to comment.