-
-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add PointMaze observation space subseting tutorial (#108)
* Add observation space subseting tutorial * Run pre-commit, tutorial text changes
- Loading branch information
1 parent
59683a1
commit 4b2a712
Showing
1 changed file
with
118 additions
and
0 deletions.
There are no files selected for viewing
118 changes: 118 additions & 0 deletions
118
docs/tutorials/dataset_creation/observation_space_subseting.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# fmt: off | ||
""" | ||
Collecting a subset of a dictionary space with StepDataCallback | ||
========================================= | ||
""" | ||
# %%% | ||
# In this tutorial you'll learn how to have :class:`minari.DataCollectorV0` only collect a subset | ||
# of the observation space in PointMaze. Specifically, we'll be collecting observations using | ||
# random actions on PointMaze_UMaze-v3 from `Gymnasium-Robotics <https://robotics.farama.org/envs/maze/point_maze/>`_ | ||
# and omitting ``achieved_goal`` from the observation space of PointMaze. | ||
# This would be useful if you don't plan on training goal oriented learning agents | ||
# on the Minari dataset generated and `need to minimize the space taken up | ||
# by the final dataset <https://github.com/Farama-Foundation/Minari/issues/57>`_. | ||
# | ||
# Please note that while only a subset of the observation space of PointMaze is touched on in | ||
# this tutorial, the outlined procedure can be extended to both action/observation spaces | ||
# of any other environment. | ||
# | ||
# Let's get started by importing the required modules: | ||
|
||
# %% | ||
import gymnasium as gym | ||
import numpy as np | ||
from gymnasium import spaces | ||
|
||
import minari | ||
from minari import DataCollectorV0 | ||
from minari.data_collector.callbacks import StepDataCallback | ||
|
||
|
||
# %% | ||
# We'll first need to initialize our PointMaze_UMaze environment and find | ||
# the shape of the observation space | ||
|
||
env = gym.make("PointMaze_UMaze-v3") | ||
|
||
print(f"Observation space: {env.observation_space}") | ||
# %% | ||
# Which should output: | ||
# | ||
# .. code:: py | ||
# | ||
# Observation space: Dict('achieved_goal': Box(-inf, inf, (2,), float64), 'desired_goal': Box(-inf, inf, (2,), float64), 'observation': Box(-inf, inf, (4,), float64)) | ||
# | ||
# We now need to create a version of the observation space with ``achieved_goal`` omitted | ||
# using Gymnasium spaces. This will be compared with ``step_data`` in each step to validate that it | ||
# fits with the expected observation space. When saving the final HDF5 dataset file generated by Mirana to disk, this | ||
# will be added as metadata. | ||
# | ||
# We'll also need to define a :class:`minari.StepDataCallback` in order | ||
# to modify ``step_data`` to delete ``achieved_goal``. | ||
|
||
observation_space_subset = spaces.Dict( | ||
{ | ||
# "achieved_goal": spaces.Box(low=float('-inf'), high=float('inf'), shape=(2,), dtype=np.float64), | ||
"desired_goal": spaces.Box( | ||
low=float("-inf"), high=float("inf"), shape=(2,), dtype=np.float64 | ||
), | ||
"observation": spaces.Box( | ||
low=float("-inf"), high=float("inf"), shape=(4,), dtype=np.float64 | ||
), | ||
} | ||
) | ||
|
||
|
||
class CustomSubsetStepDataCallback(StepDataCallback): | ||
def __call__(self, env, **kwargs): | ||
step_data = super().__call__(env, **kwargs) | ||
del step_data["observations"]["achieved_goal"] | ||
return step_data | ||
|
||
|
||
# %% | ||
# Finally we'll record 10 episodes with our observation space subset and | ||
# callback passed to :class:`minari.DataCollectorV0`. | ||
|
||
dataset_id = "point-maze-subseted-v3" | ||
|
||
# delete the test dataset if it already exists | ||
local_datasets = minari.list_local_datasets() | ||
if dataset_id in local_datasets: | ||
minari.delete_dataset(dataset_id) | ||
|
||
env = DataCollectorV0( | ||
env, | ||
observation_space=observation_space_subset, | ||
# action_space=action_space_subset, | ||
step_data_callback=CustomSubsetStepDataCallback, | ||
) | ||
num_episodes = 10 | ||
|
||
env.reset(seed=42) | ||
|
||
for episode in range(num_episodes): | ||
terminated = False | ||
truncated = False | ||
while not terminated and not truncated: | ||
action = env.action_space.sample() # Choose random actions | ||
_, _, terminated, truncated, _ = env.step(action) | ||
env.reset() | ||
|
||
# Create Minari dataset and store locally | ||
dataset = minari.create_dataset_from_collector_env( | ||
dataset_id=dataset_id, | ||
collector_env=env, | ||
algorithm_name="random_policy", | ||
) | ||
|
||
print(dataset.sample_episodes(1)[0].observations.keys()) | ||
|
||
# %% | ||
# The output from the final line above, should be | ||
# | ||
# .. code:: py | ||
# | ||
# dict_keys(['desired_goal', 'observation']) | ||
# | ||
# Showing that we have successfully omitted ``achieved_goal`` from the observations. |