diff --git a/docs/_scripts/generate_gif.py b/docs/_scripts/generate_gif.py index 0a36111a..482a9a1f 100644 --- a/docs/_scripts/generate_gif.py +++ b/docs/_scripts/generate_gif.py @@ -22,33 +22,30 @@ def _space_at(values, index): return values[index] -def generate_gif(dataset_id, path, num_frames=256, fps=16): +def generate_gif(dataset_id, path, num_frames=512, fps=32): dataset = minari.load_dataset(dataset_id) env = dataset.recover_environment(render_mode="rgb_array") images = [] - if "seed" not in dataset.storage.get_episode_metadata([0])[0]: - raise ValueError("Cannot reproduce episodes with unknown seed.") - - episode_id = 0 - while len(images) < num_frames: - episode = dataset[episode_id] - episode_metadata = dataset.storage.get_episode_metadata([episode_id])[0] - env.reset( - seed=episode_metadata.get("seed"), options=episode_metadata.get("options") - ) + + metadatas = dataset.storage.get_episode_metadata(dataset.episode_indices) + for episode, episode_metadata in zip(dataset.iterate_episodes(), metadatas): + seed = episode_metadata.get("seed") + if episode.id == 0 and seed is None: + raise ValueError("Cannot reproduce episodes with unknown seed.") + + env.reset(seed=seed, options=episode_metadata.get("options")) images.append(env.render()) for step_id in range(len(episode)): act = _space_at(episode.actions, step_id) env.step(act) images.append(env.render()) + if len(images) > num_frames: + env.close() + gif_file = os.path.join(path, f"{dataset_id}.gif") + imageio.mimsave(gif_file, images, fps=fps) + return gif_file - episode_id += 1 - - env.close() - - gif_file = os.path.join(path, f"{dataset_id}.gif") - imageio.mimsave(gif_file, images, fps=fps) - return gif_file + raise ValueError("There are not enough steps in the dataset.") def main(argv): diff --git a/minari/dataset/_storages/hdf5_storage.py b/minari/dataset/_storages/hdf5_storage.py index 1278d606..03559c98 100644 --- a/minari/dataset/_storages/hdf5_storage.py +++ b/minari/dataset/_storages/hdf5_storage.py @@ -74,12 +74,7 @@ def get_episode_metadata(self, episode_indices: Iterable[int]) -> Iterable[Dict] ep_group = file[f"episode_{ep_idx}"] assert isinstance(ep_group, h5py.Group) metadata: dict = dict(ep_group.attrs) - if "option_names" in metadata: - metadata["options"] = {} - for name in metadata["option_names"]: - metadata["options"][name] = metadata[f"options/{name}"] - del metadata[f"options/{name}"] - del metadata["option_names"] + metadata = unflatten_dict(metadata) if metadata.get("seed") is not None: metadata["seed"] = int(metadata["seed"]) @@ -114,19 +109,6 @@ def _decode_space( assert isinstance(hdf_ref, h5py.Dataset) return hdf_ref[()] - def _decode_dict(self, dict_group: h5py.Group) -> Dict: - result = {} - for key, value in dict_group.items(): - if isinstance(value, h5py.Group): - result[key] = self._decode_dict(value) - elif isinstance(value, h5py.Dataset): - result[key] = value[()] - else: - raise ValueError( - "Infos are in an unsupported format; see Minari documentation for supported formats." - ) - return result - def get_episodes(self, episode_indices: Iterable[int]) -> Iterable[dict]: with h5py.File(self._file_path, "r") as file: for ep_idx in episode_indices: @@ -136,7 +118,7 @@ def get_episodes(self, episode_indices: Iterable[int]) -> Iterable[dict]: if "infos" in ep_group: info_group = ep_group["infos"] assert isinstance(info_group, h5py.Group) - infos = self._decode_dict(info_group) + infos = _decode_info(info_group) ep_dict = { "id": ep_idx, @@ -171,9 +153,8 @@ def update_episodes(self, episodes: Iterable[EpisodeBuffer]): episode_group.attrs["seed"] = eps_buff.seed if eps_buff.options is not None: assert "options" not in episode_group.attrs.keys() - for name, option in eps_buff.options.items(): - episode_group.attrs[f"options/{name}"] = option - episode_group.attrs["option_names"] = list(eps_buff.options.keys()) + flatten_option = flatten_dict(eps_buff.options, "options") + episode_group.attrs.update(flatten_option) episode_steps = len(eps_buff.rewards) episode_group.attrs["total_steps"] = episode_steps @@ -243,3 +224,41 @@ def _add_episode_to_group(episode_buffer: Dict, episode_group: h5py.Group): episode_group.create_dataset( key, data=data, dtype=dtype, chunks=True, maxshape=(None, *dshape) ) + + +def _decode_info(info_group: h5py.Group) -> Dict: + result = {} + for key, value in info_group.items(): + if isinstance(value, h5py.Group): + result[key] = _decode_info(value) + elif isinstance(value, h5py.Dataset): + result[key] = value[()] + else: + raise ValueError( + "Infos are in an unsupported format; see Minari documentation for supported formats." + ) + return result + + +def flatten_dict(d: Dict, parent_key: str) -> Dict: + flatten_d = {} + for k, v in d.items(): + new_key = f"{parent_key}/{k}" + if isinstance(v, dict): + flatten_d.update(flatten_dict(v, new_key)) + else: + flatten_d[new_key] = v + return flatten_d + + +def unflatten_dict(d: Dict) -> Dict: + result = {} + for k, v in d.items(): + keys = k.split("/") + current = result + for key in keys[:-1]: + if key not in current: + current[key] = {} + current = current[key] + current[keys[-1]] = v + return result