Skip to content

Commit

Permalink
move options to attrs for hdf5
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Aug 21, 2024
1 parent 9d8e1f2 commit a3f419b
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
9 changes: 8 additions & 1 deletion minari/dataset/_storages/arrow_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def update_episode_metadata(
metadata = json.load(file)
metadata.update(new_metadata)
with open(metadata_path, "w") as file:
json.dump(metadata, file)
json.dump(metadata, file, cls=NumpyEncoder)

def get_episode_metadata(self, episode_indices: Iterable[int]) -> Iterable[Dict]:
for episode_id in episode_indices:
Expand Down Expand Up @@ -253,3 +253,10 @@ def _decode_info(values: pa.Array):
value = value.reshape(len(value), *data_shape)
nested_dict[field.name] = value
return nested_dict


class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return super().default(obj)
17 changes: 12 additions & 5 deletions minari/dataset/_storages/hdf5_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,12 @@ def get_episode_metadata(self, episode_indices: Iterable[int]) -> Iterable[Dict]
ep_group = file[f"episode_{ep_idx}"]
assert isinstance(ep_group, h5py.Group)
metadata: dict = dict(ep_group.attrs)
if "options" in ep_group:
options_group = ep_group["options"]
assert isinstance(options_group, h5py.Group)
metadata["options"] = self._decode_dict(options_group)
if "option_names" in metadata:
metadata["options"] = {}
for name in metadata["option_names"]:
metadata["options"][name] = metadata[f"options/{name}"]
del metadata[f"options/{name}"]
del metadata["option_names"]
if metadata.get("seed") is not None:
metadata["seed"] = int(metadata["seed"])

Expand Down Expand Up @@ -167,6 +169,12 @@ def update_episodes(self, episodes: Iterable[EpisodeBuffer]):
if eps_buff.seed is not None:
assert "seed" not in episode_group.attrs.keys()
episode_group.attrs["seed"] = eps_buff.seed
if eps_buff.options is not None:
assert "options" not in episode_group.attrs.keys()
for name, option in eps_buff.options.items():
episode_group.attrs[f"options/{name}"] = option
episode_group.attrs["option_names"] = list(eps_buff.options.keys())

episode_steps = len(eps_buff.rewards)
episode_group.attrs["total_steps"] = episode_steps
additional_steps += episode_steps
Expand All @@ -178,7 +186,6 @@ def update_episodes(self, episodes: Iterable[EpisodeBuffer]):
"terminations": eps_buff.terminations,
"truncations": eps_buff.truncations,
"infos": eps_buff.infos,
"options": eps_buff.options,
}
_add_episode_to_group(dict_buffer, episode_group)

Expand Down
9 changes: 7 additions & 2 deletions tests/dataset/test_minari_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,9 @@ def test_minari_get_dataset_size_from_buffer(

num_episodes = 10
seed = 42
observation, _ = env.reset(seed=seed)
episode_buffer = EpisodeBuffer(observations=observation, seed=seed)
options = {"int": 3, "array": np.array([1, 2, 3])}
observation, _ = env.reset(seed=seed, options=options)
episode_buffer = EpisodeBuffer(observations=observation, seed=seed, options=options)

for episode in range(num_episodes):
terminated = False
Expand Down Expand Up @@ -291,6 +292,10 @@ def test_minari_get_dataset_size_from_buffer(
)

assert dataset.storage.metadata["dataset_size"] == dataset.storage.get_size()
ep_metadata_0 = next(iter(dataset.storage.get_episode_metadata([0])))
assert ep_metadata_0["seed"] == seed
assert ep_metadata_0["options"]["int"] == options["int"]
assert np.all(ep_metadata_0["options"]["array"] == options["array"])

check_data_integrity(dataset, list(dataset.episode_indices))

Expand Down

0 comments on commit a3f419b

Please sign in to comment.