diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 0c56135f..9edbda47 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -35,14 +35,14 @@ def __init__( self._action_space = action_space @classmethod - def read(cls, data_path: PathLike) -> MinariStorage: - """Create a MinariStorage to read data from a path. + def read_raw_metadata(cls, data_path: PathLike) -> Dict[str, Any]: + """Read the raw metadata from a path. Args: data_path (str or Path): directory where the data is stored. Returns: - A new MinariStorage object to read the data. + metadata (dict): metadata of the dataset. Raises: ValueError: if the specified path doesn't exist or doesn't contain any data. @@ -55,6 +55,22 @@ def read(cls, data_path: PathLike) -> MinariStorage: raise ValueError(f"No data found in data path {data_path}") with open(metadata_file_path) as file: metadata = json.load(file) + return metadata + + @classmethod + def read(cls, data_path: PathLike) -> MinariStorage: + """Create a MinariStorage to read data from a path. + + Args: + data_path (str or Path): directory where the data is stored. + + Returns: + A new MinariStorage object to read the data. + + Raises: + ValueError: if the specified path doesn't exist or doesn't contain any data. + """ + metadata = MinariStorage.read_raw_metadata(data_path) observation_space = None action_space = None @@ -85,7 +101,7 @@ def read(cls, data_path: PathLike) -> MinariStorage: from minari.dataset._storages import get_minari_storage # avoid circular import return get_minari_storage(metadata["data_format"])( - data_path, + pathlib.Path(data_path), observation_space, action_space, ) @@ -180,8 +196,7 @@ def _create( @property def metadata(self) -> Dict[str, Any]: """Metadata of the dataset.""" - with open(self.data_path.joinpath(METADATA_FILE_NAME)) as file: - metadata = json.load(file) + metadata = MinariStorage.read_raw_metadata(self.data_path) metadata["observation_space"] = self.observation_space metadata["action_space"] = self.action_space @@ -215,8 +230,7 @@ def update_metadata(self, metadata: Dict): assert isinstance(metadata.get("author_email", set()), set) assert isinstance(metadata.get("minari_version", ""), str) - with open(self.data_path.joinpath(METADATA_FILE_NAME)) as file: - saved_metadata = json.load(file) + saved_metadata = MinariStorage.read_raw_metadata(self.data_path) forbidden_keys = not_updatable_keys.intersection(metadata.keys()) forbidden_keys = forbidden_keys.intersection(saved_metadata.keys()) diff --git a/minari/storage/hosting.py b/minari/storage/hosting.py index d77c7a15..705509be 100644 --- a/minari/storage/hosting.py +++ b/minari/storage/hosting.py @@ -9,7 +9,7 @@ from gymnasium import logger from minari.dataset.minari_dataset import gen_dataset_id, parse_dataset_id -from minari.dataset.minari_storage import METADATA_FILE_NAME +from minari.dataset.minari_storage import METADATA_FILE_NAME, MinariStorage from minari.storage.datasets_root_dir import get_dataset_path from minari.storage.local import dataset_id_sort_key, load_dataset from minari.storage.remotes import get_cloud_storage @@ -171,7 +171,9 @@ def download_dataset(dataset_id: str, force_download: bool = False): # Skip a force download of an incompatible dataset version if dataset_version in compatible_dataset_versions: - combined_datasets = load_dataset(dataset_id).spec.combined_datasets + data_path = file_path.joinpath("data") + metadata = MinariStorage.read_raw_metadata(data_path) + combined_datasets = metadata.get("combined_datasets", []) # If the dataset is a combination of other datasets download the subdatasets recursively if len(combined_datasets) > 0: diff --git a/minari/storage/local.py b/minari/storage/local.py index 824c4538..d2ac680d 100644 --- a/minari/storage/local.py +++ b/minari/storage/local.py @@ -94,7 +94,7 @@ def recurse_directories(base_path, namespace): for dst_id in dataset_ids: data_path = os.path.join(datasets_path, dst_id, "data") try: - metadata = MinariStorage.read(data_path).metadata + metadata = MinariStorage.read_raw_metadata(data_path) metadata_id = metadata["dataset_id"] if dst_id != metadata_id: