Skip to content

Commit

Permalink
Add Rice-N, C-Rice-N, Fishery, Cournot competition and a parameter sh…
Browse files Browse the repository at this point in the history
…aring runner (#161)

* most of the changes

* most of the changes

* more stuff

* Add cournot game

* test optimal policy

* Add more cournot configs, fixes and first draft of the fishery environment

* fix cournot test

* fix cournot test and config

* fix fishery tests

* improvements: cournot optimum v nash optimum, fishery configs

* fishery eval checkpoint

* nplayer

* n player fixes, n player cournot

* add rice environment

* checkpoint: parity between pax rice and ai4coop rice

* Add 5 regions rice configuration

* add a rice_n regression

* rice consistency checkpoint

* fully vectorized rice environment

* checkpoint

* checkpoint

* fix fishery

* cleanup

* more refactoring

* fixes

* Add pytest regression

* checkpoint

* refactor watchers file and fix some types and unused imports

* more experiments plus rice refactor

* tests failing

* reformat version file

* exclude version file

* fix exclude statement

* exclude version file in github action

* fix

* fix

* another attempt

---------

Co-authored-by: alexandrasouly <[email protected]>
  • Loading branch information
chrismatix and alexandrasouly committed Oct 18, 2023
1 parent 7936b16 commit 9d3fa62
Show file tree
Hide file tree
Showing 119 changed files with 21,547 additions and 494 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# .github/workflows/app.yaml
name: PyTest
on:
on:
pull_request:
branches:
branches:
- main

jobs:
Expand All @@ -16,6 +16,7 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: '3.9'
- uses: pre-commit/[email protected]
- name: Ensure latest pip
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ exp/

# Hydra
.hydra
exp/
venv/
plots/
figures/
Expand All @@ -115,4 +114,3 @@ experiment.log

# Pax
pax/version.py
experiment.log
4 changes: 2 additions & 2 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
"justMyCode": false,
"env": {
"OBJC_DISABLE_INITIALIZE_FORK_SAFETY": "YES",
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python"
}
}
]
}
}
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"python.formatting.provider": "black"
}
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ Pax is written in pure Python, but depends on C++ code via JAX.

Because JAX installation is different depending on your CUDA version, Haiku does not list JAX as a dependency in requirements.txt.

First, follow these instructions to install JAX with the relevant accelerator support.
First, follow these instructions to [install](https://github.com/google/jax#installation) JAX with the relevant accelerator support.

## General Information
The project entrypoint is `pax/experiment.py`. The simplest command to run a game would be:
Expand Down Expand Up @@ -265,4 +265,4 @@ If you use Pax in any of your work, please cite:
journal = {GitHub repository},
howpublished = {\url{https://github.com/akbir/pax}},
}
```
```
20 changes: 12 additions & 8 deletions docs/envs.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
## Environments

Pax includes many environments specified by `env_id`. These are `infinite_matrix_game`, `iterated_matrix_game` and `coin_game`. Independetly you can specify your enviroment type as either a meta environment (with an inner/ outer loop) by `env_type`, the options supported are `sequential` or `meta`.
Pax includes many environments specified by `env_id`. These are `infinite_matrix_game`, `iterated_matrix_game` and `coin_game`. Independently you can specify your environment type as either a meta environment (with an inner/ outer loop) by `env_type`, the options supported are `sequential` or `meta`.

These are specified in the config files in `pax/configs/{env_id}/EXPERIMENT.yaml`.

| Environment ID | Environment Type | Description |
| ----------- | ----------- | ----------- |
|`iterated_matrix_game`| `sequential` | An iterated matrix game with a predetermined number of timesteps per episode with a discount factor $\gamma$ |
|`iterated_matrix_game` | `meta` | A meta game over the iterated matrix game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode |
|`infinite_matrix_game` | `meta`| An infinite matrix game that calculates exact returns given a payoff and discount factor $\gamma$ |
|coin_game | `sequential` | A sequential series of episode of the coin game between two players. Each player updates at the end of an episode|
|coin_game | `meta` | A meta learning version of the coin game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode|
| Environment ID | Environment Type | Description |
|------------------------|---------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `iterated_matrix_game` | `sequential` | An iterated matrix game with a predetermined number of timesteps per episode with a discount factor $\gamma$ |
| `iterated_matrix_game` | `meta` | A meta game over the iterated matrix game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode |
| `infinite_matrix_game` | `meta` | An infinite matrix game that calculates exact returns given a payoff and discount factor $\gamma$ |
| coin_game | `sequential` | A sequential series of episode of the coin game between two players. Each player updates at the end of an episode |
| coin_game | `meta` | A meta learning version of the coin game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode |
| cournot | `sequential`/`meta` | A one-shot version of a [Cournot competition](https://en.wikipedia.org/wiki/Cournot_competition) |
| fishery | `sequential`/`meta` | A dynamic resource harvesting game as specified in Perman et al. |
| Rice-N | `sequential`/`meta` | A re-implementation of the Integrated Assessment Model introduced by [Zhang et al.](https://papers.ssrn.com/abstract=4189735) available with either the original 27 regions or a new calibration of only 5 regions |
| C-Rice-N | `sequential`/`meta` | An extension of Rice-N with a simple climate club mechanism |
5 changes: 3 additions & 2 deletions pax/agents/hyper/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import haiku as hk
import jax
import jax.numpy as jnp
from distrax import MultivariateNormalDiag

from pax import utils

Expand Down Expand Up @@ -74,7 +75,7 @@ def make_GRU(num_actions: int):

def forward_fn(
inputs: jnp.ndarray, state: jnp.ndarray
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
) -> Tuple[Tuple[MultivariateNormalDiag, jnp.ndarray], jnp.ndarray]:
"""forward function"""
gru = hk.GRU(hidden_size)
embedding, state = gru(inputs, state)
Expand All @@ -92,7 +93,7 @@ def make_GRU_hypernetwork(num_actions: int):

def forward_fn(
inputs: jnp.ndarray, state: jnp.ndarray
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
) -> Tuple[Tuple[MultivariateNormalDiag, jnp.ndarray], jnp.ndarray]:
"""forward function"""
gru = hk.GRU(hidden_size)
embedding, state = gru(inputs, state)
Expand Down
18 changes: 7 additions & 11 deletions pax/agents/lola/lola.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
from typing import Any, Dict, List, Mapping, NamedTuple, Tuple
from typing import Any, List, NamedTuple, Tuple

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
from dm_env import TimeStep

# from pax.lola.buffer import TrajectoryBuffer
from pax.agents.lola.network import make_network

from pax import utils
from pax.agents.ppo.ppo_gru import PPO
from pax.agents.lola.network import make_network
from pax.runners.runner_marl import Sample
from pax.utils import MemoryState, TrainingState

Expand Down Expand Up @@ -295,7 +289,9 @@ def inner_loss(
"loss_value": value_objective,
}

def make_initial_state(key: Any, hidden) -> TrainingState:
def make_initial_state(
key: Any, hidden
) -> Tuple[TrainingState, MemoryState]:
"""Initialises the training state (parameters and optimiser state)."""
key, subkey = jax.random.split(key)
dummy_obs = jnp.zeros(shape=obs_spec)
Expand Down Expand Up @@ -530,7 +526,7 @@ def lola_inlookahead_rollout(carry, unused):
)
my_mem = carry[7]
other_mems = carry[8]
# jax.debug.breakpoint()

# flip axes to get (num_envs, num_inner, obs_dim) to vmap over numenvs
vmap_trajectories = jax.tree_map(
lambda x: jnp.swapaxes(x, 0, 1), trajectories
Expand All @@ -545,7 +541,7 @@ def lola_inlookahead_rollout(carry, unused):
rewards_self=vmap_trajectories[0].rewards_self,
rewards_other=[traj.rewards for traj in vmap_trajectories[1:]],
)
# jax.debug.breakpoint()

# get gradients of opponents
other_gradients = []
for idx in range(len((self.other_agents))):
Expand Down
2 changes: 0 additions & 2 deletions pax/agents/lola/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import jax
import jax.numpy as jnp

from pax import utils


class CategoricalValueHead(hk.Module):
"""Network head that produces a categorical distribution and value."""
Expand Down
26 changes: 23 additions & 3 deletions pax/agents/mfos_ppo/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import haiku as hk
import jax
import jax.numpy as jnp
from distrax import Categorical

from pax import utils


class ActorCriticMFOS(hk.Module):
def __init__(self, num_values, hidden_size):
def __init__(self, num_values, hidden_size, categorical=True):
super().__init__(name="ActorCriticMFOS")
self.linear_t_0 = hk.Linear(hidden_size)
self.linear_a_0 = hk.Linear(hidden_size)
Expand All @@ -30,6 +31,7 @@ def __init__(self, num_values, hidden_size):
w_init=hk.initializers.Orthogonal(1.0), # baseline
with_bias=False,
)
self._categorical = categorical

def __call__(self, inputs: jnp.ndarray, state: jnp.ndarray):
input, th = inputs
Expand All @@ -54,7 +56,10 @@ def __call__(self, inputs: jnp.ndarray, state: jnp.ndarray):

hidden = jnp.concatenate([hidden_t, hidden_a, hidden_v], axis=-1)
state = (_current_th, hidden)
return (distrax.Categorical(logits=logits), value, state)
if self._categorical:
return distrax.Categorical(logits=logits), value, state
else:
return distrax.MultivariateNormalDiag(loc=logits), value, state


class CNNFusion(hk.Module):
Expand Down Expand Up @@ -147,6 +152,21 @@ def forward_fn(
return network, hidden_state


def make_mfos_continuous_network(num_actions: int, hidden_size: int):
hidden_state = jnp.zeros((1, 3 * hidden_size))

def forward_fn(
inputs: jnp.ndarray,
state: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
mfos = ActorCriticMFOS(num_actions, hidden_size, categorical=False)
logits, values, state = mfos(inputs, state)
return (logits, values), state

network = hk.without_apply_rng(hk.transform(forward_fn))
return network, hidden_state


def make_mfos_ipditm_network(
num_actions: int, hidden_size: int, output_channels, kernel_shape
):
Expand All @@ -155,7 +175,7 @@ def make_mfos_ipditm_network(
def forward_fn(
inputs: jnp.ndarray,
state: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
) -> Tuple[Tuple[Categorical, jnp.ndarray], jnp.ndarray]:
mfos = CNNMFOS(num_actions, hidden_size, output_channels, kernel_shape)
logits, values, state = mfos(inputs, state)
return (logits, values), state
Expand Down
17 changes: 14 additions & 3 deletions pax/agents/mfos_ppo/ppo_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from pax.agents.mfos_ppo.networks import (
make_mfos_ipditm_network,
make_mfos_network,
make_mfos_continuous_network,
)
from pax.envs.rice.rice import Rice
from pax.utils import TrainingState, get_advantages


Expand Down Expand Up @@ -559,6 +561,16 @@ def make_mfos_agent(
action_spec,
agent_args.hidden_size,
)
elif args.env_id in ["Cournot", Rice.env_id]:
network, initial_hidden_state = make_mfos_continuous_network(
action_spec,
agent_args.hidden_size,
)
elif args.env_id == "Fishery":
network, initial_hidden_state = make_mfos_continuous_network(
action_spec,
agent_args.hidden_size,
)
elif args.env_id == "InTheMatrix":
network, initial_hidden_state = make_mfos_ipditm_network(
action_spec,
Expand All @@ -573,13 +585,12 @@ def make_mfos_agent(

# Optimizer
transition_steps = (
num_iterations,
*agent_args.num_epochs * agent_args.num_minibatches,
num_iterations * agent_args.num_epochs * agent_args.num_minibatches
)

if agent_args.lr_scheduling:
scheduler = optax.linear_schedule(
init_value=agent_args.ppo.learning_rate,
init_value=agent_args.learning_rate,
end_value=0,
transition_steps=transition_steps,
)
Expand Down
3 changes: 1 addition & 2 deletions pax/agents/naive/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import jax
import jax.numpy as jnp
import numpy as np
from dm_env import TimeStep


Expand Down Expand Up @@ -105,7 +104,7 @@ def size(self) -> int:
return min(self._rollout_length, self._num_added)

def fraction_filled(self) -> float:
return self.size / self._rollout_length
return self.size() / self._rollout_length

def reset(self):
"""Resets the replay buffer. Called upon __init__ and when buffer is full"""
Expand Down
9 changes: 7 additions & 2 deletions pax/agents/naive/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
import jax
import jax.numpy as jnp
import optax
from dm_env import TimeStep

from pax import utils
from pax.agents.agent import AgentInterface
from pax.agents.naive.network import make_coingame_network, make_network
from pax.agents.naive.network import (
make_coingame_network,
make_network,
make_rice_network,
)
from pax.utils import MemoryState, TrainingState, get_advantages


Expand Down Expand Up @@ -397,6 +400,8 @@ def make_naive_pg(args, obs_spec, action_spec, seed: int, player_id: int):
if args.env_id == "coin_game":
print(f"Making network for {args.env_id} with CNN")
network = make_coingame_network(action_spec, args)
elif args.env_id == "Rice-N":
network = make_rice_network(action_spec)
else:
network = make_network(action_spec)

Expand Down
Loading

0 comments on commit 9d3fa62

Please sign in to comment.