-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #257 from chandar-lab/recurrent_dqn
initial commit DRQN implementation
- Loading branch information
Showing
11 changed files
with
429 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,288 @@ | ||
import copy | ||
import os | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from hive.agents.agent import Agent | ||
from hive.agents.qnets.base import FunctionApproximator | ||
from hive.agents.qnets.qnet_heads import DQNNetwork | ||
from hive.agents.qnets.utils import ( | ||
InitializationFn, | ||
calculate_output_dim, | ||
create_init_weights_fn, | ||
) | ||
from hive.replays import BaseReplayBuffer, CircularReplayBuffer | ||
from hive.utils.loggers import Logger, NullLogger | ||
from hive.utils.schedule import ( | ||
LinearSchedule, | ||
PeriodicSchedule, | ||
Schedule, | ||
SwitchSchedule, | ||
) | ||
from hive.agents.dqn import DQNAgent | ||
from hive.utils.utils import LossFn, OptimizerFn, create_folder, seeder | ||
|
||
|
||
class DRQNAgent(DQNAgent): | ||
"""An agent implementing the DQN algorithm. Uses an epsilon greedy | ||
exploration policy | ||
""" | ||
|
||
def __init__( | ||
self, | ||
representation_net: FunctionApproximator, | ||
obs_dim, | ||
act_dim: int, | ||
id=0, | ||
optimizer_fn: OptimizerFn = None, | ||
loss_fn: LossFn = None, | ||
init_fn: InitializationFn = None, | ||
replay_buffer: BaseReplayBuffer = None, | ||
discount_rate: float = 0.99, | ||
n_step: int = 1, | ||
grad_clip: float = None, | ||
reward_clip: float = None, | ||
update_period_schedule: Schedule = None, | ||
target_net_soft_update: bool = False, | ||
target_net_update_fraction: float = 0.05, | ||
target_net_update_schedule: Schedule = None, | ||
epsilon_schedule: Schedule = None, | ||
test_epsilon: float = 0.001, | ||
min_replay_history: int = 5000, | ||
batch_size: int = 32, | ||
device="cpu", | ||
logger: Logger = None, | ||
log_frequency: int = 100, | ||
): | ||
""" | ||
Args: | ||
representation_net (FunctionApproximator): A network that outputs the | ||
representations that will be used to compute Q-values (e.g. | ||
everything except the final layer of the DQN). | ||
obs_dim: The shape of the observations. | ||
act_dim (int): The number of actions available to the agent. | ||
id: Agent identifier. | ||
optimizer_fn (OptimizerFn): A function that takes in a list of parameters | ||
to optimize and returns the optimizer. If None, defaults to | ||
:py:class:`~torch.optim.Adam`. | ||
loss_fn (LossFn): Loss function used by the agent. If None, defaults to | ||
:py:class:`~torch.nn.SmoothL1Loss`. | ||
init_fn (InitializationFn): Initializes the weights of qnet using | ||
create_init_weights_fn. | ||
replay_buffer (BaseReplayBuffer): The replay buffer that the agent will | ||
push observations to and sample from during learning. If None, | ||
defaults to | ||
:py:class:`~hive.replays.circular_replay.CircularReplayBuffer`. | ||
discount_rate (float): A number between 0 and 1 specifying how much | ||
future rewards are discounted by the agent. | ||
n_step (int): The horizon used in n-step returns to compute TD(n) targets. | ||
grad_clip (float): Gradients will be clipped to between | ||
[-grad_clip, grad_clip]. | ||
reward_clip (float): Rewards will be clipped to between | ||
[-reward_clip, reward_clip]. | ||
update_period_schedule (Schedule): Schedule determining how frequently | ||
the agent's Q-network is updated. | ||
target_net_soft_update (bool): Whether the target net parameters are | ||
replaced by the qnet parameters completely or using a weighted | ||
average of the target net parameters and the qnet parameters. | ||
target_net_update_fraction (float): The weight given to the target | ||
net parameters in a soft update. | ||
target_net_update_schedule (Schedule): Schedule determining how frequently | ||
the target net is updated. | ||
epsilon_schedule (Schedule): Schedule determining the value of epsilon | ||
through the course of training. | ||
test_epsilon (float): epsilon (probability of choosing a random action) | ||
to be used during testing phase. | ||
min_replay_history (int): How many observations to fill the replay buffer | ||
with before starting to learn. | ||
batch_size (int): The size of the batch sampled from the replay buffer | ||
during learning. | ||
device: Device on which all computations should be run. | ||
logger (ScheduledLogger): Logger used to log agent's metrics. | ||
log_frequency (int): How often to log the agent's metrics. | ||
""" | ||
super().__init__( | ||
representation_net=representation_net, | ||
obs_dim=obs_dim, | ||
act_dim=act_dim, | ||
id=0, | ||
optimizer_fn=optimizer_fn, | ||
loss_fn=loss_fn, | ||
init_fn=init_fn, | ||
replay_buffer=replay_buffer, | ||
discount_rate=discount_rate, | ||
n_step=n_step, | ||
grad_clip=grad_clip, | ||
reward_clip=reward_clip, | ||
update_period_schedule=update_period_schedule, | ||
target_net_soft_update=target_net_soft_update, | ||
target_net_update_fraction=target_net_update_fraction, | ||
target_net_update_schedule=target_net_update_schedule, | ||
epsilon_schedule=epsilon_schedule, | ||
test_epsilon=test_epsilon, | ||
min_replay_history=min_replay_history, | ||
batch_size=batch_size, | ||
device=device, | ||
logger=logger, | ||
log_frequency=log_frequency, | ||
) | ||
|
||
def create_q_networks(self, representation_net): | ||
"""Creates the Q-network and target Q-network. | ||
Args: | ||
representation_net: A network that outputs the representations that will | ||
be used to compute Q-values (e.g. everything except the final layer | ||
of the DQN). | ||
""" | ||
network = representation_net(self._obs_dim) | ||
network_output_dim = np.prod(calculate_output_dim(network, self._obs_dim)) | ||
self._qnet = DQNNetwork( | ||
network, network_output_dim, self._act_dim, use_rnn=True | ||
).to(self._device) | ||
self._qnet.apply(self._init_fn) | ||
self._target_qnet = copy.deepcopy(self._qnet).requires_grad_(False) | ||
|
||
def preprocess_update_info(self, update_info): | ||
"""Preprocesses the :obj:`update_info` before it goes into the replay buffer. | ||
Clips the reward in update_info. | ||
Args: | ||
update_info: Contains the information from the current timestep that the | ||
agent should use to update itself. | ||
""" | ||
if self._reward_clip is not None: | ||
update_info["reward"] = np.clip( | ||
update_info["reward"], -self._reward_clip, self._reward_clip | ||
) | ||
preprocessed_update_info = { | ||
"observation": update_info["observation"], | ||
"action": update_info["action"], | ||
"reward": update_info["reward"], | ||
"done": update_info["done"], | ||
} | ||
if "agent_id" in update_info: | ||
preprocessed_update_info["agent_id"] = int(update_info["agent_id"]) | ||
|
||
return preprocessed_update_info | ||
|
||
def preprocess_update_batch(self, batch): | ||
"""Preprocess the batch sampled from the replay buffer. | ||
Args: | ||
batch: Batch sampled from the replay buffer for the current update. | ||
Returns: | ||
(tuple): | ||
- (tuple) Inputs used to calculate current state values. | ||
- (tuple) Inputs used to calculate next state values | ||
- Preprocessed batch. | ||
""" | ||
for key in batch: | ||
batch[key] = torch.tensor(batch[key], device=self._device) | ||
return (batch["observation"],), (batch["next_observation"],), batch | ||
|
||
@torch.no_grad() | ||
def act(self, observation): | ||
"""Returns the action for the agent. If in training mode, follows an epsilon | ||
greedy policy. Otherwise, returns the action with the highest Q-value. | ||
Args: | ||
observation: The current observation. | ||
""" | ||
|
||
# Determine and log the value of epsilon | ||
if self._training: | ||
if not self._learn_schedule.get_value(): | ||
epsilon = 1.0 | ||
else: | ||
epsilon = self._epsilon_schedule.update() | ||
if self._logger.update_step(self._timescale): | ||
self._logger.log_scalar("epsilon", epsilon, self._timescale) | ||
else: | ||
epsilon = self._test_epsilon | ||
|
||
# Sample action. With epsilon probability choose random action, | ||
# otherwise select the action with the highest q-value. | ||
observation = torch.tensor( | ||
np.expand_dims(observation, axis=0), device=self._device | ||
).float() | ||
qvals = self._qnet(observation) | ||
if self._rng.random() < epsilon: | ||
action = self._rng.integers(self._act_dim) | ||
else: | ||
# Note: not explicitly handling the ties | ||
action = torch.argmax(qvals).item() | ||
|
||
if ( | ||
self._training | ||
and self._logger.should_log(self._timescale) | ||
and self._state["episode_start"] | ||
): | ||
self._logger.log_scalar("train_qval", torch.max(qvals), self._timescale) | ||
self._state["episode_start"] = False | ||
return action | ||
|
||
def update(self, update_info): | ||
""" | ||
Updates the DQN agent. | ||
Args: | ||
update_info: dictionary containing all the necessary information to | ||
update the agent. Should contain a full transition, with keys for | ||
"observation", "action", "reward", and "done". | ||
""" | ||
if update_info["done"]: | ||
self._state["episode_start"] = True | ||
|
||
if not self._training: | ||
return | ||
|
||
# Add the most recent transition to the replay buffer. | ||
self._replay_buffer.add(**self.preprocess_update_info(update_info)) | ||
|
||
# Update the q network based on a sample batch from the replay buffer. | ||
# If the replay buffer doesn't have enough samples, catch the exception | ||
# and move on. | ||
if ( | ||
self._learn_schedule.update() | ||
and self._replay_buffer.size() > 0 | ||
and self._update_period_schedule.update() | ||
): | ||
batch = self._replay_buffer.sample(batch_size=self._batch_size) | ||
( | ||
current_state_inputs, | ||
next_state_inputs, | ||
batch, | ||
) = self.preprocess_update_batch(batch) | ||
|
||
# Compute predicted Q values | ||
self._optimizer.zero_grad() | ||
pred_qvals = self._qnet(*current_state_inputs) | ||
actions = batch["action"].long() | ||
pred_qvals = pred_qvals[torch.arange(pred_qvals.size(0)), actions] | ||
|
||
# Compute 1-step Q targets | ||
next_qvals = self._target_qnet(*next_state_inputs) | ||
next_qvals, _ = torch.max(next_qvals, dim=1) | ||
|
||
q_targets = batch["reward"] + self._discount_rate * next_qvals * ( | ||
1 - batch["done"] | ||
) | ||
|
||
loss = self._loss_fn(pred_qvals, q_targets).mean() | ||
|
||
if self._logger.should_log(self._timescale): | ||
self._logger.log_scalar("train_loss", loss, self._timescale) | ||
|
||
loss.backward() | ||
if self._grad_clip is not None: | ||
torch.nn.utils.clip_grad_value_( | ||
self._qnet.parameters(), self._grad_clip | ||
) | ||
self._optimizer.step() | ||
|
||
# Update target network | ||
if self._target_net_update_schedule.update(): | ||
self._update_target() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.