Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jun 11, 2024
1 parent 2fa4e5e commit 5078725
Showing 1 changed file with 28 additions and 22 deletions.
50 changes: 28 additions & 22 deletions benchmarl/models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@
import torch_geometric
from torch_geometric.transforms import BaseTransform

class _RelVel(BaseTransform):
"""Transform that reads graph.vel and writes node1.vel - node2.vel in the edge attributes"""

def __init__(self):
pass

def __call__(self, data):
(row, col), vel, pseudo = data.edge_index, data.vel, data.edge_attr

cart = vel[row] - vel[col]
cart = cart.view(-1, 1) if cart.dim() == 1 else cart

if pseudo is not None:
pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo
data.edge_attr = torch.cat([pseudo, cart.type_as(pseudo)], dim=-1)
else:
data.edge_attr = cart
return data


TOPOLOGY_TYPES = {"full", "empty"}


Expand All @@ -38,6 +58,14 @@ class Gnn(Model):
self_loops (str): Whether the resulting adjacency matrix will have self loops.
gnn_class (Type[torch_geometric.nn.MessagePassing]): the gnn convolution class to use
gnn_kwargs (dict, optional): the dict of arguments to pass to the gnn conv class
position_key (str, optional): if provided, it will need to match a leaf key in the env observation spec
representing the agent position. This key will not be processed as a node feature, but it will used to construct
edge features. In particular it be used to compute relative positions (``pos_node_1 - pos_node_2``) and a
one-dimensional distance for all neighbours in the graph.
velocity_key (str, optional): if provided, it will need to match a leaf key in the env observation spec
representing the agent velocity. This key will not be processed as a node feature, but it will used to construct
edge features. In particular it be used to compute relative velocities (``vel_node_1 - vel_node_2``) for all neighbours
in the graph.
Examples:
Expand Down Expand Up @@ -75,8 +103,6 @@ class Gnn(Model):
)
experiment.run()
"""

def __init__(
Expand Down Expand Up @@ -354,26 +380,6 @@ def _batch_from_dense_to_ptg(
return graphs


class _RelVel(BaseTransform):
"""Transform that reads graph.vel and writes node1.vel - node2.vel in the edge attributes"""

def __init__(self):
pass

def __call__(self, data):
(row, col), vel, pseudo = data.edge_index, data.vel, data.edge_attr

cart = vel[row] - vel[col]
cart = cart.view(-1, 1) if cart.dim() == 1 else cart

if pseudo is not None:
pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo
data.edge_attr = torch.cat([pseudo, cart.type_as(pseudo)], dim=-1)
else:
data.edge_attr = cart
return data


@dataclass
class GnnConfig(ModelConfig):
"""Dataclass config for a :class:`~benchmarl.models.Gnn`."""
Expand Down

0 comments on commit 5078725

Please sign in to comment.