Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rnn support: DRQN agent + recurrent buffer #258

Merged
merged 51 commits into from
Oct 13, 2022
Merged

Rnn support: DRQN agent + recurrent buffer #258

merged 51 commits into from
Oct 13, 2022

Conversation

hnekoeiq
Copy link
Collaborator

No description provided.

self.size() + self._max_seq_len + self._n_step - 1
)
elements = array[full_indices]
elements = elements.reshape(indices.shape[0], -1, *elements.shape[2:])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to ensure the dimension of observation is (batch_size, seq_length, C, H, W). C=1 according to https://github.com/chandar-lab/RLHive/blame/main/hive/envs/atari/atari.py#L62.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functions and code resued from CircularReplayBuffer? Maybe just use the inherited functions? Or are there changes across all the functions?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using the inherited functions defined in CircularReplayBuffer except for the ones where max_seq_len and stack_size differ.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can something like super calls be done? For e.g. the sample function in PPOReplayBuffer? That is the input is updated such that it can be passed to the function of the master class.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will be a new PR that implements a recurrent replay that saves trajectories rather than sequences of transitions. It will be more suitable for recurrent DQN. For now I think it's fine to keep those functions where max_seq_len and stack_size are differerent.

@@ -200,7 +200,7 @@ def sample(self, batch_size):
trajectory_lengths = (
np.argmax(terminals.astype(bool), axis=1) + 1
) * is_terminal + self._n_step * (1 - is_terminal)
is_terminal = terminals[:,1:self._n_step-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the sequence length, is_terminal has to be of shape (B*seq_length). Hence this particular change after calculating trajectory lengths.

] + np.arange(self._n_step)
disc_rewards = np.einsum(
"ijk,k->ij", rewards[:, idx], self._discount
)
rewards = disc_rewards
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hnekoeiq hnekoeiq requested review from dapatil211 and a team March 16, 2022 18:08
Copy link
Collaborator

@dapatil211 dapatil211 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went through and gave a quick review. I didn't thoroughly check for correctness yet, ideally would want to see experiments that show it's working first before doing that. Also please write updated docstrings.

Comment on lines 43 to 46
"""Implements the standard DQN value computation. Transforms output from
:obj:`base_network` with output dimension :obj:`hidden_dim` to dimension
:obj:`out_dim`, which should be equal to the number of actions.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update docstring

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 57 to 58
base_network (torch.nn.Module): Backbone network that computes the
representations that are used to compute action values.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update docstring

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should say something about the expected output of this base_network

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update documentation according to previous comment. Specifically that base_network returns two things, and what those two things are.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 28 to 29
"""An agent implementing the DQN algorithm. Uses an epsilon greedy
exploration policy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


if mlp_layers is not None:
# MLP Layers
# conv_output_size = calculate_output_dim(self.conv, in_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -20,9 +20,12 @@ def calculate_output_dim(net, input_shape):
"""
if isinstance(input_shape, int):
input_shape = (input_shape,)
placeholder = torch.zeros((0,) + tuple(input_shape))
placeholder = torch.zeros((1,) + tuple(input_shape)).to(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Device shouldn't need to be passed in right? Like at this point it should all be on the cpu? I am fine with the change, just not sure why it's necessary

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dapatil211 The problem is here we either need to create a dummy hidden_state on CPU and pass it to the network or add an if condition in the forward function of ConvRNNNetwork to take care of hidden_state = None. We did the latter but there might be another case other than calculate_output_dim where hidden_state is None and in that case, our hidden_state inside the forward function should be on the _device. (We had to change the empty array because lstm was not able to process it)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

output = net(placeholder)
return output.size()[1:]
if isinstance(output, tuple):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's a tuple, you should return the size of each output. It doesn't make sense to only do the first output. That's just specific to your current use case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dapatil211 This one needs a bit more discussion. Since hidden_state (the second item in the tuple) is a tuple itself, should we check the items inside that and return their size as well?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should be recursive? Keep going until you hit scalars or tensors

Comment on lines 250 to 251
# mask = np.expand_dims(trajectory_lengths, 1) > np.arange(self._n_step)
# rewards = np.sum(rewards * mask, axis=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

idx = np.arange(rewards.shape[1] - self._n_step + 1)[
:, None
] + np.arange(self._n_step)
disc_rewards = np.einsum(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give a comment as to what this is doing? einsum is notoriously bad in terms of being interpreted.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure will do!

@hnekoeiq
Copy link
Collaborator Author

representation_net=representation_net,
obs_dim=obs_dim,
act_dim=act_dim,
id=0,
Copy link
Collaborator

@karthiks1701 karthiks1701 Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we missed id=id. to pass the id to the base DQN class, because of this in marlgrid all the agents have the same id and that is causing problems.

@dapatil211 dapatil211 added the major Major PRs label May 20, 2022
conv_output_size = calculate_output_dim(self.conv, in_dim)
if self._rnn_type == "lstm":
self.rnn = nn.LSTM(
np.prod(conv_output_size),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing different types of network to a sequencer class

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"observation", "action", "reward", and "done".
"""
if update_info["done"]:
self._state["episode_start"] = True

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is episode start used anywhere except the buffer? because the buffer takes care of it and it is redundant.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is there for stuff in act(). There's probably a better way to do it. It doesn't really make sense for the agent to do it. It might make sense to add it as part of the observation, but need to think about this a bit.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably want to fix it in a separate PR as this is what DQN does too.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We created an issue for fixing it in both DQN and DRQN.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this part is fine. I don't think it needs fixing.

batch_first (bool): If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature).
"""
super().__init__(
rnn_input_size=rnn_input_size,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove rnn_input_size if the base function does not use it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. SImilarly batch_first is also removed.

self._optimizer.zero_grad()
pred_qvals, hidden_state = self._qnet(*current_state_inputs, hidden_state)
pred_qvals = pred_qvals.view(
self._batch_size, self._replay_buffer._max_seq_len, -1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not be accessing internal variable of buffer

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -16,6 +17,7 @@
"LegalMovesRainbowAgent": LegalMovesRainbowAgent,
"RainbowDQNAgent": RainbowDQNAgent,
"RandomAgent": RandomAgent,
"DRQNAgent": DRQNAgent,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alphabetical order please.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 1 to 26
import copy
import os

import gym
import numpy as np
import torch

from hive.agents.agent import Agent
from hive.agents.qnets.base import FunctionApproximator
from hive.agents.qnets.qnet_heads import DRQNNetwork
from hive.agents.qnets.utils import (
InitializationFn,
calculate_output_dim,
create_init_weights_fn,
)
from hive.replays import BaseReplayBuffer, CircularReplayBuffer
from hive.replays.recurrent_replay import RecurrentReplayBuffer
from hive.utils.loggers import Logger, NullLogger
from hive.utils.schedule import (
LinearSchedule,
PeriodicSchedule,
Schedule,
SwitchSchedule,
)
from hive.agents.dqn import DQNAgent
from hive.utils.utils import LossFn, OptimizerFn, create_folder, seeder
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure you run import sorting (isort)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Args:
observation_space (gym.spaces.Box): Observation space for the agent.
action_space (gym.spaces.Discrete): Action space for the agent.
representation_net (FunctionApproximator): A network that outputs the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am assuming there are restrictions on the representation_net? eg it needs to be one of your recurrent ones? Please mention this in the documentation.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you more explicitly mention the restrictions on representation_net? For example, which methods it should have or that it should follow the structure of ConvRNNNetwork or something?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 69 to 70
stack_size: Number of observations stacked to create the state fed to the
DRQN.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mean in the context of DRQN? You aren't using this are you? If not, remove. You may need to keep a dummy or add a kwargs to make it work with the runner.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def __init__(
self,
in_dim,
sequence_fn: SequenceModule,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment in the sequence_models.py file about this.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +15 to +26
capacity: int = 10000,
max_seq_len: int = 1,
n_step: int = 1,
gamma: float = 0.99,
observation_shape=(),
observation_dtype=np.uint8,
action_shape=(),
action_dtype=np.int8,
reward_shape=(),
reward_dtype=np.float32,
extra_storage_types=None,
num_players_sharing_buffer: int = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the spacing inconsistent?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we specify the datatype black will add whitespaces around the equal sign, otherwise it will remove those whitespaces.

capacity (int): Total number of observations that can be stored in the
buffer. Note, this is not the same as the number of transitions that
can be stored in the buffer.
max_seq_len (int): The number of consecutive transitions in a sequence.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clarify the description of this parameter.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +58 to +73
super().__init__(
capacity=capacity,
stack_size=1,
n_step=n_step,
gamma=gamma,
observation_shape=observation_shape,
observation_dtype=observation_dtype,
action_shape=action_shape,
action_dtype=action_dtype,
reward_shape=reward_shape,
reward_dtype=reward_dtype,
extra_storage_types=extra_storage_types,
num_players_sharing_buffer=num_players_sharing_buffer,
)
self._max_seq_len = max_seq_len

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this class be equivalent to just creating CircularReplayBuffer(stack_size=max_seq_len)? If so, why not just do that? If not, what/where is the change in logic?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are differences in sample().

self._n_step
) # (S-N+1) x N
rewards = rewards[:, idx] # B x (S-N+1) x N
# Creating a vectorized sliding window to calculate discounted returns for every element in the sequence
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make sure commented lines and lines part of docstrings are still under 88 characters.

@@ -153,7 +153,6 @@ def set_up_experiment(config):
agent = agent_fn(
observation_space=env_spec.observation_space[0],
action_space=env_spec.action_space[0],
stack_size=config.get("stack_size", 1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be removed. It will break other agents.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"observation", "action", "reward", and "done".
"""
if update_info["done"]:
self._state["episode_start"] = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this part is fine. I don't think it needs fixing.

Comment on lines 216 to 218
self._hidden_state = self._qnet.base_network.init_hidden(
batch_size=1, device=self._device
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the same thing. episode_start can only be set in update because that's when the agent knows that the episode ended. My comment is about how the resetting of the hidden state should be done in act()

rnn_input_size (int): The number of expected features in the input x.
rnn_hidden_size (int): The number of features in the hidden state h.
num_rnn_layers (int): Number of recurrent layers.
batch_first (bool): If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line length

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

rnn_input_size (int): The number of expected features in the input x.
rnn_hidden_size (int): The number of features in the hidden state h.
num_rnn_layers (int): Number of recurrent layers.
batch_first (bool): If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line length

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

batch_first=batch_first,
)

def init_hidden(self, batch_size, device="cpu"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this device not just be passed once in the initializer?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Comment on lines 136 to 145
if replay_buffer is None:
replay_buffer = RecurrentReplayBuffer
self._replay_buffer = replay_buffer(
max_seq_len=max_seq_len,
observation_shape=self._observation_space.shape,
observation_dtype=self._observation_space.dtype,
action_shape=self._action_space.shape,
action_dtype=self._action_space.dtype,
)
self._max_seq_len = max_seq_len
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be moved above the super constructor call:

        if replay_buffer is None:
            replay_buffer = RecurrentReplayBuffer
        replay_buffer = partial(replay_buffer, max_seq_len=max_seq_len)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

replay_buffer (BaseReplayBuffer): The replay buffer that the agent will
push observations to and sample from during learning. If None,
defaults to
:py:class:`~hive.replays.circular_replay.CircularReplayBuffer`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect. Please go through all the documentation and make sure it is correct.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Args:
observation_space (gym.spaces.Box): Observation space for the agent.
action_space (gym.spaces.Discrete): Action space for the agent.
representation_net (FunctionApproximator): A network that outputs the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you more explicitly mention the restrictions on representation_net? For example, which methods it should have or that it should follow the structure of ConvRNNNetwork or something?

Comment on lines 241 to 246
hidden_state = self._qnet.base_network.init_hidden(
batch_size=self._batch_size, device=self._device
)
target_hidden_state = self._target_qnet.base_network.init_hidden(
batch_size=self._batch_size, device=self._device
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels weird that you are accessing an internal module of the qnet. I think instead of self._qnet.base_network.init_hidden, it should be self._qnet.init_hidden

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +73 to +75
self.base_network = base_network
self._linear_fn = linear_fn if linear_fn is not None else nn.Linear
self.output_layer = self._linear_fn(hidden_dim, out_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably make all of these internal variables

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe in a separate PR? All network modules defined here have base_network and output_layer.

rewards = rewards[:, idx] # B x (S-N+1) x N
# Creating a vectorized sliding window to calculate
# discounted returns for every element in the sequence.
# equivalent to np.sum(rewards * self._discount[None, None, :], axis=2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is longer than 88 characters.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

)
# Compute predicted Q values
self._optimizer.zero_grad()
pred_qvals, _ = self._qnet(*current_state_inputs, hidden_state)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure, the qnet takes in a window of past observations and you take as output the last hidden state and pass it through an MLP to get the Q-values? So, it involves some redundant computation when calculating Q-values for s_t and s_t+1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have another PR #270 that handles hidden states saving & burn-in frames. In this PR the hidden states are initialized from 0's.

Could you also provide some reference if you have seen more efficient ways of reusing hidden state and computing Q?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can have a look at the cleanrl/sb3 implementation of recurrent networks. there implementation with jax might have some principles or tricks which we can use for our code base?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current implementation is good enough and is working. We can create new prs to improve its efficiency maybe.

hive/agents/qnets/__init__.py Show resolved Hide resolved
x, hidden_state = self.base_network(x, hidden_state)

x = x.flatten(start_dim=1)
return self.output_layer(x), hidden_state

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it can have an internal function to call during act and update separately?

hive/agents/drqn.py Show resolved Hide resolved
hive/agents/qnets/qnet_heads.py Show resolved Hide resolved
hive/agents/qnets/rnn.py Outdated Show resolved Hide resolved
hive/agents/qnets/sequence_models.py Show resolved Hide resolved
hive/replays/__init__.py Show resolved Hide resolved
self.size() + self._max_seq_len + self._n_step - 1
)
elements = array[full_indices]
elements = elements.reshape(indices.shape[0], -1, *elements.shape[2:])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functions and code resued from CircularReplayBuffer? Maybe just use the inherited functions? Or are there changes across all the functions?

Copy link
Collaborator

@dapatil211 dapatil211 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming everything runs and matches existing benchmarks, looks good. Please add wandb runs to this conversation and also to the benchmarking report. Also when merging, do a squash and merge.

@hnekoeiq
Copy link
Collaborator Author

hnekoeiq commented Oct 7, 2022

Assuming everything runs and matches existing benchmarks, looks good. Please add wandb runs to this conversation and also to the benchmarking report. Also when merging, do a squash and merge.

We should rerun the experiments for benchmarking, but here are some of the results on atari:
#258 (comment)

And results on Hanabi DRQN: https://wandb.ai/chandar-rl/Hive/reports/DRQN-for-Hanabi--VmlldzoyMzU1NTUy?accessToken=efz63017kwgubdp8oyloepsfpjn3pmf0qka070h6rh90jjrcxcp8jd3eo20ko9sj

@hnekoeiq hnekoeiq merged commit 8a5c91e into dev Oct 13, 2022
@hnekoeiq hnekoeiq deleted the rnn_support branch October 13, 2022 21:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
major Major PRs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants