From 981be53322f6106f4399899adc0be319ffb79ace Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 10 Mar 2022 23:30:46 -0500 Subject: [PATCH] initial commit DRQN implementation --- hive/agents/__init__.py | 2 + hive/agents/dqn.py | 2 +- hive/agents/drqn.py | 288 +++++++++++++++++++++ hive/agents/qnets/__init__.py | 2 + hive/agents/qnets/qnet_heads.py | 21 +- hive/agents/qnets/rnn.py | 107 ++++++++ hive/agents/qnets/utils.py | 8 +- hive/configs/atari/drqn.yml | 82 ++++++ hive/replays/circular_replay.py | 2 +- hive/replays/prioritized_replay.py | 2 +- tests/hive/replays/test_circular_buffer.py | 2 +- 11 files changed, 509 insertions(+), 9 deletions(-) create mode 100644 hive/agents/drqn.py create mode 100644 hive/agents/qnets/rnn.py create mode 100644 hive/configs/atari/drqn.yml diff --git a/hive/agents/__init__.py b/hive/agents/__init__.py index bffdf9bc..0a888b68 100644 --- a/hive/agents/__init__.py +++ b/hive/agents/__init__.py @@ -1,6 +1,7 @@ from hive.agents import qnets from hive.agents.agent import Agent from hive.agents.dqn import DQNAgent +from hive.agents.drqn import DRQNAgent from hive.agents.legal_moves_rainbow import LegalMovesRainbowAgent from hive.agents.rainbow import RainbowDQNAgent from hive.agents.random import RandomAgent @@ -13,6 +14,7 @@ "LegalMovesRainbowAgent": LegalMovesRainbowAgent, "RainbowDQNAgent": RainbowDQNAgent, "RandomAgent": RandomAgent, + "DRQNAgent": DRQNAgent, }, ) diff --git a/hive/agents/dqn.py b/hive/agents/dqn.py index 9c910993..bcf34231 100644 --- a/hive/agents/dqn.py +++ b/hive/agents/dqn.py @@ -112,7 +112,7 @@ def __init__( self._replay_buffer = replay_buffer if self._replay_buffer is None: self._replay_buffer = CircularReplayBuffer() - self._discount_rate = discount_rate ** n_step + self._discount_rate = discount_rate**n_step self._grad_clip = grad_clip self._reward_clip = reward_clip self._target_net_soft_update = target_net_soft_update diff --git a/hive/agents/drqn.py b/hive/agents/drqn.py new file mode 100644 index 00000000..309b2d36 --- /dev/null +++ b/hive/agents/drqn.py @@ -0,0 +1,288 @@ +import copy +import os + +import numpy as np +import torch + +from hive.agents.agent import Agent +from hive.agents.qnets.base import FunctionApproximator +from hive.agents.qnets.qnet_heads import DQNNetwork +from hive.agents.qnets.utils import ( + InitializationFn, + calculate_output_dim, + create_init_weights_fn, +) +from hive.replays import BaseReplayBuffer, CircularReplayBuffer +from hive.utils.loggers import Logger, NullLogger +from hive.utils.schedule import ( + LinearSchedule, + PeriodicSchedule, + Schedule, + SwitchSchedule, +) +from hive.agents.dqn import DQNAgent +from hive.utils.utils import LossFn, OptimizerFn, create_folder, seeder + + +class DRQNAgent(DQNAgent): + """An agent implementing the DQN algorithm. Uses an epsilon greedy + exploration policy + """ + + def __init__( + self, + representation_net: FunctionApproximator, + obs_dim, + act_dim: int, + id=0, + optimizer_fn: OptimizerFn = None, + loss_fn: LossFn = None, + init_fn: InitializationFn = None, + replay_buffer: BaseReplayBuffer = None, + discount_rate: float = 0.99, + n_step: int = 1, + grad_clip: float = None, + reward_clip: float = None, + update_period_schedule: Schedule = None, + target_net_soft_update: bool = False, + target_net_update_fraction: float = 0.05, + target_net_update_schedule: Schedule = None, + epsilon_schedule: Schedule = None, + test_epsilon: float = 0.001, + min_replay_history: int = 5000, + batch_size: int = 32, + device="cpu", + logger: Logger = None, + log_frequency: int = 100, + ): + """ + Args: + representation_net (FunctionApproximator): A network that outputs the + representations that will be used to compute Q-values (e.g. + everything except the final layer of the DQN). + obs_dim: The shape of the observations. + act_dim (int): The number of actions available to the agent. + id: Agent identifier. + optimizer_fn (OptimizerFn): A function that takes in a list of parameters + to optimize and returns the optimizer. If None, defaults to + :py:class:`~torch.optim.Adam`. + loss_fn (LossFn): Loss function used by the agent. If None, defaults to + :py:class:`~torch.nn.SmoothL1Loss`. + init_fn (InitializationFn): Initializes the weights of qnet using + create_init_weights_fn. + replay_buffer (BaseReplayBuffer): The replay buffer that the agent will + push observations to and sample from during learning. If None, + defaults to + :py:class:`~hive.replays.circular_replay.CircularReplayBuffer`. + discount_rate (float): A number between 0 and 1 specifying how much + future rewards are discounted by the agent. + n_step (int): The horizon used in n-step returns to compute TD(n) targets. + grad_clip (float): Gradients will be clipped to between + [-grad_clip, grad_clip]. + reward_clip (float): Rewards will be clipped to between + [-reward_clip, reward_clip]. + update_period_schedule (Schedule): Schedule determining how frequently + the agent's Q-network is updated. + target_net_soft_update (bool): Whether the target net parameters are + replaced by the qnet parameters completely or using a weighted + average of the target net parameters and the qnet parameters. + target_net_update_fraction (float): The weight given to the target + net parameters in a soft update. + target_net_update_schedule (Schedule): Schedule determining how frequently + the target net is updated. + epsilon_schedule (Schedule): Schedule determining the value of epsilon + through the course of training. + test_epsilon (float): epsilon (probability of choosing a random action) + to be used during testing phase. + min_replay_history (int): How many observations to fill the replay buffer + with before starting to learn. + batch_size (int): The size of the batch sampled from the replay buffer + during learning. + device: Device on which all computations should be run. + logger (ScheduledLogger): Logger used to log agent's metrics. + log_frequency (int): How often to log the agent's metrics. + """ + super().__init__( + representation_net=representation_net, + obs_dim=obs_dim, + act_dim=act_dim, + id=0, + optimizer_fn=optimizer_fn, + loss_fn=loss_fn, + init_fn=init_fn, + replay_buffer=replay_buffer, + discount_rate=discount_rate, + n_step=n_step, + grad_clip=grad_clip, + reward_clip=reward_clip, + update_period_schedule=update_period_schedule, + target_net_soft_update=target_net_soft_update, + target_net_update_fraction=target_net_update_fraction, + target_net_update_schedule=target_net_update_schedule, + epsilon_schedule=epsilon_schedule, + test_epsilon=test_epsilon, + min_replay_history=min_replay_history, + batch_size=batch_size, + device=device, + logger=logger, + log_frequency=log_frequency, + ) + + def create_q_networks(self, representation_net): + """Creates the Q-network and target Q-network. + + Args: + representation_net: A network that outputs the representations that will + be used to compute Q-values (e.g. everything except the final layer + of the DQN). + """ + network = representation_net(self._obs_dim) + network_output_dim = np.prod(calculate_output_dim(network, self._obs_dim)) + self._qnet = DQNNetwork( + network, network_output_dim, self._act_dim, use_rnn=True + ).to(self._device) + self._qnet.apply(self._init_fn) + self._target_qnet = copy.deepcopy(self._qnet).requires_grad_(False) + + def preprocess_update_info(self, update_info): + """Preprocesses the :obj:`update_info` before it goes into the replay buffer. + Clips the reward in update_info. + + Args: + update_info: Contains the information from the current timestep that the + agent should use to update itself. + """ + if self._reward_clip is not None: + update_info["reward"] = np.clip( + update_info["reward"], -self._reward_clip, self._reward_clip + ) + preprocessed_update_info = { + "observation": update_info["observation"], + "action": update_info["action"], + "reward": update_info["reward"], + "done": update_info["done"], + } + if "agent_id" in update_info: + preprocessed_update_info["agent_id"] = int(update_info["agent_id"]) + + return preprocessed_update_info + + def preprocess_update_batch(self, batch): + """Preprocess the batch sampled from the replay buffer. + + Args: + batch: Batch sampled from the replay buffer for the current update. + + Returns: + (tuple): + - (tuple) Inputs used to calculate current state values. + - (tuple) Inputs used to calculate next state values + - Preprocessed batch. + """ + for key in batch: + batch[key] = torch.tensor(batch[key], device=self._device) + return (batch["observation"],), (batch["next_observation"],), batch + + @torch.no_grad() + def act(self, observation): + """Returns the action for the agent. If in training mode, follows an epsilon + greedy policy. Otherwise, returns the action with the highest Q-value. + + Args: + observation: The current observation. + """ + + # Determine and log the value of epsilon + if self._training: + if not self._learn_schedule.get_value(): + epsilon = 1.0 + else: + epsilon = self._epsilon_schedule.update() + if self._logger.update_step(self._timescale): + self._logger.log_scalar("epsilon", epsilon, self._timescale) + else: + epsilon = self._test_epsilon + + # Sample action. With epsilon probability choose random action, + # otherwise select the action with the highest q-value. + observation = torch.tensor( + np.expand_dims(observation, axis=0), device=self._device + ).float() + qvals = self._qnet(observation) + if self._rng.random() < epsilon: + action = self._rng.integers(self._act_dim) + else: + # Note: not explicitly handling the ties + action = torch.argmax(qvals).item() + + if ( + self._training + and self._logger.should_log(self._timescale) + and self._state["episode_start"] + ): + self._logger.log_scalar("train_qval", torch.max(qvals), self._timescale) + self._state["episode_start"] = False + return action + + def update(self, update_info): + """ + Updates the DQN agent. + + Args: + update_info: dictionary containing all the necessary information to + update the agent. Should contain a full transition, with keys for + "observation", "action", "reward", and "done". + """ + if update_info["done"]: + self._state["episode_start"] = True + + if not self._training: + return + + # Add the most recent transition to the replay buffer. + self._replay_buffer.add(**self.preprocess_update_info(update_info)) + + # Update the q network based on a sample batch from the replay buffer. + # If the replay buffer doesn't have enough samples, catch the exception + # and move on. + if ( + self._learn_schedule.update() + and self._replay_buffer.size() > 0 + and self._update_period_schedule.update() + ): + batch = self._replay_buffer.sample(batch_size=self._batch_size) + ( + current_state_inputs, + next_state_inputs, + batch, + ) = self.preprocess_update_batch(batch) + + # Compute predicted Q values + self._optimizer.zero_grad() + pred_qvals = self._qnet(*current_state_inputs) + actions = batch["action"].long() + pred_qvals = pred_qvals[torch.arange(pred_qvals.size(0)), actions] + + # Compute 1-step Q targets + next_qvals = self._target_qnet(*next_state_inputs) + next_qvals, _ = torch.max(next_qvals, dim=1) + + q_targets = batch["reward"] + self._discount_rate * next_qvals * ( + 1 - batch["done"] + ) + + loss = self._loss_fn(pred_qvals, q_targets).mean() + + if self._logger.should_log(self._timescale): + self._logger.log_scalar("train_loss", loss, self._timescale) + + loss.backward() + if self._grad_clip is not None: + torch.nn.utils.clip_grad_value_( + self._qnet.parameters(), self._grad_clip + ) + self._optimizer.step() + + # Update target network + if self._target_net_update_schedule.update(): + self._update_target() diff --git a/hive/agents/qnets/__init__.py b/hive/agents/qnets/__init__.py index 232c0d50..48323d99 100644 --- a/hive/agents/qnets/__init__.py +++ b/hive/agents/qnets/__init__.py @@ -3,12 +3,14 @@ from hive.agents.qnets.base import FunctionApproximator from hive.agents.qnets.conv import ConvNetwork from hive.agents.qnets.mlp import MLPNetwork +from hive.agents.qnets.rnn import ConvRNNNetwork registry.register_all( FunctionApproximator, { "MLPNetwork": FunctionApproximator(MLPNetwork), "ConvNetwork": FunctionApproximator(ConvNetwork), + "ConvRNNNetwork": FunctionApproximator(ConvRNNNetwork), "NatureAtariDQNModel": FunctionApproximator(NatureAtariDQNModel), }, ) diff --git a/hive/agents/qnets/qnet_heads.py b/hive/agents/qnets/qnet_heads.py index c70161c8..b88d197d 100644 --- a/hive/agents/qnets/qnet_heads.py +++ b/hive/agents/qnets/qnet_heads.py @@ -15,6 +15,7 @@ def __init__( hidden_dim: int, out_dim: int, linear_fn: nn.Module = None, + use_rnn: bool = False, ): """ Args: @@ -32,9 +33,13 @@ def __init__( self.base_network = base_network self._linear_fn = linear_fn if linear_fn is not None else nn.Linear self.output_layer = self._linear_fn(hidden_dim, out_dim) + self._use_rnn = use_rnn def forward(self, x): - x = self.base_network(x) + if self._use_rnn: + x, hidden_state = self.base_network(x) + else: + x = self.base_network(x) x = x.flatten(start_dim=1) return self.output_layer(x) @@ -52,6 +57,7 @@ def __init__( out_dim: int, linear_fn: nn.Module = None, atoms: int = 1, + use_rnn: bool = False, ): """ Args: @@ -75,6 +81,7 @@ def __init__( self._atoms = atoms self._linear_fn = linear_fn if linear_fn is not None else nn.Linear self.init_networks() + self._use_rnn = use_rnn def init_networks(self): self.output_layer_adv = self._linear_fn( @@ -84,7 +91,10 @@ def init_networks(self): self.output_layer_val = self._linear_fn(self._hidden_dim, 1 * self._atoms) def forward(self, x): - x = self.base_network(x) + if self._use_rnn: + x, hidden_state = self.base_network(x) + else: + x = self.base_network(x) x = x.flatten(start_dim=1) adv = self.output_layer_adv(x) val = self.output_layer_val(x) @@ -111,6 +121,7 @@ def __init__( vmin: float = 0, vmax: float = 200, atoms: int = 51, + use_rnn: bool = False, ): """ Args: @@ -131,6 +142,7 @@ def __init__( self._supports = torch.nn.Parameter(torch.linspace(vmin, vmax, atoms)) self._out_dim = out_dim self._atoms = atoms + self._use_rnn = use_rnn def forward(self, x): x = self.dist(x) @@ -139,7 +151,10 @@ def forward(self, x): def dist(self, x): """Computes a categorical distribution over values for each action.""" - x = self.base_network(x) + if self._use_rnn: + x, hidden_state = self.base_network(x) + else: + x = self.base_network(x) x = x.view(-1, self._out_dim, self._atoms) x = F.softmax(x, dim=-1) return x diff --git a/hive/agents/qnets/rnn.py b/hive/agents/qnets/rnn.py new file mode 100644 index 00000000..0655e949 --- /dev/null +++ b/hive/agents/qnets/rnn.py @@ -0,0 +1,107 @@ +import numpy as np +import torch +from torch import nn + +from hive.agents.qnets.mlp import MLPNetwork +from hive.agents.qnets.conv import ConvNetwork +from hive.agents.qnets.utils import calculate_output_dim + + +class ConvRNNNetwork(ConvNetwork): + """ + Basic convolutional neural network architecture. Applies a number of + convolutional layers (each followed by a ReLU activation), and then + feeds the output into an :py:class:`hive.agents.qnets.mlp.MLPNetwork`. + + Note, if :obj:`channels` is :const:`None`, the network created for the + convolution portion of the architecture is simply an + :py:class:`torch.nn.Identity` module. If :obj:`mlp_layers` is + :const:`None`, the mlp portion of the architecture is an + :py:class:`torch.nn.Identity` module. + """ + + def __init__( + self, + in_dim, + channels=None, + mlp_layers=None, + kernel_sizes=1, + strides=1, + paddings=0, + normalization_factor=255, + lstm_hidden_size=128, + num_lstm_layers=1, + noisy=False, + std_init=0.5, + ): + """ + Args: + in_dim (tuple): The tuple of observations dimension (channels, width, + height). + channels (list): The size of output channel for each convolutional layer. + mlp_layers (list): The number of neurons for each mlp layer after the + convolutional layers. + kernel_sizes (list | int): The kernel size for each convolutional layer + strides (list | int): The stride used for each convolutional layer. + paddings (list | int): The size of the padding used for each convolutional + layer. + normalization_factor (float | int): What the input is divided by before + the forward pass of the network. + noisy (bool): Whether the MLP part of the network will use + :py:class:`~hive.agents.qnets.noisy_linear.NoisyLinear` layers or + :py:class:`torch.nn.Linear` layers. + std_init (float): The range for the initialization of the standard + deviation of the weights in + :py:class:`~hive.agents.qnets.noisy_linear.NoisyLinear`. + """ + super().__init__( + in_dim=in_dim, + channels=channels, + mlp_layers=mlp_layers, + kernel_sizes=kernel_sizes, + strides=strides, + paddings=paddings, + normalization_factor=normalization_factor, + noisy=noisy, + std_init=std_init, + ) + self._lstm_hidden_size = lstm_hidden_size + self._num_lstm_layers = num_lstm_layers + + # RNN Layers + conv_output_size = calculate_output_dim(self.conv, in_dim) + self.lstm = nn.LSTM( + np.prod(conv_output_size), lstm_hidden_size, num_lstm_layers + ) + + if mlp_layers is not None: + # MLP Layers + # conv_output_size = calculate_output_dim(self.conv, in_dim) + self.mlp = MLPNetwork( + lstm_hidden_size, mlp_layers, noisy=noisy, std_init=std_init + ) + else: + self.mlp = nn.Identity() + + def forward(self, x, hidden_state=None): + if len(x.shape) == 3: + x = x.unsqueeze(0) + elif len(x.shape) == 5: + x = x.reshape(x.size(0), -1, x.size(-2), x.size(-1)) + x = x.float() + x = x / self._normalization_factor + x = self.conv(x) + + if hidden_state is None: + hidden_state = ( + torch.zeros( + (self._num_lstm_layers, x.shape[1], self._lstm_hidden_size) + ).float(), + torch.zeros( + (self._num_lstm_layers, x.shape[1], self._lstm_hidden_size) + ).float(), + ) + x = torch.flatten(x, start_dim=-2, end_dim=-1) + x, hidden_state = self.lstm(x, hidden_state) + x = self.mlp(x.squeeze(0)) + return x, hidden_state diff --git a/hive/agents/qnets/utils.py b/hive/agents/qnets/utils.py index 9a42ab6e..49501ad5 100644 --- a/hive/agents/qnets/utils.py +++ b/hive/agents/qnets/utils.py @@ -20,9 +20,13 @@ def calculate_output_dim(net, input_shape): """ if isinstance(input_shape, int): input_shape = (input_shape,) - placeholder = torch.zeros((0,) + tuple(input_shape)) + # placeholder = torch.zeros((0,) + tuple(input_shape)) + placeholder = torch.zeros(tuple(input_shape)) output = net(placeholder) - return output.size()[1:] + if isinstance(output, tuple): + return output[0].size()[1:] + else: + return output.size()[1:] def create_init_weights_fn(initialization_fn): diff --git a/hive/configs/atari/drqn.yml b/hive/configs/atari/drqn.yml new file mode 100644 index 00000000..8acbce13 --- /dev/null +++ b/hive/configs/atari/drqn.yml @@ -0,0 +1,82 @@ +run_name: &run_name 'atari-dqn' +train_steps: 50000000 +test_frequency: 250000 +test_episodes: 10 +max_steps_per_episode: 27000 +stack_size: &stack_size 4 +save_dir: 'experiment' +saving_schedule: + name: 'PeriodicSchedule' + kwargs: + off_value: False + on_value: True + period: 1000000 +environment: + name: 'AtariEnv' + kwargs: + env_name: 'Asterix' + +agent: + name: 'DRQNAgent' + kwargs: + representation_net: + name: 'ConvRNNNetwork' + kwargs: + channels: [32, 64, 64] + kernel_sizes: [8, 4, 3] + strides: [4, 2, 1] + paddings: [2, 2, 1] + mlp_layers: [512] + + optimizer_fn: + name: 'RMSpropTF' + kwargs: + lr: 0.00025 + alpha: .95 + eps: 0.00001 + centered: True + init_fn: + name: 'xavier_uniform' + loss_fn: + name: 'SmoothL1Loss' + replay_buffer: + name: 'CircularReplayBuffer' + kwargs: + capacity: 1000000 + stack_size: *stack_size + gamma: &gamma .99 + discount_rate: *gamma + reward_clip: 1 + update_period_schedule: + name: 'PeriodicSchedule' + kwargs: + off_value: False + on_value: True + period: 4 + target_net_update_schedule: + name: 'PeriodicSchedule' + kwargs: + off_value: False + on_value: True + period: 8000 + epsilon_schedule: + name: 'LinearSchedule' + kwargs: + init_value: 1.0 + end_value: .01 + steps: 250000 + test_epsilon: .001 + min_replay_history: 20000 + device: 'cuda' + log_frequency: 1000 +# List of logger configs used. +loggers: + - + name: ChompLogger +# - +# name: WandbLogger +# kwargs: +# project: Hive +# name: *run_name +# resume: "allow" +# start_method: "fork" diff --git a/hive/replays/circular_replay.py b/hive/replays/circular_replay.py index dce4343a..75236c9c 100644 --- a/hive/replays/circular_replay.py +++ b/hive/replays/circular_replay.py @@ -71,7 +71,7 @@ def __init__( self._n_step = n_step self._gamma = gamma self._discount = np.asarray( - [self._gamma ** i for i in range(self._n_step)], + [self._gamma**i for i in range(self._n_step)], dtype=self._specs["reward"][0], ) self._episode_start = True diff --git a/hive/replays/prioritized_replay.py b/hive/replays/prioritized_replay.py index 41b725a8..1f4a6cb8 100644 --- a/hive/replays/prioritized_replay.py +++ b/hive/replays/prioritized_replay.py @@ -160,7 +160,7 @@ class SumTree: def __init__(self, capacity: int): self._capacity = capacity self._depth = int(np.ceil(np.log2(capacity))) + 1 - self._tree = np.zeros(2 ** self._depth - 1) + self._tree = np.zeros(2**self._depth - 1) self._last_level_start = 2 ** (self._depth - 1) - 1 self._priorities = self._tree[ self._last_level_start : self._last_level_start + self._capacity diff --git a/tests/hive/replays/test_circular_buffer.py b/tests/hive/replays/test_circular_buffer.py index 7ea5fe6e..9c49540d 100644 --- a/tests/hive/replays/test_circular_buffer.py +++ b/tests/hive/replays/test_circular_buffer.py @@ -301,7 +301,7 @@ def test_n_step_buffer(full_n_step_buffer): assert batch["observation"][i].shape == OBS_SHAPE expected_reward = 0 for delta_t in range(N_STEP_HORIZON): - expected_reward += ((timestep + delta_t) % 10) * (GAMMA ** delta_t) + expected_reward += ((timestep + delta_t) % 10) * (GAMMA**delta_t) if (timestep + delta_t + 1) % 15 == 0: break assert batch["reward"][i] == pytest.approx(expected_reward)