Skip to content

Commit

Permalink
create namespace hierarchy at dataset creation
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Nov 6, 2024
1 parent abedd11 commit 70f9029
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 14 deletions.
6 changes: 6 additions & 0 deletions minari/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def create_namespace(
with open(directory / NAMESPACE_METADATA_FILENAME, "w") as file:
json.dump(metadata, file)

for parent_namespace in namespace_hierarchy(namespace):
if parent_namespace not in list_local_namespaces():
parent_namespace_path = get_dataset_path(parent_namespace)
with open(parent_namespace_path / NAMESPACE_METADATA_FILENAME, "w") as file:
json.dump({}, file)


def update_namespace_metadata(
namespace: str,
Expand Down
8 changes: 6 additions & 2 deletions minari/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from gymnasium.wrappers import RecordEpisodeStatistics # type: ignore

from minari.data_collector.episode_buffer import EpisodeBuffer
from minari.dataset.minari_dataset import MinariDataset
from minari.dataset.minari_dataset import MinariDataset, parse_dataset_id
from minari.dataset.minari_storage import MinariStorage
from minari.namespace import create_namespace, list_local_namespaces
from minari.serialization import deserialize_space
from minari.storage.datasets_root_dir import get_dataset_path

Expand Down Expand Up @@ -401,6 +402,10 @@ def create_dataset_from_buffers(
observation_space = observation_space or gym_env.observation_space
action_space = action_space or gym_env.action_space

namespace = parse_dataset_id(dataset_id)[0]
if namespace is not None and namespace not in list_local_namespaces():
create_namespace(namespace)

metadata = _generate_dataset_metadata(
dataset_id,
env_spec,
Expand All @@ -416,7 +421,6 @@ def create_dataset_from_buffers(
description,
requirements,
)

data_format_kwarg = {"data_format": data_format} if data_format is not None else {}
storage = MinariStorage.new(
dataset_path,
Expand Down
21 changes: 9 additions & 12 deletions tests/test_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,13 @@ def test_namespace_update(namespace):
assert get_namespace_metadata(namespace) == {"description": "a new definition"}


@pytest.mark.parametrize("namespace", ["test_namespace"])
def test_create_nested_namespaces(namespace):
def test_create_nested_namespaces():
parent_namespace = "test_namespace"
namespace = f"{parent_namespace}/nested"
create_namespace(namespace, description="my description")
assert list_local_namespaces() == [namespace]
assert set(list_local_namespaces()) == {parent_namespace, namespace}
assert get_namespace_metadata(namespace) == {"description": "my description"}

nested_namespace = f"{namespace}/nested"
create_namespace(nested_namespace, description="is nested")
assert list_local_namespaces() == [namespace, nested_namespace]
assert get_namespace_metadata(nested_namespace) == {"description": "is nested"}


def test_nonexistent_namespaces():
with pytest.raises(ValueError, match="does not exist"):
Expand All @@ -81,14 +77,15 @@ def test_create_invalid_namespace(namespace):
create_namespace(namespace)


@pytest.mark.parametrize("namespace", ["nested/namespace"])
def test_create_namespaced_datasets(namespace):
def test_create_namespaced_datasets():
parent_namespace = "nested"
namespace = f"{parent_namespace}/namespace"
env = gym.make("CartPole-v1")
env = DataCollector(env)

dataset_id_1 = f"{namespace}/test-v1"
create_dummy_dataset_with_collecter_env_helper(dataset_id_1, env)
assert list_local_namespaces() == [namespace]
assert list_local_namespaces() == [parent_namespace, namespace]
assert get_namespace_metadata(namespace) == {}

update_namespace_metadata(namespace, description="new description")
Expand All @@ -97,7 +94,7 @@ def test_create_namespaced_datasets(namespace):
# Creating a new dataset in the same namespace doesn't change the namespace metadata
dataset_id_2 = f"{namespace}/test-v2"
create_dummy_dataset_with_collecter_env_helper(dataset_id_2, env)
assert list_local_namespaces() == [namespace]
assert list_local_namespaces() == [parent_namespace, namespace]
assert get_namespace_metadata(namespace) == {"description": "new description"}

assert list(minari.list_local_datasets().keys()) == [dataset_id_1, dataset_id_2]
Expand Down

0 comments on commit 70f9029

Please sign in to comment.