Skip to content

Commit

Permalink
nits
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jun 12, 2024
1 parent a500433 commit 2fb3c3a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
4 changes: 2 additions & 2 deletions benchmarl/models/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ def __init__(
def _perform_checks(self):
super()._perform_checks()

input_shape_tensor = None
self.image_in_keys = []
input_shape_image = None
self.image_in_keys = []
input_shape_tensor = None
self.tensor_in_keys = []
for input_key, input_spec in self.input_spec.items(True, True):
if (self.input_has_agent_dim and len(input_spec.shape) == 4) or (
Expand Down
4 changes: 1 addition & 3 deletions benchmarl/models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,7 @@ def _perform_checks(self):

input_shape = None
for input_key, input_spec in self.input_spec.items(True, True):
if (self.input_has_agent_dim and len(input_spec.shape) == 2) or (
not self.input_has_agent_dim and len(input_spec.shape) == 1
):
if self.input_has_agent_dim and len(input_spec.shape) == 2:
if input_shape is None:
input_shape = input_spec.shape[:-1]
else:
Expand Down

0 comments on commit 2fb3c3a

Please sign in to comment.