diff --git a/benchmarl/conf/model/layers/gnn.yaml b/benchmarl/conf/model/layers/gnn.yaml new file mode 100644 index 00000000..daa6a39e --- /dev/null +++ b/benchmarl/conf/model/layers/gnn.yaml @@ -0,0 +1 @@ +name: gnn diff --git a/benchmarl/models/__init__.py b/benchmarl/models/__init__.py index 2fefe7ab..31c8bd89 100644 --- a/benchmarl/models/__init__.py +++ b/benchmarl/models/__init__.py @@ -1,3 +1,4 @@ +from .gnn import GnnConfig from .mlp import MlpConfig -model_config_registry = {"mlp": MlpConfig} +model_config_registry = {"mlp": MlpConfig, "gnn": GnnConfig} diff --git a/benchmarl/models/gnn.py b/benchmarl/models/gnn.py index 5c2c3b9f..f4afb1ad 100644 --- a/benchmarl/models/gnn.py +++ b/benchmarl/models/gnn.py @@ -1,23 +1,22 @@ from __future__ import annotations -from dataclasses import dataclass, MISSING - -from typing import Optional, Sequence, Type +import importlib +from dataclasses import dataclass +from math import prod +from typing import Optional import torch -import torch_geometric -from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.utils.annotations import override from tensordict import TensorDictBase from torch import nn, Tensor -from torch_geometric.nn import GATv2Conv, GINEConv, GraphConv, MessagePassing -from torch_geometric.transforms import BaseTransform -from torchrl.modules import MLP, MultiAgentMLP from benchmarl.models.common import Model, ModelConfig, parse_model_config from benchmarl.utils import read_yaml_config +_has_torch_geometric = importlib.util.find_spec("torch_geometric") is not None +if _has_torch_geometric: + import torch_geometric + from torch_geometric.nn import GATv2Conv, GINEConv, GraphConv + class Gnn(Model): def __init__( @@ -29,43 +28,38 @@ def __init__( self.input_features = self.input_leaf_spec.shape[-1] self.output_features = self.output_leaf_spec.shape[-1] - if self.input_has_agent_dim: - self.mlp = MultiAgentMLP( - n_agent_inputs=self.input_features, - n_agent_outputs=self.output_features, - n_agents=self.n_agents, - centralised=self.centralised, - share_params=self.share_params, - device=self.device, - **kwargs, - ) - else: - self.mlp = nn.ModuleList( - [ - MLP( - in_features=self.input_features, - out_features=self.output_features, - device=self.device, - **kwargs, - ) - for _ in range(self.n_agents if not self.share_params else 1) - ] - ) + self.gnns = nn.ModuleList( + [ + GnnKernel( + in_dim=self.input_features, + out_dim=self.output_features, + ) + for _ in range(self.n_agents if not self.share_params else 1) + ] + ) + self.fully_connected_adjacency = torch.ones( + self.n_agents, self.n_agents, device=self.device + ) def _perform_checks(self): super()._perform_checks() + if not self.input_has_agent_dim: + raise ValueError( + "The GNN module is not compatible with input that does not have the agent dimension," + "such as the global state in centralised critics. Please choose another critic model" + "if your algorithm has a centralized critic and the task has a global state." + ) - if self.input_has_agent_dim and self.input_leaf_spec.shape[-2] != self.n_agents: + if self.input_leaf_spec.shape[-2] != self.n_agents: raise ValueError( - "If the MLP input has the agent dimension," - " the second to last spec dimension should be the number of agents" + "The second to last input spec dimension should be the number of agents" ) if ( self.output_has_agent_dim and self.output_leaf_spec.shape[-2] != self.n_agents ): raise ValueError( - "If the MLP output has the agent dimension," + "If the GNN output has the agent dimension," " the second to last spec dimension should be the number of agents" ) @@ -73,46 +67,52 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: # Gather in_key input = tensordict.get(self.in_key) - # Has multi-agent input dimension - if self.input_has_agent_dim: - res = self.mlp.forward(input) - if not self.output_has_agent_dim: - # If we are here the module is centralised and parameter shared. - # Thus the multi-agent dimension has been expanded, - # We remove it without loss of data - res = res[..., 0, :] + # For now fully connected + adjacency = self.fully_connected_adjacency.to(input.device) + + edge_index, _ = torch_geometric.utils.dense_to_sparse(adjacency) + edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index) + + batch_size = input.shape[:-2] + + graph = batch_from_dense_to_ptg(x=input, edge_index=edge_index) + + if not self.share_params: + res = torch.stack( + [ + gnn(graph.x, graph.edge_index).view( + *batch_size, + self.n_agents, + self.output_features, + )[:, i] + for i, gnn in enumerate(self.gnns) + ], + dim=-2, + ) - # Does not have multi-agent input dimension else: - if not self.share_params: - res = torch.stack( - [net(input) for net in self.mlp], - dim=-2, - ) - else: - res = self.mlp[0](input) + res = self.gnns[0]( + graph.x, + graph.edge_index, + ).view(*batch_size, self.n_agents, self.output_features) tensordict.set(self.out_key, res) return tensordict class GnnKernel(nn.Module): - def __init__(self, in_dim, out_dim, edge_features, **cfg): + def __init__(self, in_dim, out_dim, **cfg): super().__init__() - gnn_types = {"GraphConv", "GATv2Conv", "GINEConv"} - aggr_types = {"add", "mean", "max"} - - self.aggr = cfg["aggr"] - self.gnn_type = cfg["gnn_type"] + # gnn_types = {"GraphConv", "GATv2Conv", "GINEConv"} + # aggr_types = {"add", "mean", "max"} - assert self.aggr in aggr_types - assert self.gnn_type in gnn_types + self.aggr = "add" + self.gnn_type = "GraphConv" self.in_dim = in_dim self.out_dim = out_dim - self.edge_features = edge_features - self.activation_fn = get_activation_fn(cfg["activation_fn"]) + self.activation_fn = nn.Tanh if self.gnn_type == "GraphConv": self.gnn = GraphConv( @@ -140,114 +140,57 @@ def __init__(self, in_dim, out_dim, edge_features, **cfg): edge_dim=self.edge_features, aggr=self.aggr, ) - elif self.gnn_type == "MatPosConv": - self.gnn = MatPosConv( - self.in_dim, - self.out_dim, - edge_features=self.edge_features, - **cfg, - ) - else: - assert False - - def forward(self, x, edge_index, edge_attr): - if self.gnn_type == "GraphConv": - out = self.gnn(x, edge_index) - elif ( - self.gnn_type == "GATv2Conv" - or self.gnn_type == "GINEConv" - or self.gnn_type == "MatPosConv" - ): - out = self.gnn(x, edge_index, edge_attr) - else: - assert False + def forward(self, x, edge_index): + out = self.gnn(x, edge_index) return out def batch_from_dense_to_ptg( - x, - pos: Tensor = None, - vel: Tensor = None, - edge_index: Tensor = None, - comm_radius: float = -1, - rel_pos: bool = True, - distance: bool = True, - rel_vel: bool = True, + x: Tensor, + edge_index: Tensor, ) -> torch_geometric.data.Batch: - batch_size = x.shape[0] - n_agents = x.shape[1] + batch_size = prod(x.shape[:-2]) + n_agents = x.shape[-2] x = x.view(-1, x.shape[-1]) - if pos is not None: - pos = pos.view(-1, pos.shape[-1]) - if vel is not None: - vel = vel.view(-1, vel.shape[-1]) - - assert (edge_index is None or comm_radius < 0) and ( - edge_index is not None or comm_radius > 0 - ) b = torch.arange(batch_size, device=x.device) graphs = torch_geometric.data.Batch() graphs.ptr = torch.arange(0, (batch_size + 1) * n_agents, n_agents) graphs.batch = torch.repeat_interleave(b, n_agents) - graphs.pos = pos - graphs.vel = vel graphs.x = x graphs.edge_attr = None - if edge_index is not None: - n_edges = edge_index.shape[1] - # Tensor of shape [batch_size * n_edges] - # in which edges corresponding to the same graph have the same index. - batch = torch.repeat_interleave(b, n_edges) - # Edge index for the batched graphs of shape [2, n_edges * batch_size] - # we sum to each batch an offset of batch_num * n_agents to make sure that - # the adjacency matrices remain independent - batch_edge_index = edge_index.repeat(1, batch_size) + batch * n_agents - graphs.edge_index = batch_edge_index - else: - assert pos is not None - graphs.edge_index = torch_geometric.nn.pool.radius_graph( - graphs.pos, batch=graphs.batch, r=comm_radius, loop=False - ) + n_edges = edge_index.shape[1] + # Tensor of shape [batch_size * n_edges] + # in which edges corresponding to the same graph have the same index. + batch = torch.repeat_interleave(b, n_edges) + # Edge index for the batched graphs of shape [2, n_edges * batch_size] + # we sum to each batch an offset of batch_num * n_agents to make sure that + # the adjacency matrices remain independent + batch_edge_index = edge_index.repeat(1, batch_size) + batch * n_agents + graphs.edge_index = batch_edge_index graphs = graphs.to(x.device) - if pos is not None and rel_pos: - graphs = torch_geometric.transforms.Cartesian(norm=False)(graphs) - if pos is not None and distance: - graphs = torch_geometric.transforms.Distance(norm=False)(graphs) - if vel is not None and rel_vel: - graphs = RelVel()(graphs) - return graphs @dataclass class GnnConfig(ModelConfig): - num_cells: Sequence[int] = MISSING - layer_class: Type[nn.Module] = MISSING - - activation_class: Type[nn.Module] = MISSING - activation_kwargs: Optional[dict] = None - - norm_class: Type[nn.Module] = None - norm_kwargs: Optional[dict] = None - @staticmethod def associated_class(): - return Mlp + return Gnn @staticmethod - def get_from_yaml(path: Optional[str] = None) -> MlpConfig: + def get_from_yaml(path: Optional[str] = None) -> GnnConfig: if path is None: - return MlpConfig( + return GnnConfig( **ModelConfig._load_from_yaml( - name=MlpConfig.associated_class().__name__, + name=GnnConfig.associated_class().__name__, ) ) else: - return MlpConfig(**parse_model_config(read_yaml_config(path))) + return GnnConfig(**parse_model_config(read_yaml_config(path)))