Skip to content

Commit

Permalink
Merge pull request #256 from chandar-lab/recurrent_buffer
Browse files Browse the repository at this point in the history
initial commit recurrent buffer implementation
  • Loading branch information
hnekoeiq authored Mar 12, 2022
2 parents 43c88db + 0aabcff commit 9d666a8
Show file tree
Hide file tree
Showing 3 changed files with 320 additions and 0 deletions.
81 changes: 81 additions & 0 deletions hive/configs/atari/drqn.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
run_name: &run_name 'atari-dqn'
train_steps: 50000000
test_frequency: 250000
test_episodes: 10
max_steps_per_episode: 27000
max_seq_len: &max_seq_len 10
save_dir: 'experiment'
saving_schedule:
name: 'PeriodicSchedule'
kwargs:
off_value: False
on_value: True
period: 1000000
environment:
name: 'AtariEnv'
kwargs:
env_name: 'Asterix'

agent:
name: 'DQNAgent'
kwargs:
representation_net:
name: 'ConvNetwork'
kwargs:
channels: [32, 64, 64]
kernel_sizes: [8, 4, 3]
strides: [4, 2, 1]
paddings: [2, 2, 1]
mlp_layers: [512]
optimizer_fn:
name: 'RMSpropTF'
kwargs:
lr: 0.00025
alpha: .95
eps: 0.00001
centered: True
init_fn:
name: 'xavier_uniform'
loss_fn:
name: 'SmoothL1Loss'
replay_buffer:
name: 'RecurrentReplayBuffer'
kwargs:
capacity: 1000000
max_seq_len: *max_seq_len
gamma: &gamma .99
discount_rate: *gamma
reward_clip: 1
update_period_schedule:
name: 'PeriodicSchedule'
kwargs:
off_value: False
on_value: True
period: 4
target_net_update_schedule:
name: 'PeriodicSchedule'
kwargs:
off_value: False
on_value: True
period: 8000
epsilon_schedule:
name: 'LinearSchedule'
kwargs:
init_value: 1.0
end_value: .01
steps: 250000
test_epsilon: .001
min_replay_history: 200
device: 'cuda'
log_frequency: 1000
# List of logger configs used.
loggers:
-
name: ChompLogger
# -
# name: WandbLogger
# kwargs:
# project: Hive
# name: *run_name
# resume: "allow"
# start_method: "fork"
2 changes: 2 additions & 0 deletions hive/replays/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from hive.replays.circular_replay import CircularReplayBuffer, SimpleReplayBuffer
from hive.replays.legal_moves_replay import LegalMovesBuffer
from hive.replays.prioritized_replay import PrioritizedReplayBuffer
from hive.replays.recurrent_replay import RecurrentReplayBuffer
from hive.replays.replay_buffer import BaseReplayBuffer
from hive.utils.registry import registry

Expand All @@ -11,6 +12,7 @@
"SimpleReplayBuffer": SimpleReplayBuffer,
"PrioritizedReplayBuffer": PrioritizedReplayBuffer,
"LegalMovesBuffer": LegalMovesBuffer,
"RecurrentReplayBuffer": RecurrentReplayBuffer,
},
)

Expand Down
237 changes: 237 additions & 0 deletions hive/replays/recurrent_replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import os
import pickle

import numpy as np
from hive.replays.circular_replay import CircularReplayBuffer


class RecurrentReplayBuffer(CircularReplayBuffer):
"""
First implementation of recurrent buffer without storing hidden states
"""
def __init__(
self,
capacity: int = 10000,
max_seq_len: int = 1,
n_step: int = 1,
gamma: float = 0.99,
observation_shape=(),
observation_dtype=np.uint8,
action_shape=(),
action_dtype=np.int8,
reward_shape=(),
reward_dtype=np.float32,
extra_storage_types=None,
num_players_sharing_buffer: int = None,
):
"""Constructor for CircularReplayBuffer.
Args:
capacity (int): Total number of observations that can be stored in the
buffer. Note, this is not the same as the number of transitions that
can be stored in the buffer.
max_seq_len (int): The number of consecutive transitions in a sequence.
n_step (int): Horizon used to compute n-step return reward
gamma (float): Discounting factor used to compute n-step return reward
observation_shape: Shape of observations that will be stored in the buffer.
observation_dtype: Type of observations that will be stored in the buffer.
This can either be the type itself or string representation of the
type. The type can be either a native python type or a numpy type. If
a numpy type, a string of the form np.uint8 or numpy.uint8 is
acceptable.
action_shape: Shape of actions that will be stored in the buffer.
action_dtype: Type of actions that will be stored in the buffer. Format is
described in the description of observation_dtype.
action_shape: Shape of actions that will be stored in the buffer.
action_dtype: Type of actions that will be stored in the buffer. Format is
described in the description of observation_dtype.
reward_shape: Shape of rewards that will be stored in the buffer.
reward_dtype: Type of rewards that will be stored in the buffer. Format is
described in the description of observation_dtype.
extra_storage_types (dict): A dictionary describing extra items to store
in the buffer. The mapping should be from the name of the item to a
(type, shape) tuple.
num_players_sharing_buffer (int): Number of agents that share their
buffers. It is used for self-play.
"""
super().__init__(
capacity=capacity,
stack_size=1,
n_step=n_step,
gamma=gamma,
observation_shape=observation_shape,
observation_dtype=observation_dtype,
action_shape=action_shape,
action_dtype=action_dtype,
reward_shape=reward_shape,
reward_dtype=reward_dtype,
extra_storage_types=extra_storage_types,
num_players_sharing_buffer=num_players_sharing_buffer,
)
self._max_seq_len = max_seq_len

def add(self, observation, action, reward, done, **kwargs):
"""Adds a transition to the buffer.
The required components of a transition are given as positional arguments. The
user can pass additional components to store in the buffer as kwargs as long as
they were defined in the specification in the constructor.
"""

if self._episode_start:
self._pad_buffer(self._max_seq_len - 1)
self._episode_start = False
transition = {
"observation": observation,
"action": action,
"reward": reward,
"done": done,
}
transition.update(kwargs)
for key in self._specs:
obj_type = (
transition[key].dtype
if hasattr(transition[key], "dtype")
else type(transition[key])
)
if not np.can_cast(obj_type, self._specs[key][0], casting="same_kind"):
raise ValueError(
f"Key {key} has wrong dtype. Expected {self._specs[key][0]},"
f"received {type(transition[key])}."
)
if self._num_players_sharing_buffer is None:
self._add_transition(**transition)
else:
self._episode_storage[kwargs["agent_id"]].append(transition)
if done:
for transition in self._episode_storage[kwargs["agent_id"]]:
self._add_transition(**transition)
self._episode_storage[kwargs["agent_id"]] = []

if done:
self._episode_start = True

def _get_from_array(self, array, indices, num_to_access=1):
"""Retrieves consecutive elements in the array, wrapping around if necessary.
If more than 1 element is being accessed, the elements are concatenated along
the first dimension.
Args:
array: array to access from
indices: starts of ranges to access from
num_to_access: how many consecutive elements to access
"""
full_indices = np.indices((indices.shape[0], num_to_access))[1]
full_indices = (full_indices + np.expand_dims(indices, axis=1)) % (
self.size() + self._max_seq_len + self._n_step - 1
)
elements = array[full_indices]
elements = elements.reshape(indices.shape[0], -1, *elements.shape[3:])
return elements

def _get_from_storage(self, key, indices, num_to_access=1):
"""Gets values from storage.
Args:
key: The name of the component to retrieve.
indices: This can be a single int or a 1D numpyp array. The indices are
adjusted to fall within the current bounds of the buffer.
num_to_access: how many consecutive elements to access
"""
if not isinstance(indices, np.ndarray):
indices = np.array([indices])
if num_to_access == 0:
return np.array([])
elif num_to_access == 1:
return self._storage[key][
indices % (self.size() + self._max_seq_len + self._n_step - 1)
]
else:
return self._get_from_array(
self._storage[key], indices, num_to_access=num_to_access
)

def _sample_indices(self, batch_size):
"""Samples valid indices that can be used by the replay."""
indices = np.array([], dtype=np.int32)
while len(indices) < batch_size:
start_index = (
self._rng.integers(self.size(), size=batch_size - len(indices))
+ self._cursor
)
start_index = self._filter_transitions(start_index)
indices = np.concatenate([indices, start_index])
return indices + self._max_seq_len - 1

def _filter_transitions(self, indices):
"""Filters invalid indices."""
if self._max_seq_len == 1:
return indices
done = self._get_from_storage("done", indices, self._max_seq_len - 1)
done = done.astype(bool)
if self._max_seq_len == 2:
indices = indices[~done]
else:
indices = indices[~done.any(axis=1)]
return indices

def sample(self, batch_size):
"""Sample transitions from the buffer. For a given transition, if it's
done is True, the next_observation value should not be taken to have any
meaning.
Args:
batch_size (int): Number of transitions to sample.
"""
if self._num_added < self._max_seq_len + self._n_step:
raise ValueError("Not enough transitions added to the buffer to sample")
indices = self._sample_indices(batch_size)
batch = {}
batch["indices"] = indices
terminals = self._get_from_storage("done", indices, self._n_step)

if self._n_step == 1:
is_terminal = terminals
trajectory_lengths = np.ones(batch_size)
else:
is_terminal = terminals.any(axis=1).astype(int)
trajectory_lengths = (
np.argmax(terminals.astype(bool), axis=1) + 1
) * is_terminal + self._n_step * (1 - is_terminal)
trajectory_lengths = trajectory_lengths.astype(np.int64)

for key in self._specs:
if key == "observation":
batch[key] = self._get_from_storage(
"observation",
indices - self._max_seq_len + 1,
num_to_access=self._max_seq_len,
)
elif key == "action":
batch[key] = self._get_from_storage(
"action",
indices - self._max_seq_len + 1,
num_to_access=self._max_seq_len,
)
elif key == "done":
batch["done"] = is_terminal
elif key == "reward":
rewards = self._get_from_storage("reward", indices - self._max_seq_len + 1,
num_to_access=self._max_seq_len + self._n_step - 1)
if self._max_seq_len + self._n_step - 1 == 1:
rewards = np.expand_dims(rewards, 1)
rewards = rewards * np.expand_dims(self._discount, axis=0)

# Mask out rewards past trajectory length
# mask = np.expand_dims(trajectory_lengths, 1) > np.arange(self._n_step)
# rewards = np.sum(rewards * mask, axis=1)
batch["reward"] = rewards
else:
batch[key] = self._get_from_storage(key, indices)

batch["trajectory_lengths"] = trajectory_lengths
batch["next_observation"] = self._get_from_storage(
"observation",
indices + trajectory_lengths - self._max_seq_len + 1,
num_to_access=self._max_seq_len,
)
return batch


0 comments on commit 9d666a8

Please sign in to comment.