Skip to content

Commit

Permalink
[Feature] GNN build topology dynamically from positions (#98)
Browse files Browse the repository at this point in the history
* gnn radius

* amend

* amend

* amend

* amend

* amend
  • Loading branch information
matteobettini committed Jun 13, 2024
1 parent a5c629b commit 70d9ec7
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 22 deletions.
3 changes: 3 additions & 0 deletions benchmarl/conf/model/layers/gnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ gnn_kwargs:

position_key: null
velocity_key: null

exclude_pos_from_node_features: False
edge_radius: null
4 changes: 2 additions & 2 deletions benchmarl/models/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ def __init__(
def _perform_checks(self):
super()._perform_checks()

input_shape_tensor = None
self.image_in_keys = []
input_shape_image = None
self.image_in_keys = []
input_shape_tensor = None
self.tensor_in_keys = []
for input_key, input_spec in self.input_spec.items(True, True):
if (self.input_has_agent_dim and len(input_spec.shape) == 4) or (
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def share_params_with(self, other_model):
or self.input_spec != other_model.input_spec
or self.output_spec != other_model.output_spec
):
raise warnings.warn(
warnings.warn(
"Sharing parameters with models that are not identical. "
"This might result in unintended behavior or error."
)
Expand Down
73 changes: 55 additions & 18 deletions benchmarl/models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __call__(self, data):
return data


TOPOLOGY_TYPES = {"full", "empty"}
TOPOLOGY_TYPES = {"full", "empty", "from_pos"}


class Gnn(Model):
Expand All @@ -54,18 +54,23 @@ class Gnn(Model):
GNN models can be used as "decentralized" actors or critics.
Args:
topology (str): Topology of the graph adjacency matrix. Options: "full", "empty".
topology (str): Topology of the graph adjacency matrix. Options: "full", "empty", "from_pos". "from_pos" builds
the topology dynamically based on ``position_key`` and ``edge_radius``.
self_loops (str): Whether the resulting adjacency matrix will have self loops.
gnn_class (Type[torch_geometric.nn.MessagePassing]): the gnn convolution class to use
gnn_kwargs (dict, optional): the dict of arguments to pass to the gnn conv class
position_key (str, optional): if provided, it will need to match a leaf key in the env observation spec
representing the agent position. This key will not be processed as a node feature, but it will used to construct
edge features. In particular it be used to compute relative positions (``pos_node_1 - pos_node_2``) and a
one-dimensional distance for all neighbours in the graph.
exclude_pos_from_node_features (optional, bool): If ``position_key`` is provided,
wether to use it just to compute edge features or also include it in node features.
velocity_key (str, optional): if provided, it will need to match a leaf key in the env observation spec
representing the agent velocity. This key will not be processed as a node feature, but it will used to construct
edge features. In particular it be used to compute relative velocities (``vel_node_1 - vel_node_2``) for all neighbours
in the graph.
edge_radius (float, optional): If topology is ``"from_pos"`` the radius to use to build the agent graph.
Agents within this radius distance will be neighnours.
Examples:
Expand Down Expand Up @@ -112,13 +117,17 @@ def __init__(
gnn_class: Type[torch_geometric.nn.MessagePassing],
gnn_kwargs: Optional[dict],
position_key: Optional[str],
exclude_pos_from_node_features: Optional[bool],
velocity_key: Optional[str],
edge_radius: Optional[float],
**kwargs,
):
self.topology = topology
self.self_loops = self_loops
self.position_key = position_key
self.velocity_key = velocity_key
self.exclude_pos_from_node_features = exclude_pos_from_node_features
self.edge_radius = edge_radius

super().__init__(**kwargs)

Expand All @@ -143,7 +152,8 @@ def __init__(
[
spec.shape[-1]
for key, spec in self.input_spec.items(True, True)
if _unravel_key_to_tuple(key)[-1] not in (velocity_key, position_key)
if _unravel_key_to_tuple(key)[-1]
not in ((position_key) if self.exclude_pos_from_node_features else ())
]
) # Input keys not ending with `velocity_key` and `position_key`
self.output_features = self.output_leaf_spec.shape[-1]
Expand Down Expand Up @@ -189,6 +199,15 @@ def _perform_checks(self):
raise ValueError(
f"Got topology: {self.topology} but only available options are {TOPOLOGY_TYPES}"
)
if self.topology == "from_pos" and self.position_key is None:
raise ValueError("If topology is from_pos, position_key must be provided")
if (
self.position_key is not None
and self.exclude_pos_from_node_features is None
):
raise ValueError(
"exclude_pos_from_node_features needs to be specified when position_key is provided"
)

if not self.input_has_agent_dim:
raise ValueError(
Expand All @@ -200,9 +219,7 @@ def _perform_checks(self):

input_shape = None
for input_key, input_spec in self.input_spec.items(True, True):
if (self.input_has_agent_dim and len(input_spec.shape) == 2) or (
not self.input_has_agent_dim and len(input_spec.shape) == 1
):
if len(input_spec.shape) == 2:
if input_shape is None:
input_shape = input_spec.shape[:-1]
else:
Expand Down Expand Up @@ -235,7 +252,9 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict.get(in_key)
for in_key in self.in_keys
if _unravel_key_to_tuple(in_key)[-1]
not in (self.position_key, self.velocity_key)
not in (
(self.position_key) if self.exclude_pos_from_node_features else ()
)
],
dim=-1,
)
Expand Down Expand Up @@ -265,7 +284,12 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
batch_size = input.shape[:-2]

graph = _batch_from_dense_to_ptg(
x=input, edge_index=self.edge_index, pos=pos, vel=vel
x=input,
edge_index=self.edge_index,
pos=pos,
vel=vel,
self_loops=self.self_loops,
edge_radius=self.edge_radius,
)
forward_gnn_params = {
"x": graph.x,
Expand Down Expand Up @@ -330,6 +354,8 @@ def _get_edge_index(topology: str, self_loops: bool, n_agents: int, device: str)
)
else:
edge_index = torch.empty((2, 0), device=device, dtype=torch.long)
elif topology == "from_pos":
edge_index = None
else:
raise ValueError(f"Topology {topology} not supported")

Expand All @@ -338,9 +364,11 @@ def _get_edge_index(topology: str, self_loops: bool, n_agents: int, device: str)

def _batch_from_dense_to_ptg(
x: Tensor,
edge_index: Tensor,
edge_index: Optional[Tensor],
self_loops: bool,
pos: Tensor = None,
vel: Tensor = None,
edge_radius: Optional[float] = None,
) -> torch_geometric.data.Batch:
batch_size = prod(x.shape[:-2])
n_agents = x.shape[-2]
Expand All @@ -360,15 +388,22 @@ def _batch_from_dense_to_ptg(
graphs.vel = vel
graphs.edge_attr = None

n_edges = edge_index.shape[1]
# Tensor of shape [batch_size * n_edges]
# in which edges corresponding to the same graph have the same index.
batch = torch.repeat_interleave(b, n_edges)
# Edge index for the batched graphs of shape [2, n_edges * batch_size]
# we sum to each batch an offset of batch_num * n_agents to make sure that
# the adjacency matrices remain independent
batch_edge_index = edge_index.repeat(1, batch_size) + batch * n_agents
graphs.edge_index = batch_edge_index
if edge_index is not None:
n_edges = edge_index.shape[1]
# Tensor of shape [batch_size * n_edges]
# in which edges corresponding to the same graph have the same index.
batch = torch.repeat_interleave(b, n_edges)
# Edge index for the batched graphs of shape [2, n_edges * batch_size]
# we sum to each batch an offset of batch_num * n_agents to make sure that
# the adjacency matrices remain independent
batch_edge_index = edge_index.repeat(1, batch_size) + batch * n_agents
graphs.edge_index = batch_edge_index
else:
if pos is None:
raise RuntimeError("from_pos topology needs positions as input")
graphs.edge_index = torch_geometric.nn.pool.radius_graph(
graphs.pos, batch=graphs.batch, r=edge_radius, loop=self_loops
)

graphs = graphs.to(x.device)
if pos is not None:
Expand All @@ -392,6 +427,8 @@ class GnnConfig(ModelConfig):

position_key: Optional[str] = None
velocity_key: Optional[str] = None
exclude_pos_from_node_features: Optional[bool] = None
edge_radius: Optional[float] = None

@staticmethod
def associated_class():
Expand Down
1 change: 0 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def mlp_gnn_sequence_config() -> ModelConfig:
topology="full",
self_loops=False,
gnn_class=torch_geometric.nn.conv.GATv2Conv,
gnn_kwargs={},
),
MlpConfig(num_cells=[4], activation_class=nn.Tanh, layer_class=nn.Linear),
],
Expand Down
2 changes: 2 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def test_gnn_edge_attrs(
gnn_class=torch_geometric.nn.GATv2Conv,
gnn_kwargs=None,
position_key=position_key,
exclude_pos_from_node_features=False,
).get_model(
input_spec=input_spec,
output_spec=output_spec,
Expand Down Expand Up @@ -356,6 +357,7 @@ def test_gnn_edge_attrs(
gnn_class=torch_geometric.nn.GraphConv,
gnn_kwargs=None,
position_key=position_key,
exclude_pos_from_node_features=False,
).get_model(
input_spec=input_spec,
output_spec=output_spec,
Expand Down

0 comments on commit 70d9ec7

Please sign in to comment.