Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rnn support: DRQN agent + recurrent buffer #258

Merged
merged 51 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
981be53
initial commit DRQN implementation
hnekoeiq Mar 11, 2022
0aabcff
initial commit recurrent buffer implementation
hnekoeiq Mar 11, 2022
9d666a8
Merge pull request #256 from chandar-lab/recurrent_buffer
hnekoeiq Mar 12, 2022
f557787
Merge branch 'rnn_support' into recurrent_dqn
hnekoeiq Mar 12, 2022
38af33c
Merge pull request #257 from chandar-lab/recurrent_dqn
hnekoeiq Mar 12, 2022
41b6041
drqn agent working with recurrent buffer
hnekoeiq Mar 12, 2022
d2f205f
fix device
hnekoeiq Mar 15, 2022
1ae9cf0
nstep update for RNN
karthiks1701 Mar 16, 2022
346a46f
handled input obs dimension depending on whether agent is acting or u…
TongTongX Mar 16, 2022
921d336
nstep update for RNN after format
karthiks1701 Mar 16, 2022
d66e8ad
Merge branch 'rnn_support' of https://github.com/chandar-lab/RLHive i…
karthiks1701 Mar 16, 2022
c0f976f
reformatted code with black
TongTongX Mar 16, 2022
cc28dfa
Merge branch 'rnn_support' of github.com:chandar-lab/RLHive into rnn_…
TongTongX Mar 16, 2022
4874d7a
fix bugs with buffer size and device
Mar 18, 2022
cf6cfe4
update device configuration
hnekoeiq Mar 22, 2022
8e935b9
DRQN agent tested + updated doctrings and some cleanups
hnekoeiq Mar 22, 2022
b3c810c
fixing id issue for MARL
Mar 31, 2022
7264b91
formatted with black
Mar 31, 2022
de35f5c
initial commit DRQN implementation
hnekoeiq Mar 11, 2022
4ce0d65
initial commit recurrent buffer implementation
hnekoeiq Mar 11, 2022
d985d7e
drqn agent working with recurrent buffer
hnekoeiq Mar 12, 2022
48135a7
fix device
hnekoeiq Mar 15, 2022
29d7607
nstep update for RNN
karthiks1701 Mar 16, 2022
9b25182
handled input obs dimension depending on whether agent is acting or u…
TongTongX Mar 16, 2022
28178cb
nstep update for RNN after format
karthiks1701 Mar 16, 2022
abe510b
reformatted code with black
TongTongX Mar 16, 2022
8f58643
fix bugs with buffer size and device
Mar 18, 2022
a682dcc
update device configuration
hnekoeiq Mar 22, 2022
89e0c05
DRQN agent tested + updated doctrings and some cleanups
hnekoeiq Mar 22, 2022
67f7dfa
fixing id issue for MARL
Mar 31, 2022
9315f23
formatted with black
Mar 31, 2022
f1a9dd8
updating to new registeration
hnekoeiq Apr 11, 2022
a4e9989
suppor both lstm and gru
hnekoeiq Apr 11, 2022
f6d2976
docstrings, cleanups and adding some utils functions
hnekoeiq Apr 12, 2022
1eb931b
Merge branch 'rnn_support' of https://github.com/chandar-lab/RLHive i…
Apr 25, 2022
058f979
merged changes related to callable objects
TongTongX Jul 15, 2022
1df5a69
added sequence model class
TongTongX Jul 16, 2022
a9b73f1
removed unused parameter in base sequence module class; docstring min…
TongTongX Jul 18, 2022
9eb4d52
Merge branch 'dev' into rnn_support
hnekoeiq Jul 22, 2022
abaa7a8
docstring and other minor changes
TongTongX Aug 2, 2022
104823e
reverted stack_size for other agents
TongTongX Aug 2, 2022
088549b
sequence registrable not inherit torch nn.Module
TongTongX Aug 9, 2022
6c6a87b
format
TongTongX Aug 9, 2022
f441706
DRQN reset hidden state in act(), set device of sequence model in yml…
TongTongX Aug 30, 2022
fcc8e6e
clean up
hnekoeiq Sep 19, 2022
89376b2
Merge branch 'dev' into rnn_support
hnekoeiq Sep 20, 2022
2ffb18f
minors fixes
TongTongX Oct 3, 2022
64a88ac
alphabetical ordering
TongTongX Oct 3, 2022
eedf125
fixed device mismatch between rnn hidden state and placeholder image
Oct 13, 2022
39082aa
add both batch and sequence dim to observation during acting
Oct 13, 2022
9ddf9d2
Merge branch 'dev' into rnn_support
sriyash421 Oct 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hive/agents/drqn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import os

import gym
import numpy as np
import torch

Expand Down
58 changes: 13 additions & 45 deletions hive/agents/qnets/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from hive.agents.qnets.mlp import MLPNetwork
from hive.agents.qnets.conv import ConvNetwork
from hive.agents.qnets.utils import calculate_output_dim
from hive.agents.qnets.sequence_models import SequenceModule


class ConvRNNNetwork(nn.Module):
Expand All @@ -23,22 +24,23 @@ class ConvRNNNetwork(nn.Module):
def __init__(
self,
in_dim,
sequence_fn: SequenceModule,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment in the sequence_models.py file about this.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

channels=None,
mlp_layers=None,
kernel_sizes=1,
strides=1,
paddings=0,
normalization_factor=255,
rnn_type="lstm",
rnn_hidden_size=128,
num_rnn_layers=1,
noisy=False,
std_init=0.5,
):
"""
Args:
in_dim (tuple): The tuple of observations dimension (channels, width,
height).
sequence_fn (SequenceModule): A sequence neural network that learns
recurrent representation. Usually placed between the convolutional
layers and mlp layers.
channels (list): The size of output channel for each convolutional layer.
mlp_layers (list): The number of neurons for each mlp layer after the
convolutional layers.
Expand All @@ -48,9 +50,6 @@ def __init__(
layer.
normalization_factor (float | int): What the input is divided by before
the forward pass of the network.
rnn_type (str): Type of the recurrent layer. For now, we support lstm and gru.
rnn_hidden_size (int): The number of features in the hidden state h.
num_rnn_layers (int): Number of recurrent layers.
noisy (bool): Whether the MLP part of the network will use
:py:class:`~hive.agents.qnets.noisy_linear.NoisyLinear` layers or
:py:class:`torch.nn.Linear` layers.
Expand All @@ -59,9 +58,6 @@ def __init__(
:py:class:`~hive.agents.qnets.noisy_linear.NoisyLinear`.
"""
super().__init__()
self._rnn_type = rnn_type
self._rnn_hidden_size = rnn_hidden_size
self._num_rnn_layers = num_rnn_layers
self._normalization_factor = normalization_factor
if channels is not None:
if isinstance(kernel_sizes, int):
Expand Down Expand Up @@ -96,27 +92,14 @@ def __init__(

# RNN Layers
conv_output_size = calculate_output_dim(self.conv, in_dim)
if self._rnn_type == "lstm":
self.rnn = nn.LSTM(
np.prod(conv_output_size),
rnn_hidden_size,
num_rnn_layers,
batch_first=True,
)
elif self._rnn_type == "gru":
self.rnn = nn.GRU(
np.prod(conv_output_size),
rnn_hidden_size,
num_rnn_layers,
batch_first=True,
)
else:
raise ValueError("Invalid rnn type: {}".format(self._rnn_type))
self.rnn = sequence_fn(
rnn_input_size=np.prod(conv_output_size),
)

if mlp_layers is not None:
# MLP Layers
self.mlp = MLPNetwork(
rnn_hidden_size, mlp_layers, noisy=noisy, std_init=std_init
self.rnn.hidden_size, mlp_layers, noisy=noisy, std_init=std_init
)
else:
self.mlp = nn.Identity()
Expand Down Expand Up @@ -145,24 +128,9 @@ def forward(self, x, hidden_state=None):
return x, hidden_state

def init_hidden(self, batch_size, device="cpu"):
if self._rnn_type == "lstm":
hidden_state = (
torch.zeros(
(self._num_rnn_layers, batch_size, self._rnn_hidden_size),
dtype=torch.float32,
device=device,
),
torch.zeros(
(self._num_rnn_layers, batch_size, self._rnn_hidden_size),
dtype=torch.float32,
device=device,
),
)
elif self._rnn_type == "gru":
hidden_state = torch.zeros(
(self._num_rnn_layers, batch_size, self._rnn_hidden_size),
dtype=torch.float32,
device=device,
)
hidden_state = self.rnn.init_hidden(
batch_size=batch_size,
device=device,
)

return hidden_state
145 changes: 145 additions & 0 deletions hive/agents/qnets/sequence_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import torch
from torch import nn

from hive.utils.registry import registry, Registrable


class SequenceModule(nn.Module, Registrable):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really see what the point of this class is. It's not really providing any additional logic. I think instead of this, you can just do:

registry.register_all(FunctionApproximator,
{
'LSTM':FunctionApproximator(torch.nn.LSTM),
'GRU':FunctionApproximator(torch.nn.GRU),
}

In the convrnnnetwork class, you can just have it take in a FunctionApproximator for the seqence_fn and create it directly. If you really want to differentiate SequenceModels as a type, you should just have it subclass Registrable and provide a type name, similar to what FunctionApproximator does. It shouldn't subclass nn.Module.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although SequenceModule is simply a wrapper for LSTM or GRU for now, we made it a separate class mainly because we can implement our own recurrent module later on.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""
Base sequence neural network architecture.
"""

def __init__(
self,
rnn_input_size=256,
rnn_hidden_size=128,
num_rnn_layers=1,
batch_first=True,
):
"""
Args:
rnn_input_size (int): The number of expected features in the input x.
rnn_hidden_size (int): The number of features in the hidden state h.
num_rnn_layers (int): Number of recurrent layers.
batch_first (bool): If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature).
"""
super().__init__()
self._rnn_input_size = rnn_input_size
self._rnn_hidden_size = rnn_hidden_size
self._num_rnn_layers = num_rnn_layers
self._batch_first = batch_first
self.core = None

def forward(self, x, hidden_state=None):
x, hidden_state = self.core(x, hidden_state)
return x, hidden_state

@property
def hidden_size(self):
return self._rnn_hidden_size

@classmethod
def type_name(cls):
return "sequence_fn"


class LSTMModule(SequenceModule):
"""
A multi-layer long short-term memory (LSTM) RNN.
"""

def __init__(
self,
rnn_input_size=256,
rnn_hidden_size=128,
num_rnn_layers=1,
batch_first=True,
):
"""
Args:
rnn_input_size (int): The number of expected features in the input x.
rnn_hidden_size (int): The number of features in the hidden state h.
num_rnn_layers (int): Number of recurrent layers.
batch_first (bool): If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line length

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

"""
super().__init__(
rnn_input_size=rnn_input_size,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove rnn_input_size if the base function does not use it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. SImilarly batch_first is also removed.

rnn_hidden_size=rnn_hidden_size,
num_rnn_layers=num_rnn_layers,
batch_first=batch_first,
)
self.core = nn.LSTM(
input_size=self._rnn_input_size,
hidden_size=self._rnn_hidden_size,
num_layers=self._num_rnn_layers,
batch_first=self._batch_first,
)

def init_hidden(self, batch_size, device="cpu"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this device not just be passed once in the initializer?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

hidden_state = (
torch.zeros(
(self._num_rnn_layers, batch_size, self._rnn_hidden_size),
dtype=torch.float32,
device=device,
),
torch.zeros(
(self._num_rnn_layers, batch_size, self._rnn_hidden_size),
dtype=torch.float32,
device=device,
),
)

return hidden_state


class GRUModule(SequenceModule):
"""
A multi-layer gated recurrent unit (GRU) RNN.
"""

def __init__(
self,
rnn_input_size=256,
rnn_hidden_size=128,
num_rnn_layers=1,
batch_first=True,
):
"""
Args:
rnn_input_size (int): The number of expected features in the input x.
rnn_hidden_size (int): The number of features in the hidden state h.
num_rnn_layers (int): Number of recurrent layers.
batch_first (bool): If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line length

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

"""
super().__init__(
rnn_input_size=rnn_input_size,
rnn_hidden_size=rnn_hidden_size,
num_rnn_layers=num_rnn_layers,
batch_first=batch_first,
)
self.core = nn.GRU(
input_size=self._rnn_input_size,
hidden_size=self._rnn_hidden_size,
num_layers=self._num_rnn_layers,
batch_first=self._batch_first,
)

def init_hidden(self, batch_size, device="cpu"):
hidden_state = torch.zeros(
(self._num_rnn_layers, batch_size, self._rnn_hidden_size),
dtype=torch.float32,
device=device,
)

return hidden_state


registry.register_all(
SequenceModule,
{
"LSTM": LSTMModule,
"GRU": GRUModule,
},
)

get_sequence_fn = getattr(registry, f"get_{SequenceModule.type_name()}")
5 changes: 5 additions & 0 deletions hive/configs/atari/drqn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ agent:
strides: [4, 2, 1]
paddings: [2, 2, 1]
mlp_layers: [512]
sequence_fn:
name: 'LSTM'
kwargs:
rnn_hidden_size: 128
num_rnn_layers: 1
optimizer_fn:
name: 'RMSpropTF'
kwargs:
Expand Down