Skip to content

Commit

Permalink
fix gif generation (#234)
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Aug 23, 2024
1 parent 1306bab commit ed39de1
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 41 deletions.
33 changes: 15 additions & 18 deletions docs/_scripts/generate_gif.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,30 @@ def _space_at(values, index):
return values[index]


def generate_gif(dataset_id, path, num_frames=256, fps=16):
def generate_gif(dataset_id, path, num_frames=512, fps=32):
dataset = minari.load_dataset(dataset_id)
env = dataset.recover_environment(render_mode="rgb_array")
images = []
if "seed" not in dataset.storage.get_episode_metadata([0])[0]:
raise ValueError("Cannot reproduce episodes with unknown seed.")

episode_id = 0
while len(images) < num_frames:
episode = dataset[episode_id]
episode_metadata = dataset.storage.get_episode_metadata([episode_id])[0]
env.reset(
seed=episode_metadata.get("seed"), options=episode_metadata.get("options")
)

metadatas = dataset.storage.get_episode_metadata(dataset.episode_indices)
for episode, episode_metadata in zip(dataset.iterate_episodes(), metadatas):
seed = episode_metadata.get("seed")
if episode.id == 0 and seed is None:
raise ValueError("Cannot reproduce episodes with unknown seed.")

env.reset(seed=seed, options=episode_metadata.get("options"))
images.append(env.render())
for step_id in range(len(episode)):
act = _space_at(episode.actions, step_id)
env.step(act)
images.append(env.render())
if len(images) > num_frames:
env.close()
gif_file = os.path.join(path, f"{dataset_id}.gif")
imageio.mimsave(gif_file, images, fps=fps)
return gif_file

episode_id += 1

env.close()

gif_file = os.path.join(path, f"{dataset_id}.gif")
imageio.mimsave(gif_file, images, fps=fps)
return gif_file
raise ValueError("There are not enough steps in the dataset.")


def main(argv):
Expand Down
65 changes: 42 additions & 23 deletions minari/dataset/_storages/hdf5_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,7 @@ def get_episode_metadata(self, episode_indices: Iterable[int]) -> Iterable[Dict]
ep_group = file[f"episode_{ep_idx}"]
assert isinstance(ep_group, h5py.Group)
metadata: dict = dict(ep_group.attrs)
if "option_names" in metadata:
metadata["options"] = {}
for name in metadata["option_names"]:
metadata["options"][name] = metadata[f"options/{name}"]
del metadata[f"options/{name}"]
del metadata["option_names"]
metadata = unflatten_dict(metadata)
if metadata.get("seed") is not None:
metadata["seed"] = int(metadata["seed"])

Expand Down Expand Up @@ -114,19 +109,6 @@ def _decode_space(
assert isinstance(hdf_ref, h5py.Dataset)
return hdf_ref[()]

def _decode_dict(self, dict_group: h5py.Group) -> Dict:
result = {}
for key, value in dict_group.items():
if isinstance(value, h5py.Group):
result[key] = self._decode_dict(value)
elif isinstance(value, h5py.Dataset):
result[key] = value[()]
else:
raise ValueError(
"Infos are in an unsupported format; see Minari documentation for supported formats."
)
return result

def get_episodes(self, episode_indices: Iterable[int]) -> Iterable[dict]:
with h5py.File(self._file_path, "r") as file:
for ep_idx in episode_indices:
Expand All @@ -136,7 +118,7 @@ def get_episodes(self, episode_indices: Iterable[int]) -> Iterable[dict]:
if "infos" in ep_group:
info_group = ep_group["infos"]
assert isinstance(info_group, h5py.Group)
infos = self._decode_dict(info_group)
infos = _decode_info(info_group)

ep_dict = {
"id": ep_idx,
Expand Down Expand Up @@ -171,9 +153,8 @@ def update_episodes(self, episodes: Iterable[EpisodeBuffer]):
episode_group.attrs["seed"] = eps_buff.seed
if eps_buff.options is not None:
assert "options" not in episode_group.attrs.keys()
for name, option in eps_buff.options.items():
episode_group.attrs[f"options/{name}"] = option
episode_group.attrs["option_names"] = list(eps_buff.options.keys())
flatten_option = flatten_dict(eps_buff.options, "options")
episode_group.attrs.update(flatten_option)

episode_steps = len(eps_buff.rewards)
episode_group.attrs["total_steps"] = episode_steps
Expand Down Expand Up @@ -243,3 +224,41 @@ def _add_episode_to_group(episode_buffer: Dict, episode_group: h5py.Group):
episode_group.create_dataset(
key, data=data, dtype=dtype, chunks=True, maxshape=(None, *dshape)
)


def _decode_info(info_group: h5py.Group) -> Dict:
result = {}
for key, value in info_group.items():
if isinstance(value, h5py.Group):
result[key] = _decode_info(value)
elif isinstance(value, h5py.Dataset):
result[key] = value[()]
else:
raise ValueError(
"Infos are in an unsupported format; see Minari documentation for supported formats."
)
return result


def flatten_dict(d: Dict, parent_key: str) -> Dict:
flatten_d = {}
for k, v in d.items():
new_key = f"{parent_key}/{k}"
if isinstance(v, dict):
flatten_d.update(flatten_dict(v, new_key))
else:
flatten_d[new_key] = v
return flatten_d


def unflatten_dict(d: Dict) -> Dict:
result = {}
for k, v in d.items():
keys = k.split("/")
current = result
for key in keys[:-1]:
if key not in current:
current[key] = {}
current = current[key]
current[keys[-1]] = v
return result

0 comments on commit ed39de1

Please sign in to comment.