Skip to content

Commit

Permalink
Fix: Avoid error when missing dataset dependency (#241)
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Sep 5, 2024
1 parent be4ee5b commit 7bd0f16
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 11 deletions.
30 changes: 22 additions & 8 deletions minari/dataset/minari_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ def __init__(
self._action_space = action_space

@classmethod
def read(cls, data_path: PathLike) -> MinariStorage:
"""Create a MinariStorage to read data from a path.
def read_raw_metadata(cls, data_path: PathLike) -> Dict[str, Any]:
"""Read the raw metadata from a path.
Args:
data_path (str or Path): directory where the data is stored.
Returns:
A new MinariStorage object to read the data.
metadata (dict): metadata of the dataset.
Raises:
ValueError: if the specified path doesn't exist or doesn't contain any data.
Expand All @@ -55,6 +55,22 @@ def read(cls, data_path: PathLike) -> MinariStorage:
raise ValueError(f"No data found in data path {data_path}")
with open(metadata_file_path) as file:
metadata = json.load(file)
return metadata

@classmethod
def read(cls, data_path: PathLike) -> MinariStorage:
"""Create a MinariStorage to read data from a path.
Args:
data_path (str or Path): directory where the data is stored.
Returns:
A new MinariStorage object to read the data.
Raises:
ValueError: if the specified path doesn't exist or doesn't contain any data.
"""
metadata = MinariStorage.read_raw_metadata(data_path)

observation_space = None
action_space = None
Expand Down Expand Up @@ -85,7 +101,7 @@ def read(cls, data_path: PathLike) -> MinariStorage:
from minari.dataset._storages import get_minari_storage # avoid circular import

return get_minari_storage(metadata["data_format"])(
data_path,
pathlib.Path(data_path),
observation_space,
action_space,
)
Expand Down Expand Up @@ -180,8 +196,7 @@ def _create(
@property
def metadata(self) -> Dict[str, Any]:
"""Metadata of the dataset."""
with open(self.data_path.joinpath(METADATA_FILE_NAME)) as file:
metadata = json.load(file)
metadata = MinariStorage.read_raw_metadata(self.data_path)

metadata["observation_space"] = self.observation_space
metadata["action_space"] = self.action_space
Expand Down Expand Up @@ -215,8 +230,7 @@ def update_metadata(self, metadata: Dict):
assert isinstance(metadata.get("author_email", set()), set)
assert isinstance(metadata.get("minari_version", ""), str)

with open(self.data_path.joinpath(METADATA_FILE_NAME)) as file:
saved_metadata = json.load(file)
saved_metadata = MinariStorage.read_raw_metadata(self.data_path)

forbidden_keys = not_updatable_keys.intersection(metadata.keys())
forbidden_keys = forbidden_keys.intersection(saved_metadata.keys())
Expand Down
6 changes: 4 additions & 2 deletions minari/storage/hosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from gymnasium import logger

from minari.dataset.minari_dataset import gen_dataset_id, parse_dataset_id
from minari.dataset.minari_storage import METADATA_FILE_NAME
from minari.dataset.minari_storage import METADATA_FILE_NAME, MinariStorage
from minari.storage.datasets_root_dir import get_dataset_path
from minari.storage.local import dataset_id_sort_key, load_dataset
from minari.storage.remotes import get_cloud_storage
Expand Down Expand Up @@ -171,7 +171,9 @@ def download_dataset(dataset_id: str, force_download: bool = False):

# Skip a force download of an incompatible dataset version
if dataset_version in compatible_dataset_versions:
combined_datasets = load_dataset(dataset_id).spec.combined_datasets
data_path = file_path.joinpath("data")
metadata = MinariStorage.read_raw_metadata(data_path)
combined_datasets = metadata.get("combined_datasets", [])

# If the dataset is a combination of other datasets download the subdatasets recursively
if len(combined_datasets) > 0:
Expand Down
2 changes: 1 addition & 1 deletion minari/storage/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def recurse_directories(base_path, namespace):
for dst_id in dataset_ids:
data_path = os.path.join(datasets_path, dst_id, "data")
try:
metadata = MinariStorage.read(data_path).metadata
metadata = MinariStorage.read_raw_metadata(data_path)
metadata_id = metadata["dataset_id"]

if dst_id != metadata_id:
Expand Down

0 comments on commit 7bd0f16

Please sign in to comment.