Skip to content

Commit

Permalink
fishery eval checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
chrismatix committed Aug 2, 2023
1 parent 6e27370 commit 95e2d76
Show file tree
Hide file tree
Showing 23 changed files with 467 additions and 63 deletions.
30 changes: 29 additions & 1 deletion pax/agents/ppo/networks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Optional, Tuple
from typing import Optional, Tuple, Any

import distrax
import haiku as hk
import jax
import jax.numpy as jnp
from distrax import MultivariateNormalDiag
from jax import Array

from pax import utils

Expand Down Expand Up @@ -621,6 +623,32 @@ def forward_fn(
return network, hidden_state


def make_GRU_fishery_network(
num_actions: int,
hidden_size: int,
):
hidden_state = jnp.zeros((1, hidden_size))

def forward_fn(
inputs: jnp.ndarray, state: jnp.ndarray
) -> tuple[tuple[MultivariateNormalDiag, Array], Any]:
"""forward function"""
gru = hk.GRU(
hidden_size,
w_i_init=hk.initializers.Orthogonal(jnp.sqrt(1)),
w_h_init=hk.initializers.Orthogonal(jnp.sqrt(1)),
b_init=hk.initializers.Constant(0),
)

cvh = ContinuousValueHead(num_values=num_actions)
embedding, state = gru(inputs, state)
logits, values = cvh(embedding)
return (logits, values), state

network = hk.without_apply_rng(hk.transform(forward_fn))
return network, hidden_state


def test_GRU():
key = jax.random.PRNGKey(seed=0)
num_actions = 2
Expand Down
2 changes: 1 addition & 1 deletion pax/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def loss(
):
"""Surrogate loss using clipped probability ratios."""
distribution, values = network.apply(params, observations)
log_prob = distribution.log_prob(actions + 1e-6)
log_prob = distribution.log_prob(actions)
entropy = distribution.entropy()

# Compute importance sampling weights: current policy / behavior policy.
Expand Down
9 changes: 7 additions & 2 deletions pax/agents/ppo/ppo_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
make_GRU_cartpole_network,
make_GRU_coingame_network,
make_GRU_ipd_network,
make_GRU_ipditm_network,
make_GRU_ipditm_network, make_GRU_fishery_network,
)
from pax.utils import MemoryState, TrainingState, get_advantages

Expand Down Expand Up @@ -517,7 +517,10 @@ def make_gru_agent(
network, initial_hidden_state = make_GRU_ipd_network(
action_spec, agent_args.hidden_size
)

elif args.env_id == "Fishery":
network, initial_hidden_state = make_GRU_fishery_network(
action_spec, agent_args.hidden_size
)
elif args.env_id == "InTheMatrix":
network, initial_hidden_state = make_GRU_ipditm_network(
action_spec,
Expand All @@ -526,6 +529,8 @@ def make_gru_agent(
agent_args.output_channels,
agent_args.kernel_shape,
)
else:
raise NotImplementedError(f"No gru network implemented for env {args.env_id}")

gru_dim = initial_hidden_state.shape[1]

Expand Down
2 changes: 1 addition & 1 deletion pax/conf/experiment/cournot/gs_v_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,6 @@ es:
wandb:
group: cournot
name: 'cournot-GS-${agent1}-vs-${agent2}'
log: False
log: True


File renamed without changes.
1 change: 1 addition & 0 deletions pax/conf/experiment/cournot/mfos_v_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,6 @@ es:
wandb:
group: cournot
name: 'cournot-MFOS-${agent1}-vs-${agent2}'
log: True


1 change: 1 addition & 0 deletions pax/conf/experiment/cournot/shaper_v_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,6 @@ es:
wandb:
group: cournot
name: 'cournot-SHAPER-${agent1}-vs-${agent2}'
log: True


108 changes: 108 additions & 0 deletions pax/conf/experiment/fishery/eval_gs_v_ppo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# @package _global_

# Agents
agent1: 'PPO'
agent2: 'PPO_memory'

# Environment
env_id: Fishery
env_type: meta
g: 0.15
e: 0.009
P: 200
w: 0.9
s_0: 0.5
s_max: 1.0


# Runner
runner: eval

# TODO
run_path: chrismatix/thesis/dhzxkw57
model_path: exp/fishery/fishery-GS-PPO-vs-PPO_memory/2023-08-01_20.27.07.402547/generation_900


# Training
top_k: 5
popsize: 1000
num_envs: 1
num_opps: 1
num_outer_steps: 1
num_steps: 600 # Run num_steps // num_inner_steps trials
num_inner_steps: 300
num_iters: 2
num_devices: 1

# PPO agent parameters
ppo1:
num_minibatches: 4
num_epochs: 2
gamma: 0.96
gae_lambda: 0.95
ppo_clipping_epsilon: 0.2
value_coeff: 0.5
clip_value: True
max_gradient_norm: 0.5
anneal_entropy: False
entropy_coeff_start: 0.02
entropy_coeff_horizon: 2000000
entropy_coeff_end: 0.001
lr_scheduling: False
learning_rate: 1
adam_epsilon: 1e-5
with_memory: False
with_cnn: False
hidden_size: 16

# PPO agent parameters
ppo2:
num_minibatches: 4
num_epochs: 2
gamma: 0.96
gae_lambda: 0.95
ppo_clipping_epsilon: 0.2
value_coeff: 0.5
clip_value: True
max_gradient_norm: 0.5
anneal_entropy: False
entropy_coeff_start: 0.02
entropy_coeff_horizon: 2000000
entropy_coeff_end: 0.001
lr_scheduling: False
learning_rate: 1
adam_epsilon: 1e-5
with_memory: False
with_cnn: False
hidden_size: 16


# ES parameters
es:
algo: OpenES # [OpenES, CMA_ES]
sigma_init: 0.04 # Initial scale of isotropic Gaussian noise
sigma_decay: 0.999 # Multiplicative decay factor
sigma_limit: 0.01 # Smallest possible scale
init_min: 0.0 # Range of parameter mean initialization - Min
init_max: 0.0 # Range of parameter mean initialization - Max
clip_min: -1e10 # Range of parameter proposals - Min
clip_max: 1e10 # Range of parameter proposals - Max
lrate_init: 0.01 # Initial learning rate
lrate_decay: 0.9999 # Multiplicative decay factor
lrate_limit: 0.001 # Smallest possible lrate
beta_1: 0.99 # Adam - beta_1
beta_2: 0.999 # Adam - beta_2
eps: 1e-8 # eps constant,
centered_rank: False # Fitness centered_rank
w_decay: 0 # Decay old elite fitness
maximise: True # Maximise fitness
z_score: False # Normalise fitness
mean_reduce: True # Remove mean

# Logging setup
wandb:
group: fishery
name: 'EVAL_fishery-GS-${agent1}-vs-${agent2}'
log: True


79 changes: 79 additions & 0 deletions pax/conf/experiment/fishery/eval_marl_baseline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# @package _global_

# Agents
agent1: 'PPO_memoruy'
agent2: 'PPO'

# Environment
env_id: Fishery
env_type: sequential
g: 0.15
e: 0.009
P: 200
w: 0.9
s_0: 0.5
s_max: 1.0

# Runner
runner: eval

# TODO
run_path: chrismatix/thesis/dhzxkw57
model_path: exp/fishery/fishery-GS-PPO-vs-PPO_memory/2023-08-01_20.27.07.402547/generation_900

# env_batch_size = num_envs * num_opponents
num_envs: 100
num_opps: 1
num_outer_steps: 1
num_inner_steps: 300 # number of inner steps (only for MetaFinite Env)
num_iters: 1e6

# Useful information
# batch_size = num_envs * num_steps

# PPO agent parameters
ppo1:
num_minibatches: 10
num_epochs: 4
gamma: 0.96
gae_lambda: 0.95
ppo_clipping_epsilon: 0.2
value_coeff: 0.5
clip_value: True
max_gradient_norm: 0.5
anneal_entropy: True
entropy_coeff_start: 0.1
entropy_coeff_horizon: 0.25e9
entropy_coeff_end: 0.05
lr_scheduling: True
learning_rate: 3e-4
adam_epsilon: 1e-5
with_memory: True
hidden_size: 16
with_cnn: False

ppo2:
num_minibatches: 10
num_epochs: 4
gamma: 0.96
gae_lambda: 0.95
ppo_clipping_epsilon: 0.2
value_coeff: 0.5
clip_value: True
max_gradient_norm: 0.5
anneal_entropy: True
entropy_coeff_start: 0.1
entropy_coeff_horizon: 0.25e9
entropy_coeff_end: 0.05
lr_scheduling: True
learning_rate: 3e-4
adam_epsilon: 1e-5
with_memory: True
hidden_size: 16
with_cnn: False

# Logging setup
wandb:
group: fishery
name: 'fishery-MARL^2-${agent1}-vs-${agent2}-parity'
log: True
8 changes: 4 additions & 4 deletions pax/conf/experiment/fishery/gs_v_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Agents
agent1: 'PPO'
agent2: 'PPO'
agent2: 'PPO_memory'

# Environment
env_id: Fishery
Expand All @@ -23,9 +23,9 @@ top_k: 5
popsize: 1000
num_envs: 2
num_opps: 1
num_outer_steps: 80
num_outer_steps: 40
num_inner_steps: 300
num_iters: 3000
num_iters: 1000
num_devices: 1

# PPO agent parameters
Expand Down Expand Up @@ -97,6 +97,6 @@ es:
wandb:
group: fishery
name: 'fishery-GS-${agent1}-vs-${agent2}'
log: False
log: True


Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# @package _global_

# Agents
agent1: 'PPO'
agent1: 'PPO_memoruy'
agent2: 'PPO'

# Environment
Expand Down
11 changes: 6 additions & 5 deletions pax/conf/experiment/fishery/mfos_v_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Agents
agent1: 'MFOS'
agent2: 'PPO'
agent2: 'PPO_memory'

# Environment
env_id: Fishery
Expand All @@ -23,8 +23,8 @@ popsize: 1000
num_envs: 4
num_opps: 1
num_outer_steps: 100
num_inner_steps: 50
num_iters: 3000
num_inner_steps: 300
num_iters: 1000
num_devices: 1
num_steps: '${num_inner_steps}'

Expand Down Expand Up @@ -93,7 +93,8 @@ es:

# Logging setup
wandb:
group: cournot
name: 'cournot-MFOS-${agent1}-vs-${agent2}'
group: fishery
name: 'fishery-MFOS-${agent1}-vs-${agent2}'
log: True


9 changes: 5 additions & 4 deletions pax/conf/experiment/fishery/shaper_v_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ popsize: 1000
num_envs: 2
num_opps: 1
num_outer_steps: 1
num_inner_steps: 100
num_iters: 5000
num_inner_steps: 300
num_iters: 1000
num_devices: 1
num_steps: '${num_inner_steps}'

Expand Down Expand Up @@ -94,7 +94,8 @@ es:

# Logging setup
wandb:
group: cournot
name: 'cournot-SHAPER-${agent1}-vs-${agent2}'
group: fishery
name: 'fishery-SHAPER-${agent1}-vs-${agent2}'
log: True


Loading

0 comments on commit 95e2d76

Please sign in to comment.