Skip to content

Commit

Permalink
Add DiscreteAction wrapper (#24)
Browse files Browse the repository at this point in the history
* Enable flattening mixed discrete/continuous observation spaces

* Add discrete action wrapper
  • Loading branch information
smorad committed Aug 8, 2023
1 parent 3e2dd8b commit f70e824
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 3 deletions.
8 changes: 6 additions & 2 deletions docs/source/environment_quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Let's create an environment and add some wrappers to it. First, let's do all req
import gymnasium as gym
import popgym
from popgym.wrappers import PreviousAction, Antialias, Markovian, Flatten
from popgym.wrappers import PreviousAction, Antialias, Markovian, Flatten, DiscreteAction
from popgym.core.observability import Observability, STATE
env_classes = popgym.envs.ALL.keys()
print(env_classes)
Expand Down Expand Up @@ -45,8 +45,12 @@ At the initial timestep, there is no previous action. By default, PreviousAction
wrapped_env = Antialias(wrapped_env)
Many RL libraries have spotty support for nested observations or MultiDiscrete action spaces. If you are using DQN or similar approaches, you might want to flatten the observation and action spaces, then convert the action space into a single large Discrete space

Finally, we can decide if we want the hidden Markov state. We can add it as part of the observation, into the info dict, etc. See Observability for more options.
.. code-block:: python
DiscreteAction(Flatten(wrapped_env))
We will not actually assign this to wrapped env, as for this example we want to inspect the observation and action spaces. Finally, we can decide if we want the hidden Markov state. We can add it as part of the observation, into the info dict, etc. See Observability for more options.

.. code-block:: python
Expand Down
3 changes: 2 additions & 1 deletion popgym/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Various wrappers for POPGym environments"""
from popgym.wrappers.antialias import Antialias
from popgym.wrappers.discrete_action import DiscreteAction
from popgym.wrappers.flatten import Flatten
from popgym.wrappers.markovian import Markovian
from popgym.wrappers.previous_action import PreviousAction

__all__ = ["Antialias", "Markovian", "PreviousAction", "Flatten"]
__all__ = ["Antialias", "Markovian", "PreviousAction", "Flatten", "DiscreteAction"]
50 changes: 50 additions & 0 deletions popgym/wrappers/discrete_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Tuple

import numpy as np
from gymnasium import spaces
from gymnasium.core import ActType, ObsType

from popgym.core.env import POPGymEnv
from popgym.core.wrapper import POPGymWrapper

PREV_ACTION = "prev_action"


class DiscreteAction(POPGymWrapper):
"""Wrapper that converts a MultiDiscrete into a single Discrete action.
Args:
env: The environment
Returns:
A gym environment
"""

def __init__(self, env: POPGymEnv):
super().__init__(env)
self.action_space: spaces.Space
if isinstance(self.action_space, spaces.Discrete):
# Done, do nothing
self.ravel_actions = False
elif isinstance(self.action_space, spaces.MultiDiscrete):
self.ravel_actions = True
self.old_action_space = self.action_space
self.action_space = spaces.Discrete(np.prod(self.action_space.nvec))
elif isinstance(self.action_space, (spaces.Tuple, spaces.Dict)):
raise NotImplementedError(
"Action space must be Discrete or MultiDiscrete, got a nested space"
f" {self.action_space}.Please use the Flatten wrapper first."
)
else:
raise NotImplementedError(
"Action space must be Discrete or MultiDiscrete, got"
f" {self.action_space}."
)

def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
if not self.ravel_actions:
return self.env.step(action)

discrete_action = np.unravel_index(action, self.old_action_space.nvec)
obs, reward, terminated, truncated, info = self.env.step(discrete_action)
return obs, reward, terminated, truncated, info
10 changes: 10 additions & 0 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from popgym import envs
from popgym.core.observability import OBS, STATE, Observability
from popgym.wrappers.antialias import Antialias
from popgym.wrappers.discrete_action import DiscreteAction
from popgym.wrappers.flatten import Flatten
from popgym.wrappers.markovian import Markovian
from popgym.wrappers.previous_action import PreviousAction
Expand Down Expand Up @@ -106,3 +107,12 @@ def test_flatten_step(env):
obs, _ = wrapped_aa.reset()
assert wrapped_aa.observation_space.contains(obs)
check_env(wrapped_aa, skip_render_check=True)


@pytest.mark.parametrize("env", envs.ALL.keys())
def test_discrete_action(env):
if issubclass(env, (envs.StatelessPendulum, envs.NoisyStatelessPendulum)):
pytest.skip("StatelessPendulum does not support discrete action space")
wrapped = DiscreteAction(Flatten(env()))
_, _ = wrapped.reset()
wrapped.step(wrapped.action_space.sample())

0 comments on commit f70e824

Please sign in to comment.