From 2fb3c3a4aaccc4905f6b458d4470ed196bd0ded9 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 12 Jun 2024 17:24:07 +0200 Subject: [PATCH] nits --- benchmarl/models/cnn.py | 4 ++-- benchmarl/models/gnn.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/benchmarl/models/cnn.py b/benchmarl/models/cnn.py index efe77003..c3e7611d 100644 --- a/benchmarl/models/cnn.py +++ b/benchmarl/models/cnn.py @@ -200,9 +200,9 @@ def __init__( def _perform_checks(self): super()._perform_checks() - input_shape_tensor = None - self.image_in_keys = [] input_shape_image = None + self.image_in_keys = [] + input_shape_tensor = None self.tensor_in_keys = [] for input_key, input_spec in self.input_spec.items(True, True): if (self.input_has_agent_dim and len(input_spec.shape) == 4) or ( diff --git a/benchmarl/models/gnn.py b/benchmarl/models/gnn.py index 97d7eb85..2361416d 100644 --- a/benchmarl/models/gnn.py +++ b/benchmarl/models/gnn.py @@ -200,9 +200,7 @@ def _perform_checks(self): input_shape = None for input_key, input_spec in self.input_spec.items(True, True): - if (self.input_has_agent_dim and len(input_spec.shape) == 2) or ( - not self.input_has_agent_dim and len(input_spec.shape) == 1 - ): + if self.input_has_agent_dim and len(input_spec.shape) == 2: if input_shape is None: input_shape = input_spec.shape[:-1] else: