Skip to content

Commit

Permalink
Merge pull request #257 from chandar-lab/recurrent_dqn
Browse files Browse the repository at this point in the history
initial commit DRQN implementation
  • Loading branch information
hnekoeiq authored Mar 12, 2022
2 parents 9d666a8 + f557787 commit 38af33c
Show file tree
Hide file tree
Showing 11 changed files with 429 additions and 11 deletions.
2 changes: 2 additions & 0 deletions hive/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from hive.agents import qnets
from hive.agents.agent import Agent
from hive.agents.dqn import DQNAgent
from hive.agents.drqn import DRQNAgent
from hive.agents.legal_moves_rainbow import LegalMovesRainbowAgent
from hive.agents.rainbow import RainbowDQNAgent
from hive.agents.random import RandomAgent
Expand All @@ -13,6 +14,7 @@
"LegalMovesRainbowAgent": LegalMovesRainbowAgent,
"RainbowDQNAgent": RainbowDQNAgent,
"RandomAgent": RandomAgent,
"DRQNAgent": DRQNAgent,
},
)

Expand Down
2 changes: 1 addition & 1 deletion hive/agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
self._replay_buffer = replay_buffer
if self._replay_buffer is None:
self._replay_buffer = CircularReplayBuffer()
self._discount_rate = discount_rate ** n_step
self._discount_rate = discount_rate**n_step
self._grad_clip = grad_clip
self._reward_clip = reward_clip
self._target_net_soft_update = target_net_soft_update
Expand Down
288 changes: 288 additions & 0 deletions hive/agents/drqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
import copy
import os

import numpy as np
import torch

from hive.agents.agent import Agent
from hive.agents.qnets.base import FunctionApproximator
from hive.agents.qnets.qnet_heads import DQNNetwork
from hive.agents.qnets.utils import (
InitializationFn,
calculate_output_dim,
create_init_weights_fn,
)
from hive.replays import BaseReplayBuffer, CircularReplayBuffer
from hive.utils.loggers import Logger, NullLogger
from hive.utils.schedule import (
LinearSchedule,
PeriodicSchedule,
Schedule,
SwitchSchedule,
)
from hive.agents.dqn import DQNAgent
from hive.utils.utils import LossFn, OptimizerFn, create_folder, seeder


class DRQNAgent(DQNAgent):
"""An agent implementing the DQN algorithm. Uses an epsilon greedy
exploration policy
"""

def __init__(
self,
representation_net: FunctionApproximator,
obs_dim,
act_dim: int,
id=0,
optimizer_fn: OptimizerFn = None,
loss_fn: LossFn = None,
init_fn: InitializationFn = None,
replay_buffer: BaseReplayBuffer = None,
discount_rate: float = 0.99,
n_step: int = 1,
grad_clip: float = None,
reward_clip: float = None,
update_period_schedule: Schedule = None,
target_net_soft_update: bool = False,
target_net_update_fraction: float = 0.05,
target_net_update_schedule: Schedule = None,
epsilon_schedule: Schedule = None,
test_epsilon: float = 0.001,
min_replay_history: int = 5000,
batch_size: int = 32,
device="cpu",
logger: Logger = None,
log_frequency: int = 100,
):
"""
Args:
representation_net (FunctionApproximator): A network that outputs the
representations that will be used to compute Q-values (e.g.
everything except the final layer of the DQN).
obs_dim: The shape of the observations.
act_dim (int): The number of actions available to the agent.
id: Agent identifier.
optimizer_fn (OptimizerFn): A function that takes in a list of parameters
to optimize and returns the optimizer. If None, defaults to
:py:class:`~torch.optim.Adam`.
loss_fn (LossFn): Loss function used by the agent. If None, defaults to
:py:class:`~torch.nn.SmoothL1Loss`.
init_fn (InitializationFn): Initializes the weights of qnet using
create_init_weights_fn.
replay_buffer (BaseReplayBuffer): The replay buffer that the agent will
push observations to and sample from during learning. If None,
defaults to
:py:class:`~hive.replays.circular_replay.CircularReplayBuffer`.
discount_rate (float): A number between 0 and 1 specifying how much
future rewards are discounted by the agent.
n_step (int): The horizon used in n-step returns to compute TD(n) targets.
grad_clip (float): Gradients will be clipped to between
[-grad_clip, grad_clip].
reward_clip (float): Rewards will be clipped to between
[-reward_clip, reward_clip].
update_period_schedule (Schedule): Schedule determining how frequently
the agent's Q-network is updated.
target_net_soft_update (bool): Whether the target net parameters are
replaced by the qnet parameters completely or using a weighted
average of the target net parameters and the qnet parameters.
target_net_update_fraction (float): The weight given to the target
net parameters in a soft update.
target_net_update_schedule (Schedule): Schedule determining how frequently
the target net is updated.
epsilon_schedule (Schedule): Schedule determining the value of epsilon
through the course of training.
test_epsilon (float): epsilon (probability of choosing a random action)
to be used during testing phase.
min_replay_history (int): How many observations to fill the replay buffer
with before starting to learn.
batch_size (int): The size of the batch sampled from the replay buffer
during learning.
device: Device on which all computations should be run.
logger (ScheduledLogger): Logger used to log agent's metrics.
log_frequency (int): How often to log the agent's metrics.
"""
super().__init__(
representation_net=representation_net,
obs_dim=obs_dim,
act_dim=act_dim,
id=0,
optimizer_fn=optimizer_fn,
loss_fn=loss_fn,
init_fn=init_fn,
replay_buffer=replay_buffer,
discount_rate=discount_rate,
n_step=n_step,
grad_clip=grad_clip,
reward_clip=reward_clip,
update_period_schedule=update_period_schedule,
target_net_soft_update=target_net_soft_update,
target_net_update_fraction=target_net_update_fraction,
target_net_update_schedule=target_net_update_schedule,
epsilon_schedule=epsilon_schedule,
test_epsilon=test_epsilon,
min_replay_history=min_replay_history,
batch_size=batch_size,
device=device,
logger=logger,
log_frequency=log_frequency,
)

def create_q_networks(self, representation_net):
"""Creates the Q-network and target Q-network.
Args:
representation_net: A network that outputs the representations that will
be used to compute Q-values (e.g. everything except the final layer
of the DQN).
"""
network = representation_net(self._obs_dim)
network_output_dim = np.prod(calculate_output_dim(network, self._obs_dim))
self._qnet = DQNNetwork(
network, network_output_dim, self._act_dim, use_rnn=True
).to(self._device)
self._qnet.apply(self._init_fn)
self._target_qnet = copy.deepcopy(self._qnet).requires_grad_(False)

def preprocess_update_info(self, update_info):
"""Preprocesses the :obj:`update_info` before it goes into the replay buffer.
Clips the reward in update_info.
Args:
update_info: Contains the information from the current timestep that the
agent should use to update itself.
"""
if self._reward_clip is not None:
update_info["reward"] = np.clip(
update_info["reward"], -self._reward_clip, self._reward_clip
)
preprocessed_update_info = {
"observation": update_info["observation"],
"action": update_info["action"],
"reward": update_info["reward"],
"done": update_info["done"],
}
if "agent_id" in update_info:
preprocessed_update_info["agent_id"] = int(update_info["agent_id"])

return preprocessed_update_info

def preprocess_update_batch(self, batch):
"""Preprocess the batch sampled from the replay buffer.
Args:
batch: Batch sampled from the replay buffer for the current update.
Returns:
(tuple):
- (tuple) Inputs used to calculate current state values.
- (tuple) Inputs used to calculate next state values
- Preprocessed batch.
"""
for key in batch:
batch[key] = torch.tensor(batch[key], device=self._device)
return (batch["observation"],), (batch["next_observation"],), batch

@torch.no_grad()
def act(self, observation):
"""Returns the action for the agent. If in training mode, follows an epsilon
greedy policy. Otherwise, returns the action with the highest Q-value.
Args:
observation: The current observation.
"""

# Determine and log the value of epsilon
if self._training:
if not self._learn_schedule.get_value():
epsilon = 1.0
else:
epsilon = self._epsilon_schedule.update()
if self._logger.update_step(self._timescale):
self._logger.log_scalar("epsilon", epsilon, self._timescale)
else:
epsilon = self._test_epsilon

# Sample action. With epsilon probability choose random action,
# otherwise select the action with the highest q-value.
observation = torch.tensor(
np.expand_dims(observation, axis=0), device=self._device
).float()
qvals = self._qnet(observation)
if self._rng.random() < epsilon:
action = self._rng.integers(self._act_dim)
else:
# Note: not explicitly handling the ties
action = torch.argmax(qvals).item()

if (
self._training
and self._logger.should_log(self._timescale)
and self._state["episode_start"]
):
self._logger.log_scalar("train_qval", torch.max(qvals), self._timescale)
self._state["episode_start"] = False
return action

def update(self, update_info):
"""
Updates the DQN agent.
Args:
update_info: dictionary containing all the necessary information to
update the agent. Should contain a full transition, with keys for
"observation", "action", "reward", and "done".
"""
if update_info["done"]:
self._state["episode_start"] = True

if not self._training:
return

# Add the most recent transition to the replay buffer.
self._replay_buffer.add(**self.preprocess_update_info(update_info))

# Update the q network based on a sample batch from the replay buffer.
# If the replay buffer doesn't have enough samples, catch the exception
# and move on.
if (
self._learn_schedule.update()
and self._replay_buffer.size() > 0
and self._update_period_schedule.update()
):
batch = self._replay_buffer.sample(batch_size=self._batch_size)
(
current_state_inputs,
next_state_inputs,
batch,
) = self.preprocess_update_batch(batch)

# Compute predicted Q values
self._optimizer.zero_grad()
pred_qvals = self._qnet(*current_state_inputs)
actions = batch["action"].long()
pred_qvals = pred_qvals[torch.arange(pred_qvals.size(0)), actions]

# Compute 1-step Q targets
next_qvals = self._target_qnet(*next_state_inputs)
next_qvals, _ = torch.max(next_qvals, dim=1)

q_targets = batch["reward"] + self._discount_rate * next_qvals * (
1 - batch["done"]
)

loss = self._loss_fn(pred_qvals, q_targets).mean()

if self._logger.should_log(self._timescale):
self._logger.log_scalar("train_loss", loss, self._timescale)

loss.backward()
if self._grad_clip is not None:
torch.nn.utils.clip_grad_value_(
self._qnet.parameters(), self._grad_clip
)
self._optimizer.step()

# Update target network
if self._target_net_update_schedule.update():
self._update_target()
2 changes: 2 additions & 0 deletions hive/agents/qnets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from hive.agents.qnets.base import FunctionApproximator
from hive.agents.qnets.conv import ConvNetwork
from hive.agents.qnets.mlp import MLPNetwork
from hive.agents.qnets.rnn import ConvRNNNetwork

registry.register_all(
FunctionApproximator,
{
"MLPNetwork": FunctionApproximator(MLPNetwork),
"ConvNetwork": FunctionApproximator(ConvNetwork),
"ConvRNNNetwork": FunctionApproximator(ConvRNNNetwork),
"NatureAtariDQNModel": FunctionApproximator(NatureAtariDQNModel),
},
)
Expand Down
Loading

0 comments on commit 38af33c

Please sign in to comment.