diff --git a/android_env/wrappers/image_rescale_wrapper.py b/android_env/wrappers/image_rescale_wrapper.py index 268d702..6faeebc 100644 --- a/android_env/wrappers/image_rescale_wrapper.py +++ b/android_env/wrappers/image_rescale_wrapper.py @@ -71,14 +71,16 @@ def _process_pixels(self, raw_observation: np.ndarray) -> np.ndarray: return self._resize_image_array(image, new_shape) def _resize_image_array( - self, - grayscale_or_rbg_array: np.ndarray, - new_shape: Sequence[int]) -> np.ndarray: + self, grayscale_or_rbg_array: np.ndarray, new_shape: np.ndarray + ) -> np.ndarray: """Resize color or grayscale/action_layer array to new_shape.""" - assert np.array(new_shape).ndim == 1 + assert new_shape.ndim == 1 assert len(new_shape) == 2 - resized_array = np.array(Image.fromarray( - grayscale_or_rbg_array.astype('uint8')).resize(new_shape)) + resized_array = np.array( + Image.fromarray(grayscale_or_rbg_array.astype('uint8')).resize( + tuple(new_shape) + ) + ) if resized_array.ndim == 2: return np.expand_dims(resized_array, axis=-1) return resized_array