Skip to content

Commit

Permalink
support text spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Jun 23, 2023
1 parent 3c2eb06 commit ef1db04
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 55 deletions.
76 changes: 30 additions & 46 deletions minari/data_collector/data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,76 +348,60 @@ def clear_buffer(dictionary_buffer: EpisodeBuffer, episode_group: h5py.Group):
"""
for key, data in dictionary_buffer.items():
if isinstance(data, dict):

if key in episode_group:
eps_group_to_clear = episode_group[key]
else:
eps_group_to_clear = episode_group.create_group(key)
eps_group_to_clear = episode_group.get(
key, episode_group.create_group(key)
)
clear_buffer(data, eps_group_to_clear)
elif all([isinstance(entry, tuple) for entry in data]):
elif all(map(lambda elem: isinstance(elem, tuple), data)):
# we have a list of tuples, so we need to act appropriately
dict_data = {
f"_index_{str(i)}": [entry[i] for entry in data]
for i, _ in enumerate(data[0])
}
if key in episode_group:
eps_group_to_clear = episode_group[key]
else:
eps_group_to_clear = episode_group.create_group(key)
eps_group_to_clear = episode_group.get(
key, episode_group.create_group(key)
)
clear_buffer(dict_data, eps_group_to_clear)
elif all([isinstance(entry, OrderedDict) for entry in data]):

elif all(map(lambda elem: isinstance(elem, OrderedDict), data)):
# we have a list of OrderedDicts, so we need to act appropriately
dict_data = {
key: [entry[key] for entry in data]
for key, value in data[0].items()
}

if key in episode_group:
eps_group_to_clear = episode_group[key]
else:
eps_group_to_clear = episode_group.create_group(key)
eps_group_to_clear = episode_group.get(
key, episode_group.create_group(key)
)
clear_buffer(dict_data, eps_group_to_clear)
else:
# convert data to numpy
np_data = np.asarray(data)
assert np.all(
np.logical_not(np.isnan(np_data))
), "Nan found after cast to nump array, check the type of 'data'."
if all(map(lambda elem: isinstance(elem, str), data)):
data_shape = (len(data),)
dtype = h5py.string_dtype(encoding="utf-8")
else:
data = np.asarray(data)
data_shape = data.shape
dtype = data.dtype
assert np.all(
np.logical_not(np.isnan(data))
), "Nan found after cast to nump array, check the type of 'data'."

# Check if last episode group is terminated or truncated
if (
not self._last_episode_group_term_or_trunc
and key in episode_group
):
# Append to last episode group datasets
if key not in STEP_DATA_KEYS and key != "infos":
# check current dataset size directly from hdf5 since
# non step data (actions, obs, rew, term, trunc) may not be
# added in a per-step/sequential basis, including "infos"
current_dataset_shape = episode_group[key].shape[0]
else:
current_dataset_shape = self._last_episode_n_steps
if key == "observations":
current_dataset_shape += (
1 # include initial observation
)
current_dataset_shape = episode_group[key].shape[0]
if key == "observations":
current_dataset_shape += 1
episode_group[key].resize(
current_dataset_shape + len(data), axis=0
)
episode_group[key][-len(data) :] = np_data
episode_group[key][-len(data) :] = data
else:
if not current_episode_group_term_or_trunc:
# Create resizable datasets
episode_group.create_dataset(
key,
data=np_data,
maxshape=(None,) + np_data.shape[1:],
chunks=True,
)
else:
# Dump everything to episode group
episode_group.create_dataset(key, data=np_data, chunks=True)
data_shape = (None,) + data_shape[1:] # resizable dataset

episode_group.create_dataset(
key, data=data, maxshape=data_shape, dtype=dtype
)

for i, eps_buff in enumerate(self._buffer):
# Make sure that the episode has stepped, by checking if the 'actions' key has been added to the episode buffer.
Expand Down
4 changes: 3 additions & 1 deletion minari/dataset/minari_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ def clear_episode_buffer(episode_buffer: Dict, episode_group: h5py.Group) -> h5p
else:
episode_group_to_clear = episode_group.create_group(key)
clear_episode_buffer(dict_data, episode_group_to_clear)
elif all(map(lambda elem: isinstance(elem, str), data)):
dtype = h5py.string_dtype(encoding="utf-8")
episode_group.create_dataset(key, data=data, dtype=dtype, chunks=True)
else:
# assert data is numpy array
assert np.all(np.logical_not(np.isnan(data)))
# add seed to attributes
episode_group.create_dataset(key, data=data, chunks=True)
Expand Down
6 changes: 5 additions & 1 deletion minari/dataset/minari_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _decode_space(
self,
hdf_ref: Union[h5py.Group, h5py.Dataset],
space: gym.spaces.Space,
) -> Union[Dict, Tuple, np.ndarray]:
) -> Union[Dict, Tuple, List, np.ndarray]:
if isinstance(space, gym.spaces.Tuple):
assert isinstance(hdf_ref, h5py.Group)
result = []
Expand All @@ -108,6 +108,10 @@ def _decode_space(
for key in hdf_ref:
result[key] = self._decode_space(hdf_ref[key], space.spaces[key])
return result
elif isinstance(space, gym.spaces.Text):
assert isinstance(hdf_ref, h5py.Dataset)
result = map(lambda string: string.decode("utf-8"), hdf_ref[()])
return list(result)
else:
assert isinstance(hdf_ref, h5py.Dataset)
return hdf_ref[()]
Expand Down
25 changes: 25 additions & 0 deletions minari/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ def _serialize_tuple(space: spaces.Tuple, to_string=True) -> Union[Dict, str]:
return result


@serialize_space.register(spaces.Text)
def _serialize_text(space: spaces.Text, to_string=True) -> Union[Dict, str]:
result = {
"type": "Text",
"max_length": space.max_length,
"min_length": space.min_length,
"charset": space.characters,
}

if to_string:
return json.dumps(result)
else:
return result


class type_value_dispatch:
def __init__(self, func) -> None:
self.registry = defaultdict(func)
Expand Down Expand Up @@ -127,3 +142,13 @@ def _deserialize_discrete(space_dict: Dict) -> spaces.Discrete:
n = space_dict["n"]
start = space_dict["start"]
return spaces.Discrete(n=n, start=start)


@deserialize_space.register("Text")
def _deserialize_text(space_dict: Dict) -> spaces.Text:
assert space_dict["type"] == "Text"
return spaces.Text(
max_length=space_dict["max_length"],
min_length=space_dict["min_length"],
charset=space_dict["charset"],
)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def get_version():
"""Gets the gymnasium version."""
"""Gets the Minari version."""
path = CWD / "minari" / "__init__.py"
content = path.read_text()

Expand Down
44 changes: 38 additions & 6 deletions tests/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Iterable, Union
from typing import Any, Iterable, List, Union

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -143,6 +143,23 @@ def reset(self, seed=None, options=None):
return self.observation_space.sample(), {}


class DummyTextEnv(gym.Env):
def __init__(self):
self.action_space = spaces.Text(max_length=10, min_length=2, charset="01")

self.observation_space = spaces.Text(max_length=20)

def step(self, action):
terminated = self.timestep > 5
self.timestep += 1

return self.observation_space.sample(), 0, terminated, False, {}

def reset(self, seed=None, options=None):
self.timestep = 0
return self.observation_space.sample(), {}


class DummyComboEnv(gym.Env):
def __init__(self):
self.action_space = spaces.Tuple(
Expand Down Expand Up @@ -228,6 +245,12 @@ def register_dummy_envs():
max_episode_steps=5,
)

register(
id="DummyTextEnv-v0",
entry_point="tests.common:DummyTextEnv",
max_episode_steps=5,
)

register(
id="DummyComboEnv-v0",
entry_point="tests.common:DummyComboEnv",
Expand All @@ -240,6 +263,8 @@ def register_dummy_envs():
gym.spaces.Box(low=-1, high=4, shape=(3,), dtype=np.float32),
gym.spaces.Box(low=-1, high=4, shape=(2, 2, 2), dtype=np.float32),
gym.spaces.Box(low=-1, high=4, shape=(3, 3, 3), dtype=np.float32),
gym.spaces.Text(max_length=10, min_length=10),
gym.spaces.Text(max_length=10, charset="01"),
gym.spaces.Tuple(
(
gym.spaces.Discrete(1),
Expand Down Expand Up @@ -306,6 +331,7 @@ def register_dummy_envs():
),
}
),
"component_3": gym.spaces.Text(100, min_length=20),
}
),
)
Expand Down Expand Up @@ -338,7 +364,9 @@ def register_dummy_envs():
gym.spaces.Box(
low=4, high=5, dtype=np.float32
),
gym.spaces.Text(1),
gym.spaces.Graph(
gym.spaces.Box(-1, 1), None
),
)
),
}
Expand Down Expand Up @@ -435,6 +463,10 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]):
obs = _reconstuct_obs_or_action_at_index_recursive(
episode["observations"], i
)
if not data.observation_space.contains(obs):
import pdb

pdb.set_trace()
assert data.observation_space.contains(obs)
for i in range(episode["total_timesteps"]):
action = _reconstuct_obs_or_action_at_index_recursive(episode["actions"], i)
Expand All @@ -460,11 +492,11 @@ def _reconstuct_obs_or_action_at_index_recursive(
for entry in data
]
)

elif isinstance(data, np.ndarray):
return data[index]
else:
assert False, "error, invalid observation or action structure"
assert isinstance(
data, (np.ndarray, List)
), "error, invalid observation or action structure"
return data[index]


def _check_space_elem(data: Any, space: spaces.Space, n_elements: int):
Expand Down
2 changes: 2 additions & 0 deletions tests/utils/test_dataset_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
("dummy-dict-test-v0", "DummyDictEnv-v0"),
("dummy-box-test-v0", "DummyBoxEnv-v0"),
("dummy-tuple-test-v0", "DummyTupleEnv-v0"),
("dummy-text-test-v0", "DummyTextEnv-v0"),
("dummy-combo-test-v0", "DummyComboEnv-v0"),
("dummy-tuple-discrete-box-test-v0", "DummyTupleDisceteBoxEnv-v0"),
],
Expand Down Expand Up @@ -92,6 +93,7 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id):
("cartpole-test-v0", "CartPole-v1"),
("dummy-dict-test-v0", "DummyDictEnv-v0"),
("dummy-tuple-test-v0", "DummyTupleEnv-v0"),
("dummy-text-test-v0", "DummyTextEnv-v0"),
("dummy-combo-test-v0", "DummyComboEnv-v0"),
("dummy-tuple-discrete-box-test-v0", "DummyTupleDisceteBoxEnv-v0"),
],
Expand Down

0 comments on commit ef1db04

Please sign in to comment.