diff --git a/benchmarl/models/cnn.py b/benchmarl/models/cnn.py index efe77003..c3e7611d 100644 --- a/benchmarl/models/cnn.py +++ b/benchmarl/models/cnn.py @@ -200,9 +200,9 @@ def __init__( def _perform_checks(self): super()._perform_checks() - input_shape_tensor = None - self.image_in_keys = [] input_shape_image = None + self.image_in_keys = [] + input_shape_tensor = None self.tensor_in_keys = [] for input_key, input_spec in self.input_spec.items(True, True): if (self.input_has_agent_dim and len(input_spec.shape) == 4) or ( diff --git a/benchmarl/models/gnn.py b/benchmarl/models/gnn.py index 97d7eb85..2361416d 100644 --- a/benchmarl/models/gnn.py +++ b/benchmarl/models/gnn.py @@ -200,9 +200,7 @@ def _perform_checks(self): input_shape = None for input_key, input_spec in self.input_spec.items(True, True): - if (self.input_has_agent_dim and len(input_spec.shape) == 2) or ( - not self.input_has_agent_dim and len(input_spec.shape) == 1 - ): + if self.input_has_agent_dim and len(input_spec.shape) == 2: if input_shape is None: input_shape = input_spec.shape[:-1] else: