-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rnn support: DRQN agent + recurrent buffer #258
Changes from 1 commit
981be53
0aabcff
9d666a8
f557787
38af33c
41b6041
d2f205f
1ae9cf0
346a46f
921d336
d66e8ad
c0f976f
cc28dfa
4874d7a
cf6cfe4
8e935b9
b3c810c
7264b91
de35f5c
4ce0d65
d985d7e
48135a7
29d7607
9b25182
28178cb
abe510b
8f58643
a682dcc
89e0c05
67f7dfa
9315f23
f1a9dd8
a4e9989
f6d2976
1eb931b
058f979
1df5a69
a9b73f1
9eb4d52
abaa7a8
104823e
088549b
6c6a87b
f441706
fcc8e6e
89376b2
2ffb18f
64a88ac
eedf125
39082aa
9ddf9d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import copy | ||
import os | ||
|
||
import gym | ||
import numpy as np | ||
import torch | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import torch | ||
from torch import nn | ||
|
||
from hive.utils.registry import registry, Registrable | ||
|
||
|
||
class SequenceModule(nn.Module, Registrable): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't really see what the point of this class is. It's not really providing any additional logic. I think instead of this, you can just do:
In the convrnnnetwork class, you can just have it take in a FunctionApproximator for the seqence_fn and create it directly. If you really want to differentiate SequenceModels as a type, you should just have it subclass Registrable and provide a type name, similar to what FunctionApproximator does. It shouldn't subclass nn.Module. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
""" | ||
Base sequence neural network architecture. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
rnn_input_size=256, | ||
rnn_hidden_size=128, | ||
num_rnn_layers=1, | ||
batch_first=True, | ||
): | ||
""" | ||
Args: | ||
rnn_input_size (int): The number of expected features in the input x. | ||
rnn_hidden_size (int): The number of features in the hidden state h. | ||
num_rnn_layers (int): Number of recurrent layers. | ||
batch_first (bool): If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature). | ||
""" | ||
super().__init__() | ||
self._rnn_input_size = rnn_input_size | ||
self._rnn_hidden_size = rnn_hidden_size | ||
self._num_rnn_layers = num_rnn_layers | ||
self._batch_first = batch_first | ||
self.core = None | ||
|
||
def forward(self, x, hidden_state=None): | ||
x, hidden_state = self.core(x, hidden_state) | ||
return x, hidden_state | ||
|
||
@property | ||
def hidden_size(self): | ||
return self._rnn_hidden_size | ||
|
||
@classmethod | ||
def type_name(cls): | ||
return "sequence_fn" | ||
|
||
|
||
class LSTMModule(SequenceModule): | ||
""" | ||
A multi-layer long short-term memory (LSTM) RNN. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
rnn_input_size=256, | ||
rnn_hidden_size=128, | ||
num_rnn_layers=1, | ||
batch_first=True, | ||
): | ||
""" | ||
Args: | ||
rnn_input_size (int): The number of expected features in the input x. | ||
rnn_hidden_size (int): The number of features in the hidden state h. | ||
num_rnn_layers (int): Number of recurrent layers. | ||
batch_first (bool): If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. line length There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
""" | ||
super().__init__( | ||
rnn_input_size=rnn_input_size, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can remove rnn_input_size if the base function does not use it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. SImilarly batch_first is also removed. |
||
rnn_hidden_size=rnn_hidden_size, | ||
num_rnn_layers=num_rnn_layers, | ||
batch_first=batch_first, | ||
) | ||
self.core = nn.LSTM( | ||
input_size=self._rnn_input_size, | ||
hidden_size=self._rnn_hidden_size, | ||
num_layers=self._num_rnn_layers, | ||
batch_first=self._batch_first, | ||
) | ||
|
||
def init_hidden(self, batch_size, device="cpu"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can this device not just be passed once in the initializer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
hidden_state = ( | ||
torch.zeros( | ||
(self._num_rnn_layers, batch_size, self._rnn_hidden_size), | ||
dtype=torch.float32, | ||
device=device, | ||
), | ||
torch.zeros( | ||
(self._num_rnn_layers, batch_size, self._rnn_hidden_size), | ||
dtype=torch.float32, | ||
device=device, | ||
), | ||
) | ||
|
||
return hidden_state | ||
|
||
|
||
class GRUModule(SequenceModule): | ||
""" | ||
A multi-layer gated recurrent unit (GRU) RNN. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
rnn_input_size=256, | ||
rnn_hidden_size=128, | ||
num_rnn_layers=1, | ||
batch_first=True, | ||
): | ||
""" | ||
Args: | ||
rnn_input_size (int): The number of expected features in the input x. | ||
rnn_hidden_size (int): The number of features in the hidden state h. | ||
num_rnn_layers (int): Number of recurrent layers. | ||
batch_first (bool): If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. line length There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
""" | ||
super().__init__( | ||
rnn_input_size=rnn_input_size, | ||
rnn_hidden_size=rnn_hidden_size, | ||
num_rnn_layers=num_rnn_layers, | ||
batch_first=batch_first, | ||
) | ||
self.core = nn.GRU( | ||
input_size=self._rnn_input_size, | ||
hidden_size=self._rnn_hidden_size, | ||
num_layers=self._num_rnn_layers, | ||
batch_first=self._batch_first, | ||
) | ||
|
||
def init_hidden(self, batch_size, device="cpu"): | ||
hidden_state = torch.zeros( | ||
(self._num_rnn_layers, batch_size, self._rnn_hidden_size), | ||
dtype=torch.float32, | ||
device=device, | ||
) | ||
|
||
return hidden_state | ||
|
||
|
||
registry.register_all( | ||
SequenceModule, | ||
{ | ||
"LSTM": LSTMModule, | ||
"GRU": GRUModule, | ||
}, | ||
) | ||
|
||
get_sequence_fn = getattr(registry, f"get_{SequenceModule.type_name()}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comment in the sequence_models.py file about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done