Skip to content

Commit

Permalink
Uniform get_episodes return type (#232)
Browse files Browse the repository at this point in the history
  • Loading branch information
younik authored Aug 20, 2024
1 parent cc2a22a commit 9d8e1f2
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 45 deletions.
14 changes: 5 additions & 9 deletions minari/dataset/_storages/arrow_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import pathlib
from itertools import zip_longest
from typing import Any, Dict, Iterable, List, Optional, Sequence
from typing import Any, Dict, Iterable, Optional, Sequence

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -65,16 +65,13 @@ def update_episode_metadata(
with open(metadata_path, "w") as file:
json.dump(metadata, file)

def get_episode_metadata(self, episode_indices: Iterable[int]) -> List[Dict]:
ep_metadata = []
def get_episode_metadata(self, episode_indices: Iterable[int]) -> Iterable[Dict]:
for episode_id in episode_indices:
metadata_path = self.data_path.joinpath(str(episode_id), "metadata.json")
with open(metadata_path) as file:
ep_metadata.append(json.load(file))
yield json.load(file)

return ep_metadata

def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]:
def get_episodes(self, episode_indices: Iterable[int]) -> Iterable[dict]:
dataset = pa.dataset.dataset(
[
pa.dataset.dataset(
Expand All @@ -101,8 +98,7 @@ def _to_dict(id, episode):
else {},
}

episodes = map(_to_dict, episode_indices, dataset.to_batches())
return list(episodes)
return map(_to_dict, episode_indices, dataset.to_batches())

def update_episodes(self, episodes: Iterable[EpisodeBuffer]):
total_steps = self.total_steps
Expand Down
13 changes: 4 additions & 9 deletions minari/dataset/_storages/hdf5_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def update_episode_metadata(
ep_group = file[f"episode_{episode_id}"]
ep_group.attrs.update(metadata)

def get_episode_metadata(self, episode_indices: Iterable[int]) -> List[Dict]:
out = []
def get_episode_metadata(self, episode_indices: Iterable[int]) -> Iterable[Dict]:
with h5py.File(self._file_path, "r") as file:
for ep_idx in episode_indices:
ep_group = file[f"episode_{ep_idx}"]
Expand All @@ -81,9 +80,8 @@ def get_episode_metadata(self, episode_indices: Iterable[int]) -> List[Dict]:
metadata["options"] = self._decode_dict(options_group)
if metadata.get("seed") is not None:
metadata["seed"] = int(metadata["seed"])
out.append(metadata)

return out
yield metadata

def _decode_space(
self,
Expand Down Expand Up @@ -127,8 +125,7 @@ def _decode_dict(self, dict_group: h5py.Group) -> Dict:
)
return result

def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]:
outs = []
def get_episodes(self, episode_indices: Iterable[int]) -> Iterable[dict]:
with h5py.File(self._file_path, "r") as file:
for ep_idx in episode_indices:
ep_group = file[f"episode_{ep_idx}"]
Expand All @@ -154,9 +151,7 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]:
assert isinstance(group_value, h5py.Dataset)
ep_dict[key] = group_value[:]

outs.append(ep_dict)

return outs
yield ep_dict

def update_episodes(self, episodes: Iterable[EpisodeBuffer]):
additional_steps = 0
Expand Down
33 changes: 16 additions & 17 deletions minari/dataset/minari_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import os
import re
from dataclasses import dataclass, field
from typing import Callable, Iterable, Iterator, List, Optional, Union
from typing import Callable, Iterable, Iterator, List

import gymnasium as gym
import numpy as np
import numpy.typing as npt
from gymnasium import error, logger
from gymnasium.envs.registration import EnvSpec
from packaging.requirements import InvalidRequirement, Requirement
Expand Down Expand Up @@ -73,7 +74,7 @@ def gen_dataset_id(

@dataclass
class MinariDatasetSpec:
env_spec: Optional[EnvSpec]
env_spec: EnvSpec | None
total_episodes: int
total_steps: int
dataset_id: str
Expand Down Expand Up @@ -102,8 +103,8 @@ class MinariDataset:

def __init__(
self,
data: Union[MinariStorage, PathLike],
episode_indices: Optional[np.ndarray] = None,
data: MinariStorage | PathLike,
episode_indices: npt.NDArray[np.int_] | None = None,
):
"""Initialize properties of the Minari Dataset.
Expand All @@ -120,7 +121,8 @@ def __init__(

if episode_indices is None:
episode_indices = np.arange(self._data.total_episodes)
self._episode_indices: np.ndarray = episode_indices
assert episode_indices is not None
self._episode_indices: npt.NDArray[np.int_] = episode_indices
self._total_steps = None

metadata = self._data.metadata
Expand Down Expand Up @@ -248,23 +250,21 @@ def sample_episodes(self, n_episodes: int) -> Iterable[EpisodeData]:
return list(map(lambda data: EpisodeData(**data), episodes))

def iterate_episodes(
self, episode_indices: Optional[List[int]] = None
self, episode_indices: Iterable[int] | None = None
) -> Iterator[EpisodeData]:
"""Iterate over episodes from the dataset.
Args:
episode_indices (Optional[List[int]], optional): episode indices to iterate over.
episode_indices (Optional[Iterable[int]], optional): episode indices to iterate over.
"""
if episode_indices is None:
assert self.episode_indices is not None
assert self.episode_indices.ndim == 1
episode_indices = self.episode_indices.tolist()
episode_indices = self.episode_indices

assert episode_indices is not None

for episode_index in episode_indices:
data = self.storage.get_episodes([episode_index])[0]
yield EpisodeData(**data)
episodes_data = self.storage.get_episodes(episode_indices)
return map(lambda data: EpisodeData(**data), episodes_data)

def update_dataset_from_buffer(self, buffer: List[EpisodeBuffer]):
"""Additional data can be added to the Minari Dataset from a list of episode dictionary buffers.
Expand All @@ -282,9 +282,8 @@ def __iter__(self):
return self.iterate_episodes()

def __getitem__(self, idx: int) -> EpisodeData:
episodes_data = self.storage.get_episodes([self.episode_indices[idx]])
assert len(episodes_data) == 1
return EpisodeData(**episodes_data[0])
episode = self.iterate_episodes([self.episode_indices[idx]])
return next(episode)

def __len__(self) -> int:
return self.total_episodes
Expand All @@ -308,12 +307,12 @@ def total_steps(self) -> int:
return int(self._total_steps)

@property
def episode_indices(self) -> np.ndarray:
def episode_indices(self) -> npt.NDArray[np.int_]:
"""Indices of the available episodes to sample within the Minari dataset."""
return self._episode_indices

@episode_indices.setter
def episode_indices(self, new_value: np.ndarray):
def episode_indices(self, new_value: npt.NDArray[np.int_]):
self._total_steps = None # invalidate cache
self._episode_indices = new_value

Expand Down
10 changes: 5 additions & 5 deletions minari/dataset/minari_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pathlib
import warnings
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, Optional, Union

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -239,14 +239,14 @@ def update_episode_metadata(
...

@abstractmethod
def get_episode_metadata(self, episode_indices: Iterable[int]) -> List[Dict]:
def get_episode_metadata(self, episode_indices: Iterable[int]) -> Iterable[Dict]:
"""Get the metadata of episodes.
Args:
episode_indices (Iterable[int]): episodes id to return
Returns:
metadatas (List[Dict]): list of episodes metadata
metadatas (Iterable[Dict]): episodes metadata
"""
...

Expand All @@ -271,14 +271,14 @@ def apply(
return map(function, ep_dicts)

@abstractmethod
def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]:
def get_episodes(self, episode_indices: Iterable[int]) -> Iterable[dict]:
"""Get a list of episodes.
Args:
episode_indices (Iterable[int]): episodes id to return
Returns:
episodes (List[dict]): list of episodes data
episodes (Iterable[dict]): episodes data
"""
...

Expand Down
6 changes: 4 additions & 2 deletions tests/data_collector/test_data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,10 @@ def test_reproducibility(seed, data_format, options, register_dummy_envs):
# Step through the env again using the stored seed and check it matches
env = dataset.recover_environment()

for episode in dataset.iterate_episodes():
episode_metadata = dataset.storage.get_episode_metadata([episode.id])[0]
assert len(dataset) == num_episodes
episodes = dataset.iterate_episodes()
metadatas = dataset.storage.get_episode_metadata(range(num_episodes))
for episode, episode_metadata in zip(episodes, metadatas):
episode_seed = episode_metadata["seed"]
assert episode_seed >= 0
if seed is not None:
Expand Down
6 changes: 3 additions & 3 deletions tests/dataset/test_minari_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def test_add_episodes(tmp_dataset_dir, data_format):
assert storage.total_episodes == n_episodes
assert storage.total_steps == n_episodes * steps_per_episode

for i, ep in enumerate(episodes):
storage_ep = storage.get_episodes([i])[0]
storage_episodes = storage.get_episodes(range(n_episodes))
for ep, storage_ep in zip(episodes, storage_episodes):
assert np.all(ep.observations == storage_ep["observations"])
assert np.all(ep.actions == storage_ep["actions"])
assert np.all(ep.rewards == storage_ep["rewards"])
Expand Down Expand Up @@ -319,6 +319,6 @@ def test_seed_change(tmp_dataset_dir, data_format):

assert storage.total_episodes == len(seeds)
episodes_metadata = storage.get_episode_metadata(range(len(episodes)))
assert len(episodes_metadata) == len(seeds)
assert len(list(episodes_metadata)) == len(seeds)
for seed, ep_metadata in zip(seeds, episodes_metadata):
assert ep_metadata.get("seed") == seed

0 comments on commit 9d8e1f2

Please sign in to comment.