Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Rice-N, C-Rice-N, Fishery, Cournot competition and a parameter sharing runner #161

Merged
merged 37 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
36b8bad
most of the changes
alexandrasouly Sep 22, 2023
b514182
most of the changes
alexandrasouly Sep 22, 2023
49de145
more stuff
alexandrasouly Sep 22, 2023
4d35b76
Add cournot game
chrismatix May 25, 2023
0c33a70
test optimal policy
chrismatix May 25, 2023
62fdd08
Add more cournot configs, fixes and first draft of the fishery enviro…
chrismatix Jun 25, 2023
2149c25
fix cournot test
chrismatix Jun 27, 2023
597b62c
fix cournot test and config
chrismatix Jun 27, 2023
cdc61dd
fix fishery tests
chrismatix Jul 5, 2023
1fada97
improvements: cournot optimum v nash optimum, fishery configs
chrismatix Jul 11, 2023
4becf93
fishery eval checkpoint
chrismatix Aug 2, 2023
649e89b
nplayer
alexandrasouly Apr 4, 2023
e218372
n player fixes, n player cournot
chrismatix Aug 7, 2023
a2cbf6b
add rice environment
chrismatix Aug 16, 2023
5992996
checkpoint: parity between pax rice and ai4coop rice
chrismatix Aug 17, 2023
3ee9e77
Add 5 regions rice configuration
chrismatix Aug 23, 2023
f5f740d
add a rice_n regression
chrismatix Aug 27, 2023
25796de
rice consistency checkpoint
chrismatix Sep 6, 2023
4280c3e
fully vectorized rice environment
chrismatix Sep 7, 2023
137e3dd
checkpoint
chrismatix Sep 9, 2023
94c234b
checkpoint
chrismatix Sep 20, 2023
cac097b
fix fishery
chrismatix Sep 26, 2023
cab9ee3
cleanup
chrismatix Sep 27, 2023
51f1215
more refactoring
chrismatix Sep 27, 2023
43253fc
fixes
chrismatix Sep 28, 2023
d0dbb54
Add pytest regression
chrismatix Sep 28, 2023
8726257
checkpoint
chrismatix Sep 29, 2023
1827582
refactor watchers file and fix some types and unused imports
chrismatix Sep 30, 2023
ea63d5c
more experiments plus rice refactor
chrismatix Oct 18, 2023
a8a817c
tests failing
chrismatix Oct 18, 2023
fdb19fe
reformat version file
chrismatix Oct 18, 2023
d5ba4d9
exclude version file
chrismatix Oct 18, 2023
c07bf1c
fix exclude statement
chrismatix Oct 18, 2023
7bde020
exclude version file in github action
chrismatix Oct 18, 2023
18e92b0
fix
chrismatix Oct 18, 2023
46e47bb
fix
chrismatix Oct 18, 2023
2c0e007
another attempt
chrismatix Oct 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# .github/workflows/app.yaml
name: PyTest
on:
on:
pull_request:
branches:
branches:
- main

jobs:
Expand All @@ -16,6 +16,7 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: '3.9'
- uses: pre-commit/[email protected]
- name: Ensure latest pip
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ exp/

# Hydra
.hydra
exp/
venv/
plots/
figures/
Expand All @@ -115,4 +114,3 @@ experiment.log

# Pax
pax/version.py
experiment.log
4 changes: 2 additions & 2 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
"justMyCode": false,
"env": {
"OBJC_DISABLE_INITIALIZE_FORK_SAFETY": "YES",
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python"
}
}
]
}
}
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"python.formatting.provider": "black"
}
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ Pax is written in pure Python, but depends on C++ code via JAX.

Because JAX installation is different depending on your CUDA version, Haiku does not list JAX as a dependency in requirements.txt.

First, follow these instructions to install JAX with the relevant accelerator support.
First, follow these instructions to [install](https://github.com/google/jax#installation) JAX with the relevant accelerator support.

## General Information
The project entrypoint is `pax/experiment.py`. The simplest command to run a game would be:
Expand Down Expand Up @@ -265,4 +265,4 @@ If you use Pax in any of your work, please cite:
journal = {GitHub repository},
howpublished = {\url{https://github.com/akbir/pax}},
}
```
```
20 changes: 12 additions & 8 deletions docs/envs.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
## Environments

Pax includes many environments specified by `env_id`. These are `infinite_matrix_game`, `iterated_matrix_game` and `coin_game`. Independetly you can specify your enviroment type as either a meta environment (with an inner/ outer loop) by `env_type`, the options supported are `sequential` or `meta`.
Pax includes many environments specified by `env_id`. These are `infinite_matrix_game`, `iterated_matrix_game` and `coin_game`. Independently you can specify your environment type as either a meta environment (with an inner/ outer loop) by `env_type`, the options supported are `sequential` or `meta`.

These are specified in the config files in `pax/configs/{env_id}/EXPERIMENT.yaml`.

| Environment ID | Environment Type | Description |
| ----------- | ----------- | ----------- |
|`iterated_matrix_game`| `sequential` | An iterated matrix game with a predetermined number of timesteps per episode with a discount factor $\gamma$ |
|`iterated_matrix_game` | `meta` | A meta game over the iterated matrix game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode |
|`infinite_matrix_game` | `meta`| An infinite matrix game that calculates exact returns given a payoff and discount factor $\gamma$ |
|coin_game | `sequential` | A sequential series of episode of the coin game between two players. Each player updates at the end of an episode|
|coin_game | `meta` | A meta learning version of the coin game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode|
| Environment ID | Environment Type | Description |
|------------------------|---------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `iterated_matrix_game` | `sequential` | An iterated matrix game with a predetermined number of timesteps per episode with a discount factor $\gamma$ |
| `iterated_matrix_game` | `meta` | A meta game over the iterated matrix game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode |
| `infinite_matrix_game` | `meta` | An infinite matrix game that calculates exact returns given a payoff and discount factor $\gamma$ |
| coin_game | `sequential` | A sequential series of episode of the coin game between two players. Each player updates at the end of an episode |
| coin_game | `meta` | A meta learning version of the coin game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode |
| cournot | `sequential`/`meta` | A one-shot version of a [Cournot competition](https://en.wikipedia.org/wiki/Cournot_competition) |
| fishery | `sequential`/`meta` | A dynamic resource harvesting game as specified in Perman et al. |
| Rice-N | `sequential`/`meta` | A re-implementation of the Integrated Assessment Model introduced by [Zhang et al.](https://papers.ssrn.com/abstract=4189735) available with either the original 27 regions or a new calibration of only 5 regions |
| C-Rice-N | `sequential`/`meta` | An extension of Rice-N with a simple climate club mechanism |
5 changes: 3 additions & 2 deletions pax/agents/hyper/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import haiku as hk
import jax
import jax.numpy as jnp
from distrax import MultivariateNormalDiag

from pax import utils

Expand Down Expand Up @@ -74,7 +75,7 @@ def make_GRU(num_actions: int):

def forward_fn(
inputs: jnp.ndarray, state: jnp.ndarray
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
) -> Tuple[Tuple[MultivariateNormalDiag, jnp.ndarray], jnp.ndarray]:
"""forward function"""
gru = hk.GRU(hidden_size)
embedding, state = gru(inputs, state)
Expand All @@ -92,7 +93,7 @@ def make_GRU_hypernetwork(num_actions: int):

def forward_fn(
inputs: jnp.ndarray, state: jnp.ndarray
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
) -> Tuple[Tuple[MultivariateNormalDiag, jnp.ndarray], jnp.ndarray]:
"""forward function"""
gru = hk.GRU(hidden_size)
embedding, state = gru(inputs, state)
Expand Down
18 changes: 7 additions & 11 deletions pax/agents/lola/lola.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
from typing import Any, Dict, List, Mapping, NamedTuple, Tuple
from typing import Any, List, NamedTuple, Tuple

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
from dm_env import TimeStep

# from pax.lola.buffer import TrajectoryBuffer
from pax.agents.lola.network import make_network

from pax import utils
from pax.agents.ppo.ppo_gru import PPO
from pax.agents.lola.network import make_network
from pax.runners.runner_marl import Sample
from pax.utils import MemoryState, TrainingState

Expand Down Expand Up @@ -295,7 +289,9 @@ def inner_loss(
"loss_value": value_objective,
}

def make_initial_state(key: Any, hidden) -> TrainingState:
def make_initial_state(
key: Any, hidden
) -> Tuple[TrainingState, MemoryState]:
"""Initialises the training state (parameters and optimiser state)."""
key, subkey = jax.random.split(key)
dummy_obs = jnp.zeros(shape=obs_spec)
Expand Down Expand Up @@ -530,7 +526,7 @@ def lola_inlookahead_rollout(carry, unused):
)
my_mem = carry[7]
other_mems = carry[8]
# jax.debug.breakpoint()

# flip axes to get (num_envs, num_inner, obs_dim) to vmap over numenvs
vmap_trajectories = jax.tree_map(
lambda x: jnp.swapaxes(x, 0, 1), trajectories
Expand All @@ -545,7 +541,7 @@ def lola_inlookahead_rollout(carry, unused):
rewards_self=vmap_trajectories[0].rewards_self,
rewards_other=[traj.rewards for traj in vmap_trajectories[1:]],
)
# jax.debug.breakpoint()

# get gradients of opponents
other_gradients = []
for idx in range(len((self.other_agents))):
Expand Down
2 changes: 0 additions & 2 deletions pax/agents/lola/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import jax
import jax.numpy as jnp

from pax import utils


class CategoricalValueHead(hk.Module):
"""Network head that produces a categorical distribution and value."""
Expand Down
26 changes: 23 additions & 3 deletions pax/agents/mfos_ppo/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import haiku as hk
import jax
import jax.numpy as jnp
from distrax import Categorical

from pax import utils


class ActorCriticMFOS(hk.Module):
def __init__(self, num_values, hidden_size):
def __init__(self, num_values, hidden_size, categorical=True):
super().__init__(name="ActorCriticMFOS")
self.linear_t_0 = hk.Linear(hidden_size)
self.linear_a_0 = hk.Linear(hidden_size)
Expand All @@ -30,6 +31,7 @@ def __init__(self, num_values, hidden_size):
w_init=hk.initializers.Orthogonal(1.0), # baseline
with_bias=False,
)
self._categorical = categorical

def __call__(self, inputs: jnp.ndarray, state: jnp.ndarray):
input, th = inputs
Expand All @@ -54,7 +56,10 @@ def __call__(self, inputs: jnp.ndarray, state: jnp.ndarray):

hidden = jnp.concatenate([hidden_t, hidden_a, hidden_v], axis=-1)
state = (_current_th, hidden)
return (distrax.Categorical(logits=logits), value, state)
if self._categorical:
return distrax.Categorical(logits=logits), value, state
else:
return distrax.MultivariateNormalDiag(loc=logits), value, state
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allow for continuous action spaces



class CNNFusion(hk.Module):
Expand Down Expand Up @@ -147,6 +152,21 @@ def forward_fn(
return network, hidden_state


def make_mfos_continuous_network(num_actions: int, hidden_size: int):
hidden_state = jnp.zeros((1, 3 * hidden_size))

def forward_fn(
inputs: jnp.ndarray,
state: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
mfos = ActorCriticMFOS(num_actions, hidden_size, categorical=False)
logits, values, state = mfos(inputs, state)
return (logits, values), state

network = hk.without_apply_rng(hk.transform(forward_fn))
return network, hidden_state


def make_mfos_ipditm_network(
num_actions: int, hidden_size: int, output_channels, kernel_shape
):
Expand All @@ -155,7 +175,7 @@ def make_mfos_ipditm_network(
def forward_fn(
inputs: jnp.ndarray,
state: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
) -> Tuple[Tuple[Categorical, jnp.ndarray], jnp.ndarray]:
mfos = CNNMFOS(num_actions, hidden_size, output_channels, kernel_shape)
logits, values, state = mfos(inputs, state)
return (logits, values), state
Expand Down
17 changes: 14 additions & 3 deletions pax/agents/mfos_ppo/ppo_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from pax.agents.mfos_ppo.networks import (
make_mfos_ipditm_network,
make_mfos_network,
make_mfos_continuous_network,
)
from pax.envs.rice.rice import Rice
from pax.utils import TrainingState, get_advantages


Expand Down Expand Up @@ -559,6 +561,16 @@ def make_mfos_agent(
action_spec,
agent_args.hidden_size,
)
elif args.env_id in ["Cournot", Rice.env_id]:
network, initial_hidden_state = make_mfos_continuous_network(
action_spec,
agent_args.hidden_size,
)
elif args.env_id == "Fishery":
network, initial_hidden_state = make_mfos_continuous_network(
action_spec,
agent_args.hidden_size,
)
elif args.env_id == "InTheMatrix":
network, initial_hidden_state = make_mfos_ipditm_network(
action_spec,
Expand All @@ -573,13 +585,12 @@ def make_mfos_agent(

# Optimizer
transition_steps = (
num_iterations,
*agent_args.num_epochs * agent_args.num_minibatches,
num_iterations * agent_args.num_epochs * agent_args.num_minibatches
)

if agent_args.lr_scheduling:
scheduler = optax.linear_schedule(
init_value=agent_args.ppo.learning_rate,
init_value=agent_args.learning_rate,
end_value=0,
transition_steps=transition_steps,
)
Expand Down
3 changes: 1 addition & 2 deletions pax/agents/naive/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import jax
import jax.numpy as jnp
import numpy as np
from dm_env import TimeStep


Expand Down Expand Up @@ -105,7 +104,7 @@ def size(self) -> int:
return min(self._rollout_length, self._num_added)

def fraction_filled(self) -> float:
return self.size / self._rollout_length
return self.size() / self._rollout_length

def reset(self):
"""Resets the replay buffer. Called upon __init__ and when buffer is full"""
Expand Down
9 changes: 7 additions & 2 deletions pax/agents/naive/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
import jax
import jax.numpy as jnp
import optax
from dm_env import TimeStep

from pax import utils
from pax.agents.agent import AgentInterface
from pax.agents.naive.network import make_coingame_network, make_network
from pax.agents.naive.network import (
make_coingame_network,
make_network,
make_rice_network,
)
from pax.utils import MemoryState, TrainingState, get_advantages


Expand Down Expand Up @@ -397,6 +400,8 @@ def make_naive_pg(args, obs_spec, action_spec, seed: int, player_id: int):
if args.env_id == "coin_game":
print(f"Making network for {args.env_id} with CNN")
network = make_coingame_network(action_spec, args)
elif args.env_id == "Rice-N":
network = make_rice_network(action_spec)
else:
network = make_network(action_spec)

Expand Down
Loading
Loading