diff --git a/minari/integrations/hugging_face.py b/minari/integrations/hugging_face.py index 28c549a7..cd0e0f46 100644 --- a/minari/integrations/hugging_face.py +++ b/minari/integrations/hugging_face.py @@ -1,3 +1,4 @@ +import io import json import warnings from collections import OrderedDict @@ -6,11 +7,10 @@ import gymnasium as gym import numpy as np from datasets import Dataset, DatasetInfo, load_dataset -from huggingface_hub import whoami +from huggingface_hub import hf_hub_download, upload_file, whoami import minari from minari import MinariDataset -from minari.dataset.minari_dataset import MinariDataset from minari.serialization import deserialize_space, serialize_space @@ -100,16 +100,20 @@ def convert_minari_dataset_to_hugging_face_dataset(dataset: MinariDataset): return hugging_face_dataset -def _cast_to_numpy_recursive(space: gym.spaces.space, entry: Union[tuple, dict, list]): +def _cast_to_numpy_recursive( + space: gym.spaces.Space, entry: Union[tuple, dict, list] +) -> Union[OrderedDict, np.ndarray, tuple]: """Recurses on an observation or action space, and mirrors the recursion on an observation or action, casting all components to numpy arrays.""" - if isinstance(space, gym.spaces.Dict): + if isinstance(space, gym.spaces.Dict) and isinstance(entry, dict): result = OrderedDict() for key in space.spaces.keys(): result[key] = _cast_to_numpy_recursive(space.spaces[key], entry[key]) return result - elif isinstance(space, gym.spaces.Tuple): - result = [] - for i in range(len(entry.keys())): + elif isinstance(space, gym.spaces.Tuple) and isinstance( + entry, dict + ): # we substitute tuples with dicts in the hugging face dataset + result = [] # with keys corresponding to the elements index in the tuple. + for i in range(len(space.spaces)): result.append( _cast_to_numpy_recursive(space.spaces[i], entry[f"_index_{str(i)}"]) ) @@ -119,12 +123,13 @@ def _cast_to_numpy_recursive(space: gym.spaces.space, entry: Union[tuple, dict, elif isinstance(space, gym.spaces.Box): return np.asarray(entry, dtype=space.dtype) else: - raise TypeError(f"{type(state)} is not supported.") + raise TypeError( + f"{type(space)} is not supported. or there is type mismatch with the entry type {type(entry)}" + ) def convert_hugging_face_dataset_to_minari_dataset(dataset: Dataset): - description_data = json.loads(dataset.info.description) action_space = deserialize_space(description_data["action_space"]) @@ -181,10 +186,27 @@ def push_dataset_to_hugging_face(dataset: Dataset, path: str, private: bool = Tr "Please log in using the huggingface-hub cli in order to push to a remote dataset." ) return + + metadata_file = io.BytesIO(dataset.info.description.encode("utf-8")) + + upload_file( + path_or_fileobj=metadata_file, + path_in_repo="metadata.json", + repo_id=path, + repo_type="dataset", + ) + dataset.push_to_hub(path, private=private) def pull_dataset_from_hugging_face(path: str) -> Dataset: - """Pulls a hugging face dataset froms the HuggingFace respository at the specfied path.""" + """Pulls a hugging face dataset from the HuggingFace repository at the specified path.""" hugging_face_dataset = load_dataset(path) + + with open( + hf_hub_download(filename="metadata.json", repo_id=path, repo_type="dataset"), + ) as metadata_file: + metadata_str = metadata_file.read() + hugging_face_dataset["train"].info.description = metadata_str + return hugging_face_dataset["train"] diff --git a/tests/integrations/test_hugging_face.py b/tests/integrations/test_hugging_face.py index 450514d4..f5d7a4be 100644 --- a/tests/integrations/test_hugging_face.py +++ b/tests/integrations/test_hugging_face.py @@ -72,13 +72,18 @@ def test_convert_minari_dataset_to_hugging_face_dataset_and_back(dataset_id, env check_load_and_delete_dataset(dataset_id) -#@pytest.mark.skip( -# reason="relies on a private repo, just using this for testing while developing" -#) +@pytest.mark.skip( + reason="relies on a private repo, if you want to use this test locally, you'll need to change it to point at a repo you control" +) @pytest.mark.parametrize( "dataset_id,env_id", [ + ("cartpole-test-v0", "CartPole-v1"), + ("dummy-dict-test-v0", "DummyDictEnv-v0"), + ("dummy-box-test-v0", "DummyBoxEnv-v0"), + ("dummy-tuple-test-v0", "DummyTupleEnv-v0"), ("dummy-combo-test-v0", "DummyComboEnv-v0"), + ("dummy-tuple-discrete-box-test-v0", "DummyTupleDisceteBoxEnv-v0"), ], ) def test_hugging_face_push_and_pull_dataset(dataset_id, env_id): @@ -98,16 +103,12 @@ def test_hugging_face_push_and_pull_dataset(dataset_id, env_id): ) hugging_face_dataset = convert_minari_dataset_to_hugging_face_dataset(dataset) - print("DATASET INFO BEFORE UPLOADING") - print(hugging_face_dataset.info) push_dataset_to_hugging_face(hugging_face_dataset, "balisujohn/minari_test") minari.delete_dataset(dataset_id) recovered_hugging_face_dataset = pull_dataset_from_hugging_face( "balisujohn/minari_test" ) - print("DATASET INFO AFTER UPLOADING AND DOWNLOADING") - print(recovered_hugging_face_dataset.info) reconstructed_minari_dataset = convert_hugging_face_dataset_to_minari_dataset( recovered_hugging_face_dataset )