diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index e86c52e4..532e8041 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -734,6 +734,7 @@ def _get_excluded_keys(self, group: str): for other_group in self.group_map.keys(): if other_group != group: excluded_keys += [other_group, ("next", other_group)] + excluded_keys += ["info", (group, "info"), ("next", group, "info")] return excluded_keys def _optimizer_loop(self, group: str) -> TensorDictBase: diff --git a/benchmarl/models/gnn.py b/benchmarl/models/gnn.py index 5cb0f00e..9c6b6949 100644 --- a/benchmarl/models/gnn.py +++ b/benchmarl/models/gnn.py @@ -60,18 +60,23 @@ class Gnn(Model): gnn_class (Type[torch_geometric.nn.MessagePassing]): the gnn convolution class to use gnn_kwargs (dict, optional): the dict of arguments to pass to the gnn conv class position_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env - (we suggest to use the "info" dict) representing the agent position. This key will be processed as a - node feature (unless exclude_pos_from_node_features=True) and it will be used to construct edge features. + (in the `observation_spec`) representing the agent position. + To do this, your environment needs to have dictionary observations and one of the keys needs to be `position_key`. + This key will be processed as a node feature (unless exclude_pos_from_node_features=True) and it will be used to construct edge features. In particular, it will be used to compute relative positions (``pos_node_1 - pos_node_2``) and a one-dimensional distance for all neighbours in the graph. + If you want to use this feature in a :class:`~benchmarl.models.SequenceModel`, the GNN needs to be first in sequence. pos_features (int, optional): Needed when position_key is specified. It has to match to the last element of the shape the tensor under position_key. exclude_pos_from_node_features (optional, bool): If ``position_key`` is provided, wether to use it just to compute edge features or also include it in node features. velocity_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env - (we suggest to use the "info" dict) representing the agent velocity. This key will be processed as a node feature, and + (in the `observation_spec`) representing the agent position. + To do this, your environment needs to have dictionary observations and one of the keys needs to be `velocity_key`. + This key will be processed as a node feature, and it will be used to construct edge features. In particular, it will be used to compute relative velocities (``vel_node_1 - vel_node_2``) for all neighbours in the graph. + If you want to use this feature in a :class:`~benchmarl.models.SequenceModel`, the GNN needs to be first in sequence. vel_features (int, optional): Needed when velocity_key is specified. It has to match to the last element of the shape the tensor under velocity_key. edge_radius (float, optional): If topology is ``"from_pos"`` the radius to use to build the agent graph. @@ -170,8 +175,7 @@ def __init__( ) and not self.gnn_supports_edge_attrs: warnings.warn( "Position key or velocity key provided but GNN class does not support edge attributes. " - "These input keys will be ignored. If instead you want to process them as node features, " - "just set them (position_key or velocity_key) to null." + "These keys will not be used for computing edge features." ) if ( position_key is not None or velocity_key is not None @@ -369,10 +373,15 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: def _get_key_terminating_with(self, keys: List[NestedKey], key: str) -> NestedKey: for k in keys: k_tuple = _unravel_key_to_tuple(k) - if k_tuple[-1] == key and self.agent_group in k_tuple: + if ( + k_tuple[-1] == key + and self.agent_group in k_tuple + and not "next" == k_tuple[0] + ): return k raise KeyError( - f"Key terminating with {key} and containing {self.agent_group} not found in keys: {keys}" + f"Key terminating with {key} and containing {self.agent_group} not found in keys: {keys}. " + f"If you are using the GNN in a `SequenceModel` and want to use this key, it needs to be the first model." ) diff --git a/docs/source/conf.py b/docs/source/conf.py index f35b11ff..b9b9a792 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -34,7 +34,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3/", None), "sphinx": ("https://www.sphinx-doc.org/en/master/", None), - "torch": ("https://pytorch.org/docs/master", None), + "torch": ("https://pytorch.org/docs/stable/", None), "torchrl": ("https://pytorch.org/rl/stable/", None), "tensordict": ("https://pytorch.org/tensordict/stable", None), }