From e3cfe2d2854c04f3ab8b291c4136437f4cc8cde5 Mon Sep 17 00:00:00 2001 From: Omar Younis <42100908+younik@users.noreply.github.com> Date: Thu, 29 Aug 2024 13:02:38 +0200 Subject: [PATCH] fix total_steps computation (#237) --- minari/dataset/minari_dataset.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index f15ea332..c310d2ae 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -119,11 +119,12 @@ def __init__( else: raise ValueError(f"Unrecognized type {type(data)} for data") + self._total_steps = None if episode_indices is None: episode_indices = np.arange(self._data.total_episodes) + self._total_steps = self._data.total_steps assert episode_indices is not None self._episode_indices: npt.NDArray[np.int_] = episode_indices - self._total_steps = None metadata = self._data.metadata @@ -307,13 +308,10 @@ def total_episodes(self) -> int: def total_steps(self) -> int: """Total episodes steps in the Minari dataset.""" if self._total_steps is None: - if self.episode_indices is None: - self._total_steps = self.storage.total_steps - else: - self._total_steps = 0 - metadatas = self.storage.get_episode_metadata(self.episode_indices) - for m in metadatas: - self._total_steps += m["total_steps"] + self._total_steps = 0 + metadatas = self.storage.get_episode_metadata(self.episode_indices) + for m in metadatas: + self._total_steps += m["total_steps"] return int(self._total_steps) @property