Skip to content

Commit

Permalink
Multi-player multi-shaper runners and lola (#166)
Browse files Browse the repository at this point in the history
* most of the changes

* most of the changes

* more stuff
  • Loading branch information
alexandrasouly authored Oct 11, 2023
1 parent 4665b7b commit 7936b16
Show file tree
Hide file tree
Showing 29 changed files with 6,340 additions and 123 deletions.
3 changes: 2 additions & 1 deletion pax/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Tuple

from pax.utils import MemoryState, TrainingState
import jax.numpy as jnp

from pax.utils import MemoryState, TrainingState


class AgentInterface:
"""Interface for agents to interact with runners and environemnts.
Expand Down
4 changes: 3 additions & 1 deletion pax/agents/hyper/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,9 @@ def model_update_epoch(
return new_state, new_mem, metrics

@jax.jit
def make_initial_state(key: Any, hidden: jnp.ndarray) -> TrainingState:
def make_initial_state(
key: Any, hidden: jnp.ndarray
) -> Tuple[TrainingState, MemoryState]:
"""Initialises the training state (parameters and optimiser state)."""
key, subkey = jax.random.split(key)
dummy_obs = jnp.zeros(shape=obs_spec)
Expand Down
Empty file added pax/agents/lola/__init__.py
Empty file.
Loading

0 comments on commit 7936b16

Please sign in to comment.