Skip to content

Commit

Permalink
F/vmap optstate (#167)
Browse files Browse the repository at this point in the history
* attention branch

* MFOS averaging

* hardstop challenge + blindspots

Co-authored-by: Akbir Khan <[email protected]>

* adding hardstop yamls

* eval changes

* optstate vmap success

* intermediate commit

* intermediary push, gonna look at _pred next

* adding mixed ipd payoff runner

* final commit

* restructuring and deleting some runners

* adding make_ipd_network and iterated matrix games

* checking if runners work

* adding runner tests

* adding runner descriptions

---------

Co-authored-by: Akbir Khan <[email protected]>
  • Loading branch information
Aidandos and akbir authored Oct 27, 2023
1 parent 00a6b6f commit 430b254
Show file tree
Hide file tree
Showing 101 changed files with 18,027 additions and 72 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,8 @@ experiment.log

# Pax
pax/version.py

*.gif
*.json
*.png
*.sh
55 changes: 55 additions & 0 deletions docs/getting-started/runners.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,61 @@ In order for this approach to work the observation vector needs to include one e

See [this experiment](https://github.com/akbir/pax/blob/9d3fa62e34279a338c07cffcbf208edc8a95e7ba/pax/conf/experiment/rice/weight_sharing.yaml) for an example of how to configure it.

## Evo Hardstop

The Evo Runner optimizes the first agent using evolutionary learning.
This runner stops the learning of an opponent during training, corresponds to the hardstop challenge of Shaper.

See [this experiment](https://github.com/akbir/pax/blob/9a01bae33dcb2f812977be388751393f570957e9/pax/conf/experiment/ipd/shaper_att_v_tabular.yaml) for an example of how to configure it.

## Evo Scanned

The Evo Runner optimizes the first agent using evolutionary learning.
Here we also scan over the evolutionary steps, which makes compilation longer, training shorter and logging stats is not possible.

See [this experiment](https://github.com/akbir/pax/blob/9a01bae33dcb2f812977be388751393f570957e9/pax/conf/experiment/ipd/shaper_att_v_tabular.yaml) for an example of how to configure it.

## Evo Mixed LR Runner (experimental)

The Evo Runner optimizes the first agent using evolutionary learning.
This runner randomly samples learning rates for the opponents.

See [this experiment](https://github.com/akbir/pax/blob/9a01bae33dcb2f812977be388751393f570957e9/pax/conf/experiment/ipd/shaper_att_v_tabular.yaml) for an example of how to configure it.

## Evo Mixed Payoff (experimental)

The Evo Runner optimizes the first agent using evolutionary learning.
Payoff matrix is randomly sampled at each rollout. Each opponent has a different payoff matrix.

See [this experiment](https://github.com/akbir/pax/blob/9a01bae33dcb2f812977be388751393f570957e9/pax/conf/experiment/ipd/shaper_att_v_tabular.yaml) for an example of how to configure it.

## Evo Mixed Payoff Gen (experimental)

The Evo Runner optimizes the first agent using evolutionary learning.
Payoff matrix is randomly sampled at each rollout. Each opponent has the same payoff matrix.

See [this experiment](https://github.com/akbir/pax/blob/9a01bae33dcb2f812977be388751393f570957e9/pax/conf/experiment/ipd/shaper_att_v_tabular.yaml) for an example of how to configure it.

## Evo Mixed IPD Payoff (experimental)

The Evo Runner optimizes the first agent using evolutionary learning.
This runner randomly samples payoffs that follow Iterated Prisoner's Dilemma [constraints](https://en.wikipedia.org/wiki/Prisoner%27s_dilemma).

See [this experiment](https://github.com/akbir/pax/blob/9a01bae33dcb2f812977be388751393f570957e9/pax/conf/experiment/ipd/shaper_att_v_tabular.yaml) for an example of how to configure it.

## Evo Mixed Payoff Input (experimental)

The Evo Runner optimizes the first agent using evolutionary learning.
Payoff matrix is randomly sampled at each rollout. Each opponent has the same payoff matrix. The payoff matrix is observed as input to the agent.

See [this experiment](https://github.com/akbir/pax/blob/9a01bae33dcb2f812977be388751393f570957e9/pax/conf/experiment/ipd/shaper_att_v_tabular.yaml) for an example of how to configure it.

## Evo Mixed Payoff Only Opp (experimental)

The Evo Runner optimizes the first agent using evolutionary learning.
Noise is added to the opponents IPD-like payout matrix at each rollout. Each opponent has the same noise added.

See [this experiment](https://github.com/akbir/pax/blob/9a01bae33dcb2f812977be388751393f570957e9/pax/conf/experiment/ipd/shaper_att_v_tabular.yaml) for an example of how to configure it.



24 changes: 24 additions & 0 deletions hardstop_eval_bash.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash
###### MFOS AVG ######
python -m pax.experiment -m +experiment/ipd=mfos_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/4ykf9oe8 ++model_path=exp/MFOS-vs-Tabular/run-seed-23-pop-size-1000/2023-05-11_14.58.45.927266/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=1
python -m pax.experiment -m +experiment/ipd=mfos_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/eopf93re ++model_path=exp/MFOS-vs-Tabular/run-seed-65-pop-size-1000/2023-05-11_20.31.48.530245/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=1
python -m pax.experiment -m +experiment/ipd=mfos_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/1sqbd09n ++model_path=exp/MFOS-vs-Tabular/run-seed-47-pop-size-1000/2023-05-11_17.45.03.318240/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=1
python -m pax.experiment -m +experiment/ipd=mfos_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/3n7l8ods ++model_path=exp/MFOS-vs-Tabular/run-seed-8-pop-size-1000/2023-05-11_12.12.19.914211/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=1
python -m pax.experiment -m +experiment/ipd=mfos_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/4mf1ecxq ++model_path=exp/MFOS-vs-Tabular/run-seed-6-pop-size-1000/2023-05-11_09.25.40.656392/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=1

python -m pax.experiment -m +experiment/ipd=mfos_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/4ykf9oe8 ++model_path=exp/MFOS-vs-Tabular/run-seed-23-pop-size-1000/2023-05-11_14.58.45.927266/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=100
python -m pax.experiment -m +experiment/ipd=mfos_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/eopf93re ++model_path=exp/MFOS-vs-Tabular/run-seed-65-pop-size-1000/2023-05-11_20.31.48.530245/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=100
python -m pax.experiment -m +experiment/ipd=mfos_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/1sqbd09n ++model_path=exp/MFOS-vs-Tabular/run-seed-47-pop-size-1000/2023-05-11_17.45.03.318240/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=100
python -m pax.experiment -m +experiment/ipd=mfos_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/3n7l8ods ++model_path=exp/MFOS-vs-Tabular/run-seed-8-pop-size-1000/2023-05-11_12.12.19.914211/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=100
python -m pax.experiment -m +experiment/ipd=mfos_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/4mf1ecxq ++model_path=exp/MFOS-vs-Tabular/run-seed-6-pop-size-1000/2023-05-11_09.25.40.656392/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=100

###### Shaper Nothing #$$$
python -m pax.experiment -m +experiment/ipd=shaper_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/2m3wh5g7 ++model_path=exp/EARL-Shaper-vs-Tabular/run-seed-65-OpenES-pop-size-1000-num-opps-10-att-type-nothing/2023-05-14_22.17.07.592872/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=100
python -m pax.experiment -m +experiment/ipd=shaper_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/1jk5zly5 ++model_path=exp/EARL-Shaper-vs-Tabular/run-seed-47-OpenES-pop-size-1000-num-opps-10-att-type-nothing/2023-05-14_20.46.58.588813/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=100
python -m pax.experiment -m +experiment/ipd=shaper_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/1cvpiolk ++model_path=exp/EARL-Shaper-vs-Tabular/run-seed-23-OpenES-pop-size-1000-num-opps-10-att-type-nothing/2023-05-14_19.16.56.990716/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=100
python -m pax.experiment -m +experiment/ipd=shaper_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/3vml0wjy ++model_path=exp/EARL-Shaper-vs-Tabular/run-seed-6-OpenES-pop-size-1000-num-opps-10-att-type-nothing/2023-05-14_16.16.19.180942/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=100

python -m pax.experiment -m +experiment/ipd=shaper_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/2m3wh5g7 ++model_path=exp/EARL-Shaper-vs-Tabular/run-seed-65-OpenES-pop-size-1000-num-opps-10-att-type-nothing/2023-05-14_22.17.07.592872/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=1
python -m pax.experiment -m +experiment/ipd=shaper_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/1jk5zly5 ++model_path=exp/EARL-Shaper-vs-Tabular/run-seed-47-OpenES-pop-size-1000-num-opps-10-att-type-nothing/2023-05-14_20.46.58.588813/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=1
python -m pax.experiment -m +experiment/ipd=shaper_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/1cvpiolk ++model_path=exp/EARL-Shaper-vs-Tabular/run-seed-23-OpenES-pop-size-1000-num-opps-10-att-type-nothing/2023-05-14_19.16.56.990716/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=1
python -m pax.experiment -m +experiment/ipd=shaper_att_v_tabular_hardstop_eval ++wandb.log=True ++run_path=ucl-dark/ipd/3vml0wjy ++model_path=exp/EARL-Shaper-vs-Tabular/run-seed-6-OpenES-pop-size-1000-num-opps-10-att-type-nothing/2023-05-14_16.16.19.180942/generation_900 ++seed=85768,785678,764578,89678,97869,4567456,856778,3456347,45673,83346 ++stop=1
22 changes: 22 additions & 0 deletions pax/agents/mfos_ppo/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,28 @@ def forward_fn(
network = hk.without_apply_rng(hk.transform(forward_fn))
return network, hidden_state

def make_mfos_avg_network(num_actions: int, hidden_size: int):
hidden_state = jnp.zeros((1, 3 * hidden_size))

def forward_fn(
inputs: jnp.ndarray,
state: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
mfos = ActorCriticMFOS(num_actions, hidden_size)
hidden_t, hidden_a, hidden_v = jnp.split(state, 3, axis=-1)
avg_hidden_t = jnp.mean(hidden_t, axis=0, keepdims=True).repeat(state.shape[0], axis=0)
avg_hidden_a = jnp.mean(hidden_a, axis=0, keepdims=True).repeat(state.shape[0], axis=0)
avg_hidden_v = jnp.mean(hidden_v, axis=0, keepdims=True).repeat(state.shape[0], axis=0)
hidden_t = 0.5*hidden_t + 0.5*avg_hidden_t
hidden_a = 0.5*hidden_a + 0.5*avg_hidden_a
hidden_v = 0.5*hidden_v + 0.5*avg_hidden_v
state = jnp.concatenate([hidden_t, hidden_a, hidden_v], axis=-1)
logits, values, state = mfos(inputs, state)
return (logits, values), state

network = hk.without_apply_rng(hk.transform(forward_fn))
return network, hidden_state


def make_mfos_continuous_network(num_actions: int, hidden_size: int):
hidden_state = jnp.zeros((1, 3 * hidden_size))
Expand Down
18 changes: 14 additions & 4 deletions pax/agents/mfos_ppo/ppo_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pax.agents.mfos_ppo.networks import (
make_mfos_ipditm_network,
make_mfos_network,
make_mfos_avg_network,
make_mfos_continuous_network,
)
from pax.envs.rice.rice import Rice
Expand Down Expand Up @@ -65,7 +66,6 @@ def __init__(
obs_spec: Tuple,
batch_size: int = 2000,
num_envs: int = 4,
num_steps: int = 500,
num_minibatches: int = 16,
num_epochs: int = 4,
clip_value: bool = True,
Expand Down Expand Up @@ -481,8 +481,8 @@ def prepare_batch(

# Other useful hyperparameters
self._num_envs = num_envs # number of environments
self._num_steps = num_steps # number of steps per environment
self._batch_size = int(num_envs * num_steps) # number in one batch
# self._num_steps = num_steps # number of steps per environment
# self._batch_size = int(num_envs * num_steps) # number in one batch
self._num_minibatches = num_minibatches # number of minibatches
self._num_epochs = num_epochs # number of epochs to use sample
self._gru_dim = gru_dim
Expand Down Expand Up @@ -578,6 +578,17 @@ def make_mfos_agent(
agent_args.output_channels,
agent_args.kernel_shape,
)
elif args.env_id == "iterated_matrix_game":
if args.att_type=='att':
raise ValueError("Attention not supported")
elif args.att_type=='avg':
network, initial_hidden_state = make_mfos_avg_network(
action_spec, agent_args.hidden_size
)
elif args.att_type=='nothing':
network, initial_hidden_state = make_mfos_network(
action_spec, agent_args.hidden_size
)
else:
raise ValueError("Unsupported environment")

Expand Down Expand Up @@ -620,7 +631,6 @@ def make_mfos_agent(
obs_spec=obs_spec,
batch_size=None,
num_envs=args.num_envs,
num_steps=args.num_steps,
num_minibatches=agent_args.num_minibatches,
num_epochs=agent_args.num_epochs,
clip_value=agent_args.clip_value,
Expand Down
18 changes: 16 additions & 2 deletions pax/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,16 @@ def make_agent(
agent_args.output_channels,
agent_args.kernel_shape,
)
elif args.env_id in [
"iterated_matrix_game",
"iterated_tensor_game",
"iterated_nplayer_tensor_game",
"third_party_punishment",
"third_party_random",
]:
network = make_ipd_network(
action_spec, tabular, agent_args.hidden_size
)
elif args.env_id == "Cournot":
network = make_cournot_network(action_spec, agent_args.hidden_size)
elif args.env_id == "Fishery":
Expand Down Expand Up @@ -534,6 +544,7 @@ def make_agent(
)

if agent_args.lr_scheduling:
scale = optax.inject_hyperparams(optax.scale)(step_size=-1.0)
scheduler = optax.linear_schedule(
init_value=agent_args.learning_rate,
end_value=0,
Expand All @@ -543,15 +554,18 @@ def make_agent(
optax.clip_by_global_norm(agent_args.max_gradient_norm),
optax.scale_by_adam(eps=agent_args.adam_epsilon),
optax.scale_by_schedule(scheduler),
optax.scale(-1),
scale,
)
# optimizer = optax.inject_hyperparams(optimizer)(learning_rate=agent_args.learning_rate)

else:
scale = optax.inject_hyperparams(optax.scale)(step_size=-agent_args.learning_rate)
optimizer = optax.chain(
optax.clip_by_global_norm(agent_args.max_gradient_norm),
optax.scale_by_adam(eps=agent_args.adam_epsilon),
optax.scale(-agent_args.learning_rate),
scale,
)
# optimizer = optax.inject_hyperparams(optimizer)(learning_rate=agent_args.learning_rate)

# Random key
random_key = jax.random.PRNGKey(seed=seed)
Expand Down
Loading

0 comments on commit 430b254

Please sign in to comment.