Skip to content

Commit

Permalink
Hopefully fix CUDA again???
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed Sep 9, 2023
1 parent 2118e01 commit 9694479
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion coltra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
from coltra.buffers import Action, Observation
from coltra import utils

__version__ = "0.1.8"
__version__ = "0.1.9"
VERSION = __version__
6 changes: 3 additions & 3 deletions coltra/models/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,19 +297,19 @@ def latent(self, x: Tensor, hidden_state: Tuple[Tensor, Tensor]) -> Tensor:
return x

def get_initial_state(
self, batch_size: int = 1, requires_grad: bool = True
self, batch_size: int = 1, requires_grad: bool = True, device: str = "cpu"
) -> Tuple[Tensor, Tensor]:
return (
torch.zeros(
batch_size,
self.lstm_hidden_size,
requires_grad=requires_grad,
device=self.device,
device=device,
),
torch.zeros(
batch_size,
self.lstm_hidden_size,
requires_grad=requires_grad,
device=self.device,
device=device,
),
)
4 changes: 2 additions & 2 deletions coltra/models/lstm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def get_initial_state(
self, batch_size: int = 1, requires_grad: bool = True
) -> Tuple:
return self.policy_network.get_initial_state(
batch_size, requires_grad
), self.value_network.get_initial_state(batch_size, requires_grad)
batch_size, requires_grad, self.device
), self.value_network.get_initial_state(batch_size, requires_grad, self.device)


class FlattenLSTMModel(LSTMModel):
Expand Down

0 comments on commit 9694479

Please sign in to comment.