diff --git a/metadrive/envs/gym_wrapper.py b/metadrive/envs/gym_wrapper.py index 5a2531163..f98def5f3 100644 --- a/metadrive/envs/gym_wrapper.py +++ b/metadrive/envs/gym_wrapper.py @@ -6,6 +6,11 @@ import gym.spaces def gymnasiumToGym(space: gymnasium.spaces.Space) -> gym.spaces.Space: + return gymnasium_to_gym(space) + + def gymnasium_to_gym(space: gymnasium.spaces.Space) -> gym.spaces.Space: + if isinstance(space, gym.spaces.Space): + return space if isinstance(space, gymnasium.spaces.Box): return gym.spaces.Box(low=space.low, high=space.high, shape=space.shape) elif isinstance(space, gymnasium.spaces.Discrete): @@ -17,9 +22,14 @@ def gymnasiumToGym(space: gymnasium.spaces.Space) -> gym.spaces.Space: elif isinstance(space, gymnasium.spaces.Dict): return gym.spaces.Dict({key: gymnasiumToGym(subspace) for key, subspace in space.spaces.items()}) else: - raise ValueError("unsupported space") + raise ValueError(f"unsupported space: {type(space)}!") def gymToGymnasium(space: gym.spaces.Space) -> gymnasium.spaces.Space: + return gym_to_gymnasium(space) + + def gym_to_gymnasium(space: gym.spaces.Space) -> gymnasium.spaces.Space: + if isinstance(space, gymnasium.spaces.Space): + return space if isinstance(space, gym.spaces.Box): return gymnasium.spaces.Box(low=space.low, high=space.high, shape=space.shape) elif isinstance(space, gym.spaces.Discrete): @@ -31,9 +41,12 @@ def gymToGymnasium(space: gym.spaces.Space) -> gymnasium.spaces.Space: elif isinstance(space, gym.spaces.Dict): return gymnasium.spaces.Dict({key: gymToGymnasium(subspace) for key, subspace in space.spaces.items()}) else: - raise ValueError("unsupported space") + raise ValueError(f"unsupported space: {type(space)}!") def createGymWrapper(inner_class: type): + return create_gym_wrapper(inner_class) + + def create_gym_wrapper(inner_class: type): """ "inner_class": A gymnasium based Metadrive environment class """