Skip to content

Commit

Permalink
Merge branch 'main' into f/vmap_optstate
Browse files Browse the repository at this point in the history
  • Loading branch information
Aidandos authored Oct 18, 2023
2 parents f920d79 + 9d3fa62 commit 88e347d
Show file tree
Hide file tree
Showing 132 changed files with 28,249 additions and 980 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# .github/workflows/app.yaml
name: PyTest
on:
on:
pull_request:
branches:
branches:
- main

jobs:
Expand All @@ -16,6 +16,7 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: '3.9'
- uses: pre-commit/[email protected]
- name: Ensure latest pip
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ exp/

# Hydra
.hydra
exp/
venv/
plots/
figures/
Expand All @@ -115,4 +114,3 @@ experiment.log

# Pax
pax/version.py
experiment.log
4 changes: 2 additions & 2 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
"justMyCode": false,
"env": {
"OBJC_DISABLE_INITIALIZE_FORK_SAFETY": "YES",
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python"
}
}
]
}
}
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"python.formatting.provider": "black"
}
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ Pax is written in pure Python, but depends on C++ code via JAX.

Because JAX installation is different depending on your CUDA version, Haiku does not list JAX as a dependency in requirements.txt.

First, follow these instructions to install JAX with the relevant accelerator support.
First, follow these instructions to [install](https://github.com/google/jax#installation) JAX with the relevant accelerator support.

## General Information
The project entrypoint is `pax/experiment.py`. The simplest command to run a game would be:
Expand Down Expand Up @@ -265,4 +265,4 @@ If you use Pax in any of your work, please cite:
journal = {GitHub repository},
howpublished = {\url{https://github.com/akbir/pax}},
}
```
```
20 changes: 12 additions & 8 deletions docs/envs.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
## Environments

Pax includes many environments specified by `env_id`. These are `infinite_matrix_game`, `iterated_matrix_game` and `coin_game`. Independetly you can specify your enviroment type as either a meta environment (with an inner/ outer loop) by `env_type`, the options supported are `sequential` or `meta`.
Pax includes many environments specified by `env_id`. These are `infinite_matrix_game`, `iterated_matrix_game` and `coin_game`. Independently you can specify your environment type as either a meta environment (with an inner/ outer loop) by `env_type`, the options supported are `sequential` or `meta`.

These are specified in the config files in `pax/configs/{env_id}/EXPERIMENT.yaml`.

| Environment ID | Environment Type | Description |
| ----------- | ----------- | ----------- |
|`iterated_matrix_game`| `sequential` | An iterated matrix game with a predetermined number of timesteps per episode with a discount factor $\gamma$ |
|`iterated_matrix_game` | `meta` | A meta game over the iterated matrix game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode |
|`infinite_matrix_game` | `meta`| An infinite matrix game that calculates exact returns given a payoff and discount factor $\gamma$ |
|coin_game | `sequential` | A sequential series of episode of the coin game between two players. Each player updates at the end of an episode|
|coin_game | `meta` | A meta learning version of the coin game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode|
| Environment ID | Environment Type | Description |
|------------------------|---------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `iterated_matrix_game` | `sequential` | An iterated matrix game with a predetermined number of timesteps per episode with a discount factor $\gamma$ |
| `iterated_matrix_game` | `meta` | A meta game over the iterated matrix game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode |
| `infinite_matrix_game` | `meta` | An infinite matrix game that calculates exact returns given a payoff and discount factor $\gamma$ |
| coin_game | `sequential` | A sequential series of episode of the coin game between two players. Each player updates at the end of an episode |
| coin_game | `meta` | A meta learning version of the coin game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode |
| cournot | `sequential`/`meta` | A one-shot version of a [Cournot competition](https://en.wikipedia.org/wiki/Cournot_competition) |
| fishery | `sequential`/`meta` | A dynamic resource harvesting game as specified in Perman et al. |
| Rice-N | `sequential`/`meta` | A re-implementation of the Integrated Assessment Model introduced by [Zhang et al.](https://papers.ssrn.com/abstract=4189735) available with either the original 27 regions or a new calibration of only 5 regions |
| C-Rice-N | `sequential`/`meta` | An extension of Rice-N with a simple climate club mechanism |
3 changes: 2 additions & 1 deletion pax/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Tuple

from pax.utils import MemoryState, TrainingState
import jax.numpy as jnp

from pax.utils import MemoryState, TrainingState


class AgentInterface:
"""Interface for agents to interact with runners and environemnts.
Expand Down
5 changes: 3 additions & 2 deletions pax/agents/hyper/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import haiku as hk
import jax
import jax.numpy as jnp
from distrax import MultivariateNormalDiag

from pax import utils

Expand Down Expand Up @@ -74,7 +75,7 @@ def make_GRU(num_actions: int):

def forward_fn(
inputs: jnp.ndarray, state: jnp.ndarray
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
) -> Tuple[Tuple[MultivariateNormalDiag, jnp.ndarray], jnp.ndarray]:
"""forward function"""
gru = hk.GRU(hidden_size)
embedding, state = gru(inputs, state)
Expand All @@ -92,7 +93,7 @@ def make_GRU_hypernetwork(num_actions: int):

def forward_fn(
inputs: jnp.ndarray, state: jnp.ndarray
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
) -> Tuple[Tuple[MultivariateNormalDiag, jnp.ndarray], jnp.ndarray]:
"""forward function"""
gru = hk.GRU(hidden_size)
embedding, state = gru(inputs, state)
Expand Down
4 changes: 3 additions & 1 deletion pax/agents/hyper/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,9 @@ def model_update_epoch(
return new_state, new_mem, metrics

@jax.jit
def make_initial_state(key: Any, hidden: jnp.ndarray) -> TrainingState:
def make_initial_state(
key: Any, hidden: jnp.ndarray
) -> Tuple[TrainingState, MemoryState]:
"""Initialises the training state (parameters and optimiser state)."""
key, subkey = jax.random.split(key)
dummy_obs = jnp.zeros(shape=obs_spec)
Expand Down
Empty file added pax/agents/lola/__init__.py
Empty file.
Loading

0 comments on commit 88e347d

Please sign in to comment.