Skip to content

Commit

Permalink
[BugFix] GNN position and velocity key (#132)
Browse files Browse the repository at this point in the history
* amend

* amend

* amend

* amend
  • Loading branch information
matteobettini committed Sep 21, 2024
1 parent dc793b5 commit 58ff47b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
1 change: 1 addition & 0 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,7 @@ def _get_excluded_keys(self, group: str):
for other_group in self.group_map.keys():
if other_group != group:
excluded_keys += [other_group, ("next", other_group)]
excluded_keys += ["info", (group, "info"), ("next", group, "info")]
return excluded_keys

def _optimizer_loop(self, group: str) -> TensorDictBase:
Expand Down
23 changes: 16 additions & 7 deletions benchmarl/models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,23 @@ class Gnn(Model):
gnn_class (Type[torch_geometric.nn.MessagePassing]): the gnn convolution class to use
gnn_kwargs (dict, optional): the dict of arguments to pass to the gnn conv class
position_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env
(we suggest to use the "info" dict) representing the agent position. This key will be processed as a
node feature (unless exclude_pos_from_node_features=True) and it will be used to construct edge features.
(in the `observation_spec`) representing the agent position.
To do this, your environment needs to have dictionary observations and one of the keys needs to be `position_key`.
This key will be processed as a node feature (unless exclude_pos_from_node_features=True) and it will be used to construct edge features.
In particular, it will be used to compute relative positions (``pos_node_1 - pos_node_2``) and a
one-dimensional distance for all neighbours in the graph.
If you want to use this feature in a :class:`~benchmarl.models.SequenceModel`, the GNN needs to be first in sequence.
pos_features (int, optional): Needed when position_key is specified.
It has to match to the last element of the shape the tensor under position_key.
exclude_pos_from_node_features (optional, bool): If ``position_key`` is provided,
wether to use it just to compute edge features or also include it in node features.
velocity_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env
(we suggest to use the "info" dict) representing the agent velocity. This key will be processed as a node feature, and
(in the `observation_spec`) representing the agent position.
To do this, your environment needs to have dictionary observations and one of the keys needs to be `velocity_key`.
This key will be processed as a node feature, and
it will be used to construct edge features. In particular, it will be used to compute relative velocities
(``vel_node_1 - vel_node_2``) for all neighbours in the graph.
If you want to use this feature in a :class:`~benchmarl.models.SequenceModel`, the GNN needs to be first in sequence.
vel_features (int, optional): Needed when velocity_key is specified.
It has to match to the last element of the shape the tensor under velocity_key.
edge_radius (float, optional): If topology is ``"from_pos"`` the radius to use to build the agent graph.
Expand Down Expand Up @@ -170,8 +175,7 @@ def __init__(
) and not self.gnn_supports_edge_attrs:
warnings.warn(
"Position key or velocity key provided but GNN class does not support edge attributes. "
"These input keys will be ignored. If instead you want to process them as node features, "
"just set them (position_key or velocity_key) to null."
"These keys will not be used for computing edge features."
)
if (
position_key is not None or velocity_key is not None
Expand Down Expand Up @@ -369,10 +373,15 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def _get_key_terminating_with(self, keys: List[NestedKey], key: str) -> NestedKey:
for k in keys:
k_tuple = _unravel_key_to_tuple(k)
if k_tuple[-1] == key and self.agent_group in k_tuple:
if (
k_tuple[-1] == key
and self.agent_group in k_tuple
and not "next" == k_tuple[0]
):
return k
raise KeyError(
f"Key terminating with {key} and containing {self.agent_group} not found in keys: {keys}"
f"Key terminating with {key} and containing {self.agent_group} not found in keys: {keys}. "
f"If you are using the GNN in a `SequenceModel` and want to use this key, it needs to be the first model."
)


Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
intersphinx_mapping = {
"python": ("https://docs.python.org/3/", None),
"sphinx": ("https://www.sphinx-doc.org/en/master/", None),
"torch": ("https://pytorch.org/docs/master", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"torchrl": ("https://pytorch.org/rl/stable/", None),
"tensordict": ("https://pytorch.org/tensordict/stable", None),
}
Expand Down

0 comments on commit 58ff47b

Please sign in to comment.