Skip to content

Commit

Permalink
ENH: improve MinariDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Jun 28, 2023
1 parent 3c2eb06 commit e3677f1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 27 deletions.
11 changes: 0 additions & 11 deletions PUBLISHING.md

This file was deleted.

35 changes: 19 additions & 16 deletions minari/dataset/minari_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def __init__(
self._additional_data_id = 0
if episode_indices is None:
episode_indices = np.arange(self._data.total_episodes)
self._episode_indices = episode_indices
self.episode_indices = episode_indices

self.spec = MinariDatasetSpec(
env_spec=self._data.env_spec,
Expand All @@ -210,25 +210,20 @@ def __init__(
@property
def total_episodes(self):
"""Total episodes recorded in the Minari dataset."""
assert self._episode_indices is not None
return len(self._episode_indices)
assert self.episode_indices is not None
return len(self.episode_indices)

@property
def total_steps(self):
"""Total episodes steps in the Minari dataset."""
if self._total_steps is None:
t_steps = self._data.apply(
lambda episode: episode["total_steps"],
episode_indices=self._episode_indices,
episode_indices=self.episode_indices,
)
self._total_steps = sum(t_steps)
return self._total_steps

@property
def episode_indices(self) -> np.ndarray:
"""Indices of the available episodes to sample within the Minari dataset."""
return self._episode_indices

def recover_environment(self):
"""Recover the Gymnasium environment used to create the dataset.
Expand All @@ -255,9 +250,9 @@ def filter_episodes(self, condition: Callable[[h5py.Group], bool]) -> MinariData
Args:
condition (Callable[[h5py.Group], bool]): callable that accepts an episode group and returns True if certain condition is met.
"""
mask = self._data.apply(condition, episode_indices=self._episode_indices)
assert self._episode_indices is not None
return MinariDataset(self._data, episode_indices=self._episode_indices[mask])
mask = self._data.apply(condition, episode_indices=self.episode_indices)
assert self.episode_indices is not None
return MinariDataset(self._data, episode_indices=self.episode_indices[mask])

def sample_episodes(self, n_episodes: int) -> Iterable[EpisodeData]:
"""Sample n number of episodes from the dataset.
Expand All @@ -266,7 +261,7 @@ def sample_episodes(self, n_episodes: int) -> Iterable[EpisodeData]:
n_episodes (Optional[int], optional): number of episodes to sample.
"""
indices = self._generator.choice(
self._episode_indices, size=n_episodes, replace=False
self.episode_indices, size=n_episodes, replace=False
)
episodes = self._data.get_episodes(indices)
return list(map(lambda data: EpisodeData(**data), episodes))
Expand All @@ -280,9 +275,9 @@ def iterate_episodes(
episode_indices (Optional[List[int]], optional): episode indices to iterate over.
"""
if episode_indices is None:
assert self._episode_indices is not None
assert self._episode_indices.ndim == 1
episode_indices = self._episode_indices.tolist()
assert self.episode_indices is not None
assert self.episode_indices.ndim == 1
episode_indices = self.episode_indices.tolist()

assert episode_indices is not None

Expand Down Expand Up @@ -386,3 +381,11 @@ def update_dataset_from_buffer(self, buffer: List[dict]):

def __iter__(self):
return self.iterate_episodes()

def __getitem__(self, idx: int) -> EpisodeData:
episodes_data = self._data.get_episodes([self.episode_indices[idx]])
assert len(episodes_data) == 1
return EpisodeData(**episodes_data[0])

def __len__(self) -> int:
return self.total_episodes

0 comments on commit e3677f1

Please sign in to comment.