From 1306bab9744be80d509f8604b9f022d9346a8e3e Mon Sep 17 00:00:00 2001 From: "Jose Antonio Martin H." Date: Thu, 22 Aug 2024 12:12:26 +0200 Subject: [PATCH] Add support for Gymnasium 1.0.0 (#177) Co-authored-by: Mark Towers Co-authored-by: Omar Younis --- minari/cli.py | 5 +++++ minari/dataset/minari_dataset.py | 10 ++++++++++ minari/dataset/minari_storage.py | 5 +++++ minari/utils.py | 3 +-- tests/integrations/test_agile_rl.py | 2 +- 5 files changed, 22 insertions(+), 3 deletions(-) diff --git a/minari/cli.py b/minari/cli.py index d7779223..54570f03 100644 --- a/minari/cli.py +++ b/minari/cli.py @@ -216,6 +216,11 @@ def show(dataset: Annotated[str, typer.Argument()]): if env_spec_json is not None: assert isinstance(env_spec_json, str) + env_spec_json = ( # for gymnasium 1.0.0 compatibility + env_spec_json.replace('"order_enforce": true,', "") + .replace('"apply_api_compatibility": false,', "") + .replace('"autoreset": false, ', "") + ) env_spec = EnvSpec.from_json(env_spec_json) env_spec_table = Table(show_header=False, highlight=True) env_spec_table.add_column(style="bold") diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index 06c85ba4..f15ea332 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -130,12 +130,22 @@ def __init__( env_spec = metadata.get("env_spec") if env_spec is not None: assert isinstance(env_spec, str) + env_spec = ( # for gymnasium 1.0.0 compatibility + env_spec.replace('"order_enforce": true,', "") + .replace('"apply_api_compatibility": false,', "") + .replace('"autoreset": false, ', "") + ) env_spec = EnvSpec.from_json(env_spec) self._env_spec = env_spec eval_env_spec = metadata.get("eval_env_spec") if eval_env_spec is not None: assert isinstance(eval_env_spec, str) + eval_env_spec = ( # for gymnasium 1.0.0 compatibility + eval_env_spec.replace('"order_enforce": true,', "") + .replace('"apply_api_compatibility": false,', "") + .replace('"autoreset": false, ', "") + ) eval_env_spec = EnvSpec.from_json(eval_env_spec) self._eval_env_spec = eval_env_spec diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index f8a40e44..0c56135f 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -70,6 +70,11 @@ def read(cls, data_path: PathLike) -> MinariStorage: if action_space is None or observation_space is None: env_spec_str = metadata.get("env_spec") assert isinstance(env_spec_str, str) + env_spec_str = ( # for gymnasium 1.0.0 compatibility + env_spec_str.replace('"order_enforce": true,', "") + .replace('"apply_api_compatibility": false,', "") + .replace('"autoreset": false, ', "") + ) env_spec = EnvSpec.from_json(env_spec_str) env = gym.make(env_spec) if observation_space is None: diff --git a/minari/utils.py b/minari/utils.py index e06c265f..197100d9 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -12,7 +12,7 @@ from gymnasium.core import ActType, ObsType from gymnasium.envs.registration import EnvSpec from gymnasium.error import NameNotFound -from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics +from gymnasium.wrappers import RecordEpisodeStatistics # type: ignore from minari.data_collector.episode_buffer import EpisodeBuffer from minari.dataset.minari_dataset import MinariDataset @@ -492,7 +492,6 @@ def get_env_spec_dict(env_spec: EnvSpec) -> Dict[str, str]: "reward_threshold": str(env_spec.reward_threshold), "nondeterministic": f"`{env_spec.nondeterministic}`", "order_enforce": f"`{env_spec.order_enforce}`", - "autoreset": f"`{env_spec.autoreset}`", "disable_env_checker": f"`{env_spec.disable_env_checker}`", "kwargs": f"`{env_spec.kwargs}`", "additional_wrappers": f"`{env_spec.additional_wrappers}`", diff --git a/tests/integrations/test_agile_rl.py b/tests/integrations/test_agile_rl.py index 7b5c3eb9..b3843a59 100644 --- a/tests/integrations/test_agile_rl.py +++ b/tests/integrations/test_agile_rl.py @@ -15,7 +15,7 @@ def dataset_id(): @pytest.fixture(autouse=True) -def createAndDestroyMinariDataset(dataset_id): +def create_and_destroy_minari_dataset(dataset_id): env = gym.make("CartPole-v1") env = DataCollector(env, record_infos=True)