diff --git a/benchmarl/models/gnn.py b/benchmarl/models/gnn.py index f18c9038..518c75f2 100644 --- a/benchmarl/models/gnn.py +++ b/benchmarl/models/gnn.py @@ -149,11 +149,8 @@ def __init__( if _unravel_key_to_tuple(key)[-1] not in (position_key, velocity_key) ] ) # Input keys - if ( - self.position_key is not None - and not not self.exclude_pos_from_node_features - ): - self.input_features += self.pos_features + if self.position_key is not None and not self.exclude_pos_from_node_features: + self.input_features += self.pos_features - 1 if self.velocity_key is not None: self.input_features += self.vel_features