Skip to content

Commit

Permalink
added temporary workaround for metadata uploading
Browse files Browse the repository at this point in the history
  • Loading branch information
balisujohn committed Jun 23, 2023
1 parent 8e02372 commit 308a8c0
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 17 deletions.
42 changes: 32 additions & 10 deletions minari/integrations/hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import json
import warnings
from collections import OrderedDict
Expand All @@ -6,11 +7,10 @@
import gymnasium as gym
import numpy as np
from datasets import Dataset, DatasetInfo, load_dataset
from huggingface_hub import whoami
from huggingface_hub import hf_hub_download, upload_file, whoami

import minari
from minari import MinariDataset
from minari.dataset.minari_dataset import MinariDataset
from minari.serialization import deserialize_space, serialize_space


Expand Down Expand Up @@ -100,16 +100,20 @@ def convert_minari_dataset_to_hugging_face_dataset(dataset: MinariDataset):
return hugging_face_dataset


def _cast_to_numpy_recursive(space: gym.spaces.space, entry: Union[tuple, dict, list]):
def _cast_to_numpy_recursive(
space: gym.spaces.Space, entry: Union[tuple, dict, list]
) -> Union[OrderedDict, np.ndarray, tuple]:
"""Recurses on an observation or action space, and mirrors the recursion on an observation or action, casting all components to numpy arrays."""
if isinstance(space, gym.spaces.Dict):
if isinstance(space, gym.spaces.Dict) and isinstance(entry, dict):
result = OrderedDict()
for key in space.spaces.keys():
result[key] = _cast_to_numpy_recursive(space.spaces[key], entry[key])
return result
elif isinstance(space, gym.spaces.Tuple):
result = []
for i in range(len(entry.keys())):
elif isinstance(space, gym.spaces.Tuple) and isinstance(
entry, dict
): # we substitute tuples with dicts in the hugging face dataset
result = [] # with keys corresponding to the elements index in the tuple.
for i in range(len(space.spaces)):
result.append(
_cast_to_numpy_recursive(space.spaces[i], entry[f"_index_{str(i)}"])
)
Expand All @@ -119,12 +123,13 @@ def _cast_to_numpy_recursive(space: gym.spaces.space, entry: Union[tuple, dict,
elif isinstance(space, gym.spaces.Box):
return np.asarray(entry, dtype=space.dtype)
else:
raise TypeError(f"{type(state)} is not supported.")
raise TypeError(
f"{type(space)} is not supported. or there is type mismatch with the entry type {type(entry)}"
)


def convert_hugging_face_dataset_to_minari_dataset(dataset: Dataset):


description_data = json.loads(dataset.info.description)

action_space = deserialize_space(description_data["action_space"])
Expand Down Expand Up @@ -181,10 +186,27 @@ def push_dataset_to_hugging_face(dataset: Dataset, path: str, private: bool = Tr
"Please log in using the huggingface-hub cli in order to push to a remote dataset."
)
return

metadata_file = io.BytesIO(dataset.info.description.encode("utf-8"))

upload_file(
path_or_fileobj=metadata_file,
path_in_repo="metadata.json",
repo_id=path,
repo_type="dataset",
)

dataset.push_to_hub(path, private=private)


def pull_dataset_from_hugging_face(path: str) -> Dataset:
"""Pulls a hugging face dataset froms the HuggingFace respository at the specfied path."""
"""Pulls a hugging face dataset from the HuggingFace repository at the specified path."""
hugging_face_dataset = load_dataset(path)

with open(
hf_hub_download(filename="metadata.json", repo_id=path, repo_type="dataset"),
) as metadata_file:
metadata_str = metadata_file.read()
hugging_face_dataset["train"].info.description = metadata_str

return hugging_face_dataset["train"]
15 changes: 8 additions & 7 deletions tests/integrations/test_hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,18 @@ def test_convert_minari_dataset_to_hugging_face_dataset_and_back(dataset_id, env
check_load_and_delete_dataset(dataset_id)


#@pytest.mark.skip(
# reason="relies on a private repo, just using this for testing while developing"
#)
@pytest.mark.skip(
reason="relies on a private repo, if you want to use this test locally, you'll need to change it to point at a repo you control"
)
@pytest.mark.parametrize(
"dataset_id,env_id",
[
("cartpole-test-v0", "CartPole-v1"),
("dummy-dict-test-v0", "DummyDictEnv-v0"),
("dummy-box-test-v0", "DummyBoxEnv-v0"),
("dummy-tuple-test-v0", "DummyTupleEnv-v0"),
("dummy-combo-test-v0", "DummyComboEnv-v0"),
("dummy-tuple-discrete-box-test-v0", "DummyTupleDisceteBoxEnv-v0"),
],
)
def test_hugging_face_push_and_pull_dataset(dataset_id, env_id):
Expand All @@ -98,16 +103,12 @@ def test_hugging_face_push_and_pull_dataset(dataset_id, env_id):
)

hugging_face_dataset = convert_minari_dataset_to_hugging_face_dataset(dataset)
print("DATASET INFO BEFORE UPLOADING")
print(hugging_face_dataset.info)

push_dataset_to_hugging_face(hugging_face_dataset, "balisujohn/minari_test")
minari.delete_dataset(dataset_id)
recovered_hugging_face_dataset = pull_dataset_from_hugging_face(
"balisujohn/minari_test"
)
print("DATASET INFO AFTER UPLOADING AND DOWNLOADING")
print(recovered_hugging_face_dataset.info)
reconstructed_minari_dataset = convert_hugging_face_dataset_to_minari_dataset(
recovered_hugging_face_dataset
)
Expand Down

0 comments on commit 308a8c0

Please sign in to comment.