Skip to content

Commit

Permalink
tests failing
Browse files Browse the repository at this point in the history
  • Loading branch information
chrismatix committed Oct 18, 2023
1 parent ea63d5c commit a8a817c
Show file tree
Hide file tree
Showing 38 changed files with 1,586 additions and 1,072 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# .github/workflows/app.yaml
name: PyTest
on:
on:
pull_request:
branches:
branches:
- main

jobs:
Expand All @@ -21,6 +21,7 @@ jobs:
python -m pip install --upgrade pip
- name: Install from repo in the test mode
run: "pip install -e '.[dev]'"
- uses: pre-commit/[email protected]
- name: Test with pytest
run: |
pip install pytest
Expand Down
45 changes: 0 additions & 45 deletions optim_test.py

This file was deleted.

8 changes: 5 additions & 3 deletions pax/agents/lola/lola.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,9 @@ def inner_loss(
"loss_value": value_objective,
}

def make_initial_state(key: Any, hidden) -> Tuple[TrainingState, MemoryState]:
def make_initial_state(
key: Any, hidden
) -> Tuple[TrainingState, MemoryState]:
"""Initialises the training state (parameters and optimiser state)."""
key, subkey = jax.random.split(key)
dummy_obs = jnp.zeros(shape=obs_spec)
Expand Down Expand Up @@ -524,7 +526,7 @@ def lola_inlookahead_rollout(carry, unused):
)
my_mem = carry[7]
other_mems = carry[8]
# jax.debug.breakpoint()

# flip axes to get (num_envs, num_inner, obs_dim) to vmap over numenvs
vmap_trajectories = jax.tree_map(
lambda x: jnp.swapaxes(x, 0, 1), trajectories
Expand All @@ -539,7 +541,7 @@ def lola_inlookahead_rollout(carry, unused):
rewards_self=vmap_trajectories[0].rewards_self,
rewards_other=[traj.rewards for traj in vmap_trajectories[1:]],
)
# jax.debug.breakpoint()

# get gradients of opponents
other_gradients = []
for idx in range(len((self.other_agents))):
Expand Down
3 changes: 2 additions & 1 deletion pax/agents/mfos_ppo/ppo_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from pax.agents.agent import AgentInterface
from pax.agents.mfos_ppo.networks import (
make_mfos_ipditm_network,
make_mfos_network, make_mfos_continuous_network,
make_mfos_network,
make_mfos_continuous_network,
)
from pax.envs.rice.rice import Rice
from pax.utils import TrainingState, get_advantages
Expand Down
6 changes: 5 additions & 1 deletion pax/agents/naive/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

from pax import utils
from pax.agents.agent import AgentInterface
from pax.agents.naive.network import make_coingame_network, make_network, make_rice_network
from pax.agents.naive.network import (
make_coingame_network,
make_network,
make_rice_network,
)
from pax.utils import MemoryState, TrainingState, get_advantages


Expand Down
19 changes: 11 additions & 8 deletions pax/agents/naive/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ class CategoricalValueHead(hk.Module):
"""Network head that produces a categorical distribution and value."""

def __init__(
self,
num_values: int,
name: Optional[str] = None,
self,
num_values: int,
name: Optional[str] = None,
):
super().__init__(name=name)
self._logit_layer = hk.Linear(
Expand All @@ -34,10 +34,10 @@ class ContinuousValueHead(hk.Module):
"""Network head that produces a continuous distribution and value."""

def __init__(
self,
num_values: int,
name: Optional[str] = None,
mean_activation: Optional[str] = None,
self,
num_values: int,
name: Optional[str] = None,
mean_activation: Optional[str] = None,
):
super().__init__(name=name)
self.mean_action = mean_activation
Expand Down Expand Up @@ -65,7 +65,10 @@ def __call__(self, inputs: jnp.ndarray):
scales = self._scale_layer(inputs)
value = jnp.squeeze(self._value_layer(inputs), axis=-1)
scales = jnp.maximum(scales, 0.01)
return distrax.MultivariateNormalDiag(loc=means, scale_diag=scales), value
return (
distrax.MultivariateNormalDiag(loc=means, scale_diag=scales),
value,
)


class CNN(hk.Module):
Expand Down
2 changes: 1 addition & 1 deletion pax/agents/ppo/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def step(self, actions: jnp.ndarray) -> TimeStep:
rewards = []
observations = []
discounts = []
for env, action in zip(self.envs, actions):
for env, action in zip(self.envs, actions, strict=True):
t = env.step(int(action))
if t.step_type == 2:
t_reset = env.reset()
Expand Down
Loading

0 comments on commit a8a817c

Please sign in to comment.