Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Sep 3, 2024
1 parent 1a8a2e5 commit 2e92824
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions benchmarl/models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,12 @@ def _perform_checks(self):

def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
# Gather in_key
input = [tensordict.get(in_key) for in_key in self.in_keys]
input = [
tensordict.get(in_key)
for in_key in self.in_keys
if _unravel_key_to_tuple(in_key)[-1]
not in (self.position_key, self.velocity_key)
]

# Retrieve position
if self.position_key is not None:
Expand All @@ -282,10 +287,7 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)
else:
pos = tensordict.get(self._full_position_key)
if (
not self.exclude_pos_from_node_features
and self._full_position_key not in self.in_keys
):
if not self.exclude_pos_from_node_features:
input.append(pos)
else:
pos = None
Expand All @@ -304,8 +306,7 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)
else:
vel = tensordict.get(self._full_velocity_key)
if self._full_velocity_key not in self.in_keys:
input.append(vel)
input.append(vel)
else:
vel = None

Expand Down

0 comments on commit 2e92824

Please sign in to comment.